Skip to content

Enable Cohere Command-R (CohereForCausalLM / Cohere2ForCausalLM) on ATOM#675

Open
jatseng-ai wants to merge 7 commits intomainfrom
jatseng2/enable-cohere-command-r
Open

Enable Cohere Command-R (CohereForCausalLM / Cohere2ForCausalLM) on ATOM#675
jatseng-ai wants to merge 7 commits intomainfrom
jatseng2/enable-cohere-command-r

Conversation

@jatseng-ai
Copy link
Copy Markdown

Summary

Adds full support for CohereLabs/c4ai-command-r7b-12-2024 and Cohere2-family models on the ATOM AMD MI300X backend.

Changes

  • Register Cohere2ForCausalLM in ATOM model registry and plugin model wrapper (atom/plugin/vllm/register.py, atom/plugin/vllm/model_wrapper.py)
  • New atom/models/cohere.py — Cohere/Cohere2 model implementation with:
    • CohereLayerNorm using register_buffer for zero bias (avoids vLLM weight-completeness check)
    • CohereMLP, CohereAttention, CohereDecoderLayer, CohereModel, CohereForCausalLM
    • Parallel residual (attn + MLP share the same pre-norm)
    • Per-layer sliding window via config.layer_types (driven by sliding_window_pattern=4 in Cohere2Config)
    • Q/K layer normalization support (use_qk_norm)
    • Non-neox RoPE (interleaved/repeat_interleave style)
  • Fix: sliding window sentinel (atom/models/cohere.py) — use -1 (ATOM's no-sliding-window sentinel) instead of None for global attention layers, preventing vLLM from inheriting cache_config.sliding_window=4096 via its fallback path
  • Fix: qkv=None crash in triton branch (atom/plugin/attention_mha.py) — guard the fused-qkv triton path with qkv is not None; when Cohere passes separate q/k/v (qkv=None), fall through to the else branch which correctly applies RoPE on separate tensors
  • Fix: per_tensor_scale AttributeError (atom/model_ops/attention_mha.py) — initialize self.per_tensor_scale = self.kv_scale in PagedAttentionImpl.__init__ so bf16 models don't crash on first forward

Validation

Functional smoke tests on AMD MI300X (8x GPU, ctr-cx63-mi300x-21, ATOM rocm/atom-dev:vllm-latest image):

Test Result
Basic chat completion PASS
Math reasoning (7×8=56) PASS
Sliding window (long prompt) PASS
System message PASS
Multi-turn conversation PASS
/v1/models API PASS

All 6 tests passed with --enforce-eager --tensor-parallel-size 8 --max-model-len 8192.

jatseng-ai and others added 7 commits April 14, 2026 11:09
…e2 arch)

CohereLabs/c4ai-command-r7b-12-2024 reports architecture Cohere2ForCausalLM
in its HF config, not CohereForCausalLM. Without this registration the model
fell through to native vLLM instead of using the ATOM model wrapper.

Add Cohere2ForCausalLM -> CohereForCausalLM mapping in:
- atom/model_engine/model_runner.py
- atom/plugin/vllm/model_wrapper.py
- atom/plugin/vllm/register.py

Validated on ctr-cx63-mi300x-21 (MI300X) — all 6 functional smoke tests pass.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…weights as loaded

Cohere2 checkpoints omit input_layernorm.bias and norm.bias (they are
always zero). ATOM's LayerNorm creates these as nn.Parameters but the
checkpoint has no matching keys, causing vLLM's weight-completeness
check to fail with ValueError.

- Store atom_config on CohereModel so load_weights() can access it
- Implement load_weights() using load_model_in_plugin_mode (same
  pattern as Qwen3ForCausalLM)
- Explicitly mark *.layernorm.bias and norm.bias as loaded so vLLM's
  default_loader completeness check passes

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…teness error

Cohere2 checkpoints store only input_layernorm.weight, not bias (bias is
always zero). ATOM's LayerNorm registers bias as nn.Parameter, causing
vLLM's default_loader completeness check to fail.

Introduce CohereLayerNorm: identical kernel path (layernorm2d_fwd_ /
layernorm2d_fwd_with_add_) but bias is registered as a buffer instead
of a parameter. Buffers are not tracked by named_parameters() so the
completeness check never sees them.

Replace LayerNorm -> CohereLayerNorm for input_layernorm and norm in
CohereDecoderLayer and CohereModel.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
… models

per_tensor_scale was only set inside forward_impl_plugin_mode when
kv_cache_dtype == "fp8". On bf16 models (like Cohere Command-R), the
rope_cache_plugin_mode path could access self.per_tensor_scale before
it was initialized, causing AttributeError during CUDA graph warmup.

Initialize per_tensor_scale = self.kv_scale at construction time (= 1.0
for bf16). The fp8 path overwrites it later with the actual scale.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
When per_layer_sliding_window=None is passed to vLLM's Attention layer,
vLLM falls back to cache_config.sliding_window which picks up the global
sliding_window=4096 from Cohere2Config. This causes rope_cache_plugin_mode
to set use_triton_attn=True (since sw=4096 != -1), then fail on
qkv.view() when qkv=None (Cohere passes q,k,v separately, not fused).

Fix: use -1 (ATOM's "no sliding window" sentinel) as the explicit default.
-1 is non-None, so vLLM won't fall back to cache_config.sliding_window.
Sliding attention layers continue to use config.sliding_window correctly.
…re sliding attention

Cohere passes separate q, k, v tensors (not a fused qkv). The triton branch
in rope_cache_plugin_mode assumed a fused qkv and crashed with:
  AttributeError: 'NoneType' object has no attribute 'view'

For sliding attention layers (sliding_window != -1), use_triton_attn=True
triggered the triton branch. Guard it with `qkv is not None` so that when
q/k/v are passed separately (qkv=None), the path falls through to the else
branch which correctly applies RoPE via self.rotary_emb(position, q, k).

Also remove debug logging added during investigation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant