Skip to content

[tx] Implement context parallelism in tx with ring attention using ppermute#1149

Draft
tanmaysachan wants to merge 8 commits intoNovaSky-AI:mainfrom
tanmaysachan:tanmay/context_parallelism
Draft

[tx] Implement context parallelism in tx with ring attention using ppermute#1149
tanmaysachan wants to merge 8 commits intoNovaSky-AI:mainfrom
tanmaysachan:tanmay/context_parallelism

Conversation

@tanmaysachan
Copy link
Contributor

@tanmaysachan tanmaysachan commented Feb 16, 2026

Addresses #1056

This PR implements Context Parallelism (CP).
CP is now an axis alongside fsdp/ep/tp, with parity coverage against CP=1.

[AI wrote the description]

What’s included

  • Added context_parallel_size to model/backend configs.
  • Added CLI support via --cp-size in tx/run/train.py.
  • Extended mesh construction to 4D: ("fsdp", "ep", "tp", "cp").
  • Implemented CP-aware attention path:
    • ring-style KV exchange via ppermute
    • streaming softmax accumulation across CP shards
    • CP-aware default token positions.
  • Updated model attention callsites (Qwen3/Llama3/DeepSeek) to pass positions and use CP-aware attention helper.
  • Updated KV-cache layer updates for CP ownership/local position handling.
  • Updated sequence-length bucketing utility to support CP-aware padding (round_up_seq_len(..., cp=...)).
  • Refactored JAX backend train/sample paths for CP:
    • shard-map based model pass with CP partition specs
    • CP-specific sample pass with cached compilation
    • unified generate() path with optional decode runner override.
  • Added/updated tests for CP config and runtime behavior:
    • CP forward/backward execution
    • CP sample execution
    • CP parity (sample outputs, prompt logprobs, forward loss/logprobs, backward gradients)
    • attention test updates for positions API.

Validation

  • Added targeted CP tests in tests/tinker/test_jax_backend.py.
  • Updated tests/models/test_qwen3_config.py and tests/gpu/test_attention.py for new CP/positions behavior.

Tested over multi-cpu configs.

Tasks

  • Further simplify decode/prefill runner interfaces and reduce remaining wrapper boilerplate.

Open with Devin

@pcmoritz pcmoritz added the tx label Feb 16, 2026
@tanmaysachan tanmaysachan marked this pull request as ready for review February 20, 2026 12:44
@tanmaysachan tanmaysachan changed the title [tx][WIP] Implement context parallelism in tx [tx] Implement context parallelism in tx with ring attention using ppermute Feb 20, 2026
@tanmaysachan
Copy link
Contributor Author

PR ready for review.

One ugliness is the wrapping of prefill_and_decode in a shard_map in JaxBackendImpl. Do let me know if a cleaner design is preferred

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces Context Parallelism (CP) to the tx framework, adding support for 4D mesh construction (fsdp, ep, tp, cp) and CP-aware attention mechanisms. The changes include modifications to model configurations, CLI arguments, attention layer logic (including a new _ring_attention function), KV-cache updates, and sequence length bucketing. New tests have been added to validate CP functionality, including forward/backward execution, sampling, and parity checks against CP=1. The changes are well-structured and address the stated objective of implementing context parallelism.

