Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,88 @@ def parse_cle_pairs(raw_pairs: list[str] | None) -> list[tuple[str, str]]:
return pairs


def _weights_share_storage(
left: torch.Tensor,
right: torch.Tensor,
) -> bool:
"""Return True if two weight tensors share the exact same storage slice."""
if left is right:
return True

if not isinstance(left, torch.Tensor) or not isinstance(right, torch.Tensor):
return False

if left.device != right.device:
return False

if left.device.type == "meta" or right.device.type == "meta":
return False

if left.numel() == 0 or right.numel() == 0:
return False

return (
left.untyped_storage().data_ptr() == right.untyped_storage().data_ptr()
and left.storage_offset() == right.storage_offset()
and tuple(left.shape) == tuple(right.shape)
and tuple(left.stride()) == tuple(right.stride())
)


def has_tied_input_output_embeddings(model: torch.nn.Module) -> bool:
"""Return True if the input embedding and LM head weights are tied."""
get_input_embeddings = getattr(model, "get_input_embeddings", None)
get_output_embeddings = getattr(model, "get_output_embeddings", None)

if not callable(get_input_embeddings) or not callable(get_output_embeddings):
return False

input_embeddings = get_input_embeddings()
output_embeddings = get_output_embeddings()

if input_embeddings is None or output_embeddings is None:
return False

input_weight = getattr(input_embeddings, "weight", None)
output_weight = getattr(output_embeddings, "weight", None)

if input_weight is None or output_weight is None:
return False

return _weights_share_storage(input_weight, output_weight)


def validate_tied_embedding_weight_bits(
model: torch.nn.Module,
args: argparse.Namespace,
) -> None:
"""
Reject different embedding and LM head bit-widths for tied weights.

Args:
model: Model whose input embedding and output projection are inspected.
args: Parsed command-line arguments.

Raises:
ValueError: If the model ties input embedding and LM head weights while
`--embedding_weight_bits` and `--lm_head_weight_bits` differ.
"""
if args.embedding_weight_bits == args.lm_head_weight_bits:
return

if not has_tied_input_output_embeddings(model):
return

raise ValueError(
"Cannot use different bit-widths for tied input embedding and lm_head "
"weights: "
f"--embedding_weight_bits={args.embedding_weight_bits}, "
f"--lm_head_weight_bits={args.lm_head_weight_bits}. "
"Set both options to the same value or use a model with untied "
"input/output embeddings."
)


def build_gptq_config(
args,
sensitivity: dict[str, torch.Tensor] | None = None,
Expand Down Expand Up @@ -1322,6 +1404,7 @@ def main():
print_config(args, device)

model, tokenizer = load_model_and_tokenizer(args, dtype)
validate_tied_embedding_weight_bits(model, args)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

configure_max_position_embeddings(model, args)

dataset_test = load_eval_dataset(args)
Expand Down
Loading