[tx] Implement context parallelism in tx with ring attention using ppermute#1149
[tx] Implement context parallelism in tx with ring attention using ppermute#1149tanmaysachan wants to merge 8 commits intoNovaSky-AI:mainfrom
ppermute#1149Conversation
ppermute
|
PR ready for review. One ugliness is the wrapping of prefill_and_decode in a shard_map in JaxBackendImpl. Do let me know if a cleaner design is preferred |
There was a problem hiding this comment.
Code Review
The pull request introduces Context Parallelism (CP) to the tx framework, adding support for 4D mesh construction (fsdp, ep, tp, cp) and CP-aware attention mechanisms. The changes include modifications to model configurations, CLI arguments, attention layer logic (including a new _ring_attention function), KV-cache updates, and sequence length bucketing. New tests have been added to validate CP functionality, including forward/backward execution, sampling, and parity checks against CP=1. The changes are well-structured and address the stated objective of implementing context parallelism.
| # TODO: constraints for running ring attention | ||
| if cp > 1 and (is_causal or q.shape[1] == 1): |
There was a problem hiding this comment.
The TODO: constraints for running ring attention indicates incomplete functionality or known limitations. It's important to either address these constraints or document them clearly for future development and users. If there are specific conditions under which _ring_attention should not be used, they should be explicitly handled or raised as errors.
| positions = jnp.broadcast_to(jnp.arange(q.shape[1], dtype=jnp.int32), (q.shape[0], q.shape[1])) | ||
| result = dot_product_attention(q, k, v, mask, is_causal=is_causal, head_dim=head_dim, positions=positions) |
There was a problem hiding this comment.
The positions array is created using jnp.broadcast_to which might not be necessary if jnp.arange already produces the desired shape. Consider simplifying this to positions = jnp.arange(q.shape[1], dtype=jnp.int32)[None, :] if the batch dimension is always 1 for this test case, or ensuring broadcast_to is truly needed for varying batch sizes in other test scenarios.
positions = jnp.broadcast_to(jnp.arange(q.shape[1], dtype=jnp.int32), (q.shape[0], q.shape[1]))
result = dot_product_attention(q, k, v, mask, is_causal=is_causal, head_dim=head_dim, positions=positions)| kv_repeat = q.shape[2] // k.shape[2] | ||
| k_block = jnp.repeat(k, kv_repeat, axis=2) | ||
| v_block = jnp.repeat(v, kv_repeat, axis=2) | ||
| mask_block = attention_mask |
There was a problem hiding this comment.
The comment GQA handling: expand KV heads to match query heads. is slightly misleading. jnp.repeat is used here to repeat the KV heads to match the number of query heads, which is a common practice in GQA to effectively broadcast KV to Q. However, the comment could be more precise by stating that it's expanding the KV heads to match the query heads for broadcasting purposes.
| max_len = round_up_seq_len(max((len(seq) for seq in batch_prompts), default=0)) | ||
| input_ids = pad_batch(batch_prompts, max_len, np.int32) | ||
| attention_mask = pad_batch([[1] * len(seq) for seq in batch_prompts], max_len, np.int32) | ||
| positions = np.arange(max_len, dtype=np.int32)[None, :] |
There was a problem hiding this comment.
The positions array is created using np.arange(max_len, dtype=np.int32)[None, :]. While this works for the current use case, it might be more robust to use default_positions(input_ids) from tx.layers.attention to ensure consistency with how positions are generated elsewhere, especially if default_positions handles CP-specific offsets.
positions = default_positions(input_ids)There was a problem hiding this comment.
Same as the other - we need to be under a shard_map for cp axis_index to exist. This positions can also optionally be left as None to be filled by the generator.
| cp = jax.sharding.get_abstract_mesh().shape.get("cp", 1) | ||
| local_capacity = k_cache.shape[1] | ||
| update_positions = positions[:, 0] % local_capacity | ||
| owners = positions[:, 0] // local_capacity | ||
| axis_idx = jax.lax.axis_index("cp") if cp > 1 else 0 | ||
| should_update = owners == axis_idx |
There was a problem hiding this comment.
The update_at_pos function uses jnp.where(do_update, updated, cache_slice). This conditional update is correct for CP, but it introduces a potential performance overhead if do_update is frequently False for a large portion of the batch. While necessary for correctness, it's worth noting as a potential area for optimization if profiling reveals it to be a bottleneck, perhaps by using jax.lax.cond or jax.lax.select if the conditions are static enough.
| if positions is None: | ||
| positions = jnp.arange(prompt_length, dtype=jnp.int32)[None, :] |
There was a problem hiding this comment.
The positions array is initialized with jnp.arange(prompt_length, dtype=jnp.int32)[None, :] if positions is None. To maintain consistency and leverage the CP-aware logic, it would be better to use default_positions(input_ids) from tx.layers.attention here, as it correctly handles CP shard offsets.
if positions is None:
positions = default_positions(input_ids)There was a problem hiding this comment.
The axis_index fpr cp might not exist. default_positions is used when we are under a shard_map
There was a problem hiding this comment.
🔴 Decode loop initialization uses local attention_mask sum instead of global under CP, producing wrong positions and logits on non-owning shards
When _prefill_and_decode runs inside shard_map with CP > 1, the attention_mask is sharded along the CP axis. Each shard only sees its local portion. The line last_token_idx = attention_mask.sum(axis=1) - 1 computes the LOCAL count of valid tokens in each shard's slice, not the global count. This causes two critical issues:
Root Cause and Impact
Consider a prompt of length 20, padded to max_len=32, with cp=2. Each CP shard gets 16 positions:
- Shard 0 (positions 0–15):
attention_mask.sum() = 16,last_token_idx = 15→ extracts logits from global position 15 (NOT the actual last valid token at position 19) - Shard 1 (positions 16–31):
attention_mask.sum() = 4,last_token_idx = 3→ extracts logits from global position 19 (correct)
Shard 0 extracts logits from the wrong position, samples a potentially different token, and starts the decode loop at position 16 instead of 20. Meanwhile, shard 1 produces the correct output. Since the generated tokens output spec is P("fsdp", None) (replicated across CP) with check_vma=False, the inconsistency is silently ignored and the result is undefined — whichever shard's values happen to be read will determine the output.
For short prompts fitting entirely within shard 0 (as in the PR's parity tests with length-5 prompts padded to 32), shard 0 holds all valid tokens and coincidentally produces the correct result, masking this bug. But any prompt spanning multiple CP shards will produce incorrect generation.
Additionally, decode positions derived from last_token_idx (s.last_positions + 1) diverge across shards, causing wrong RoPE embeddings and KV cache writes on the non-owning shards.
(Refers to line 199)
Prompt for agents
In skyrl-tx/tx/utils/generator.py, the _prefill_and_decode function at line 199 computes last_token_idx = attention_mask.sum(axis=1) - 1. Under CP (context parallelism via shard_map), attention_mask is the local shard's slice, so this gives the local count rather than the global last-token index.
To fix this, _prefill_and_decode needs to derive the global last valid token position and ensure all CP shards agree on the same decode starting state (logits, positions, etc.). Possible approaches:
1. Before entering the decode loop, use a collective (e.g., jax.lax.pmax across the 'cp' axis) to find the global last token position, then broadcast the correct initial logits and positions to all shards.
2. Alternatively, restructure the CP sample path so that prefill runs under shard_map (with ring attention), but the decode loop runs outside shard_map with a single-device KV cache. The prefill outputs (last hidden state, KV cache) would be gathered before entering the decode loop.
The fix must ensure that (a) the correct logits from the global last valid token are used for the first sampled token, (b) all CP shards use the same decode positions for RoPE and KV cache updates, and (c) the generated tokens are consistent across CP shards to satisfy the replicated output spec.
Was this helpful? React with 👍 or 👎 to provide feedback.
There was a problem hiding this comment.
Yeah, finding a way to do this cleanly
Addresses #1056
This PR implements Context Parallelism (CP).
CP is now an axis alongside fsdp/ep/tp, with parity coverage against CP=1.
[AI wrote the description]
What’s included
Validation
Tested over multi-cpu configs.
Tasks