From 3edafe95c59d8b81c0bdc69b5c39a5c29c6739ae Mon Sep 17 00:00:00 2001 From: seongwoo Date: Thu, 21 May 2026 10:07:16 +0900 Subject: [PATCH] [quantization] Add tied embedding bit-width validation This commit adds tied embedding bit-width validation. TICO-DCO-1.0-Signed-off-by: seongwoo --- .../quantize_full_qmodel_with_gptq.py | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py index c71f7680..80968a18 100644 --- a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py @@ -437,6 +437,88 @@ def parse_cle_pairs(raw_pairs: list[str] | None) -> list[tuple[str, str]]: return pairs +def _weights_share_storage( + left: torch.Tensor, + right: torch.Tensor, +) -> bool: + """Return True if two weight tensors share the exact same storage slice.""" + if left is right: + return True + + if not isinstance(left, torch.Tensor) or not isinstance(right, torch.Tensor): + return False + + if left.device != right.device: + return False + + if left.device.type == "meta" or right.device.type == "meta": + return False + + if left.numel() == 0 or right.numel() == 0: + return False + + return ( + left.untyped_storage().data_ptr() == right.untyped_storage().data_ptr() + and left.storage_offset() == right.storage_offset() + and tuple(left.shape) == tuple(right.shape) + and tuple(left.stride()) == tuple(right.stride()) + ) + + +def has_tied_input_output_embeddings(model: torch.nn.Module) -> bool: + """Return True if the input embedding and LM head weights are tied.""" + get_input_embeddings = getattr(model, "get_input_embeddings", None) + get_output_embeddings = getattr(model, "get_output_embeddings", None) + + if not callable(get_input_embeddings) or not callable(get_output_embeddings): + return False + + input_embeddings = get_input_embeddings() + output_embeddings = get_output_embeddings() + + if input_embeddings is None or output_embeddings is None: + return False + + input_weight = getattr(input_embeddings, "weight", None) + output_weight = getattr(output_embeddings, "weight", None) + + if input_weight is None or output_weight is None: + return False + + return _weights_share_storage(input_weight, output_weight) + + +def validate_tied_embedding_weight_bits( + model: torch.nn.Module, + args: argparse.Namespace, +) -> None: + """ + Reject different embedding and LM head bit-widths for tied weights. + + Args: + model: Model whose input embedding and output projection are inspected. + args: Parsed command-line arguments. + + Raises: + ValueError: If the model ties input embedding and LM head weights while + `--embedding_weight_bits` and `--lm_head_weight_bits` differ. + """ + if args.embedding_weight_bits == args.lm_head_weight_bits: + return + + if not has_tied_input_output_embeddings(model): + return + + raise ValueError( + "Cannot use different bit-widths for tied input embedding and lm_head " + "weights: " + f"--embedding_weight_bits={args.embedding_weight_bits}, " + f"--lm_head_weight_bits={args.lm_head_weight_bits}. " + "Set both options to the same value or use a model with untied " + "input/output embeddings." + ) + + def build_gptq_config( args, sensitivity: dict[str, torch.Tensor] | None = None, @@ -1322,6 +1404,7 @@ def main(): print_config(args, device) model, tokenizer = load_model_and_tokenizer(args, dtype) + validate_tied_embedding_weight_bits(model, args) configure_max_position_embeddings(model, args) dataset_test = load_eval_dataset(args)