diff --git a/src/liger_kernel/transformers/model/glm4v.py b/src/liger_kernel/transformers/model/glm4v.py index 0dd3cda7f..de1bb1649 100644 --- a/src/liger_kernel/transformers/model/glm4v.py +++ b/src/liger_kernel/transformers/model/glm4v.py @@ -5,10 +5,30 @@ import torch +from packaging import version +from transformers import PretrainedConfig +from transformers import __version__ as transformers_version + from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast +_TRANSFORMERS_V5_OR_LATER: bool = version.parse(transformers_version) >= version.parse("5.0.0") + + +def _get_hidden_size(config: PretrainedConfig) -> int: + """Get hidden_size from Glm4vConfig in a version-aware manner.""" + if _TRANSFORMERS_V5_OR_LATER: + return config.text_config.hidden_size + return config.hidden_size + + +def _get_vocab_size(config: PretrainedConfig) -> int: + """Get vocab_size from Glm4vConfig in a version-aware manner.""" + if _TRANSFORMERS_V5_OR_LATER: + return config.text_config.vocab_size + return config.vocab_size + def lce_forward( self, @@ -130,7 +150,7 @@ def lce_forward( lm_head_weight=self.lm_head.weight, labels=labels, shift_labels=shift_labels, - hidden_size=self.config.hidden_size, + hidden_size=_get_hidden_size(self.config), **kwargs, ) loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) @@ -142,7 +162,7 @@ def lce_forward( logits=logits, labels=labels, shift_labels=shift_labels, - vocab_size=self.config.vocab_size, + vocab_size=_get_vocab_size(self.config), **kwargs, ) diff --git a/src/liger_kernel/transformers/model/glm4v_moe.py b/src/liger_kernel/transformers/model/glm4v_moe.py index 3203958f8..caa62bd64 100644 --- a/src/liger_kernel/transformers/model/glm4v_moe.py +++ b/src/liger_kernel/transformers/model/glm4v_moe.py @@ -4,10 +4,30 @@ import torch +from packaging import version +from transformers import PretrainedConfig +from transformers import __version__ as transformers_version + from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result from liger_kernel.transformers.model.output_classes import LigerGlm4vMoeCausalLMOutputWithPast +_TRANSFORMERS_V5_OR_LATER: bool = version.parse(transformers_version) >= version.parse("5.0.0") + + +def _get_hidden_size(config: PretrainedConfig) -> int: + """Get hidden_size from Glm4vMoeConfig in a version-aware manner.""" + if _TRANSFORMERS_V5_OR_LATER: + return config.text_config.hidden_size + return config.hidden_size + + +def _get_vocab_size(config: PretrainedConfig) -> int: + """Get vocab_size from Glm4vMoeConfig in a version-aware manner.""" + if _TRANSFORMERS_V5_OR_LATER: + return config.text_config.vocab_size + return config.vocab_size + def lce_forward( self, @@ -133,7 +153,7 @@ def lce_forward( lm_head_weight=self.lm_head.weight, labels=labels, shift_labels=shift_labels, - hidden_size=self.config.hidden_size, + hidden_size=_get_hidden_size(self.config), **kwargs, ) loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result) @@ -145,7 +165,7 @@ def lce_forward( logits=logits, labels=labels, shift_labels=shift_labels, - vocab_size=self.config.vocab_size, + vocab_size=_get_vocab_size(self.config), **kwargs, )