diff --git a/skyrl-tx/tests/gpu/test_attention.py b/skyrl-tx/tests/gpu/test_attention.py index d5014ebe64..f522b48f8f 100644 --- a/skyrl-tx/tests/gpu/test_attention.py +++ b/skyrl-tx/tests/gpu/test_attention.py @@ -38,7 +38,8 @@ def assert_attention_match(q, k, v, mask, is_causal, head_dim, seq_lengths=None) If None, compare all positions. """ scale = 1.0 / head_dim**0.5 - result = dot_product_attention(q, k, v, mask, is_causal=is_causal, head_dim=head_dim) + positions = jnp.broadcast_to(jnp.arange(q.shape[1], dtype=jnp.int32), (q.shape[0], q.shape[1])) + result = dot_product_attention(q, k, v, mask, is_causal=is_causal, head_dim=head_dim, positions=positions) expected = jax.nn.dot_product_attention( q, k, v, scale=scale, mask=mask[:, None, None, :].astype(bool), is_causal=is_causal ) diff --git a/skyrl-tx/tests/models/test_qwen3_config.py b/skyrl-tx/tests/models/test_qwen3_config.py index fc2783b994..2ca45d7017 100644 --- a/skyrl-tx/tests/models/test_qwen3_config.py +++ b/skyrl-tx/tests/models/test_qwen3_config.py @@ -13,6 +13,7 @@ def test_config_wraps_pretrained_config(): assert config.max_lora_adapters == 8 assert config.max_lora_rank == 16 assert config.shard_attention_heads is False + assert config.context_parallel_size == 1 # Check base config attributes were copied assert config.vocab_size > 0 @@ -27,3 +28,17 @@ def test_config_preserves_moe_config(): # Check that MoE-specific attributes are preserved assert config.num_experts > 0 + + +def test_config_sets_context_parallel_size(): + """Test that context parallel size is configurable.""" + hf_config = PretrainedConfig.from_pretrained("Qwen/Qwen3-0.6B") + config = Qwen3Config( + hf_config, + max_lora_adapters=8, + max_lora_rank=16, + shard_attention_heads=True, + context_parallel_size=2, + ) + + assert config.context_parallel_size == 2 diff --git a/skyrl-tx/tests/tinker/test_jax_backend.py b/skyrl-tx/tests/tinker/test_jax_backend.py index ab6710ee62..0fb329f0c5 100644 --- a/skyrl-tx/tests/tinker/test_jax_backend.py +++ b/skyrl-tx/tests/tinker/test_jax_backend.py @@ -25,6 +25,149 @@ def create_backend(max_lora_adapters: int = MAX_LORA_ADAPTERS): return JaxBackend(BASE_MODEL, config) +def test_context_parallel_forward_backward_runs(): + """Training path should run with CP>1 via shard_map + ppermute attention.""" + config = JaxBackendConfig(max_lora_adapters=MAX_LORA_ADAPTERS, max_lora_rank=32, context_parallel_size=2) + backend = JaxBackend(BASE_MODEL, config) + model_id = "cp_model" + create_model(backend, model_id) + reqs = {"1": (model_id, make_fwd_bwd_input([[1, 2, 3, 4]]))} + + results = backend.forward_backward(prepare_model_pass_batch(reqs)) + assert "1" in results + + +def test_context_parallel_sample_runs(): + """Sampling should run under CP>1 by disabling CP collectives in decode path.""" + config = JaxBackendConfig(max_lora_adapters=MAX_LORA_ADAPTERS, max_lora_rank=32, context_parallel_size=2) + backend = JaxBackend(BASE_MODEL, config) + reqs = {"1": ("", make_sample_input([1, 2, 3], max_tokens=4))} + + results = backend.sample(prepare_sample_batch(reqs)) + assert "1" in results + assert len(results["1"].sequences) == 1 + + +def test_context_parallel_sample_parity(): + """Sampling outputs should match between CP=1 and CP=2 for identical seeds/prompts.""" + backend_cp1 = JaxBackend( + BASE_MODEL, JaxBackendConfig(max_lora_adapters=MAX_LORA_ADAPTERS, max_lora_rank=32, context_parallel_size=1) + ) + backend_cp2 = JaxBackend( + BASE_MODEL, JaxBackendConfig(max_lora_adapters=MAX_LORA_ADAPTERS, max_lora_rank=32, context_parallel_size=2) + ) + + req1 = make_sample_input([1, 2, 3, 4, 5], max_tokens=6) + req1.sampling_params = api.SamplingParams(temperature=0.0, max_tokens=6, seed=7).to_types() + req2 = make_sample_input([6, 7, 8, 9], max_tokens=5) + req2.sampling_params = api.SamplingParams(temperature=1.0, max_tokens=5, seed=11).to_types() + reqs = {"r1": ("", req1), "r2": ("", req2)} + + out_cp1 = backend_cp1.sample(prepare_sample_batch(reqs)) + out_cp2 = backend_cp2.sample(prepare_sample_batch(reqs)) + + assert set(out_cp1.keys()) == set(out_cp2.keys()) == {"r1", "r2"} + for req_id in ("r1", "r2"): + seqs1 = out_cp1[req_id].sequences + seqs2 = out_cp2[req_id].sequences + assert len(seqs1) == len(seqs2) + for s1, s2 in zip(seqs1, seqs2): + assert s1.stop_reason == s2.stop_reason + assert s1.tokens == s2.tokens + np.testing.assert_allclose(np.asarray(s1.logprobs), np.asarray(s2.logprobs), rtol=1e-6, atol=1e-6) + + +def test_context_parallel_sample_prompt_logprobs_parity(): + """CP=2 sampling should match CP=1 for prompt logprobs and token outputs.""" + backend_cp1 = JaxBackend( + BASE_MODEL, JaxBackendConfig(max_lora_adapters=MAX_LORA_ADAPTERS, max_lora_rank=32, context_parallel_size=1) + ) + backend_cp2 = JaxBackend( + BASE_MODEL, JaxBackendConfig(max_lora_adapters=MAX_LORA_ADAPTERS, max_lora_rank=32, context_parallel_size=2) + ) + + req = make_sample_input([1, 2, 3, 4, 5, 6], prompt_logprobs=True, max_tokens=4) + req.sampling_params = api.SamplingParams(temperature=0.0, max_tokens=4, seed=17).to_types() + reqs = {"r": ("", req)} + + out_cp1 = backend_cp1.sample(prepare_sample_batch(reqs)) + out_cp2 = backend_cp2.sample(prepare_sample_batch(reqs)) + + seq1 = out_cp1["r"].sequences[0] + seq2 = out_cp2["r"].sequences[0] + assert seq1.stop_reason == seq2.stop_reason + assert seq1.tokens == seq2.tokens + np.testing.assert_allclose(np.asarray(seq1.logprobs), np.asarray(seq2.logprobs), rtol=1e-6, atol=1e-6) + + assert out_cp1["r"].prompt_logprobs is not None + assert out_cp2["r"].prompt_logprobs is not None + np.testing.assert_allclose( + np.asarray(out_cp1["r"].prompt_logprobs), + np.asarray(out_cp2["r"].prompt_logprobs), + rtol=1e-6, + atol=1e-6, + ) + + +def _extract_loss_and_logprobs(results: dict, request_id: str) -> tuple[list[np.ndarray], list[np.ndarray]]: + output = results[request_id] + losses = [] + logprobs = [] + for item in output.loss_fn_outputs: + losses.append(np.asarray(item["elementwise_loss"]["data"], dtype=np.float32)) + logprobs.append(np.asarray(item["logprobs"]["data"], dtype=np.float32)) + return losses, logprobs + + +def test_context_parallel_parity_forward_and_backward(): + """CP=2 should match CP=1 numerically for prefill forward/backward.""" + model_id = "cp_parity_model" + reqs = { + "1": ( + model_id, + make_fwd_bwd_input( + [ + [1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12], + ] + ), + ) + } + + backend_cp1 = JaxBackend( + BASE_MODEL, JaxBackendConfig(max_lora_adapters=MAX_LORA_ADAPTERS, max_lora_rank=32, context_parallel_size=1) + ) + backend_cp2 = JaxBackend( + BASE_MODEL, JaxBackendConfig(max_lora_adapters=MAX_LORA_ADAPTERS, max_lora_rank=32, context_parallel_size=2) + ) + create_model(backend_cp1, model_id) + create_model(backend_cp2, model_id) + + prepared_batch = prepare_model_pass_batch(reqs) + + # Forward parity + forward_cp1 = backend_cp1.forward(prepared_batch) + forward_cp2 = backend_cp2.forward(prepared_batch) + losses_cp1, logprobs_cp1 = _extract_loss_and_logprobs(forward_cp1, "1") + losses_cp2, logprobs_cp2 = _extract_loss_and_logprobs(forward_cp2, "1") + + assert len(losses_cp1) == len(losses_cp2) + assert len(logprobs_cp1) == len(logprobs_cp2) + for a, b in zip(losses_cp1, losses_cp2): + np.testing.assert_allclose(a, b, atol=1e-4, rtol=1e-4) + for a, b in zip(logprobs_cp1, logprobs_cp2): + np.testing.assert_allclose(a, b, atol=1e-4, rtol=1e-4) + + # Backward parity (compare per-adapter mean gradients) + backend_cp1.forward_backward(prepared_batch) + backend_cp2.forward_backward(prepared_batch) + adapter_idx_cp1 = backend_cp1.models[model_id].adapter_index + adapter_idx_cp2 = backend_cp2.models[model_id].adapter_index + mean_grads_cp1 = backend_cp1.accumulated_grads.get_mean(jnp.int32(adapter_idx_cp1)) + mean_grads_cp2 = backend_cp2.accumulated_grads.get_mean(jnp.int32(adapter_idx_cp2)) + _assert_tree_allclose(mean_grads_cp1, mean_grads_cp2, rtol=1e-3, atol=1e-3, min_match_pct=99.0) + + def create_model(backend: JaxBackend, model_id: str) -> int: """Create a model and return its adapter index.""" lora_config = LoraConfig(rank=LORA_RANK, alpha=16, seed=0) diff --git a/skyrl-tx/tx/layers/attention.py b/skyrl-tx/tx/layers/attention.py index ddb5be42d4..700f5802f9 100644 --- a/skyrl-tx/tx/layers/attention.py +++ b/skyrl-tx/tx/layers/attention.py @@ -2,12 +2,102 @@ import jax import jax.numpy as jnp +from jax.sharding import get_abstract_mesh # cuDNN flash attention supported dtypes # https://github.com/jax-ml/jax/blob/8b1f782540f71fbe230a2dccd331975faafc6c83/jax/_src/cudnn/fused_attention_stablehlo.py#L290 _CUDNN_SUPPORTED_DTYPES = (jnp.float16, jnp.bfloat16, jnp.float8_e4m3fn, jnp.float8_e5m2) +def _ring_attention( + q: jax.Array, + k: jax.Array, + v: jax.Array, + attention_mask: jax.Array, + positions: jax.Array, + scale: float, +) -> jax.Array: + """Streaming causal attention with ring KV exchange via ppermute.""" + cp = get_abstract_mesh().shape.get("cp", 1) + axis_idx = jax.lax.axis_index("cp") + local_len = k.shape[1] + + # qh: [B, H, Tq, D] + qh = jnp.transpose(q, (0, 2, 1, 3)) + + # Expand KV heads to match query heads (happens in GQA) + # k/v: [B, Tk, H_kv, D] -> [B, Tk, H, D] + kv_repeat = q.shape[2] // k.shape[2] + k_block = jnp.repeat(k, kv_repeat, axis=2) + v_block = jnp.repeat(v, kv_repeat, axis=2) + mask_block = attention_mask + + # Online softmax state (kept per [B, H, Tq]): + # carry_max = running max score + # denom = running denominator sum(exp(score - carry_max)) + # acc = running numerator sum(exp(score - m) * value), shape [B, H, Tq, D] + B, H, Tq, D = qh.shape + carry_max = jnp.full((B, H, Tq), -jnp.inf, dtype=q.dtype) + denom = jnp.zeros((B, H, Tq), dtype=q.dtype) + acc = jnp.zeros((B, H, Tq, D), dtype=q.dtype) + neg_large = jnp.array(jnp.finfo(q.dtype).min, dtype=q.dtype) + + # Ring exchange: source i -> destination (i + 1) % cp. + perm = [(i, (i + 1) % cp) for i in range(cp)] + + for step in range(cp): + source_shard = (axis_idx - step) % cp + # Absolute token positions for the current KV block, shape [Tk]. + key_positions = source_shard * local_len + jnp.arange(local_len, dtype=jnp.int32) + + # vh: [B, H, Tk, D] + # kht: [B, H, D, Tk] (K transposed for Q @ K^T) + kht = jnp.transpose(k_block, (0, 2, 3, 1)) + vh = jnp.transpose(v_block, (0, 2, 1, 3)) + scores = jnp.matmul(qh, kht) * scale + + # Mask invalid keys (future tokens + padding) before softmax update. + causal = key_positions[None, None, None, :] <= positions[:, None, :, None] + padding = mask_block[:, None, None, :].astype(bool) + valid = causal & padding + scores = jnp.where(valid, scores, neg_large) + + # Numerically stable online softmax merge: + # merge previous state (carry_max, denom, acc) with current block scores/values. + m_block = jnp.max(scores, axis=-1) + carry_max_new = jnp.maximum(carry_max, m_block) + prev_scale = jnp.where(jnp.isfinite(carry_max), jnp.exp(carry_max - carry_max_new), 0.0) + p = jnp.exp(scores - carry_max_new[..., None]) + p = jnp.where(valid, p, 0.0) + denom_new = prev_scale * denom + jnp.sum(p, axis=-1) + acc_new = prev_scale[..., None] * acc + jnp.matmul(p, vh) + carry_max, denom, acc = carry_max_new, denom_new, acc_new + + # Rotate KV/mask so the next iteration sees the next shard's block. + if step < cp - 1: + k_block = jax.lax.ppermute(k_block, axis_name="cp", perm=perm) + v_block = jax.lax.ppermute(v_block, axis_name="cp", perm=perm) + mask_block = jax.lax.ppermute(mask_block, axis_name="cp", perm=perm) + + # Final normalize and restore [B, Tq, H, D] + out = jnp.where( + denom[..., None] > 0, + acc / jnp.maximum(denom[..., None], jnp.asarray(1e-9, dtype=denom.dtype)), + 0.0, + ) + return jnp.transpose(out, (0, 2, 1, 3)) + + +def default_positions(input_ids: jax.Array) -> jax.Array: + """Build token positions from input token shape, with CP shard offset.""" + start, local_len = 0, input_ids.shape[1] + cp = get_abstract_mesh().shape.get("cp", 1) + if cp > 1: + axis_idx = jax.lax.axis_index("cp") + start = axis_idx * local_len + return (start + jnp.arange(local_len, dtype=jnp.int32))[None, :] + + def dot_product_attention( q: jax.Array, k: jax.Array, @@ -15,6 +105,9 @@ def dot_product_attention( attention_mask: jax.Array, is_causal: bool, head_dim: int, + *, + positions: jax.Array, + scale: float | None = None, ) -> jax.Array: """Compute dot-product attention with automatic backend selection. @@ -27,12 +120,19 @@ def dot_product_attention( attention_mask: Mask of shape [batch, kv_len] where 1 = valid, 0 = masked. Sequences must be right-padded (valid tokens first, then padding). is_causal: Whether to apply causal masking (for prefill/training) - head_dim: Dimension of each attention head (for scaling) + head_dim: Dimension of each attention head (for scaling when scale is not provided) + positions: Query positions, shape [batch, q_len], used for causal masking + scale: Optional explicit scale factor for attention logits Returns: Attention output of shape [batch, q_len, num_heads, head_dim] """ - scale = 1.0 / head_dim**0.5 + scale = scale if scale is not None else 1.0 / head_dim**0.5 + cp = get_abstract_mesh().shape.get("cp", 1) + + # TODO: constraints for running ring attention + if cp > 1 and (is_causal or q.shape[1] == 1): + return _ring_attention(q, k, v, attention_mask, positions, scale) if jax.default_backend() == "gpu" and q.dtype in _CUDNN_SUPPORTED_DTYPES: kv_seq_lengths = attention_mask.sum(axis=1).astype(jnp.int32) diff --git a/skyrl-tx/tx/models/configs.py b/skyrl-tx/tx/models/configs.py index 398d8c042b..ab83485c5d 100644 --- a/skyrl-tx/tx/models/configs.py +++ b/skyrl-tx/tx/models/configs.py @@ -22,6 +22,7 @@ class ModelConfig(PretrainedConfig): max_lora_adapters: int max_lora_rank: int shard_attention_heads: bool + context_parallel_size: int loss_chunk_size: int gradient_checkpointing: bool @@ -32,6 +33,7 @@ def __init__( max_lora_adapters: int, max_lora_rank: int, shard_attention_heads: bool, + context_parallel_size: int = 1, loss_chunk_size: int = 0, gradient_checkpointing: bool = False, ): @@ -42,6 +44,7 @@ def __init__( self.max_lora_adapters = max_lora_adapters self.max_lora_rank = max_lora_rank self.shard_attention_heads = shard_attention_heads + self.context_parallel_size = context_parallel_size self.loss_chunk_size = loss_chunk_size self.gradient_checkpointing = gradient_checkpointing diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py index 80ac5f1c97..8c7185ee6c 100644 --- a/skyrl-tx/tx/models/deepseekv3.py +++ b/skyrl-tx/tx/models/deepseekv3.py @@ -3,6 +3,7 @@ from jax import numpy as jnp from jax.sharding import get_abstract_mesh +from tx.layers.attention import dot_product_attention, default_positions from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear from tx.layers.rotary_embedding import get_rope from tx.layers.util import Param, prepare_routing, shard_map_ep @@ -168,13 +169,15 @@ def __call__( # Jax attention expects v to have the same shape as k v = jnp.pad(v, ((0, 0), (0, 0), (0, 0), (0, self.qk_head_dim - self.v_head_dim))) - attn_output = jax.nn.dot_product_attention( + attn_output = dot_product_attention( q, k, v, - scale=self.scaling, - mask=attention_mask[:, None, None, :].astype(bool), + attention_mask, is_causal=kv_cache is None, + head_dim=self.qk_head_dim, + positions=positions, + scale=self.scaling, ) attn_output = attn_output[:, :, :, : self.v_head_dim].reshape(B, T, self.num_heads * self.v_head_dim) @@ -575,7 +578,7 @@ def __call__( is_training: bool = False, ) -> CausalLMOutput: if positions is None: - positions = jnp.arange(attention_mask.shape[1])[None, :] + positions = default_positions(input_ids) outputs = self.model( input_ids, diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index 124bc7a09f..3e2182bbd9 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -4,7 +4,7 @@ from jax.sharding import get_abstract_mesh from transformers import LlamaConfig -from tx.layers.attention import dot_product_attention +from tx.layers.attention import dot_product_attention, default_positions from tx.layers.lora import LoRAEmbed, LoRALinear from tx.layers.rotary_embedding import apply_rope from tx.layers.layernorm import RMSNorm @@ -108,7 +108,15 @@ def __call__( updated_cache = (k, v) is_causal = kv_cache is None - attn_output = dot_product_attention(q, k, v, attention_mask, is_causal, self.head_dim) + attn_output = dot_product_attention( + q, + k, + v, + attention_mask, + is_causal, + self.head_dim, + positions=positions, + ) output = attn_output.reshape(B, T, self.num_heads * self.head_dim) return self.o_proj(output, adapter_indices=adapter_indices), updated_cache @@ -301,7 +309,7 @@ def __call__( is_training: bool = False, ) -> CausalLMOutput: if positions is None: - positions = jnp.arange(attention_mask.shape[1])[None, :] + positions = default_positions(input_ids) outputs = self.model( input_ids, diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 6c786bb390..e9712bb5f8 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -3,7 +3,7 @@ from jax import numpy as jnp from jax.sharding import get_abstract_mesh -from tx.layers.attention import dot_product_attention +from tx.layers.attention import dot_product_attention, default_positions from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear from tx.layers.util import prepare_routing, shard_map_ep from tx.layers.rotary_embedding import apply_rope @@ -109,7 +109,15 @@ def __call__( updated_cache = (k, v) is_causal = kv_cache is None - attn_output = dot_product_attention(q, k, v, attention_mask, is_causal, self.head_dim) + attn_output = dot_product_attention( + q, + k, + v, + attention_mask, + is_causal, + self.head_dim, + positions=positions, + ) output = attn_output.reshape(B, T, self.num_heads * self.head_dim) return self.o_proj(output, adapter_indices=adapter_indices), updated_cache @@ -419,7 +427,7 @@ def __call__( is_training: bool = False, ) -> CausalLMOutput: if positions is None: - positions = jnp.arange(attention_mask.shape[1])[None, :] + positions = default_positions(input_ids) outputs = self.model( input_ids, diff --git a/skyrl-tx/tx/run/train.py b/skyrl-tx/tx/run/train.py index 398c51a597..280c20086d 100644 --- a/skyrl-tx/tx/run/train.py +++ b/skyrl-tx/tx/run/train.py @@ -55,6 +55,7 @@ def train( parser=json.loads, ), tp_size: int = typer.Option(1, "--tp-size", help="Tensor parallelism degree to use for the model"), + cp_size: int = typer.Option(1, "--cp-size", help="Context parallelism degree to use for the model"), tracker_name: ExperimentTracker | None = typer.Option( None, "--tracker", help="Experiment tracker to report results to" ), @@ -66,7 +67,7 @@ def train( ), ) -> None: if not jax._src.xla_bridge.backends_are_initialized(): # ty: ignore - jax.config.update("jax_num_cpu_devices", tp_size) + jax.config.update("jax_num_cpu_devices", tp_size * cp_size) # If you want to debug NaNs, add the following: # jax.config.update("jax_debug_nans", True) @@ -77,13 +78,23 @@ def train( train_dataset = load_dataset(dataset, split=split) assert isinstance(train_dataset, Dataset) base_config = AutoConfig.from_pretrained(model_name) - model_config = Qwen3Config(base_config, max_lora_adapters=0, max_lora_rank=0, shard_attention_heads=True) + model_config = Qwen3Config( + base_config, + max_lora_adapters=0, + max_lora_rank=0, + shard_attention_heads=True, + context_parallel_size=cp_size, + ) tokenizer = AutoTokenizer.from_pretrained(model_name) tracker = get_tracker(tracker_name, base_config, **tracker_args) loader = get_loader(loader_name) model_class = get_model_class(base_config) - mesh = jax.make_mesh((1, 1, tp_size), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) + mesh = jax.make_mesh( + (1, 1, tp_size, cp_size), + ("fsdp", "ep", "tp", "cp"), + axis_types=(jax.sharding.AxisType.Auto,) * 4, + ) with jax.set_mesh(mesh): model = model_class(model_config, dtype=get_dtype(model_config.dtype), rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, get_optimizer(optimizer_name, optimizer_args), wrt=nnx.Param) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index c629433ec6..f54010658f 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -23,6 +23,7 @@ import time from contextlib import contextmanager from dataclasses import dataclass +import functools from typing import Any, Callable, get_type_hints from cloudpathlib import AnyPath @@ -68,6 +69,9 @@ class JaxBackendConfig(BaseModel, extra="forbid"): max_lora_rank: int = Field(default=32, description="Maximum LoRA rank") tensor_parallel_size: int = Field(default=1, description="Tensor parallelism degree to use for the model") expert_parallel_size: int = Field(default=1, description="Expert parallelism degree for MoE layers") + context_parallel_size: int = Field( + default=1, ge=1, description="Context parallelism degree across sequence length" + ) fully_sharded_data_parallel_size: int = Field( default=1, description="Fully sharded data parallelism degree for the model" ) @@ -159,7 +163,7 @@ class JaxBackendImpl(AbstractBackend): This backend: - Uses jax.value_and_grad for gradient computation - - Uses 2D mesh (fsdp, tp) for fully sharded data parallelism and tensor parallelism + - Uses mesh axes (fsdp, ep, tp, cp) for fully sharded data, expert, tensor, and context parallelism - Supports multiple LoRA adapters via AccumulatedGradients with counts array - Supports both FORWARD and FORWARD_BACKWARD request types """ @@ -179,6 +183,7 @@ def __init__(self, base_model: str, config: JaxBackendConfig): max_lora_adapters=config.max_lora_adapters, max_lora_rank=config.max_lora_rank, shard_attention_heads=config.shard_attention_heads, + context_parallel_size=config.context_parallel_size, loss_chunk_size=config.loss_chunk_size, gradient_checkpointing=config.gradient_checkpointing, ) @@ -191,9 +196,10 @@ def __init__(self, base_model: str, config: JaxBackendConfig): config.fully_sharded_data_parallel_size, config.expert_parallel_size, config.tensor_parallel_size, + config.context_parallel_size, ), - ("fsdp", "ep", "tp"), - axis_types=(jax.sharding.AxisType.Auto,) * 3, + ("fsdp", "ep", "tp", "cp"), + axis_types=(jax.sharding.AxisType.Auto,) * 4, ) with jax.set_mesh(self.mesh), nnx.use_eager_sharding(True): self.model = model_class(self.model_config, dtype=get_dtype(self.model_config.dtype), rngs=nnx.Rngs(0)) @@ -235,6 +241,7 @@ def __init__(self, base_model: str, config: JaxBackendConfig): ) self._create_loss_and_grad_fn() + self._create_cp_sample_pass() def _micro_batch_size(self, total: int) -> int: """Return effective micro-batch size; 0/absent => disabled (use full fused batch).""" @@ -426,27 +433,19 @@ def forward_backward_and_accumulate( self._forward = forward_only else: - # Retrieve the sharding of lora and non_lora params and compute the sharding of inputs and outputs - lora_shardings = jax.tree.map( - lambda spec: jax.NamedSharding(self.mesh, spec), nnx.get_partition_spec(self.lora_params) - ) - non_lora_shardings = jax.tree.map( - lambda spec: jax.NamedSharding(self.mesh, spec), nnx.get_partition_spec(self.non_lora_params) - ) - # Get sharding for AccumulatedGradients - accumulated_grads_shardings = jax.tree.map( - lambda spec: jax.NamedSharding(self.mesh, spec), nnx.get_partition_spec(self.accumulated_grads) - ) - - # Shard batch inputs along the FSDP axis (batch, seq_len) - batch_sharded_2d = jax.NamedSharding(self.mesh, jax.P("fsdp", None)) + # Shared sharding specs for model pass functions. + lora_specs = nnx.get_partition_spec(self.lora_params) + non_lora_specs = nnx.get_partition_spec(self.non_lora_params) + accumulated_grads_specs = nnx.get_partition_spec(self.accumulated_grads) # JIT the fused function # Input order: input_ids, attention_mask, adapter_indices, target_ids, # loss_mask, loss_fn_types, sampling_logprobs, advantages, # loss_fn_config - # All batch arrays are sharded along batch dimension - batch_sharded_1d = jax.NamedSharding(self.mesh, jax.P("fsdp")) + # All batch arrays are sharded along batch dimension (and CP for sequence-shaped inputs). + # `shard_map` expects PartitionSpec (`jax.P`) for in/out specs. + batch_sharded_1d = jax.P("fsdp") + batch_sharded_2d = jax.P("fsdp", "cp") loss_fn_config_shardings = LossFnConfig( clip_low_threshold=batch_sharded_1d, clip_high_threshold=batch_sharded_1d, @@ -461,21 +460,28 @@ def forward_backward_and_accumulate( batch_sharded_2d, # sampling_logprobs batch_sharded_2d, # advantages ) - self._forward_backward_and_accumulate = jax.jit( - forward_backward_and_accumulate, - in_shardings=(accumulated_grads_shardings, lora_shardings, non_lora_shardings) - + input_shardings - + (loss_fn_config_shardings,), - out_shardings=(accumulated_grads_shardings, batch_sharded_2d, batch_sharded_2d), - donate_argnames=("accumulated_grads",), - ) - self._forward = jax.jit( - forward_only, - in_shardings=(accumulated_grads_shardings, lora_shardings, non_lora_shardings) - + input_shardings - + (loss_fn_config_shardings,), - out_shardings=(accumulated_grads_shardings, batch_sharded_2d, batch_sharded_2d), - ) + def make_sharded_model_pass(model_pass_impl: Callable): + @jax.shard_map( + mesh=self.mesh, + in_specs=(accumulated_grads_specs, lora_specs, non_lora_specs) + + input_shardings + + (loss_fn_config_shardings,), + out_specs=(accumulated_grads_specs, batch_sharded_2d, batch_sharded_2d), + axis_names=set(self.mesh.axis_names), + check_vma=False, + ) + def _sharded( + accumulated_grads: AccumulatedGradients, + lora_params: nnx.State, + non_lora_params: nnx.State, + *batch_inputs, + ): + return model_pass_impl(accumulated_grads, lora_params, non_lora_params, *batch_inputs) + + return jax.jit(_sharded) + + self._forward_backward_and_accumulate = make_sharded_model_pass(forward_backward_and_accumulate) + self._forward = make_sharded_model_pass(forward_only) # JIT-compiled function to compute full gradients and apply optimizer update def compute_grads_and_update( @@ -584,8 +590,9 @@ def _model_pass( # Convert model_ids to adapter_indices all_adapter_indices = [self.models[model_id].adapter_index for model_id in prepared_batch.all_model_ids] + cp_size = self.mesh.shape["cp"] # Pad sequences to same length. Also bin it so the JIT has to compile fewer kernels. - max_len = round_up_seq_len(max(len(seq) for seq in all_input_ids)) + max_len = round_up_seq_len(max(len(seq) for seq in all_input_ids), cp=cp_size) input_ids = pad_batch(all_input_ids, max_len, np.int32) target_ids = pad_batch(all_targets, max_len, np.int32) @@ -609,7 +616,7 @@ def _model_pass( seq_len = input_ids.shape[1] # Sharding specs for batch inputs - sharding_2d = jax.NamedSharding(self.mesh, jax.P("fsdp", None)) + sharding_2d = jax.NamedSharding(self.mesh, jax.P("fsdp", "cp")) sharding_1d = jax.NamedSharding(self.mesh, jax.P("fsdp")) fsdp_size = self.mesh.shape["fsdp"] @@ -762,6 +769,65 @@ def optim_step(self, model_id: str, request_data: types.OptimStepInput) -> types logger.info(f"Applied optimizer step for model {model_id} (adapter {adapter_index}), grad_norm={grad_norm}") return types.OptimStepOutput(metrics={"skyrl.ai/grad_norm": grad_norm, "skyrl.ai/learning_rate": learning_rate}) + def _create_cp_sample_pass(self) -> None: + """ + Create a single CP sample pass. + Wraps _prefill_and_decode in shard_map to enable cp axis. + """ + lora_specs = nnx.get_partition_spec(self.lora_params) + non_lora_specs = nnx.get_partition_spec(self.non_lora_params) + batch_sharded_1d = jax.P("fsdp") + batch_sharded_2d = jax.P("fsdp", "cp") + positions_sharded = jax.P(None, "cp") + batch_sharded_2d_rep = jax.P("fsdp", None) + prefill_decode_in_specs = ( + batch_sharded_2d, # input_ids + batch_sharded_2d, # attention_mask + positions_sharded, # positions + None, # max_length (static) + None, # max_new_tokens (static) + batch_sharded_1d, # adapter_indices + batch_sharded_1d, # temperatures + batch_sharded_2d_rep, # rngs + batch_sharded_2d_rep, # stop_tokens + batch_sharded_1d, # top_k_values + batch_sharded_1d, # top_p_values + None, # max_top_k (static) + None, # use_top_p (static) + None, # prompt_logprobs (static) + ) + sample_output_shardings = ( + batch_sharded_2d_rep, # generated tokens + batch_sharded_2d_rep, # generated logprobs + batch_sharded_1d, # stop positions + batch_sharded_2d, # prompt logprobs + ) + static_argnums = (5, 6, 13, 14, 15) + + @jax.shard_map( + mesh=self.mesh, + in_specs=(lora_specs, non_lora_specs) + prefill_decode_in_specs, + out_specs=sample_output_shardings, + axis_names=set(self.mesh.axis_names), + check_vma=False, + ) + def _cp_sample_pass( + lora_params: nnx.State, + non_lora_params: nnx.State, + *decode_args, + ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: + model = nnx.merge(self.graphdef, lora_params, non_lora_params) + new_tokens, new_logprobs, stop_pos, prompt_logprobs_array = model._prefill_and_decode(model, *decode_args) + if prompt_logprobs_array is None: + # Satisfy shard_map output specs + prompt_logprobs_array = jnp.zeros((decode_args[0].shape[0], 0), dtype=jnp.float32) + return new_tokens, new_logprobs, stop_pos, prompt_logprobs_array + + self._cp_sample_pass = functools.partial(jax.jit, static_argnums=static_argnums)(_cp_sample_pass) + + def _cp_decode_runner(self, _model, *decode_args) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: + return self._cp_sample_pass(self.lora_params, self.non_lora_params, *decode_args) + def sample( self, prepared_batch: types.PreparedSampleBatch, @@ -797,8 +863,13 @@ def sample( all_prompt_logprobs: list[list[float]] = [] # Sharding specs for sampling inputs - sharding_2d = jax.NamedSharding(self.mesh, jax.P("fsdp", None)) - sharding_1d = jax.NamedSharding(self.mesh, jax.P("fsdp")) + cp_size = self.mesh.shape["cp"] + batch_sharded_1d = jax.P("fsdp") + batch_sharded_2d = jax.P("fsdp", "cp") + sharding_2d = jax.NamedSharding(self.mesh, batch_sharded_2d) + positions_sharding = jax.NamedSharding(self.mesh, jax.P(None, "cp")) + sharding_1d = jax.NamedSharding(self.mesh, batch_sharded_1d) + decode_runner = self._cp_decode_runner if cp_size > 1 else None with jax.set_mesh(self.mesh): model = nnx.merge(self.graphdef, self.lora_params, self.non_lora_params) @@ -810,26 +881,26 @@ def sample( all_sampling_params[batch_start:batch_end], max_batch_size, fill=all_sampling_params[batch_start] ) - # Pad sequences to same length within the batch to minimize memory usage. - # Also bin it so the JIT has to compile fewer kernels. max_len = round_up_seq_len(max((len(seq) for seq in batch_prompts), default=0)) input_ids = pad_batch(batch_prompts, max_len, np.int32) attention_mask = pad_batch([[1] * len(seq) for seq in batch_prompts], max_len, np.int32) + positions = np.arange(max_len, dtype=np.int32)[None, :] - # Shard inputs along FSDP axis (already padded to max_batch_size) - input_ids, attention_mask, adapter_indices = jax.device_put( - (input_ids, attention_mask, np.array(batch_adapter_indices, dtype=np.int32)), - (sharding_2d, sharding_2d, sharding_1d), + input_ids, attention_mask, positions, adapter_indices = jax.device_put( + (input_ids, attention_mask, positions, np.array(batch_adapter_indices, dtype=np.int32)), + (sharding_2d, sharding_2d, positions_sharding, sharding_1d), ) with self._jit_timing_context(max_len, mode="sample"): result = model.generate( input_ids, attention_mask, + positions=positions, sampling_params=sampling_params, adapter_indices=adapter_indices, prompt_logprobs=needs_prompt_logprobs, tokenizer=self.tokenizer, + decode_runner=decode_runner, ) # Only take the actual results, not the padded ones batch_size = batch_end - batch_start diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index b407140c6f..392dd1d0e4 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -1,6 +1,7 @@ """Generator mixin for autoregressive text generation with KV caching.""" from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass import functools @@ -60,11 +61,19 @@ def update_layer(kv_cache, k, v, positions): """ k_cache, v_cache = kv_cache - def update_at_pos(cache_slice, new_val_slice, pos): - return jax.lax.dynamic_update_slice(cache_slice, new_val_slice, (pos, 0, 0)) + cp = jax.sharding.get_abstract_mesh().shape.get("cp", 1) + local_capacity = k_cache.shape[1] + update_positions = positions[:, 0] % local_capacity + owners = positions[:, 0] // local_capacity + axis_idx = jax.lax.axis_index("cp") if cp > 1 else 0 + should_update = owners == axis_idx - k = jax.vmap(update_at_pos)(k_cache, k, positions[:, 0]) - v = jax.vmap(update_at_pos)(v_cache, v, positions[:, 0]) + def update_at_pos(cache_slice, new_val_slice, pos, do_update): + updated = jax.lax.dynamic_update_slice(cache_slice, new_val_slice, (pos, 0, 0)) + return jnp.where(do_update, updated, cache_slice) + + k = jax.vmap(update_at_pos)(k_cache, k, update_positions, should_update) + v = jax.vmap(update_at_pos)(v_cache, v, update_positions, should_update) return k, v def pad_to_length(self, max_length: int) -> KVCache: @@ -164,6 +173,7 @@ def _prefill_and_decode( model, input_ids: jax.Array, attention_mask: jax.Array, + positions: jax.Array, max_length: int, max_new_tokens: int, adapter_indices: jax.Array | None, @@ -181,6 +191,7 @@ def _prefill_and_decode( outputs = model( input_ids, attention_mask=attention_mask, + positions=positions, adapter_indices=adapter_indices, ) @@ -286,21 +297,27 @@ def generate( input_ids: jax.Array, attention_mask: jax.Array, *, + positions: jax.Array | None = None, sampling_params: list[types.SamplingParams], adapter_indices: jax.Array | None = None, prompt_logprobs: bool = False, tokenizer=None, + decode_runner: Callable | None = None, ) -> GenerateOutput: """Generate text autoregressively with KV caching. Args: tokenizer: Optional tokenizer for string stop sequence detection. Required if any sampling_params has stop_strings set. + decode_runner: Optional callable override for `_prefill_and_decode` + (for example, CP shard_map path). Returns: GenerateOutput containing generated_ids, stop_reasons, and optionally logprobs. """ batch_size, prompt_length = input_ids.shape + if positions is None: + positions = jnp.arange(prompt_length, dtype=jnp.int32)[None, :] assert len(sampling_params) == batch_size max_new_tokens = max(sampling_param.max_tokens for sampling_param in sampling_params) max_length = tx.utils.models.round_up_seq_len(prompt_length + max_new_tokens) @@ -327,10 +344,12 @@ def generate( max_top_k = max((sp.top_k for sp in sampling_params if sp.top_k > 0), default=0) use_top_p = any(sp.top_p < 1.0 for sp in sampling_params) - new_tokens, new_logprobs, stop_pos, prompt_logprobs_array = self._prefill_and_decode( + decode_runner = decode_runner or self._prefill_and_decode + new_tokens, new_logprobs, stop_pos, prompt_logprobs_array = decode_runner( self, input_ids, attention_mask, + positions, max_length, max_new_tokens, adapter_indices, @@ -341,7 +360,7 @@ def generate( top_p_values, max_top_k, use_top_p, - prompt_logprobs=prompt_logprobs, + prompt_logprobs, ) max_tokens = jnp.array([sp.max_tokens for sp in sampling_params]) diff --git a/skyrl-tx/tx/utils/models.py b/skyrl-tx/tx/utils/models.py index f938fe69fa..8a4c8d0f3b 100644 --- a/skyrl-tx/tx/utils/models.py +++ b/skyrl-tx/tx/utils/models.py @@ -351,24 +351,35 @@ def insert_state(path: tuple, p: jax.Array, new: jax.Array): nnx.update(lora_params, updated) -def round_up_seq_len(seq_len: int) -> int: +def round_up_seq_len(seq_len: int, cp: int | None = None) -> int: """ Rounds a sequence length up to roughly two significant binary digits. We do this to pad sequences, so the Jax JIT compiler needs to compile fewer different shapes. + + If Context parallelism is enabled (cp > 1), the sequence length is + padded to be a multiple of shard count for even distribution. """ if seq_len <= 32: - return 32 - - # Find the position of the most significant bit. - msb_pos = seq_len.bit_length() - 1 - # Create a mask for the two most significant bits. - mask = (1 << msb_pos) | (1 << (msb_pos - 1)) - # Round down to the nearest value with at most two significant bits. - result = seq_len & mask - - # If we rounded down, round up to the next bucket boundary. - if result < seq_len: - result += 1 << (msb_pos - 1) - - return result + rounded = 32 + else: + # Find the position of the most significant bit. + msb_pos = seq_len.bit_length() - 1 + # Create a mask for the two most significant bits. + mask = (1 << msb_pos) | (1 << (msb_pos - 1)) + # Round down to the nearest value with at most two significant bits. + rounded = seq_len & mask + + # If we rounded down, round up to the next bucket boundary. + if rounded < seq_len: + rounded += 1 << (msb_pos - 1) + + if cp is None: + # Try to infer from the mesh if not explicitly provided + mesh = jax.sharding.get_abstract_mesh() + cp = mesh.shape.get("cp", 1) if mesh is not None else 1 + + if cp > 1: + rounded = ((rounded + cp - 1) // cp) * cp + + return rounded