Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ def reverse_op(self) -> ConversionOps:
raise NotImplementedError


class _IdentityOp(ConversionOps):
"""Pass-through reverse op for dequantize operations.

Dequantized weights are already in their target dtype and should be
saved as-is without any conversion.
"""

def convert(self, input_dict: dict[str, Any], **kwargs) -> dict[str, Any]:
return input_dict


class Chunk(ConversionOps):
"""Split a tensor along ``dim`` into equally sized chunks."""

Expand Down
6 changes: 5 additions & 1 deletion src/transformers/integrations/finegrained_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.nn import functional as F

from ..activations import ACT2FN
from ..core_model_loading import ConversionOps
from ..core_model_loading import ConversionOps, _IdentityOp
from ..quantizers.quantizers_utils import should_convert_module
from ..utils import is_kernels_available, is_torch_available, logging
from .hub_kernels import get_kernel
Expand Down Expand Up @@ -752,3 +752,7 @@ def convert(
return {
full_layer_name: dequantized.reshape(quantized.shape),
}

@property
def reverse_op(self) -> "ConversionOps":
return _IdentityOp()
6 changes: 5 additions & 1 deletion src/transformers/integrations/metal_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
which computes ``y = x @ dequant(weight).T``, identical to ``nn.Linear``.
"""

from ..core_model_loading import ConversionOps
from ..core_model_loading import ConversionOps, _IdentityOp
from ..quantizers.quantizers_utils import should_convert_module
from ..utils import is_torch_available, logging

Expand Down Expand Up @@ -294,3 +294,7 @@ def convert(self, input_dict: dict, full_layer_name: str | None = None, **kwargs

w_deq = _affine_dequantize_tensor(quantized, scales, qbiases, group_size, bits)
return {full_layer_name: w_deq.to(scales.dtype)}

@property
def reverse_op(self) -> "ConversionOps":
return _IdentityOp()
6 changes: 5 additions & 1 deletion src/transformers/integrations/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import nn
from contextlib import contextmanager

from ..core_model_loading import ConversionOps
from ..core_model_loading import ConversionOps, _IdentityOp
from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module


Expand Down Expand Up @@ -145,6 +145,10 @@ def convert(
dequantized = dequantize_convertops(param_data[f"{proj}_blocks"], param_data[f"{proj}_scales"])
return {full_layer_name: dequantized}

@property
def reverse_op(self) -> "ConversionOps":
return _IdentityOp()


class Mxfp4Deserialize(ConversionOps):
def __init__(self, hf_quantizer):
Expand Down
Loading