Comment on lines +133 to +134
# TODO: constraints for running ring attention
if cp > 1 and (is_causal or q.shape[1] == 1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TODO: constraints for running ring attention indicates incomplete functionality or known limitations. It's important to either address these constraints or document them clearly for future development and users. If there are specific conditions under which _ring_attention should not be used, they should be explicitly handled or raised as errors.

Comment on lines +41 to +42
positions = jnp.broadcast_to(jnp.arange(q.shape[1], dtype=jnp.int32), (q.shape[0], q.shape[1]))
result = dot_product_attention(q, k, v, mask, is_causal=is_causal, head_dim=head_dim, positions=positions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The positions array is created using jnp.broadcast_to which might not be necessary if jnp.arange already produces the desired shape. Consider simplifying this to positions = jnp.arange(q.shape[1], dtype=jnp.int32)[None, :] if the batch dimension is always 1 for this test case, or ensuring broadcast_to is truly needed for varying batch sizes in other test scenarios.

positions = jnp.broadcast_to(jnp.arange(q.shape[1], dtype=jnp.int32), (q.shape[0], q.shape[1]))
    result = dot_product_attention(q, k, v, mask, is_causal=is_causal, head_dim=head_dim, positions=positions)

Comment on lines +30 to +33
kv_repeat = q.shape[2] // k.shape[2]
k_block = jnp.repeat(k, kv_repeat, axis=2)
v_block = jnp.repeat(v, kv_repeat, axis=2)
mask_block = attention_mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment GQA handling: expand KV heads to match query heads. is slightly misleading. jnp.repeat is used here to repeat the KV heads to match the number of query heads, which is a common practice in GQA to effectively broadcast KV to Q. However, the comment could be more precise by stating that it's expanding the KV heads to match the query heads for broadcasting purposes.

max_len = round_up_seq_len(max((len(seq) for seq in batch_prompts), default=0))
input_ids = pad_batch(batch_prompts, max_len, np.int32)
attention_mask = pad_batch([[1] * len(seq) for seq in batch_prompts], max_len, np.int32)
positions = np.arange(max_len, dtype=np.int32)[None, :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The positions array is created using np.arange(max_len, dtype=np.int32)[None, :]. While this works for the current use case, it might be more robust to use default_positions(input_ids) from tx.layers.attention to ensure consistency with how positions are generated elsewhere, especially if default_positions handles CP-specific offsets.

positions = default_positions(input_ids)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the other - we need to be under a shard_map for cp axis_index to exist. This positions can also optionally be left as None to be filled by the generator.

Comment on lines +64 to +69
cp = jax.sharding.get_abstract_mesh().shape.get("cp", 1)
local_capacity = k_cache.shape[1]
update_positions = positions[:, 0] % local_capacity
owners = positions[:, 0] // local_capacity
axis_idx = jax.lax.axis_index("cp") if cp > 1 else 0
should_update = owners == axis_idx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The update_at_pos function uses jnp.where(do_update, updated, cache_slice). This conditional update is correct for CP, but it introduces a potential performance overhead if do_update is frequently False for a large portion of the batch. While necessary for correctness, it's worth noting as a potential area for optimization if profiling reveals it to be a bottleneck, perhaps by using jax.lax.cond or jax.lax.select if the conditions are static enough.

Comment on lines +319 to +320
if positions is None:
positions = jnp.arange(prompt_length, dtype=jnp.int32)[None, :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The positions array is initialized with jnp.arange(prompt_length, dtype=jnp.int32)[None, :] if positions is None. To maintain consistency and leverage the CP-aware logic, it would be better to use default_positions(input_ids) from tx.layers.attention here, as it correctly handles CP shard offsets.

if positions is None:
            positions = default_positions(input_ids)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The axis_index fpr cp might not exist. default_positions is used when we are under a shard_map

devin-ai-integration[bot]

This comment was marked as resolved.

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 11 additional findings in Devin Review.

Open in Devin Review

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Decode loop initialization uses local attention_mask sum instead of global under CP, producing wrong positions and logits on non-owning shards

When _prefill_and_decode runs inside shard_map with CP > 1, the attention_mask is sharded along the CP axis. Each shard only sees its local portion. The line last_token_idx = attention_mask.sum(axis=1) - 1 computes the LOCAL count of valid tokens in each shard's slice, not the global count. This causes two critical issues:

Root Cause and Impact

Consider a prompt of length 20, padded to max_len=32, with cp=2. Each CP shard gets 16 positions:

  • Shard 0 (positions 0–15): attention_mask.sum() = 16, last_token_idx = 15 → extracts logits from global position 15 (NOT the actual last valid token at position 19)
  • Shard 1 (positions 16–31): attention_mask.sum() = 4, last_token_idx = 3 → extracts logits from global position 19 (correct)

Shard 0 extracts logits from the wrong position, samples a potentially different token, and starts the decode loop at position 16 instead of 20. Meanwhile, shard 1 produces the correct output. Since the generated tokens output spec is P("fsdp", None) (replicated across CP) with check_vma=False, the inconsistency is silently ignored and the result is undefined — whichever shard's values happen to be read will determine the output.

For short prompts fitting entirely within shard 0 (as in the PR's parity tests with length-5 prompts padded to 32), shard 0 holds all valid tokens and coincidentally produces the correct result, masking this bug. But any prompt spanning multiple CP shards will produce incorrect generation.

Additionally, decode positions derived from last_token_idx (s.last_positions + 1) diverge across shards, causing wrong RoPE embeddings and KV cache writes on the non-owning shards.

(Refers to line 199)

Prompt for agents
In skyrl-tx/tx/utils/generator.py, the _prefill_and_decode function at line 199 computes last_token_idx = attention_mask.sum(axis=1) - 1. Under CP (context parallelism via shard_map), attention_mask is the local shard's slice, so this gives the local count rather than the global last-token index.

To fix this, _prefill_and_decode needs to derive the global last valid token position and ensure all CP shards agree on the same decode starting state (logits, positions, etc.). Possible approaches:

1. Before entering the decode loop, use a collective (e.g., jax.lax.pmax across the 'cp' axis) to find the global last token position, then broadcast the correct initial logits and positions to all shards.

2. Alternatively, restructure the CP sample path so that prefill runs under shard_map (with ring attention), but the decode loop runs outside shard_map with a single-device KV cache. The prefill outputs (last hidden state, KV cache) would be gathered before entering the decode loop.

The fix must ensure that (a) the correct logits from the global last valid token are used for the first sampled token, (b) all CP shards use the same decode positions for RoPE and KV cache updates, and (c) the generated tokens are consistent across CP shards to satisfy the replicated output spec.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, finding a way to do this cleanly

@tanmaysachan tanmaysachan marked this pull request as draft February 20, 2026 13:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants