Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/mcore_bridge/config/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]:
res['add_qkv_bias'] = False
elif llm_model_type == 'olmoe':
res['qk_layernorm'] = True
elif llm_model_type in {'olmo2', 'olmo3'}:
res['qk_layernorm'] = True
if llm_model_type == 'olmo3' and window_size is not None and layer_types is not None:
res['window_size'] = f'{window_size - 1},0'
window_attn_skip_freq = ','.join(['1' if lt == 'sliding_attention' else '0' for lt in layer_types])
res['window_attn_skip_freq'] = f'[{window_attn_skip_freq}]'
elif hf_model_type == 'llama4':
qk_layernorm = res.pop('qk_layernorm', False)
if qk_layernorm:
Expand Down
2 changes: 2 additions & 0 deletions src/mcore_bridge/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ class LLMModelType:

qwen3_next = 'qwen3_next'
olmoe = 'olmoe'
olmo2 = 'olmo2'
olmo3 = 'olmo3'
glm4 = 'glm4'
minimax_m2 = 'minimax_m2'
hy_v3 = 'hy_v3'
Expand Down
3 changes: 2 additions & 1 deletion src/mcore_bridge/model/gpts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from . import bailing_hybrid, bailing_moe, deepseek_v4, glm4, hunyuan, llm, minimax_m2, olmoe, qwen3_emb, qwen3_next
from . import (bailing_hybrid, bailing_moe, deepseek_v4, glm4, hunyuan, llm, minimax_m2, olmo2, olmoe, qwen3_emb,
qwen3_next)
136 changes: 136 additions & 0 deletions src/mcore_bridge/model/gpts/olmo2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TENorm
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, apply_swiglu_sharded_factory
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.utils import sharded_state_dict_default
from typing import Optional

from mcore_bridge.config import ModelConfig

from ..constant import ModelType
from ..register import ModelLoader, ModelMeta, register_model
from .olmoe import OLMoEBridge, OLMoESelfAttention


class Olmo2SelfAttention(OLMoESelfAttention):
"""OLMo-2/3 attention.

Inherits OLMoE-style full-channel q/k RMSNorm, and additionally applies
a post-attention RMSNorm on the o_proj output (before the residual add),
matching the HF post-norm architecture (no input layernorm in HF).
"""

