diff --git a/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py b/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py index cdd99ef7..3ca13443 100644 --- a/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +++ b/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py @@ -32,29 +32,7 @@ ) from tico.quantization.algorithm.gptq.quant import quantize, Quantizer - - -def iterate_GPTQ(scale, zero, maxq, W, Hinv, max_num_of_iters=50): - - cur_weights = W.clone() - mults = torch.pow(torch.diag(Hinv), -1) - Hinv_U = torch.triu(Hinv, diagonal=1) - - init_weights = W.clone() - for _ in range(max_num_of_iters): - cur_Q = quantize(cur_weights, scale, zero, maxq) - - d_W = torch.mul((cur_weights - cur_Q), mults) - cur_weights = init_weights - torch.matmul(d_W, Hinv_U) - del d_W, cur_Q - d_W = cur_Q = None - - del init_weights - init_weights = None - - cur_Q = quantize(cur_weights, scale, zero, maxq) - - return cur_Q, cur_weights +from tico.quantization.algorithm.fpi_gptq.util import iterate_GPTQ, quantize class FPI_GPTQ: diff --git a/tico/quantization/algorithm/fpi_gptq/util.py b/tico/quantization/algorithm/fpi_gptq/util.py new file mode 100644 index 00000000..e2c9ad7d --- /dev/null +++ b/tico/quantization/algorithm/fpi_gptq/util.py @@ -0,0 +1,56 @@ +# Copyright IST-DASLab. 2025. (commit: 2d65066). GitHub repository. +# Retrieved from https://github.com/IST-DASLab/gptq. Licensed under the +# Apache License 2.0. + +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# https://github.com/IST-DASLab/gptq/blob/2d65066/quant.py + +import torch + + +def quantize(x, scale, zero, maxq): + if maxq < 0: + return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + + +def iterate_GPTQ(scale, zero, maxq, W, Hinv, max_num_of_iters=50): + + cur_weights = W.clone() + mults = torch.pow(torch.diag(Hinv), -1) + Hinv_U = torch.triu(Hinv, diagonal=1) + + init_weights = W.clone() + for _ in range(max_num_of_iters): + cur_Q = quantize(cur_weights, scale, zero, maxq) + + d_W = torch.mul((cur_weights - cur_Q), mults) + cur_weights = init_weights - torch.matmul(d_W, Hinv_U) + del d_W, cur_Q + d_W = cur_Q = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + del init_weights + init_weights = None + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + cur_Q = quantize(cur_weights, scale, zero, maxq) + + return cur_Q, cur_weights diff --git a/tico/quantization/algorithm/gptq/README.md b/tico/quantization/algorithm/gptq/README.md index ae443a13..db99dcf7 100644 --- a/tico/quantization/algorithm/gptq/README.md +++ b/tico/quantization/algorithm/gptq/README.md @@ -52,6 +52,9 @@ applied after _convert()_, the effectiveness of GPTQ may be diminished. There are two options : 1. `mse`- vanilla `mse`. Produce quantization parameters for GPTQ quantizer (`min`\`max`) which minimize mean squared error of quantization. $MSE_{MIN, MAX}(W) = argmin_{min, max}||W-Q_{min, max}(W)||^2$. 2. `smse` - sensitivity-based `mse`. Use sensitivity of some global feature (e.g. float model logits) to parameters change to minimize global effect of quantization. $SMSE_{MIN, MAX}(W) = argmin_{min, max}|(W-Q_{min, max}(W))^2*Sensitivity(W)|$. So we try to keep `important` parameters unchanged, while quantizing `unimportant` parameters more aggressively. +3. `smse_for_gptq` - `smse` adjusted for GPTQ. GPTQ modifies the matrix during the quantization process, so the most accurate method would consist in finding a quantizer that yields the smallest quantization error after the GPTQ method has been applied $SMSE\_FOR\_GPTQ_{MIN, MAX}(W) = argmin_{min, max}|(W-Q_{min, max}(W_{GPTQ}))^2*Sensitivity(W)|$. Since this would be quite computationally expensive, we can use an accelerated approximate GPTQ method — FPI_GPTQ $SMSE\_FOR\_GPTQ_{MIN, MAX}(W) = argmin_{min, max}|(W-Q_{min, max}(W_{FPI\_GPTQ}))^2*Sensitivity(W)|$. This is slower than `mse`/`smse` but can provide better accuracy. + + You can turn this feature `on`/`off` by using `mse` parameter of `GPTQConfig`: ``` diff --git a/tico/quantization/algorithm/gptq/gptq.py b/tico/quantization/algorithm/gptq/gptq.py index ab01721f..17caa85c 100644 --- a/tico/quantization/algorithm/gptq/gptq.py +++ b/tico/quantization/algorithm/gptq/gptq.py @@ -361,6 +361,8 @@ def fasterquant( H = torch.linalg.cholesky(H, upper=True).float() Hinv = H + self.quantizer.update(W, Hinv, perm) + assert isinstance(Hinv, torch.Tensor) for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) diff --git a/tico/quantization/algorithm/gptq/quant.py b/tico/quantization/algorithm/gptq/quant.py index 98e7731d..fb92817b 100644 --- a/tico/quantization/algorithm/gptq/quant.py +++ b/tico/quantization/algorithm/gptq/quant.py @@ -21,6 +21,8 @@ import torch import torch.nn as nn +from tico.quantization.algorithm.fpi_gptq.util import iterate_GPTQ + def quantize(x, scale, zero, maxq): if maxq < 0: @@ -59,10 +61,16 @@ def configure( if trits: self.maxq = torch.tensor(-1) - def find_params(self, x, weight=False): - dev = x.device - self.maxq = self.maxq.to(dev) + def _prepare_tensor(self, x, weight=False): + """Prepare tensor for quantization by flattening according to per-channel setting. + + Args: + x: Input tensor to prepare + weight: Whether the tensor is a weight (affects flattening for activations) + Returns: + Tuple of (prepared tensor, original shape) + """ shape = x.shape if self.perchannel: if weight: @@ -77,7 +85,18 @@ def find_params(self, x, weight=False): x = x.t() else: x = x.flatten().unsqueeze(0) + return x, shape + def _compute_scale_zero_bounds(self, x): + """Compute scale and zero bounds from tensor values. + + Args: + x: Prepared tensor (flattened according to per-channel setting) + + Returns: + Tuple of (scale, zero, xmin, xmax) computed from tensor bounds + """ + dev = x.device tmp = torch.zeros(x.shape[0], device=dev) xmin = torch.minimum(x.min(1)[0], tmp) xmax = torch.maximum(x.max(1)[0], tmp) @@ -92,65 +111,204 @@ def find_params(self, x, weight=False): xmax[tmp] = +1 if self.maxq < 0: - self.scale = xmax - self.zero = xmin + scale = xmax + zero = xmin else: - self.scale = (xmax - xmin) / self.maxq + scale = (xmax - xmin) / self.maxq if self.sym: - self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) # type: ignore[arg-type] + zero = torch.full_like(scale, (self.maxq + 1) / 2) # type: ignore[arg-type] else: - self.zero = torch.round(-xmin / self.scale) - - if self.mse is not None: - best = torch.full([x.shape[0]], float("inf"), device=dev) - for i in range(int(self.maxshrink * self.grid)): - p = 1 - i / self.grid - xmin1 = p * xmin - xmax1 = p * xmax - scale1 = (xmax1 - xmin1) / self.maxq - zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero - q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) - q -= x - q.abs_() - if self.mse == "smse": # senstitivity weighted mse - # in case senstitivity is a second order derivatives of some global loss - # (q**2) * self.sensitivity is just a global loss change due to quantization. - q = (q**2) * self.sensitivity.to( - q.device - ) # estimate global target change - else: - assert self.mse == "mse" - q.pow_(self.norm) - err = torch.sum(q, 1) - tmp = err < best - if torch.any(tmp): - best[tmp] = err[tmp] - self.scale[tmp] = scale1[tmp] - self.zero[tmp] = zero1[tmp] - if not self.perchannel: - if weight: - tmp = shape[0] - else: - tmp = shape[1] if len(shape) != 3 else shape[2] - assert isinstance(tmp, int) - self.scale = self.scale.repeat(tmp) - self.zero = self.zero.repeat(tmp) + zero = torch.round(-xmin / scale) + + return scale, zero, xmin, xmax + def _reshape_scale_zero(self, shape, weight=False): + """Reshape scale and zero tensors according to the original tensor shape. + + Args: + shape: Original tensor shape before preparation + weight: Whether the tensor is a weight (affects reshape for activations) + """ if weight: shape = [-1] + [1] * (len(shape) - 1) - self.scale = self.scale.reshape(shape) - self.zero = self.zero.reshape(shape) + self.scale = self.scale.reshape(shape) # type: ignore[has-type] + self.zero = self.zero.reshape(shape) # type: ignore[has-type] return + if len(shape) == 4: self.scale = self.scale.reshape((1, -1, 1, 1)) self.zero = self.zero.reshape((1, -1, 1, 1)) - if len(shape) == 3: + elif len(shape) == 3: self.scale = self.scale.reshape((1, 1, -1)) self.zero = self.zero.reshape((1, 1, -1)) - if len(shape) == 2: + elif len(shape) == 2: self.scale = self.scale.unsqueeze(0) self.zero = self.zero.unsqueeze(0) + def _expand_for_per_tensor(self, shape, weight=False): + """Expand scale and zero for per-tensor quantization. + + Args: + shape: Original tensor shape before preparation + weight: Whether the tensor is a weight + """ + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + assert isinstance(tmp, int) + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + def find_params(self, x, weight=False): + dev = x.device + self.maxq = self.maxq.to(dev) + + shape = x.shape + x, shape = self._prepare_tensor(x, weight) + + self.scale, self.zero, xmin, xmax = self._compute_scale_zero_bounds(x) + + if ( + self.mse is not None + and self.mse != "smse_for_gptq" + and self.mse != "mse_for_gptq" + ): + self._optimize_mse(x, xmin, xmax) + + self._expand_for_per_tensor(shape, weight) + self._reshape_scale_zero(shape, weight) + + def _compute_shrink_params(self, p, xmin, xmax): + """Compute scale and zero for a shrink factor p. + + Args: + p: Shrink factor (1 - i / grid) + xmin: Minimum values per channel + xmax: Maximum values per channel + + Returns: + Tuple of (scale1, zero1) for the given shrink factor + """ + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + return scale1, zero1 + + def _update_best_params(self, best, err, scale1, zero1): + """Update best parameters if current error is lower. + + Args: + best: Current best error values + err: Current iteration error values + scale1: Current iteration scale values + zero1: Current iteration zero values + + Returns: + Updated best error values + """ + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + return best + + def _grid_search(self, x, xmin, xmax, compute_error_fn): + """Common grid search loop for MSE optimization. + + Args: + x: Prepared tensor + xmin: Minimum values per channel + xmax: Maximum values per channel + compute_error_fn: Function that takes (x, scale1, zero1) and returns error tensor + """ + dev = x.device + best = torch.full([x.shape[0]], float("inf"), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + scale1, zero1 = self._compute_shrink_params(p, xmin, xmax) + err = compute_error_fn(x, scale1, zero1) + best = self._update_best_params(best, err, scale1, zero1) + + def _optimize_mse(self, x, xmin, xmax): + """Optimize scale and zero using MSE-based grid search. + + Args: + x: Prepared tensor + xmin: Minimum values per channel + xmax: Maximum values per channel + """ + + def compute_error(x, scale1, zero1): + q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) + q -= x + q.abs_() + if self.mse == "smse": # sensitivity weighted mse + # in case sensitivity is a second order derivatives of some global loss + # (q**2) * self.sensitivity is just a global loss change due to quantization. + q = (q**2) * self.sensitivity.to( + q.device + ) # estimate global target change + else: + assert self.mse == "mse" + q.pow_(self.norm) + return torch.sum(q, 1) + + self._grid_search(x, xmin, xmax, compute_error) + + def update(self, x, Hinv, perm): + if self.mse is None or self.mse != "smse_for_gptq": + return + + shape = x.shape + x, shape = self._prepare_tensor(x, weight=True) + + self.scale, self.zero, xmin, xmax = self._compute_scale_zero_bounds(x) + + sensitivity = None + if self.sensitivity is not None: + sensitivity = self.sensitivity.to(Hinv.dtype).to(x.device) + if perm is not None: + sensitivity = sensitivity[:, perm.to(x.device)] + + self._optimize_mse_for_gptq(x, Hinv, sensitivity, xmin, xmax) + + self._reshape_scale_zero(shape, weight=True) + + del sensitivity + sensitivity = None + + def _optimize_mse_for_gptq(self, x, Hinv, sensitivity, xmin, xmax): + """Optimize scale and zero using GPTQ-aware MSE grid search. + + Args: + x: Prepared tensor + Hinv: Inverse Hessian matrix + sensitivity: Sensitivity tensor for weighted MSE + xmin: Minimum values per channel + xmax: Maximum values per channel + """ + num_of_iters = 15 + + def compute_error(x, scale1, zero1): + q, _ = iterate_GPTQ( + scale1.unsqueeze(1), + zero1.unsqueeze(1), + self.maxq, + x, + Hinv, + max_num_of_iters=num_of_iters, + ) + assert sensitivity is not None + assert self.mse == "smse_for_gptq" + err = ((q - x) ** 2) * sensitivity.to(q.device) + return torch.sum(err, 1) + + self._grid_search(x, xmin, xmax, compute_error) + def quantize(self, x): if self.ready(): return quantize(x, self.scale, self.zero, self.maxq) diff --git a/tico/quantization/algorithm/gptq/quantizer.py b/tico/quantization/algorithm/gptq/quantizer.py index ba3d05b5..5b7025fb 100644 --- a/tico/quantization/algorithm/gptq/quantizer.py +++ b/tico/quantization/algorithm/gptq/quantizer.py @@ -473,6 +473,12 @@ def _hook(_, inp, out): handles = [layer.register_forward_hook(add_batch())] # Run layer forward over all cached batches to build Hessian/statistics + old_device = device + model = model.to("cpu") + model.lm_head = model.lm_head.to(old_device) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + device = next(layer.parameters()).device # in case lm_head is located on cpu for batch_idx in tqdm( range(batch_num), @@ -502,3 +508,4 @@ def _hook(_, inp, out): ) quantizers[f"lm_head"] = gptq.quantizer gptq.free() + model = model.to(old_device) diff --git a/tico/quantization/algorithm/gptq/utils.py b/tico/quantization/algorithm/gptq/utils.py index 1f4d0321..b1af9980 100644 --- a/tico/quantization/algorithm/gptq/utils.py +++ b/tico/quantization/algorithm/gptq/utils.py @@ -163,7 +163,11 @@ def compute_sensitivity_info(self): if self.show_progress is True: print("Calibrating sensitivity") for inputs, targets in tqdm.tqdm(data_loader, disable=not self.show_progress): - model.zero_grad() + model.zero_grad(set_to_none=True) + if model.device.type != "cpu": + torch.cuda.empty_cache() + torch.cuda.synchronize() + if isinstance(inputs, torch.Tensor): inp_ids = inputs.squeeze(0) # remove redundant batch dimension logits = model(inp_ids.to(model.device)).logits @@ -219,6 +223,11 @@ def compute_sensitivity_info(self): for name in modules_to_process: sensitivity[name] /= len(data_loader) + model.zero_grad(set_to_none=True) + if model.device.type != "cpu": + torch.cuda.synchronize() + torch.cuda.empty_cache() + model = model.to(dtype) return sensitivity diff --git a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py index 4a55c978..eaab7c35 100644 --- a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py @@ -218,8 +218,8 @@ def parse_args(): "--gptq_mse", type=str, default=None, - choices=["mse", "smse"], - help="Whether and how to use mse in gptq (none/mse/smse/)", + choices=["mse", "smse", "smse_for_gptq"], + help="Whether and how to use mse in gptq (none/mse/smse/smse_for_gptq)", ) parser.add_argument( "--max_seq_len", @@ -390,6 +390,52 @@ def _print_sample(title, items): _print_sample("unused GPTQ entries", unused) +def evaluate_ppl_of_model_on_dataset(model, dataset, device): + if hasattr(model, "device") and model.device.type != device.type: + if hasattr(model, "to"): + model.to(device) + nlls = [] + with torch.no_grad(): + for batch in tqdm.tqdm(dataset): + if isinstance(batch, torch.Tensor): + batch = batch.to(device) + output = model( + batch.to(device), + ) + else: + raise RuntimeError("Unknown input in ppl_eval_on_dataset") + + if hasattr(output, "logits"): + lm_logits = output.logits + elif len(output) > 1: + lm_logits = torch.tensor(output[0]) + else: + lm_logits = torch.tensor(output) + + if torch.isfinite(lm_logits).all(): + shift_logits = lm_logits[:, :-1, :].contiguous() + if isinstance(batch, torch.Tensor): + shift_labels = batch[:, 1:].contiguous() + else: + assert isinstance(batch, tuple) + shift_labels = batch[0][:, 1:].contiguous() + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss = loss_fct( + shift_logits.reshape(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ) + nlls.append(loss) + del shift_logits, shift_labels + shift_logits = shift_labels = None # type: ignore[assignment] + + del batch, lm_logits, output + lm_logits = output = batch = None # noqa: F841 + torch.cuda.empty_cache() + + ppl = np.exp(torch.cat(nlls, dim=-1).mean().item()) + return ppl + + # ------------------------------------------------------------------------- # Helper — clear gptq quantizers after injection # ------------------------------------------------------------------------- @@ -1285,9 +1331,9 @@ def build_calibration_inputs( def compute_or_load_sensitivity(model, calib_inputs, args): """ - Load or compute sensitivity information for SMSE GPTQ. + Load or compute sensitivity information for sensitivity-based GPTQ. """ - if args.gptq_mse != "smse": + if args.gptq_mse != "smse" and args.gptq_mse != "smse_for_gptq": return None if args.sensitivity_path is not None: @@ -1414,6 +1460,12 @@ def main(): evaluate_original_model(model, tokenizer, dataset_test, args, device) calib_inputs = build_calibration_inputs(model, tokenizer, args, device) + train_ppl_ioqdtype = evaluate_ppl_of_model_on_dataset( + model, calib_inputs, device=device + ) + print("\n┌── Wikitext-2 train perplexity ─────────────") + print(f"│ FP32 : {train_ppl_ioqdtype:8.2f}") + print("└───────────────────────────────────────────") model = apply_spinquant(model, args) model = apply_cle(model, args) @@ -1422,6 +1474,14 @@ def main(): q_m = quantize_using_PTQ(model, calib_inputs, args) evaluate(q_m, tokenizer, dataset_test, args) + + train_ppl_ioqdtype = evaluate_ppl_of_model_on_dataset( + q_m, calib_inputs, device=device + ) + print("\n┌── Wikitext-2 train perplexity ─────────────") + print(f"│ int16 : {train_ppl_ioqdtype:8.2f}") + print("└───────────────────────────────────────────") + save_requested_artifacts(q_m, tokenizer, calib_inputs, args)