From a3d6b5c60e60f8b47bc7733d4b35e2a24979ea12 Mon Sep 17 00:00:00 2001 From: "s.malakhov" Date: Thu, 14 May 2026 10:49:32 +0300 Subject: [PATCH 1/3] [quantization] Introduce smse_for_gptq This PR introduces smse_for_gptq to improve accuracy. TICO-DCO-1.0-Signed-off-by: s.malakhov --- .../algorithm/fpi_gptq/fpi_gptq.py | 25 +----- tico/quantization/algorithm/fpi_gptq/util.py | 55 +++++++++++++ tico/quantization/algorithm/gptq/README.md | 3 + tico/quantization/algorithm/gptq/gptq.py | 4 +- tico/quantization/algorithm/gptq/quant.py | 77 ++++++++++++++++++- tico/quantization/algorithm/gptq/quantizer.py | 7 ++ tico/quantization/algorithm/gptq/utils.py | 11 ++- .../quantize_full_qmodel_with_gptq.py | 68 +++++++++++++++- 8 files changed, 219 insertions(+), 31 deletions(-) create mode 100644 tico/quantization/algorithm/fpi_gptq/util.py diff --git a/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py b/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py index cdd99ef7..641a59ae 100644 --- a/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +++ b/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py @@ -32,30 +32,7 @@ ) from tico.quantization.algorithm.gptq.quant import quantize, Quantizer - - -def iterate_GPTQ(scale, zero, maxq, W, Hinv, max_num_of_iters=50): - - cur_weights = W.clone() - mults = torch.pow(torch.diag(Hinv), -1) - Hinv_U = torch.triu(Hinv, diagonal=1) - - init_weights = W.clone() - for _ in range(max_num_of_iters): - cur_Q = quantize(cur_weights, scale, zero, maxq) - - d_W = torch.mul((cur_weights - cur_Q), mults) - cur_weights = init_weights - torch.matmul(d_W, Hinv_U) - del d_W, cur_Q - d_W = cur_Q = None - - del init_weights - init_weights = None - - cur_Q = quantize(cur_weights, scale, zero, maxq) - - return cur_Q, cur_weights - +from tico.quantization.algorithm.fpi_gptq.util import quantize, iterate_GPTQ class FPI_GPTQ: def __init__(self, layer): diff --git a/tico/quantization/algorithm/fpi_gptq/util.py b/tico/quantization/algorithm/fpi_gptq/util.py new file mode 100644 index 00000000..a1ed1f78 --- /dev/null +++ b/tico/quantization/algorithm/fpi_gptq/util.py @@ -0,0 +1,55 @@ +# Copyright IST-DASLab. 2025. (commit: 2d65066). GitHub repository. +# Retrieved from https://github.com/IST-DASLab/gptq. Licensed under the +# Apache License 2.0. + +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# https://github.com/IST-DASLab/gptq/blob/2d65066/quant.py + +import torch + +def quantize(x, scale, zero, maxq): + if maxq < 0: + return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + + +def iterate_GPTQ(scale, zero, maxq, W, Hinv, max_num_of_iters=50): + + cur_weights = W.clone() + mults = torch.pow(torch.diag(Hinv), -1) + Hinv_U = torch.triu(Hinv, diagonal=1) + + init_weights = W.clone() + for _ in range(max_num_of_iters): + cur_Q = quantize(cur_weights, scale, zero, maxq) + + d_W = torch.mul((cur_weights - cur_Q), mults) + cur_weights = init_weights - torch.matmul(d_W, Hinv_U) + del d_W, cur_Q + d_W = cur_Q = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + del init_weights + init_weights = None + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + cur_Q = quantize(cur_weights, scale, zero, maxq) + + return cur_Q, cur_weights diff --git a/tico/quantization/algorithm/gptq/README.md b/tico/quantization/algorithm/gptq/README.md index ae443a13..db99dcf7 100644 --- a/tico/quantization/algorithm/gptq/README.md +++ b/tico/quantization/algorithm/gptq/README.md @@ -52,6 +52,9 @@ applied after _convert()_, the effectiveness of GPTQ may be diminished. There are two options : 1. `mse`- vanilla `mse`. Produce quantization parameters for GPTQ quantizer (`min`\`max`) which minimize mean squared error of quantization. $MSE_{MIN, MAX}(W) = argmin_{min, max}||W-Q_{min, max}(W)||^2$. 2. `smse` - sensitivity-based `mse`. Use sensitivity of some global feature (e.g. float model logits) to parameters change to minimize global effect of quantization. $SMSE_{MIN, MAX}(W) = argmin_{min, max}|(W-Q_{min, max}(W))^2*Sensitivity(W)|$. So we try to keep `important` parameters unchanged, while quantizing `unimportant` parameters more aggressively. +3. `smse_for_gptq` - `smse` adjusted for GPTQ. GPTQ modifies the matrix during the quantization process, so the most accurate method would consist in finding a quantizer that yields the smallest quantization error after the GPTQ method has been applied $SMSE\_FOR\_GPTQ_{MIN, MAX}(W) = argmin_{min, max}|(W-Q_{min, max}(W_{GPTQ}))^2*Sensitivity(W)|$. Since this would be quite computationally expensive, we can use an accelerated approximate GPTQ method — FPI_GPTQ $SMSE\_FOR\_GPTQ_{MIN, MAX}(W) = argmin_{min, max}|(W-Q_{min, max}(W_{FPI\_GPTQ}))^2*Sensitivity(W)|$. This is slower than `mse`/`smse` but can provide better accuracy. + + You can turn this feature `on`/`off` by using `mse` parameter of `GPTQConfig`: ``` diff --git a/tico/quantization/algorithm/gptq/gptq.py b/tico/quantization/algorithm/gptq/gptq.py index ab01721f..e780cf4d 100644 --- a/tico/quantization/algorithm/gptq/gptq.py +++ b/tico/quantization/algorithm/gptq/gptq.py @@ -360,7 +360,9 @@ def fasterquant( H = torch.cholesky_inverse(H) H = torch.linalg.cholesky(H, upper=True).float() Hinv = H - + + self.quantizer.update(W, Hinv, perm) + assert isinstance(Hinv, torch.Tensor) for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) diff --git a/tico/quantization/algorithm/gptq/quant.py b/tico/quantization/algorithm/gptq/quant.py index 98e7731d..131eb7f5 100644 --- a/tico/quantization/algorithm/gptq/quant.py +++ b/tico/quantization/algorithm/gptq/quant.py @@ -21,6 +21,7 @@ import torch import torch.nn as nn +from tico.quantization.algorithm.fpi_gptq.util import iterate_GPTQ def quantize(x, scale, zero, maxq): if maxq < 0: @@ -101,7 +102,7 @@ def find_params(self, x, weight=False): else: self.zero = torch.round(-xmin / self.scale) - if self.mse is not None: + if self.mse is not None and self.mse != "smse_for_gptq" and self.mse != "mse_for_gptq": best = torch.full([x.shape[0]], float("inf"), device=dev) for i in range(int(self.maxshrink * self.grid)): p = 1 - i / self.grid @@ -151,6 +152,80 @@ def find_params(self, x, weight=False): self.scale = self.scale.unsqueeze(0) self.zero = self.zero.unsqueeze(0) + def update(self, x, Hinv, perm): + if self.mse is None or self.mse != "smse_for_gptq": + return + + shape = x.shape + if self.perchannel: + x = x.flatten(1) + else: + x = x.flatten().unsqueeze(0) + + dev = x.device + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + if self.maxq < 0: + self.scale = xmax + self.zero = xmin + else: + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) # type: ignore[arg-type] + else: + self.zero = torch.round(-xmin / self.scale) + + sensitivity = None + if self.sensitivity is not None: + sensitivity = self.sensitivity.to(Hinv.dtype).to(dev) + if perm is not None: + sensitivity = sensitivity[:, perm.to(dev)] + + num_of_iters = 15 + best = torch.full([x.shape[0]], float("inf"), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + q, _ = iterate_GPTQ( + scale1.unsqueeze(1), + zero1.unsqueeze(1), + self.maxq, + x, + Hinv, + max_num_of_iters=num_of_iters, + ) + assert sensitivity is not None + assert self.mse == "smse_for_gptq" + err = ((q - x) ** 2) * sensitivity.to(q.device) + + err = err + err = torch.sum(err, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + + del sensitivity + sensitivity = None + def quantize(self, x): if self.ready(): return quantize(x, self.scale, self.zero, self.maxq) diff --git a/tico/quantization/algorithm/gptq/quantizer.py b/tico/quantization/algorithm/gptq/quantizer.py index ba3d05b5..81b864f9 100644 --- a/tico/quantization/algorithm/gptq/quantizer.py +++ b/tico/quantization/algorithm/gptq/quantizer.py @@ -473,6 +473,12 @@ def _hook(_, inp, out): handles = [layer.register_forward_hook(add_batch())] # Run layer forward over all cached batches to build Hessian/statistics + old_device = device + model = model.to("cpu") + model.lm_head = model.lm_head.to(old_device) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + device = next(layer.parameters()).device # in case lm_head is located on cpu for batch_idx in tqdm( range(batch_num), @@ -502,3 +508,4 @@ def _hook(_, inp, out): ) quantizers[f"lm_head"] = gptq.quantizer gptq.free() + model = model.to(old_device) diff --git a/tico/quantization/algorithm/gptq/utils.py b/tico/quantization/algorithm/gptq/utils.py index 1f4d0321..943f4dc5 100644 --- a/tico/quantization/algorithm/gptq/utils.py +++ b/tico/quantization/algorithm/gptq/utils.py @@ -163,7 +163,11 @@ def compute_sensitivity_info(self): if self.show_progress is True: print("Calibrating sensitivity") for inputs, targets in tqdm.tqdm(data_loader, disable=not self.show_progress): - model.zero_grad() + model.zero_grad(set_to_none=True) + if model.device.type != "cpu": + torch.cuda.empty_cache() + torch.cuda.synchronize() + if isinstance(inputs, torch.Tensor): inp_ids = inputs.squeeze(0) # remove redundant batch dimension logits = model(inp_ids.to(model.device)).logits @@ -219,6 +223,11 @@ def compute_sensitivity_info(self): for name in modules_to_process: sensitivity[name] /= len(data_loader) + model.zero_grad(set_to_none=True) + if model.device.type != "cpu": + torch.cuda.synchronize() + torch.cuda.empty_cache() + model = model.to(dtype) return sensitivity diff --git a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py index 4a55c978..732dc5de 100644 --- a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py @@ -218,8 +218,8 @@ def parse_args(): "--gptq_mse", type=str, default=None, - choices=["mse", "smse"], - help="Whether and how to use mse in gptq (none/mse/smse/)", + choices=["mse", "smse", "smse_for_gptq"], + help="Whether and how to use mse in gptq (none/mse/smse/smse_for_gptq)", ) parser.add_argument( "--max_seq_len", @@ -390,6 +390,51 @@ def _print_sample(title, items): _print_sample("unused GPTQ entries", unused) +def evaluate_ppl_of_model_on_dataset(model, dataset, device: str = "cuda"): + if hasattr(model, "device") and model.device.type != device.type: + if hasattr(model, "to"): + model.to(device) + nlls = [] + with torch.no_grad(): + for batch in tqdm.tqdm(dataset): + if isinstance(batch, torch.Tensor): + batch = batch.to(device) + output = model( + batch.to(device), + ) + else: + raise RuntimeError("Unknown input in ppl_eval_on_dataset") + + if hasattr(output, "logits"): + lm_logits = output.logits + elif len(output) > 1: + lm_logits = torch.tensor(output[0]) + else: + lm_logits = torch.tensor(output) + + if torch.isfinite(lm_logits).all(): + shift_logits = lm_logits[:, :-1, :].contiguous() + if isinstance(batch, torch.Tensor): + shift_labels = batch[:, 1:].contiguous() + else: + assert isinstance(batch, tuple) + shift_labels = batch[0][:, 1:].contiguous() + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss = loss_fct( + shift_logits.reshape(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ) + nlls.append(loss) + del shift_logits, shift_labels + shift_logits = shift_labels = None # type: ignore[assignment] + + del batch, lm_logits, output + lm_logits = output = batch = None # noqa: F841 + torch.cuda.empty_cache() + + ppl = np.exp(torch.cat(nlls, dim=-1).mean().item()) + return ppl + # ------------------------------------------------------------------------- # Helper — clear gptq quantizers after injection # ------------------------------------------------------------------------- @@ -1285,9 +1330,9 @@ def build_calibration_inputs( def compute_or_load_sensitivity(model, calib_inputs, args): """ - Load or compute sensitivity information for SMSE GPTQ. + Load or compute sensitivity information for sensitivity-based GPTQ. """ - if args.gptq_mse != "smse": + if args.gptq_mse != "smse" and args.gptq_mse != "smse_for_gptq": return None if args.sensitivity_path is not None: @@ -1414,6 +1459,13 @@ def main(): evaluate_original_model(model, tokenizer, dataset_test, args, device) calib_inputs = build_calibration_inputs(model, tokenizer, args, device) + train_ppl_ioqdtype = evaluate_ppl_of_model_on_dataset( + model, calib_inputs, device=device + ) + print("\n┌── Wikitext-2 train perplexity ─────────────") + print(f"│ FP32 : {train_ppl_ioqdtype:8.2f}") + print("└───────────────────────────────────────────") + model = apply_spinquant(model, args) model = apply_cle(model, args) @@ -1422,6 +1474,14 @@ def main(): q_m = quantize_using_PTQ(model, calib_inputs, args) evaluate(q_m, tokenizer, dataset_test, args) + + train_ppl_ioqdtype = evaluate_ppl_of_model_on_dataset( + q_m, calib_inputs, device=device + ) + print("\n┌── Wikitext-2 train perplexity ─────────────") + print(f"│ int16 : {train_ppl_ioqdtype:8.2f}") + print("└───────────────────────────────────────────") + save_requested_artifacts(q_m, tokenizer, calib_inputs, args) From bcdb3a5de454eda14e1c78164abb21405edd1dac Mon Sep 17 00:00:00 2001 From: "s.malakhov" Date: Thu, 14 May 2026 11:54:33 +0300 Subject: [PATCH 2/3] refactor quant TICO-DCO-1.0-Signed-off-by: s.malakhov --- tico/quantization/algorithm/gptq/quant.py | 263 ++++++++++++++-------- 1 file changed, 169 insertions(+), 94 deletions(-) diff --git a/tico/quantization/algorithm/gptq/quant.py b/tico/quantization/algorithm/gptq/quant.py index 131eb7f5..b04ff063 100644 --- a/tico/quantization/algorithm/gptq/quant.py +++ b/tico/quantization/algorithm/gptq/quant.py @@ -60,10 +60,16 @@ def configure( if trits: self.maxq = torch.tensor(-1) - def find_params(self, x, weight=False): - dev = x.device - self.maxq = self.maxq.to(dev) - + def _prepare_tensor(self, x, weight=False): + """Prepare tensor for quantization by flattening according to per-channel setting. + + Args: + x: Input tensor to prepare + weight: Whether the tensor is a weight (affects flattening for activations) + + Returns: + Tuple of (prepared tensor, original shape) + """ shape = x.shape if self.perchannel: if weight: @@ -78,7 +84,18 @@ def find_params(self, x, weight=False): x = x.t() else: x = x.flatten().unsqueeze(0) + return x, shape + def _compute_scale_zero_bounds(self, x): + """Compute scale and zero bounds from tensor values. + + Args: + x: Prepared tensor (flattened according to per-channel setting) + + Returns: + Tuple of (scale, zero, xmin, xmax) computed from tensor bounds + """ + dev = x.device tmp = torch.zeros(x.shape[0], device=dev) xmin = torch.minimum(x.min(1)[0], tmp) xmax = torch.maximum(x.max(1)[0], tmp) @@ -93,112 +110,182 @@ def find_params(self, x, weight=False): xmax[tmp] = +1 if self.maxq < 0: - self.scale = xmax - self.zero = xmin + scale = xmax + zero = xmin else: - self.scale = (xmax - xmin) / self.maxq + scale = (xmax - xmin) / self.maxq if self.sym: - self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) # type: ignore[arg-type] + zero = torch.full_like(scale, (self.maxq + 1) / 2) # type: ignore[arg-type] else: - self.zero = torch.round(-xmin / self.scale) + zero = torch.round(-xmin / scale) - if self.mse is not None and self.mse != "smse_for_gptq" and self.mse != "mse_for_gptq": - best = torch.full([x.shape[0]], float("inf"), device=dev) - for i in range(int(self.maxshrink * self.grid)): - p = 1 - i / self.grid - xmin1 = p * xmin - xmax1 = p * xmax - scale1 = (xmax1 - xmin1) / self.maxq - zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero - q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) - q -= x - q.abs_() - if self.mse == "smse": # senstitivity weighted mse - # in case senstitivity is a second order derivatives of some global loss - # (q**2) * self.sensitivity is just a global loss change due to quantization. - q = (q**2) * self.sensitivity.to( - q.device - ) # estimate global target change - else: - assert self.mse == "mse" - q.pow_(self.norm) - err = torch.sum(q, 1) - tmp = err < best - if torch.any(tmp): - best[tmp] = err[tmp] - self.scale[tmp] = scale1[tmp] - self.zero[tmp] = zero1[tmp] - if not self.perchannel: - if weight: - tmp = shape[0] - else: - tmp = shape[1] if len(shape) != 3 else shape[2] - assert isinstance(tmp, int) - self.scale = self.scale.repeat(tmp) - self.zero = self.zero.repeat(tmp) + return scale, zero, xmin, xmax + def _reshape_scale_zero(self, shape, weight=False): + """Reshape scale and zero tensors according to the original tensor shape. + + Args: + shape: Original tensor shape before preparation + weight: Whether the tensor is a weight (affects reshape for activations) + """ if weight: shape = [-1] + [1] * (len(shape) - 1) self.scale = self.scale.reshape(shape) self.zero = self.zero.reshape(shape) return + if len(shape) == 4: self.scale = self.scale.reshape((1, -1, 1, 1)) self.zero = self.zero.reshape((1, -1, 1, 1)) - if len(shape) == 3: + elif len(shape) == 3: self.scale = self.scale.reshape((1, 1, -1)) self.zero = self.zero.reshape((1, 1, -1)) - if len(shape) == 2: + elif len(shape) == 2: self.scale = self.scale.unsqueeze(0) self.zero = self.zero.unsqueeze(0) - def update(self, x, Hinv, perm): - if self.mse is None or self.mse != "smse_for_gptq": - return + def _expand_for_per_tensor(self, shape, weight=False): + """Expand scale and zero for per-tensor quantization. + + Args: + shape: Original tensor shape before preparation + weight: Whether the tensor is a weight + """ + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + assert isinstance(tmp, int) + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + def find_params(self, x, weight=False): + dev = x.device + self.maxq = self.maxq.to(dev) shape = x.shape - if self.perchannel: - x = x.flatten(1) - else: - x = x.flatten().unsqueeze(0) + x, shape = self._prepare_tensor(x, weight) + + self.scale, self.zero, xmin, xmax = self._compute_scale_zero_bounds(x) + + if self.mse is not None and self.mse != "smse_for_gptq" and self.mse != "mse_for_gptq": + self._optimize_mse(x, xmin, xmax) + + self._expand_for_per_tensor(shape, weight) + self._reshape_scale_zero(shape, weight) + + def _compute_shrink_params(self, p, xmin, xmax): + """Compute scale and zero for a shrink factor p. + + Args: + p: Shrink factor (1 - i / grid) + xmin: Minimum values per channel + xmax: Maximum values per channel + + Returns: + Tuple of (scale1, zero1) for the given shrink factor + """ + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + return scale1, zero1 + def _update_best_params(self, best, err, scale1, zero1): + """Update best parameters if current error is lower. + + Args: + best: Current best error values + err: Current iteration error values + scale1: Current iteration scale values + zero1: Current iteration zero values + + Returns: + Updated best error values + """ + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + return best + + def _grid_search(self, x, xmin, xmax, compute_error_fn): + """Common grid search loop for MSE optimization. + + Args: + x: Prepared tensor + xmin: Minimum values per channel + xmax: Maximum values per channel + compute_error_fn: Function that takes (x, scale1, zero1) and returns error tensor + """ dev = x.device - tmp = torch.zeros(x.shape[0], device=dev) - xmin = torch.minimum(x.min(1)[0], tmp) - xmax = torch.maximum(x.max(1)[0], tmp) + best = torch.full([x.shape[0]], float("inf"), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + scale1, zero1 = self._compute_shrink_params(p, xmin, xmax) + err = compute_error_fn(x, scale1, zero1) + best = self._update_best_params(best, err, scale1, zero1) - if self.sym: - xmax = torch.maximum(torch.abs(xmin), xmax) - tmp = xmin < 0 - if torch.any(tmp): - xmin[tmp] = -xmax[tmp] - tmp = (xmin == 0) & (xmax == 0) - xmin[tmp] = -1 - xmax[tmp] = +1 - if self.maxq < 0: - self.scale = xmax - self.zero = xmin - else: - self.scale = (xmax - xmin) / self.maxq - if self.sym: - self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) # type: ignore[arg-type] + def _optimize_mse(self, x, xmin, xmax): + """Optimize scale and zero using MSE-based grid search. + + Args: + x: Prepared tensor + xmin: Minimum values per channel + xmax: Maximum values per channel + """ + def compute_error(x, scale1, zero1): + q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) + q -= x + q.abs_() + if self.mse == "smse": # sensitivity weighted mse + # in case sensitivity is a second order derivatives of some global loss + # (q**2) * self.sensitivity is just a global loss change due to quantization. + q = (q**2) * self.sensitivity.to(q.device) # estimate global target change else: - self.zero = torch.round(-xmin / self.scale) + assert self.mse == "mse" + q.pow_(self.norm) + return torch.sum(q, 1) + + self._grid_search(x, xmin, xmax, compute_error) + + def update(self, x, Hinv, perm): + if self.mse is None or self.mse != "smse_for_gptq": + return + + shape = x.shape + x, shape = self._prepare_tensor(x, weight=True) + + self.scale, self.zero, xmin, xmax = self._compute_scale_zero_bounds(x) sensitivity = None if self.sensitivity is not None: - sensitivity = self.sensitivity.to(Hinv.dtype).to(dev) + sensitivity = self.sensitivity.to(Hinv.dtype).to(x.device) if perm is not None: - sensitivity = sensitivity[:, perm.to(dev)] + sensitivity = sensitivity[:, perm.to(x.device)] + + self._optimize_mse_for_gptq(x, Hinv, sensitivity, xmin, xmax) + + self._reshape_scale_zero(shape, weight=True) + + del sensitivity + sensitivity = None + def _optimize_mse_for_gptq(self, x, Hinv, sensitivity, xmin, xmax): + """Optimize scale and zero using GPTQ-aware MSE grid search. + + Args: + x: Prepared tensor + Hinv: Inverse Hessian matrix + sensitivity: Sensitivity tensor for weighted MSE + xmin: Minimum values per channel + xmax: Maximum values per channel + """ num_of_iters = 15 - best = torch.full([x.shape[0]], float("inf"), device=dev) - for i in range(int(self.maxshrink * self.grid)): - p = 1 - i / self.grid - xmin1 = p * xmin - xmax1 = p * xmax - scale1 = (xmax1 - xmin1) / self.maxq - zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + + def compute_error(x, scale1, zero1): q, _ = iterate_GPTQ( scale1.unsqueeze(1), zero1.unsqueeze(1), @@ -210,21 +297,9 @@ def update(self, x, Hinv, perm): assert sensitivity is not None assert self.mse == "smse_for_gptq" err = ((q - x) ** 2) * sensitivity.to(q.device) - - err = err - err = torch.sum(err, 1) - tmp = err < best - if torch.any(tmp): - best[tmp] = err[tmp] - self.scale[tmp] = scale1[tmp] - self.zero[tmp] = zero1[tmp] - - shape = [-1] + [1] * (len(shape) - 1) - self.scale = self.scale.reshape(shape) - self.zero = self.zero.reshape(shape) + return torch.sum(err, 1) - del sensitivity - sensitivity = None + self._grid_search(x, xmin, xmax, compute_error) def quantize(self, x): if self.ready(): From 845c1ed1332c1874bfcfc2d6b7203a795460ab28 Mon Sep 17 00:00:00 2001 From: "s.malakhov" Date: Thu, 14 May 2026 13:33:46 +0300 Subject: [PATCH 3/3] format TICO-DCO-1.0-Signed-off-by: s.malakhov --- .../algorithm/fpi_gptq/fpi_gptq.py | 3 +- tico/quantization/algorithm/fpi_gptq/util.py | 3 +- tico/quantization/algorithm/gptq/gptq.py | 4 +- tico/quantization/algorithm/gptq/quant.py | 52 +++++++++++-------- tico/quantization/algorithm/gptq/quantizer.py | 2 +- tico/quantization/algorithm/gptq/utils.py | 2 +- .../quantize_full_qmodel_with_gptq.py | 16 +++--- 7 files changed, 46 insertions(+), 36 deletions(-) diff --git a/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py b/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py index 641a59ae..3ca13443 100644 --- a/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +++ b/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py @@ -32,7 +32,8 @@ ) from tico.quantization.algorithm.gptq.quant import quantize, Quantizer -from tico.quantization.algorithm.fpi_gptq.util import quantize, iterate_GPTQ +from tico.quantization.algorithm.fpi_gptq.util import iterate_GPTQ, quantize + class FPI_GPTQ: def __init__(self, layer): diff --git a/tico/quantization/algorithm/fpi_gptq/util.py b/tico/quantization/algorithm/fpi_gptq/util.py index a1ed1f78..e2c9ad7d 100644 --- a/tico/quantization/algorithm/fpi_gptq/util.py +++ b/tico/quantization/algorithm/fpi_gptq/util.py @@ -20,6 +20,7 @@ import torch + def quantize(x, scale, zero, maxq): if maxq < 0: return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero @@ -49,7 +50,7 @@ def iterate_GPTQ(scale, zero, maxq, W, Hinv, max_num_of_iters=50): if torch.cuda.is_available(): torch.cuda.empty_cache() - + cur_Q = quantize(cur_weights, scale, zero, maxq) return cur_Q, cur_weights diff --git a/tico/quantization/algorithm/gptq/gptq.py b/tico/quantization/algorithm/gptq/gptq.py index e780cf4d..17caa85c 100644 --- a/tico/quantization/algorithm/gptq/gptq.py +++ b/tico/quantization/algorithm/gptq/gptq.py @@ -360,9 +360,9 @@ def fasterquant( H = torch.cholesky_inverse(H) H = torch.linalg.cholesky(H, upper=True).float() Hinv = H - + self.quantizer.update(W, Hinv, perm) - + assert isinstance(Hinv, torch.Tensor) for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) diff --git a/tico/quantization/algorithm/gptq/quant.py b/tico/quantization/algorithm/gptq/quant.py index b04ff063..fb92817b 100644 --- a/tico/quantization/algorithm/gptq/quant.py +++ b/tico/quantization/algorithm/gptq/quant.py @@ -23,6 +23,7 @@ from tico.quantization.algorithm.fpi_gptq.util import iterate_GPTQ + def quantize(x, scale, zero, maxq): if maxq < 0: return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero @@ -62,11 +63,11 @@ def configure( def _prepare_tensor(self, x, weight=False): """Prepare tensor for quantization by flattening according to per-channel setting. - + Args: x: Input tensor to prepare weight: Whether the tensor is a weight (affects flattening for activations) - + Returns: Tuple of (prepared tensor, original shape) """ @@ -88,10 +89,10 @@ def _prepare_tensor(self, x, weight=False): def _compute_scale_zero_bounds(self, x): """Compute scale and zero bounds from tensor values. - + Args: x: Prepared tensor (flattened according to per-channel setting) - + Returns: Tuple of (scale, zero, xmin, xmax) computed from tensor bounds """ @@ -123,17 +124,17 @@ def _compute_scale_zero_bounds(self, x): def _reshape_scale_zero(self, shape, weight=False): """Reshape scale and zero tensors according to the original tensor shape. - + Args: shape: Original tensor shape before preparation weight: Whether the tensor is a weight (affects reshape for activations) """ if weight: shape = [-1] + [1] * (len(shape) - 1) - self.scale = self.scale.reshape(shape) - self.zero = self.zero.reshape(shape) + self.scale = self.scale.reshape(shape) # type: ignore[has-type] + self.zero = self.zero.reshape(shape) # type: ignore[has-type] return - + if len(shape) == 4: self.scale = self.scale.reshape((1, -1, 1, 1)) self.zero = self.zero.reshape((1, -1, 1, 1)) @@ -146,7 +147,7 @@ def _reshape_scale_zero(self, shape, weight=False): def _expand_for_per_tensor(self, shape, weight=False): """Expand scale and zero for per-tensor quantization. - + Args: shape: Original tensor shape before preparation weight: Whether the tensor is a weight @@ -169,7 +170,11 @@ def find_params(self, x, weight=False): self.scale, self.zero, xmin, xmax = self._compute_scale_zero_bounds(x) - if self.mse is not None and self.mse != "smse_for_gptq" and self.mse != "mse_for_gptq": + if ( + self.mse is not None + and self.mse != "smse_for_gptq" + and self.mse != "mse_for_gptq" + ): self._optimize_mse(x, xmin, xmax) self._expand_for_per_tensor(shape, weight) @@ -177,12 +182,12 @@ def find_params(self, x, weight=False): def _compute_shrink_params(self, p, xmin, xmax): """Compute scale and zero for a shrink factor p. - + Args: p: Shrink factor (1 - i / grid) xmin: Minimum values per channel xmax: Maximum values per channel - + Returns: Tuple of (scale1, zero1) for the given shrink factor """ @@ -194,13 +199,13 @@ def _compute_shrink_params(self, p, xmin, xmax): def _update_best_params(self, best, err, scale1, zero1): """Update best parameters if current error is lower. - + Args: best: Current best error values err: Current iteration error values scale1: Current iteration scale values zero1: Current iteration zero values - + Returns: Updated best error values """ @@ -213,7 +218,7 @@ def _update_best_params(self, best, err, scale1, zero1): def _grid_search(self, x, xmin, xmax, compute_error_fn): """Common grid search loop for MSE optimization. - + Args: x: Prepared tensor xmin: Minimum values per channel @@ -230,12 +235,13 @@ def _grid_search(self, x, xmin, xmax, compute_error_fn): def _optimize_mse(self, x, xmin, xmax): """Optimize scale and zero using MSE-based grid search. - + Args: x: Prepared tensor xmin: Minimum values per channel xmax: Maximum values per channel """ + def compute_error(x, scale1, zero1): q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) q -= x @@ -243,12 +249,14 @@ def compute_error(x, scale1, zero1): if self.mse == "smse": # sensitivity weighted mse # in case sensitivity is a second order derivatives of some global loss # (q**2) * self.sensitivity is just a global loss change due to quantization. - q = (q**2) * self.sensitivity.to(q.device) # estimate global target change + q = (q**2) * self.sensitivity.to( + q.device + ) # estimate global target change else: assert self.mse == "mse" q.pow_(self.norm) return torch.sum(q, 1) - + self._grid_search(x, xmin, xmax, compute_error) def update(self, x, Hinv, perm): @@ -269,13 +277,13 @@ def update(self, x, Hinv, perm): self._optimize_mse_for_gptq(x, Hinv, sensitivity, xmin, xmax) self._reshape_scale_zero(shape, weight=True) - + del sensitivity sensitivity = None def _optimize_mse_for_gptq(self, x, Hinv, sensitivity, xmin, xmax): """Optimize scale and zero using GPTQ-aware MSE grid search. - + Args: x: Prepared tensor Hinv: Inverse Hessian matrix @@ -284,7 +292,7 @@ def _optimize_mse_for_gptq(self, x, Hinv, sensitivity, xmin, xmax): xmax: Maximum values per channel """ num_of_iters = 15 - + def compute_error(x, scale1, zero1): q, _ = iterate_GPTQ( scale1.unsqueeze(1), @@ -298,7 +306,7 @@ def compute_error(x, scale1, zero1): assert self.mse == "smse_for_gptq" err = ((q - x) ** 2) * sensitivity.to(q.device) return torch.sum(err, 1) - + self._grid_search(x, xmin, xmax, compute_error) def quantize(self, x): diff --git a/tico/quantization/algorithm/gptq/quantizer.py b/tico/quantization/algorithm/gptq/quantizer.py index 81b864f9..5b7025fb 100644 --- a/tico/quantization/algorithm/gptq/quantizer.py +++ b/tico/quantization/algorithm/gptq/quantizer.py @@ -478,7 +478,7 @@ def _hook(_, inp, out): model.lm_head = model.lm_head.to(old_device) if torch.cuda.is_available(): torch.cuda.empty_cache() - + device = next(layer.parameters()).device # in case lm_head is located on cpu for batch_idx in tqdm( range(batch_num), diff --git a/tico/quantization/algorithm/gptq/utils.py b/tico/quantization/algorithm/gptq/utils.py index 943f4dc5..b1af9980 100644 --- a/tico/quantization/algorithm/gptq/utils.py +++ b/tico/quantization/algorithm/gptq/utils.py @@ -227,7 +227,7 @@ def compute_sensitivity_info(self): if model.device.type != "cpu": torch.cuda.synchronize() torch.cuda.empty_cache() - + model = model.to(dtype) return sensitivity diff --git a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py index 732dc5de..eaab7c35 100644 --- a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py @@ -390,7 +390,7 @@ def _print_sample(title, items): _print_sample("unused GPTQ entries", unused) -def evaluate_ppl_of_model_on_dataset(model, dataset, device: str = "cuda"): +def evaluate_ppl_of_model_on_dataset(model, dataset, device): if hasattr(model, "device") and model.device.type != device.type: if hasattr(model, "to"): model.to(device) @@ -435,6 +435,7 @@ def evaluate_ppl_of_model_on_dataset(model, dataset, device: str = "cuda"): ppl = np.exp(torch.cat(nlls, dim=-1).mean().item()) return ppl + # ------------------------------------------------------------------------- # Helper — clear gptq quantizers after injection # ------------------------------------------------------------------------- @@ -1460,12 +1461,11 @@ def main(): calib_inputs = build_calibration_inputs(model, tokenizer, args, device) train_ppl_ioqdtype = evaluate_ppl_of_model_on_dataset( - model, calib_inputs, device=device - ) + model, calib_inputs, device=device + ) print("\n┌── Wikitext-2 train perplexity ─────────────") print(f"│ FP32 : {train_ppl_ioqdtype:8.2f}") print("└───────────────────────────────────────────") - model = apply_spinquant(model, args) model = apply_cle(model, args) @@ -1474,14 +1474,14 @@ def main(): q_m = quantize_using_PTQ(model, calib_inputs, args) evaluate(q_m, tokenizer, dataset_test, args) - + train_ppl_ioqdtype = evaluate_ppl_of_model_on_dataset( - q_m, calib_inputs, device=device - ) + q_m, calib_inputs, device=device + ) print("\n┌── Wikitext-2 train perplexity ─────────────") print(f"│ int16 : {train_ppl_ioqdtype:8.2f}") print("└───────────────────────────────────────────") - + save_requested_artifacts(q_m, tokenizer, calib_inputs, args)