Enable Cohere Command-R (CohereForCausalLM / Cohere2ForCausalLM) on ATOM#675
Open
jatseng-ai wants to merge 7 commits intomainfrom
Open
Enable Cohere Command-R (CohereForCausalLM / Cohere2ForCausalLM) on ATOM#675jatseng-ai wants to merge 7 commits intomainfrom
jatseng-ai wants to merge 7 commits intomainfrom
Conversation
…e2 arch) CohereLabs/c4ai-command-r7b-12-2024 reports architecture Cohere2ForCausalLM in its HF config, not CohereForCausalLM. Without this registration the model fell through to native vLLM instead of using the ATOM model wrapper. Add Cohere2ForCausalLM -> CohereForCausalLM mapping in: - atom/model_engine/model_runner.py - atom/plugin/vllm/model_wrapper.py - atom/plugin/vllm/register.py Validated on ctr-cx63-mi300x-21 (MI300X) — all 6 functional smoke tests pass. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…weights as loaded Cohere2 checkpoints omit input_layernorm.bias and norm.bias (they are always zero). ATOM's LayerNorm creates these as nn.Parameters but the checkpoint has no matching keys, causing vLLM's weight-completeness check to fail with ValueError. - Store atom_config on CohereModel so load_weights() can access it - Implement load_weights() using load_model_in_plugin_mode (same pattern as Qwen3ForCausalLM) - Explicitly mark *.layernorm.bias and norm.bias as loaded so vLLM's default_loader completeness check passes Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…teness error Cohere2 checkpoints store only input_layernorm.weight, not bias (bias is always zero). ATOM's LayerNorm registers bias as nn.Parameter, causing vLLM's default_loader completeness check to fail. Introduce CohereLayerNorm: identical kernel path (layernorm2d_fwd_ / layernorm2d_fwd_with_add_) but bias is registered as a buffer instead of a parameter. Buffers are not tracked by named_parameters() so the completeness check never sees them. Replace LayerNorm -> CohereLayerNorm for input_layernorm and norm in CohereDecoderLayer and CohereModel. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
… models per_tensor_scale was only set inside forward_impl_plugin_mode when kv_cache_dtype == "fp8". On bf16 models (like Cohere Command-R), the rope_cache_plugin_mode path could access self.per_tensor_scale before it was initialized, causing AttributeError during CUDA graph warmup. Initialize per_tensor_scale = self.kv_scale at construction time (= 1.0 for bf16). The fp8 path overwrites it later with the actual scale. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
When per_layer_sliding_window=None is passed to vLLM's Attention layer, vLLM falls back to cache_config.sliding_window which picks up the global sliding_window=4096 from Cohere2Config. This causes rope_cache_plugin_mode to set use_triton_attn=True (since sw=4096 != -1), then fail on qkv.view() when qkv=None (Cohere passes q,k,v separately, not fused). Fix: use -1 (ATOM's "no sliding window" sentinel) as the explicit default. -1 is non-None, so vLLM won't fall back to cache_config.sliding_window. Sliding attention layers continue to use config.sliding_window correctly.
…re sliding attention Cohere passes separate q, k, v tensors (not a fused qkv). The triton branch in rope_cache_plugin_mode assumed a fused qkv and crashed with: AttributeError: 'NoneType' object has no attribute 'view' For sliding attention layers (sliding_window != -1), use_triton_attn=True triggered the triton branch. Guard it with `qkv is not None` so that when q/k/v are passed separately (qkv=None), the path falls through to the else branch which correctly applies RoPE via self.rotary_emb(position, q, k). Also remove debug logging added during investigation.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds full support for
CohereLabs/c4ai-command-r7b-12-2024and Cohere2-family models on the ATOM AMD MI300X backend.Changes
atom/plugin/vllm/register.py,atom/plugin/vllm/model_wrapper.py)atom/models/cohere.py— Cohere/Cohere2 model implementation with:CohereLayerNormusingregister_bufferfor zero bias (avoids vLLM weight-completeness check)CohereMLP,CohereAttention,CohereDecoderLayer,CohereModel,CohereForCausalLMconfig.layer_types(driven bysliding_window_pattern=4in Cohere2Config)use_qk_norm)atom/models/cohere.py) — use-1(ATOM's no-sliding-window sentinel) instead ofNonefor global attention layers, preventing vLLM from inheritingcache_config.sliding_window=4096via its fallback pathqkv=Nonecrash in triton branch (atom/plugin/attention_mha.py) — guard the fused-qkv triton path withqkv is not None; when Cohere passes separate q/k/v (qkv=None), fall through to the else branch which correctly applies RoPE on separate tensorsper_tensor_scaleAttributeError (atom/model_ops/attention_mha.py) — initializeself.per_tensor_scale = self.kv_scaleinPagedAttentionImpl.__init__so bf16 models don't crash on first forwardValidation
Functional smoke tests on AMD MI300X (8x GPU, ctr-cx63-mi300x-21, ATOM
rocm/atom-dev:vllm-latestimage):/v1/modelsAPIAll 6 tests passed with
--enforce-eager --tensor-parallel-size 8 --max-model-len 8192.