Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion skyrl-tx/tests/gpu/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def assert_attention_match(q, k, v, mask, is_causal, head_dim, seq_lengths=None)
If None, compare all positions.
"""
scale = 1.0 / head_dim**0.5
result = dot_product_attention(q, k, v, mask, is_causal=is_causal, head_dim=head_dim)
positions = jnp.broadcast_to(jnp.arange(q.shape[1], dtype=jnp.int32), (q.shape[0], q.shape[1]))
result = dot_product_attention(q, k, v, mask, is_causal=is_causal, head_dim=head_dim, positions=positions)
Comment on lines +41 to +42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The positions array is created using jnp.broadcast_to which might not be necessary if jnp.arange already produces the desired shape. Consider simplifying this to positions = jnp.arange(q.shape[1], dtype=jnp.int32)[None, :] if the batch dimension is always 1 for this test case, or ensuring broadcast_to is truly needed for varying batch sizes in other test scenarios.

positions = jnp.broadcast_to(jnp.arange(q.shape[1], dtype=jnp.int32), (q.shape[0], q.shape[1]))
    result = dot_product_attention(q, k, v, mask, is_causal=is_causal, head_dim=head_dim, positions=positions)

expected = jax.nn.dot_product_attention(
q, k, v, scale=scale, mask=mask[:, None, None, :].astype(bool), is_causal=is_causal
)
Expand Down
15 changes: 15 additions & 0 deletions skyrl-tx/tests/models/test_qwen3_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def test_config_wraps_pretrained_config():
assert config.max_lora_adapters == 8
assert config.max_lora_rank == 16
assert config.shard_attention_heads is False
assert config.context_parallel_size == 1

# Check base config attributes were copied
assert config.vocab_size > 0
Expand All @@ -27,3 +28,17 @@ def test_config_preserves_moe_config():

# Check that MoE-specific attributes are preserved
assert config.num_experts > 0


def test_config_sets_context_parallel_size():
"""Test that context parallel size is configurable."""
hf_config = PretrainedConfig.from_pretrained("Qwen/Qwen3-0.6B")
config = Qwen3Config(
hf_config,
max_lora_adapters=8,
max_lora_rank=16,
shard_attention_heads=True,
context_parallel_size=2,
)

assert config.context_parallel_size == 2
143 changes: 143 additions & 0 deletions skyrl-tx/tests/tinker/test_jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,149 @@ def create_backend(max_lora_adapters: int = MAX_LORA_ADAPTERS):
return JaxBackend(BASE_MODEL, config)


def test_context_parallel_forward_backward_runs():
"""Training path should run with CP>1 via shard_map + ppermute attention."""
config = JaxBackendConfig(max_lora_adapters=MAX_LORA_ADAPTERS, max_lora_rank=32, context_parallel_size=2)
backend = JaxBackend(BASE_MODEL, config)
model_id = "cp_model"
create_model(backend, model_id)
reqs = {"1": (model_id, make_fwd_bwd_input([[1, 2, 3, 4]]))}

results = backend.forward_backward(prepare_model_pass_batch(reqs))
assert "1" in results


def test_context_parallel_sample_runs():
"""Sampling should run under CP>1 by disabling CP collectives in decode path."""
config = JaxBackendConfig(max_lora_adapters=MAX_LORA_ADAPTERS, max_lora_rank=32, context_parallel_size=2)
backend = JaxBackend(BASE_MODEL, config)
reqs = {"1": ("", make_sample_input([1, 2, 3], max_tokens=4))}

results = backend.sample(prepare_sample_batch(reqs))
assert "1" in results
assert len(results["1"].sequences) == 1


def test_context_parallel_sample_parity():
"""Sampling outputs should match between CP=1 and CP=2 for identical seeds/prompts."""
backend_cp1 = JaxBackend(
BASE_MODEL, JaxBackendConfig(max_lora_adapters=MAX_LORA_ADAPTERS, max_lora_rank=32, context_parallel_size=1)
)
backend_cp2 = JaxBackend(
BASE_MODEL, JaxBackendConfig(max_lora_adapters=MAX_LORA_ADAPTERS, max_lora_rank=32, context_parallel_size=2)
)

req1 = make_sample_input([1, 2, 3, 4, 5], max_tokens=6)
req1.sampling_params = api.SamplingParams(temperature=0.0, max_tokens=6, seed=7).to_types()
req2 = make_sample_input([6, 7, 8, 9], max_tokens=5)
req2.sampling_params = api.SamplingParams(temperature=1.0, max_tokens=5, seed=11).to_types()
reqs = {"r1": ("", req1), "r2": ("", req2)}

out_cp1 = backend_cp1.sample(prepare_sample_batch(reqs))
out_cp2 = backend_cp2.sample(prepare_sample_batch(reqs))

assert set(out_cp1.keys()) == set(out_cp2.keys()) == {"r1", "r2"}
for req_id in ("r1", "r2"):
seqs1 = out_cp1[req_id].sequences
seqs2 = out_cp2[req_id].sequences
assert len(seqs1) == len(seqs2)
for s1, s2 in zip(seqs1, seqs2):
assert s1.stop_reason == s2.stop_reason
assert s1.tokens == s2.tokens
np.testing.assert_allclose(np.asarray(s1.logprobs), np.asarray(s2.logprobs), rtol=1e-6, atol=1e-6)


def test_context_parallel_sample_prompt_logprobs_parity():
"""CP=2 sampling should match CP=1 for prompt logprobs and token outputs."""
backend_cp1 = JaxBackend(
BASE_MODEL, JaxBackendConfig(max_lora_adapters=MAX_LORA_ADAPTERS, max_lora_rank=32, context_parallel_size=1)
)
backend_cp2 = JaxBackend(
BASE_MODEL, JaxBackendConfig(max_lora_adapters=MAX_LORA_ADAPTERS, max_lora_rank=32, context_parallel_size=2)
)

req = make_sample_input([1, 2, 3, 4, 5, 6], prompt_logprobs=True, max_tokens=4)
req.sampling_params = api.SamplingParams(temperature=0.0, max_tokens=4, seed=17).to_types()
reqs = {"r": ("", req)}

out_cp1 = backend_cp1.sample(prepare_sample_batch(reqs))
out_cp2 = backend_cp2.sample(prepare_sample_batch(reqs))

seq1 = out_cp1["r"].sequences[0]
seq2 = out_cp2["r"].sequences[0]
assert seq1.stop_reason == seq2.stop_reason
assert seq1.tokens == seq2.tokens
np.testing.assert_allclose(np.asarray(seq1.logprobs), np.asarray(seq2.logprobs), rtol=1e-6, atol=1e-6)

assert out_cp1["r"].prompt_logprobs is not None
assert out_cp2["r"].prompt_logprobs is not None
np.testing.assert_allclose(
np.asarray(out_cp1["r"].prompt_logprobs),
np.asarray(out_cp2["r"].prompt_logprobs),
rtol=1e-6,
atol=1e-6,
)


def _extract_loss_and_logprobs(results: dict, request_id: str) -> tuple[list[np.ndarray], list[np.ndarray]]:
output = results[request_id]
losses = []
logprobs = []
for item in output.loss_fn_outputs:
losses.append(np.asarray(item["elementwise_loss"]["data"], dtype=np.float32))
logprobs.append(np.asarray(item["logprobs"]["data"], dtype=np.float32))
return losses, logprobs


def test_context_parallel_parity_forward_and_backward():
"""CP=2 should match CP=1 numerically for prefill forward/backward."""
model_id = "cp_parity_model"
reqs = {
"1": (
model_id,
make_fwd_bwd_input(
[
[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12],
]
),
)
}

backend_cp1 = JaxBackend(
BASE_MODEL, JaxBackendConfig(max_lora_adapters=MAX_LORA_ADAPTERS, max_lora_rank=32, context_parallel_size=1)
)
backend_cp2 = JaxBackend(
BASE_MODEL, JaxBackendConfig(max_lora_adapters=MAX_LORA_ADAPTERS, max_lora_rank=32, context_parallel_size=2)
)
create_model(backend_cp1, model_id)
create_model(backend_cp2, model_id)

prepared_batch = prepare_model_pass_batch(reqs)

# Forward parity
forward_cp1 = backend_cp1.forward(prepared_batch)
forward_cp2 = backend_cp2.forward(prepared_batch)
losses_cp1, logprobs_cp1 = _extract_loss_and_logprobs(forward_cp1, "1")
losses_cp2, logprobs_cp2 = _extract_loss_and_logprobs(forward_cp2, "1")

assert len(losses_cp1) == len(losses_cp2)
assert len(logprobs_cp1) == len(logprobs_cp2)
for a, b in zip(losses_cp1, losses_cp2):
np.testing.assert_allclose(a, b, atol=1e-4, rtol=1e-4)
for a, b in zip(logprobs_cp1, logprobs_cp2):
np.testing.assert_allclose(a, b, atol=1e-4, rtol=1e-4)

# Backward parity (compare per-adapter mean gradients)
backend_cp1.forward_backward(prepared_batch)
backend_cp2.forward_backward(prepared_batch)
adapter_idx_cp1 = backend_cp1.models[model_id].adapter_index
adapter_idx_cp2 = backend_cp2.models[model_id].adapter_index
mean_grads_cp1 = backend_cp1.accumulated_grads.get_mean(jnp.int32(adapter_idx_cp1))
mean_grads_cp2 = backend_cp2.accumulated_grads.get_mean(jnp.int32(adapter_idx_cp2))
_assert_tree_allclose(mean_grads_cp1, mean_grads_cp2, rtol=1e-3, atol=1e-3, min_match_pct=99.0)


def create_model(backend: JaxBackend, model_id: str) -> int:
"""Create a model and return its adapter index."""
lora_config = LoraConfig(rank=LORA_RANK, alpha=16, seed=0)
Expand Down
104 changes: 102 additions & 2 deletions skyrl-tx/tx/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,112 @@

import jax
import jax.numpy as jnp
from jax.sharding import get_abstract_mesh

# cuDNN flash attention supported dtypes
# https://github.com/jax-ml/jax/blob/8b1f782540f71fbe230a2dccd331975faafc6c83/jax/_src/cudnn/fused_attention_stablehlo.py#L290
_CUDNN_SUPPORTED_DTYPES = (jnp.float16, jnp.bfloat16, jnp.float8_e4m3fn, jnp.float8_e5m2)


def _ring_attention(
q: jax.Array,
k: jax.Array,
v: jax.Array,
attention_mask: jax.Array,
positions: jax.Array,
scale: float,
) -> jax.Array:
"""Streaming causal attention with ring KV exchange via ppermute."""
cp = get_abstract_mesh().shape.get("cp", 1)
axis_idx = jax.lax.axis_index("cp")
local_len = k.shape[1]

# qh: [B, H, Tq, D]
qh = jnp.transpose(q, (0, 2, 1, 3))

# Expand KV heads to match query heads (happens in GQA)
# k/v: [B, Tk, H_kv, D] -> [B, Tk, H, D]
kv_repeat = q.shape[2] // k.shape[2]
k_block = jnp.repeat(k, kv_repeat, axis=2)
v_block = jnp.repeat(v, kv_repeat, axis=2)
mask_block = attention_mask
Comment on lines +30 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment GQA handling: expand KV heads to match query heads. is slightly misleading. jnp.repeat is used here to repeat the KV heads to match the number of query heads, which is a common practice in GQA to effectively broadcast KV to Q. However, the comment could be more precise by stating that it's expanding the KV heads to match the query heads for broadcasting purposes.


# Online softmax state (kept per [B, H, Tq]):
# carry_max = running max score
# denom = running denominator sum(exp(score - carry_max))
# acc = running numerator sum(exp(score - m) * value), shape [B, H, Tq, D]
B, H, Tq, D = qh.shape
carry_max = jnp.full((B, H, Tq), -jnp.inf, dtype=q.dtype)
denom = jnp.zeros((B, H, Tq), dtype=q.dtype)
acc = jnp.zeros((B, H, Tq, D), dtype=q.dtype)
neg_large = jnp.array(jnp.finfo(q.dtype).min, dtype=q.dtype)

# Ring exchange: source i -> destination (i + 1) % cp.
perm = [(i, (i + 1) % cp) for i in range(cp)]

for step in range(cp):
source_shard = (axis_idx - step) % cp
# Absolute token positions for the current KV block, shape [Tk].
key_positions = source_shard * local_len + jnp.arange(local_len, dtype=jnp.int32)

# vh: [B, H, Tk, D]
# kht: [B, H, D, Tk] (K transposed for Q @ K^T)
kht = jnp.transpose(k_block, (0, 2, 3, 1))
vh = jnp.transpose(v_block, (0, 2, 1, 3))
scores = jnp.matmul(qh, kht) * scale

# Mask invalid keys (future tokens + padding) before softmax update.
causal = key_positions[None, None, None, :] <= positions[:, None, :, None]
padding = mask_block[:, None, None, :].astype(bool)
valid = causal & padding
scores = jnp.where(valid, scores, neg_large)

# Numerically stable online softmax merge:
# merge previous state (carry_max, denom, acc) with current block scores/values.
m_block = jnp.max(scores, axis=-1)
carry_max_new = jnp.maximum(carry_max, m_block)
prev_scale = jnp.where(jnp.isfinite(carry_max), jnp.exp(carry_max - carry_max_new), 0.0)
p = jnp.exp(scores - carry_max_new[..., None])
p = jnp.where(valid, p, 0.0)
denom_new = prev_scale * denom + jnp.sum(p, axis=-1)
acc_new = prev_scale[..., None] * acc + jnp.matmul(p, vh)
carry_max, denom, acc = carry_max_new, denom_new, acc_new

# Rotate KV/mask so the next iteration sees the next shard's block.
if step < cp - 1:
k_block = jax.lax.ppermute(k_block, axis_name="cp", perm=perm)
v_block = jax.lax.ppermute(v_block, axis_name="cp", perm=perm)
mask_block = jax.lax.ppermute(mask_block, axis_name="cp", perm=perm)

# Final normalize and restore [B, Tq, H, D]
out = jnp.where(
denom[..., None] > 0,
acc / jnp.maximum(denom[..., None], jnp.asarray(1e-9, dtype=denom.dtype)),
0.0,
)
return jnp.transpose(out, (0, 2, 1, 3))


def default_positions(input_ids: jax.Array) -> jax.Array:
"""Build token positions from input token shape, with CP shard offset."""
start, local_len = 0, input_ids.shape[1]
cp = get_abstract_mesh().shape.get("cp", 1)
if cp > 1:
axis_idx = jax.lax.axis_index("cp")
start = axis_idx * local_len
return (start + jnp.arange(local_len, dtype=jnp.int32))[None, :]


def dot_product_attention(
q: jax.Array,
k: jax.Array,
v: jax.Array,
attention_mask: jax.Array,
is_causal: bool,
head_dim: int,
*,
positions: jax.Array,
scale: float | None = None,
) -> jax.Array:
"""Compute dot-product attention with automatic backend selection.

Expand All @@ -27,12 +120,19 @@ def dot_product_attention(
attention_mask: Mask of shape [batch, kv_len] where 1 = valid, 0 = masked.
Sequences must be right-padded (valid tokens first, then padding).
is_causal: Whether to apply causal masking (for prefill/training)
head_dim: Dimension of each attention head (for scaling)
head_dim: Dimension of each attention head (for scaling when scale is not provided)
positions: Query positions, shape [batch, q_len], used for causal masking
scale: Optional explicit scale factor for attention logits

Returns:
Attention output of shape [batch, q_len, num_heads, head_dim]
"""
scale = 1.0 / head_dim**0.5
scale = scale if scale is not None else 1.0 / head_dim**0.5
cp = get_abstract_mesh().shape.get("cp", 1)

# TODO: constraints for running ring attention
if cp > 1 and (is_causal or q.shape[1] == 1):
Comment on lines +133 to +134
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TODO: constraints for running ring attention indicates incomplete functionality or known limitations. It's important to either address these constraints or document them clearly for future development and users. If there are specific conditions under which _ring_attention should not be used, they should be explicitly handled or raised as errors.

return _ring_attention(q, k, v, attention_mask, positions, scale)

if jax.default_backend() == "gpu" and q.dtype in _CUDNN_SUPPORTED_DTYPES:
kv_seq_lengths = attention_mask.sum(axis=1).astype(jnp.int32)
Expand Down
3 changes: 3 additions & 0 deletions skyrl-tx/tx/models/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ModelConfig(PretrainedConfig):
max_lora_adapters: int
max_lora_rank: int
shard_attention_heads: bool
context_parallel_size: int
loss_chunk_size: int
gradient_checkpointing: bool

Expand All @@ -32,6 +33,7 @@ def __init__(
max_lora_adapters: int,
max_lora_rank: int,
shard_attention_heads: bool,
context_parallel_size: int = 1,
loss_chunk_size: int = 0,
gradient_checkpointing: bool = False,
):
Expand All @@ -42,6 +44,7 @@ def __init__(
self.max_lora_adapters = max_lora_adapters
self.max_lora_rank = max_lora_rank
self.shard_attention_heads = shard_attention_heads
self.context_parallel_size = context_parallel_size
self.loss_chunk_size = loss_chunk_size
self.gradient_checkpointing = gradient_checkpointing

Expand Down
11 changes: 7 additions & 4 deletions skyrl-tx/tx/models/deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from jax import numpy as jnp
from jax.sharding import get_abstract_mesh

from tx.layers.attention import dot_product_attention, default_positions
from tx.layers.lora import LoRAEmbed, LoRAExpert, LoRALinear
from tx.layers.rotary_embedding import get_rope
from tx.layers.util import Param, prepare_routing, shard_map_ep
Expand Down Expand Up @@ -168,13 +169,15 @@ def __call__(
# Jax attention expects v to have the same shape as k
v = jnp.pad(v, ((0, 0), (0, 0), (0, 0), (0, self.qk_head_dim - self.v_head_dim)))

attn_output = jax.nn.dot_product_attention(
attn_output = dot_product_attention(
q,
k,
v,
scale=self.scaling,
mask=attention_mask[:, None, None, :].astype(bool),
attention_mask,
is_causal=kv_cache is None,
head_dim=self.qk_head_dim,
positions=positions,
scale=self.scaling,
)

attn_output = attn_output[:, :, :, : self.v_head_dim].reshape(B, T, self.num_heads * self.v_head_dim)
Expand Down Expand Up @@ -575,7 +578,7 @@ def __call__(
is_training: bool = False,
) -> CausalLMOutput:
if positions is None:
positions = jnp.arange(attention_mask.shape[1])[None, :]
positions = default_positions(input_ids)

outputs = self.model(
input_ids,
Expand Down
Loading
Loading