def __init__(self, config: ModelConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.post_self_attn_layernorm = build_module(
TENorm,
hidden_size=self.config.hidden_size,
config=self.config,
eps=self.config.layernorm_epsilon,
)

def forward(self, hidden_states, *args, **kwargs):
output, bias = super().forward(hidden_states, *args, **kwargs)
assert bias is None, 'OLMo-2/3 self attention does not support bias.'

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Avoid using assert statements for runtime validation or correctness checks, as they can be globally disabled in Python when run with optimization flags (e.g., python -O). Instead, raise an explicit ValueError or RuntimeError.

Suggested change
assert bias is None, 'OLMo-2/3 self attention does not support bias.'
if bias is not None:
raise ValueError('OLMo-2/3 self attention does not support bias.')

output = self.post_self_attn_layernorm(output)
return output, bias


class Olmo2MLP(MLP):
"""OLMo-2/3 MLP: applies a post-MLP RMSNorm before the residual add."""

def __init__(self, config: ModelConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.post_mlp_layernorm = build_module(
TENorm,
hidden_size=self.config.hidden_size,
config=self.config,
eps=self.config.layernorm_epsilon,
)

def forward(self, hidden_states, *args, **kwargs):
output, bias = super().forward(hidden_states, *args, **kwargs)
assert bias is None, 'OLMo-2/3 MLP does not support bias.'

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Avoid using assert statements for runtime validation or correctness checks, as they can be globally disabled in Python when run with optimization flags (e.g., python -O). Instead, raise an explicit ValueError or RuntimeError.

Suggested change
assert bias is None, 'OLMo-2/3 MLP does not support bias.'
if bias is not None:
raise ValueError('OLMo-2/3 MLP does not support bias.')

output = self.post_mlp_layernorm(output)
return output, bias

def sharded_state_dict(self,
prefix: str = '',
sharded_offsets: tuple = (),
metadata: Optional[dict] = None) -> ShardedStateDict:
sharded_state_dict = {}
singleton_local_shards = (metadata or {}).get('singleton_local_shards', False)
for name, module in self._modules.items():
sub_sd = sharded_state_dict_default(module, f'{prefix}{name}.', sharded_offsets, metadata)
if self.config.gated_linear_unit and name == 'linear_fc1':
for k, v in sub_sd.items():
if k in (f'{prefix}{name}.weight', f'{prefix}{name}.bias'):
sub_sd[k] = apply_swiglu_sharded_factory(v, sharded_offsets, singleton_local_shards)
sharded_state_dict.update(sub_sd)
return sharded_state_dict
Comment on lines +59 to +72

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of duplicating the entire sharding logic of the base MLP class (including the gated linear unit / SwiGLU sharding factory), you can call super().sharded_state_dict(...) and then simply update it with the sharded state dict of post_mlp_layernorm. This is much more maintainable and robust against future changes in the base class.

    def sharded_state_dict(self,
                           prefix: str = '',
                           sharded_offsets: tuple = (),
                           metadata: Optional[dict] = None) -> ShardedStateDict:
        sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
        post_norm_sd = sharded_state_dict_default(
            self.post_mlp_layernorm, f'{prefix}post_mlp_layernorm.', sharded_offsets, metadata
        )
        sharded_state_dict.update(post_norm_sd)
        return sharded_state_dict



class Olmo2Bridge(OLMoEBridge):
"""OLMo-2/3 bridge.

OLMo-2/3 is a post-norm only architecture: there is no `input_layernorm`
nor `pre_feedforward_layernorm` on the HF side. Each layer instead has:
* `post_attention_layernorm.weight` -- after self-attn, before residual
* `post_feedforward_layernorm.weight` -- after MLP, before residual
Together with OLMoE-style full-channel q/k_norm.
"""

def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool):
mg_attn = None if mg_layer is None else mg_layer.self_attention
# q/k/v/o + full-channel q_norm/k_norm via the inherited OLMoE path.
hf_state_dict.update(
self._set_attn_state(mg_attn, hf_state_dict, f'{self.hf_attn_prefix}.', layer_idx, to_mcore))
# No HF `input_layernorm.weight` exists; map the HF post-attn norm
# to the post_self_attn_layernorm we attach in Olmo2SelfAttention.
self._set_state_dict(mg_layer, 'self_attention.post_self_attn_layernorm.weight', hf_state_dict,
'post_attention_layernorm.weight', to_mcore)
return hf_state_dict

def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool, is_mtp: bool = False):
mg_mlp = None if mg_layer is None else mg_layer.mlp
hf_state_dict.update(
self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore))
# No HF `pre_feedforward_layernorm.weight` exists; map the HF
# post-MLP norm to the post_mlp_layernorm we attach in Olmo2MLP.
self._set_state_dict(mg_layer, 'mlp.post_mlp_layernorm.weight', hf_state_dict,
'post_feedforward_layernorm.weight', to_mcore)
return hf_state_dict


class Olmo2Loader(ModelLoader):

def get_transformer_layer_spec(self, vp_stage: Optional[int] = None):
transformer_layer_spec = super().get_transformer_layer_spec(vp_stage)
for layer_spec in transformer_layer_spec.layer_specs:
# OLMo-2/3 has no pre-norm: drop the layernorm fused into linear_qkv/linear_fc1
# and explicitly mark input_layernorm / pre_mlp_layernorm as identity ops.
layer_spec.submodules.input_layernorm = IdentityOp
layer_spec.submodules.pre_mlp_layernorm = IdentityOp
layer_spec.submodules.self_attention.submodules.linear_qkv = TEColumnParallelLinear
layer_spec.submodules.mlp.submodules.linear_fc1 = TEColumnParallelLinear
# Attach post-norms via custom SelfAttention / MLP modules.
layer_spec.submodules.self_attention.module = Olmo2SelfAttention
self._set_mlp_spec(layer_spec.submodules, Olmo2MLP)
return transformer_layer_spec


register_model(ModelMeta(
ModelType.olmo2,
['olmo2'],
bridge_cls=Olmo2Bridge,
loader=Olmo2Loader,
))

register_model(ModelMeta(
ModelType.olmo3,
['olmo3'],
bridge_cls=Olmo2Bridge,
loader=Olmo2Loader,
))
Loading