From 0af1a6f49215a7b20412fa3a559da2cf208df5cb Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Sat, 21 Feb 2026 07:55:28 -0800 Subject: [PATCH] Fix test_transformers_tp for torch 2.10 env Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 9d95a0651..61e934f0e 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1038,6 +1038,8 @@ def get_scale(x_max, w_max, alpha, tensor_parallel_group=None): def update_loss(self, out, out_actual, alpha): out_actual = out_actual[0] if isinstance(out_actual, tuple) else out_actual out = out[0] if isinstance(out, tuple) else out + out = out.to_local() if hasattr(out, "to_local") else out + out_actual = out_actual.to_local() if hasattr(out_actual, "to_local") else out_actual loss = (out - out_actual).float().pow(2).mean() self.awq_lite.loss[alpha] += loss.to(self.awq_lite.loss[alpha].device)