diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 9d95a0651..61e934f0e 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1038,6 +1038,8 @@ def get_scale(x_max, w_max, alpha, tensor_parallel_group=None): def update_loss(self, out, out_actual, alpha): out_actual = out_actual[0] if isinstance(out_actual, tuple) else out_actual out = out[0] if isinstance(out, tuple) else out + out = out.to_local() if hasattr(out, "to_local") else out + out_actual = out_actual.to_local() if hasattr(out_actual, "to_local") else out_actual loss = (out - out_actual).float().pow(2).mean() self.awq_lite.loss[alpha] += loss.to(self.awq_lite.loss[alpha].device)