From 8eb70ec67c685abdfd1ec7cfb32c8bc6da810ca4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 9 Feb 2026 10:46:47 +0100 Subject: [PATCH 01/10] add example to export experts part --- .../plot_export_optimind_input_observer.py | 133 ++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 _doc/final/plot_export_optimind_input_observer.py diff --git a/_doc/final/plot_export_optimind_input_observer.py b/_doc/final/plot_export_optimind_input_observer.py new file mode 100644 index 00000000..47316542 --- /dev/null +++ b/_doc/final/plot_export_optimind_input_observer.py @@ -0,0 +1,133 @@ +""" +.. _l-plot-optimind-export-input-observer: + +Export OptiMind-SFT with InputObserver +====================================== + +This reuses the recipe introduced by example :ref:`l-plot-tiny-llm-export-input-observer` +for model `microsoft/OptiMind-SFT `_. +We only export class ``GptOssExperts``. + +Let's create a random model ++++++++++++++++++++++++++++ +""" + +import pandas +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from onnx_diagnostic import doc +from onnx_diagnostic.export.api import to_onnx +from onnx_diagnostic.helpers import string_type +from onnx_diagnostic.torch_export_patches import ( + register_additional_serialization_functions, + torch_export_patches, +) +from onnx_diagnostic.investigate.input_observer import InputObserver + +device = "cuda" +model_id = "microsoft/OptiMind-SFT" +print(f"get tokenizer {model_id!r}") +tokenizer = AutoTokenizer.from_pretrained(model_id) +print(f"get config {model_id!r}") +config = AutoConfig.from_pretrained(model_id) +config.num_hidden_layers = 2 +config.layer_types = config.layer_types[:2] +print(f"create model from config for {model_id!r}") +model = AutoModelForCausalLM.from_config(config) +print(f"the model is created with {len(list(model.named_modules()))} subdmodules.") +model = model.to(device) + +# %% +# We need to only export class GptOssExperts +# ++++++++++++++++++++++++++++++++++++++++++ + + +def generate_text( + prompt, + model, + tokenizer, + max_length=50, + temperature=0.01, + top_k=50, + top_p=0.95, + do_sample=True, +): + inputs = tokenizer(prompt, return_tensors="pt") + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + outputs = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + do_sample=do_sample, + ) + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + return generated_text + + +export_module = None +for _name, sub in model.named_modules(): + if sub.__class__.__name__ == "GptOssExperts": + export_module = sub + +assert export_module is not None, ( + f"Unable to find a submodule from class GptOssExperts in " + f"{set(sub.__class__.__name__ for _, sub in model.named_modules())}" +) + +# Define your prompt +prompt = "Continue: it rains, what should I do?" +observer = InputObserver() +with ( + register_additional_serialization_functions(patch_transformers=True), + observer(export_module), +): + generate_text(prompt, model, tokenizer) + + +# %% +# Export +# ++++++ +# +# First, what was inferred. + +args = observer.infer_arguments() +dynamic_shapes = observer.infer_dynamic_shapes() +print(f"kwargs={string_type(args, with_shape=True)}") +print(f"dynamic_shapes={dynamic_shapes}") + +# %% +# Next, the export. + + +filename = "plot_export_optimind_experts_input_observer.onnx" +with torch_export_patches(patch_transformers=True): + to_onnx( + export_module, + args=args, + filename=filename, + dynamic_shapes=dynamic_shapes, + exporter="custom", + verbose=1, + ) + +# %% +# Let's measure the discrepancies. +data = observer.check_discrepancies(filename, progress_bar=True, atol=1e-2, include_io=True) +df = pandas.DataFrame(data) +df.to_excel("plot_export_optimind_input_observer.xlsx") +print(df) + +# %% +# Let's show the errors. +for row in data: + if not row["SUCCESS"] and "error" in row: + print(row["error"]) + + +# %% +doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400) From de619b92015b8760e6a1cf8701a3d1b6f433189c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Mon, 9 Feb 2026 17:42:57 +0100 Subject: [PATCH 02/10] add one more example --- .gitignore | 1 + _doc/conf.py | 2 +- .../plot_export_optimind_input_observer.py | 27 ++-- ...xport_tiny_llm_attention_input_observer.py | 137 ++++++++++++++++++ .../ut_investigate/test_input_observer.py | 46 ++++++ onnx_diagnostic/investigate/input_observer.py | 56 ++++++- 6 files changed, 248 insertions(+), 21 deletions(-) create mode 100644 _doc/final/plot_export_tiny_llm_attention_input_observer.py diff --git a/.gitignore b/.gitignore index 7294213c..0f87726f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +*.data *.pyc *.pyd *.dylib diff --git a/_doc/conf.py b/_doc/conf.py index 2a5b8928..2fddac5e 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -211,7 +211,7 @@ def linkcode_resolve(domain, info): if int(os.environ.get("UNITTEST_GOING", "0")): sphinx_gallery_conf["ignore_pattern"] = ( - ".*((tiny_llm)|(dort)|(draft_mode)|(hub_codellama.py)|(whisper)).*" + ".*((tiny_llm)|(dort)|(draft_mode)|(hub_codellama.py)|(whisper)|(optimind)).*" ) elif pv.Version(torch.__version__) < pv.Version("2.8"): sphinx_gallery_conf["ignore_pattern"] = ".*((_oe_)|(dort)|(draft_mode)).*" diff --git a/_doc/final/plot_export_optimind_input_observer.py b/_doc/final/plot_export_optimind_input_observer.py index 47316542..ae6da045 100644 --- a/_doc/final/plot_export_optimind_input_observer.py +++ b/_doc/final/plot_export_optimind_input_observer.py @@ -41,6 +41,20 @@ # ++++++++++++++++++++++++++++++++++++++++++ +export_module = None +for _name, sub in model.named_modules(): + if sub.__class__.__name__ == "GptOssExperts": + export_module = sub + +assert export_module is not None, ( + f"Unable to find a submodule from class GptOssExperts in " + f"{set(sub.__class__.__name__ for _, sub in model.named_modules())}" +) + +# %% +# Let's run the model and capture inputs and outputs + + def generate_text( prompt, model, @@ -69,17 +83,6 @@ def generate_text( return generated_text -export_module = None -for _name, sub in model.named_modules(): - if sub.__class__.__name__ == "GptOssExperts": - export_module = sub - -assert export_module is not None, ( - f"Unable to find a submodule from class GptOssExperts in " - f"{set(sub.__class__.__name__ for _, sub in model.named_modules())}" -) - -# Define your prompt prompt = "Continue: it rains, what should I do?" observer = InputObserver() with ( @@ -97,7 +100,7 @@ def generate_text( args = observer.infer_arguments() dynamic_shapes = observer.infer_dynamic_shapes() -print(f"kwargs={string_type(args, with_shape=True)}") +print(f"args={string_type(args, with_shape=True, with_device=True)}") print(f"dynamic_shapes={dynamic_shapes}") # %% diff --git a/_doc/final/plot_export_tiny_llm_attention_input_observer.py b/_doc/final/plot_export_tiny_llm_attention_input_observer.py new file mode 100644 index 00000000..5c8eae21 --- /dev/null +++ b/_doc/final/plot_export_tiny_llm_attention_input_observer.py @@ -0,0 +1,137 @@ +""" +.. _l-plot-tiny-llm-attention-export-input-observer: + +Export attention from arnir0/Tiny-LLM with InputObserver +======================================================== + +This shows how to only export attention from model +`arnir0/Tiny-LLM `_. +It uses what was shown in example +:ref:`l-plot-tiny-llm-export-input-observer`. + +Let's create a random model ++++++++++++++++++++++++++++ +""" + +import pandas +import torch +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from onnx_diagnostic import doc +from onnx_diagnostic.export.api import to_onnx +from onnx_diagnostic.helpers import string_type +from onnx_diagnostic.torch_export_patches import ( + register_additional_serialization_functions, + torch_export_patches, +) +from onnx_diagnostic.investigate.input_observer import InputObserver + +device = "cuda" +model_id = "arnir0/Tiny-LLM" +print(f"get tokenizer {model_id!r}") +tokenizer = AutoTokenizer.from_pretrained(model_id) +print(f"get config {model_id!r}") +config = AutoConfig.from_pretrained(model_id) +print(f"create model from config for {model_id!r}") +model = AutoModelForCausalLM.from_config(config) +print(f"the model is created with {len(list(model.named_modules()))} subdmodules.") +model = model.to(device).to(torch.float16) + +# %% +# We need to only export class LlamaAttention +# +++++++++++++++++++++++++++++++++++++++++++ + + +export_module = None +for _name, sub in model.named_modules(): + if sub.__class__.__name__ == "LlamaAttention": + export_module = sub + +assert export_module is not None, ( + f"Unable to find a submodule from class LlamaAttention in " + f"{set(sub.__class__.__name__ for _, sub in model.named_modules())}" +) + +# %% +# Let's run the model and capture the inputs and outputs of the attention part. + + +def generate_text( + prompt, + model, + tokenizer, + max_length=50, + temperature=0.01, + top_k=50, + top_p=0.95, + do_sample=True, +): + inputs = tokenizer(prompt, return_tensors="pt") + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + outputs = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + do_sample=do_sample, + ) + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + return generated_text + + +prompt = "Continue: it rains, what should I do?" +observer = InputObserver() +with ( + register_additional_serialization_functions(patch_transformers=True), + observer(export_module), +): + generate_text(prompt, model, tokenizer) + + +# %% +# Export +# ++++++ +# +# First, what was inferred. + +kwargs = observer.infer_arguments() +dynamic_shapes = observer.infer_dynamic_shapes() +print(f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}") +print(f"dynamic_shapes={dynamic_shapes}") + +# %% +# Next, the export. + + +filename = "plot_export_tiny_llm_attention_input_observer.onnx" +with torch_export_patches(patch_transformers=True): + to_onnx( + export_module, + args=(), + kwargs=kwargs, + filename=filename, + dynamic_shapes=dynamic_shapes, + exporter="custom", + verbose=1, + ) + +# %% +# Let's measure the discrepancies. +data = observer.check_discrepancies(filename, progress_bar=True, atol=1e-2, include_io=True) +df = pandas.DataFrame(data) +df.to_excel("plot_export_tiny_llm_attention_input_observer.xlsx") +print(df) + +# %% +# Let's show the errors. +for row in data: + if not row["SUCCESS"] and "error" in row: + print(row["error"]) + + +# %% +doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400) diff --git a/_unittests/ut_investigate/test_input_observer.py b/_unittests/ut_investigate/test_input_observer.py index 645becd3..1ef0bfb5 100644 --- a/_unittests/ut_investigate/test_input_observer.py +++ b/_unittests/ut_investigate/test_input_observer.py @@ -938,6 +938,52 @@ def forward( } self.assertEqual(expected, shapes) + def test_io_captured_kwargs_kwargs(self): + class Model(torch.nn.Module): + def forward(self, x, **kwargs): + return x + kwargs["y"] + + inputs = [ + dict(x=torch.randn((5, 6)), y=torch.randn((1, 6))), + dict(x=torch.randn((7, 7)), y=torch.randn((1, 7))), + dict(x=torch.randn((7, 8)), y=torch.randn((1, 8))), + dict(x=torch.randn((7, 9)), y=torch.randn((1, 9))), + ] + + model = Model() + expected = [model(**kwargs) for kwargs in inputs] + observer = InputObserver() + with observer(model): + for kwargs in inputs: + model(**kwargs) + self.assertEqual(len(observer.info), 3) + for i in range(3): + self.assertEqual(len(observer.info.flat_outputs[i]), 1) + torch.testing.assert_close(expected[i], observer.info.flat_outputs[i][0]) + + cst = torch.export.Dim.DYNAMIC + ds = observer.infer_dynamic_shapes() + self.assertEqual(dict(x={0: cst, 1: cst}, kwargs=dict(y={1: cst})), ds) + args = observer.infer_arguments() + self.assertIsInstance(args, dict) + self.assertEqual(2, len(args)) + self.assertEqual(["x", "y"], list(args)) + + dynamic_shapes = torch.export.AdditionalInputs() + for kwargs in inputs: + dynamic_shapes.add((), kwargs) + dss = dynamic_shapes.dynamic_shapes(model, (), inputs[0]) + self.assertEqual({"x": (cst, cst), "kwargs": {"y": (None, cst)}}, dss) + + # _get_range_constraints + torch.export.export( + model, + (), + kwargs=args, + dynamic_shapes={"x": {0: cst, 1: cst}, "kwargs": {"y": {1: cst}}}, + ) + torch.export.export(model, (), kwargs=args, dynamic_shapes=ds) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index e642f623..f4d8a3f4 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -301,6 +301,7 @@ class InputObserverInfo: is added (such as `past_key_values`). The values are only to infer dynamic shapes and arguments, not to run the model. + kwargs_name: Name of parameter **kwargs if it exists. """ def __init__( @@ -308,6 +309,7 @@ def __init__( signature_names: list[str], default_values: dict[str, int | bool | str | float], missing: dict[str, Any], + kwargs_name: str | None, ): self.default_values = default_values self.missing = missing @@ -315,6 +317,7 @@ def __init__( self.outputs_specs: list[torch.utils._pytree.PyTreeSpec] = [] self.flat_outputs: list[list[torch.Tensor | None]] = [] self.latencies: list[float] = [] + self.kwargs_name = kwargs_name self.signature_names = signature_names self._best_candidate: InputCandidate | None = None self._captured_inputs: dict[int | str, int] | None = None @@ -492,15 +495,21 @@ def _set_batch_dimension_for_flat_index(index): if not self._best_candidate.args: # only named arguments ds = dict(zip(list(self._best_candidate.kwargs), flat_dynamic_shapes)) - return {**ds, **dict.fromkeys(self._best_candidate.cst_kwargs, None)} + return self._post_process_for_kwargs( + {**ds, **dict.fromkeys(self._best_candidate.cst_kwargs, None)} + ) # positional arguments needs to be moved to the named arguments n_args = len(self._best_candidate.args) pos_names = self.signature_names[:n_args] - return { - **dict(zip(pos_names, flat_dynamic_shapes[:n_args])), - **dict(zip(list(self._best_candidate.kwargs), flat_dynamic_shapes[n_args:])), - **dict.fromkeys(self._best_candidate.cst_kwargs, None), - } + return self._post_process_for_kwargs( + { + **dict(zip(pos_names, flat_dynamic_shapes[:n_args])), + **dict( + zip(list(self._best_candidate.kwargs), flat_dynamic_shapes[n_args:]) + ), + **dict.fromkeys(self._best_candidate.cst_kwargs, None), + } + ) # nested types, here comes the fun part because the shapes cannot be unflattened, # custom classes must appear in their flattened shape. @@ -540,9 +549,9 @@ def change_function(t): if not ds_kwargs: return tuple(ds_args) if not ds_args: - return ds_kwargs + return self._post_process_for_kwargs(ds_kwargs) pos_names = self.signature_names[: len(ds_args)] - return {**dict(zip(pos_names, ds_args)), **ds_kwargs} + return self._post_process_for_kwargs({**dict(zip(pos_names, ds_args)), **ds_kwargs}) def infer_arguments( self, index_or_candidate: InputCandidate | int | None = None, flat: bool = False @@ -651,6 +660,21 @@ def infer_arguments( pos_names = self.signature_names[: len(args)] return {**dict(zip(pos_names, args)), **kwargs} + def _post_process_for_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: + """ + :func:`torch.export.export` requires to have dynamic shapes and keyword arguments + wrapped into `'kwargs': { 'param': shape or tensor }` if 'param' is not part + of the signature but is caught through `**kwargs`. + This function ensures this is the case. + """ + if not self.kwargs_name: + # Nothing to do here. + return kwargs + to_be_moved = {k for k in kwargs if k not in self.signature_names} + keywords = {k: v for k, v in kwargs.items() if k in to_be_moved} + new_kwargs = {k: v for k, v in kwargs.items() if k not in to_be_moved} + return {**new_kwargs, self.kwargs_name: keywords} + class InputObserver: """Steals forward method to collect inputs and outputs. @@ -745,6 +769,21 @@ def __call__( captured_method = getattr(model, method_name) sig = inspect.signature(captured_method) if self.info is None: + kwargs_names = [ + p + for p in sig.parameters + if sig.parameters[p].kind == inspect.Parameter.VAR_KEYWORD + ] + args_names = [ + p + for p in sig.parameters + if sig.parameters[p].kind == inspect.Parameter.VAR_POSITIONAL + ] + if args_names: + raise RuntimeError( + f"Inference is not implemented " + f"when the signature includes '*{args_names[0]}'." + ) self.info = InputObserverInfo( signature_names=list(sig.parameters), default_values={ @@ -754,6 +793,7 @@ def __call__( and isinstance(p.default, (int, bool, str, float)) }, missing=self.missing, + kwargs_name=kwargs_names[0] if kwargs_names else None, ) n_already_stored = len(self.info) lambda_method = lambda *args, _cm=captured_method, _snc=( # noqa: E731 From cf03499728338f0dd8086aafb5816c8bdeb05d50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 10 Feb 2026 12:13:20 +0100 Subject: [PATCH 03/10] add example --- _doc/technical/plot_histc.py | 215 +++++++++++++++++++++++++++++++++++ 1 file changed, 215 insertions(+) create mode 100644 _doc/technical/plot_histc.py diff --git a/_doc/technical/plot_histc.py b/_doc/technical/plot_histc.py new file mode 100644 index 00000000..f04d0349 --- /dev/null +++ b/_doc/technical/plot_histc.py @@ -0,0 +1,215 @@ +""" +.. _l-plot-histc: + +================================ +Converting torch.histc into ONNX +================================ + +:func:`torch.histc` computes an histogram of a tensor, +it counts the number of elements falling into each bin. +There are many options do to this. If the number of bins +is not too high, we can use something based on braodcasting. +This method implies the creation of a matrix :math:`N \times B` +where *N* is the number of elements in a tensor and *B* the number +if bins. To avoid this, the best way is to use a tree. +Before doing that, let's first study :func:`torch.histc`. +See `issue 174668 `_. + +float32 and float16 +=================== +""" + +import matplotlib.pyplot as plt +import torch + + +def create_input(dtype, hmin, hmax): + inf = torch.tensor(torch.inf, dtype=torch.float16) + buffer = torch.tensor([hmin], dtype=torch.float16) + res = [] + while buffer[0] <= hmax: + buffer = torch.nextafter(buffer, inf) + res.append(buffer[0]) + return torch.tensor(res, dtype=dtype) + + +hbins, hmin, hmax = 20, -5, 5 +dtype = torch.float16 +tensor = create_input(dtype, hmin, hmax) +print(f"There are {tensor.shape} elements in [{hmin}, {hmax}] of type {torch.float16}).") + +# %% +# histc + +hist = torch.histc(tensor, hbins, hmin, hmax) +print(f"{hist=}") + +# %% +# We can see there are more elements in the center. + + +def torch_histc_equivalent(tensor, bins, fmin, fmax, thresholds=None): + # thresholds + if thresholds is None: + delta = (float(fmax) - float(fmin)) / float(bins) + inf = torch.tensor(torch.inf, dtype=tensor.dtype) + delta = torch.tensor(delta, dtype=tensor.dtype) + min = torch.tensor(fmin, dtype=tensor.dtype) + max = torch.tensor(fmax, dtype=tensor.dtype) + bins = int(bins) + thresholds = torch.zeros((bins + 1,), dtype=tensor.dtype) + halfway = bins + 1 - (bins + 1) // 2 + for i in range(halfway): + thresholds[i] = min + delta * i + for i in range(halfway, bins + 1): + thresholds[i] = max - delta * (bins - i) + thresholds[-1] = torch.nextafter(thresholds[-1], inf) + + # computation + value = thresholds.unsqueeze(1) < tensor.reshape((-1,)).unsqueeze(0) + value = value.sum(dim=1).squeeze() + res = value[:-1] - value[1:] + res = res.to(torch.float16) + return res + + +hist_equiv = torch_histc_equivalent(tensor, hbins, hmin, hmax) +print(f"{hist_equiv=}") +print(f"delta={(hist_equiv - hist).to(int)}") + +# %% + +diff = torch.abs(hist_equiv - hist).sum() +print(f"sum of differences {diff} with {dtype=}.") + +# %% +# This is not really satisfactory. +# Let's check with float32. + +hist32 = torch.histc(tensor.to(torch.float32), hbins, hmin, hmax) +hist32_equiv = torch_histc_equivalent(tensor.to(torch.float32), hbins, hmin, hmax) +diff32 = hist32_equiv - hist32 +print(f"{diff32.abs().sum()} are misplaced: {diff32=}.") + +# %% +# Is histc an increasing function? +# ++++++++++++++++++++++++++++++++ + +histc_index = torch.empty(tensor.shape, dtype=torch.float64) +buffer = torch.empty((1,), dtype=tensor.dtype) +for i in range(tensor.shape[0]): + buffer[0] = tensor[i] + histc_value = torch.histc(buffer, hbins, hmin, hmax) + histc_index[i] = ( + histc_value.argmax() if histc_value.max().item() > 0 else histc_index.max() + ) + + +fig, ax = plt.subplots(1, 1) +ax.plot(list(range(tensor.shape[0])), histc_index.tolist(), "-", label="histc_index") +ax.legend() +fig.savefig("plot_histc_index.png") +ax + +# %% +# It seems growing. Let's check. + + +diff = histc_index[1:] - histc_index[:-1] +print(f"min={diff.min()}, max={diff.max()}") + +# %% +# It is so we can find threshold working with the implementation we made. +# +# Better thresholds +# ================= + + +def tune_threshold_histc( + dtype: torch.dtype, hbin: int, hmin: float, hmax: float +) -> torch.Tensor: + possible_values = create_input(dtype, hmin, hmax) + buffer = torch.empty((1,), dtype=tensor.dtype) + previous_index = None + thresholds = [] + for i in range(tensor.shape[0]): + buffer[0] = tensor[i] + histc_value = torch.histc(buffer, hbins, hmin, hmax) + if histc_value.max().item() > 0: + index = histc_value.argmax() + if previous_index is None or index != previous_index: + previous_index = index + thresholds.append(possible_values[i]) + + thresholds.append( + torch.nextafter(torch.tensor(hmax, dtype=dtype), torch.tensor(torch.inf, dtype=dtype)) + ) + return torch.tensor(thresholds, dtype=tensor.dtype) + + +thresholds = tune_threshold_histc(torch.float16, hbins, hmin, hmax) +print(f"shape={thresholds.shape}: {thresholds=}") + +# %% +# Let's check it is working. + +hist_equiv = torch_histc_equivalent(tensor, hbins, hmin, hmax, thresholds=thresholds) +print(f"{hist_equiv=}") +print(f"delta={(hist_equiv - hist).to(int)}") +diff = torch.abs(hist_equiv - hist).sum() +print(f"sum of differences {diff} with {dtype=}.") + +# %% +# That's not really working. +# Let's do another verification. +# We first start again by comparing the number of differences between +# histograms for the the whole tensor. + +histc_value = torch.histc(tensor, hbins, hmin, hmax) +histc_equiv = torch_histc_equivalent(tensor, hbins, hmin, hmax, thresholds=thresholds) +diff = (histc_value - histc_equiv).abs() +print(f"with {tensor.shape[0]} elements, there {diff.sum()} differences.") + +# %% +# We now take the elements with an even position. + + +histc_value = torch.histc(tensor[::2], hbins, hmin, hmax) +histc_equiv = torch_histc_equivalent(tensor[::2], hbins, hmin, hmax, thresholds=thresholds) +diff = (histc_value - histc_equiv).abs() +print( + f"with {tensor[::2].shape[0]} elements at even position, there {diff.sum()} differences." +) + + +# %% +# We now take the elements with an odd position. + +histc_value = torch.histc(tensor[1::2], hbins, hmin, hmax) +histc_equiv = torch_histc_equivalent(tensor[1::2], hbins, hmin, hmax, thresholds=thresholds) +diff = (histc_value - histc_equiv).abs() +print( + f"with {tensor[1::2].shape[0]} elements at odd position, there {diff.sum()} differences." +) + +# %% +# This does not add up. Let's proove now :func:`torch.histc` is really confusing. +# The following sum should be null but it is not. + +diff = torch.histc(tensor, hbins, hmin, hmax) - ( + torch.histc(tensor[::2], hbins, hmin, hmax) + torch.histc(tensor[1::2], hbins, hmin, hmax) +) +print(f"torch.histc: {tensor.dtype=}, number of differences: {diff.abs().sum()}: {diff}") + + +# %% +# This does not add up. Our implementation is more reliable. + +diff = torch_histc_equivalent(tensor, hbins, hmin, hmax) - ( + torch_histc_equivalent(tensor[::2], hbins, hmin, hmax) + + torch_histc_equivalent(tensor[1::2], hbins, hmin, hmax) +) +print( + f"torch_histc_equivalent: {tensor.dtype=}, " + f"number of differences: {diff.abs().sum()}: {diff}" +) From 8b2303d1d3a646dc7c251668784fcd3e4e15e143 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 10 Feb 2026 15:15:21 +0100 Subject: [PATCH 04/10] more improvments --- ...xport_tiny_llm_attention_input_observer.py | 3 +- _unittests/ut_export/test_dynamic_shapes.py | 15 ++++++++ .../ut_investigate/test_input_observer.py | 16 +++++---- .../test_patch_torch.py | 24 +++++++++++++ onnx_diagnostic/export/dynamic_shapes.py | 19 ++++++---- .../onnx_export_errors.py | 23 +++++++++--- .../patches/_patch_transformers_attention.py | 3 ++ .../patches/patch_torch.py | 35 ++++++++++++++++++- 8 files changed, 118 insertions(+), 20 deletions(-) diff --git a/_doc/final/plot_export_tiny_llm_attention_input_observer.py b/_doc/final/plot_export_tiny_llm_attention_input_observer.py index 5c8eae21..c12c32d3 100644 --- a/_doc/final/plot_export_tiny_llm_attention_input_observer.py +++ b/_doc/final/plot_export_tiny_llm_attention_input_observer.py @@ -100,6 +100,7 @@ def generate_text( kwargs = observer.infer_arguments() dynamic_shapes = observer.infer_dynamic_shapes() +print("attention type:", type(export_module)) print(f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}") print(f"dynamic_shapes={dynamic_shapes}") @@ -108,7 +109,7 @@ def generate_text( filename = "plot_export_tiny_llm_attention_input_observer.onnx" -with torch_export_patches(patch_transformers=True): +with torch_export_patches(patch_torch=True, patch_transformers=True): to_onnx( export_module, args=(), diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index e8b2a178..cf28fe95 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -1107,6 +1107,21 @@ def test_dynamic_batch_dynamic(self): ds, ) + def test_weird_case_kwargs_kwargs(self): + import torch + + ags = tuple() + kws = { + "x": torch.zeros((1, 2), dtype=torch.float32), + "y": torch.zeros((1, 2), dtype=torch.float32), + } + ds = {"x": {0: "batch"}, "kwargs": {"y": {0: "batch"}}} + + cpl = CoupleInputsDynamicShapes(ags, kws, ds) + backed_size_oblivious = cpl.invalid_dimensions_for_export() + self.assertTrue(backed_size_oblivious) + self.assertEqual({"x": {0: "batch"}, "kwargs": {"y": {0: "batch"}}}, ds) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_investigate/test_input_observer.py b/_unittests/ut_investigate/test_input_observer.py index 1ef0bfb5..3eacbdf1 100644 --- a/_unittests/ut_investigate/test_input_observer.py +++ b/_unittests/ut_investigate/test_input_observer.py @@ -8,6 +8,7 @@ _infer_dynamic_dimensions, ) from onnx_diagnostic.export.api import to_onnx +from onnx_diagnostic.torch_export_patches import torch_export_patches class TestInputObserver(ExtTestCase): @@ -976,13 +977,14 @@ def forward(self, x, **kwargs): self.assertEqual({"x": (cst, cst), "kwargs": {"y": (None, cst)}}, dss) # _get_range_constraints - torch.export.export( - model, - (), - kwargs=args, - dynamic_shapes={"x": {0: cst, 1: cst}, "kwargs": {"y": {1: cst}}}, - ) - torch.export.export(model, (), kwargs=args, dynamic_shapes=ds) + with torch_export_patches(patch_torch=True): + torch.export.export( + model, + (), + kwargs=args, + dynamic_shapes={"x": {0: cst, 1: cst}, "kwargs": {"y": {1: cst}}}, + ) + torch.export.export(model, (), kwargs=args, dynamic_shapes=ds) if __name__ == "__main__": diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index e1427c52..87838bf8 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -664,6 +664,30 @@ def forward(self, patch_attention_mask, position_ids, boundaries): got = ep.module()(*inputs) self.assertEqualArray(expected, got) + def test_mixed_named_and_unnamed_kwargs(self): + # see https://github.com/pytorch/pytorch/pull/174593 + class Model(torch.nn.Module): + def forward(self, x, **kwargs): + return x + kwargs["y"] + + kwargs = dict(x=torch.randn((5, 6)), y=torch.randn((1, 6))) + model = Model() + expected = model(**kwargs) + with torch_export_patches(patch_torch=True): + ep = torch.export.export( + model, + (), + kwargs=kwargs, + dynamic_shapes={ + "x": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC}, + "kwargs": {"y": {1: torch.export.Dim.DYNAMIC}}, + }, + ) + # ep.module()(**kwargs): raises NameError: name 'L' is not defined + self.assertEqualArray(kwargs["x"] + kwargs["y"], expected) + inputs = [n.name for n in ep.graph.nodes if n.op == "placeholder"] + self.assertEqual(["x", "y"], inputs) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/export/dynamic_shapes.py b/onnx_diagnostic/export/dynamic_shapes.py index f02614a6..da0847fd 100644 --- a/onnx_diagnostic/export/dynamic_shapes.py +++ b/onnx_diagnostic/export/dynamic_shapes.py @@ -44,6 +44,10 @@ class CoupleInputsDynamicShapes: dynamic shapes must be a dictionary, and positional must be added to the named arguments. Arguments names or a module must be given in that case. + + .. note:: + If the parameters ``**kwargs`` is not named, ``**kwargs``, + this could raise exceptions. """ def __init__( @@ -361,18 +365,21 @@ def _generic_walker_step( f"not_in_ds={not_in_ds}, not_in_inputs={not_in_inputs}" ) # Tweak... - kws = ds["kwargs"] - del ds["kwargs"] - ds.update(kws) + keys_inputs = set(inputs) + keys_ds = {k for k in ds if k != "kwargs"} | set(ds["kwargs"]) + else: + keys_inputs = set(inputs) + keys_ds = set(ds) - assert set(inputs) == set(ds), ( - f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, " + assert keys_inputs == keys_ds, ( + f"Keys mismatch between inputs {keys_inputs} and ds={keys_ds}, " f"inputs={string_type(inputs, with_shape=True)}, ds={ds}" ) dvalue = {} for k, v in inputs.items(): + s = ds[k] if k in ds else ds["kwargs"][k] t = cls._generic_walker_step( - processor, v, ds[k], flatten_unflatten=flatten_unflatten + processor, v, s, flatten_unflatten=flatten_unflatten ) if t is not None: dvalue[k] = t diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 30fa9727..9f555771 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -227,14 +227,15 @@ def _patch_torch( import torch._export.non_strict_utils # produce_guards_and_solve_constraints from torch.fx.experimental.symbolic_shapes import ShapeEnv from .patches.patch_torch import ( + _catch_produce_guards_and_solve_constraints, patched_infer_size, patched_vmap, - patched__broadcast_shapes, - patched__constrain_user_specified_dimhint_range, - _catch_produce_guards_and_solve_constraints, - patch__check_input_constraints_for_graph, patched__broadcast_in_dim_meta, patched__broadcast_in_dim_meta_level_2, + patched__broadcast_shapes, + patched__check_input_constraints_for_graph, + patched__constrain_user_specified_dimhint_range, + patched__get_range_constraints, patched__maybe_broadcast, patched_ShapeEnv, ) @@ -259,6 +260,7 @@ def _patch_torch( f_shape_env__log_guard = None f_shape_env__set_replacement = None f_vmap = None + f__get_range_constraints = None if verbose: print(f"[torch_export_patches] torch.__version__={torch.__version__!r}") @@ -294,6 +296,12 @@ def _patch_torch( if patch_details: patch_details.append("torch", f_infer_size, patched_infer_size) + # torch.export._trace._get_range_constraints + f__get_range_constraints = torch.export._trace._get_range_constraints + torch.export._trace._get_range_constraints = patched__get_range_constraints + if patch_details: + patch_details.append("torch", f__get_range_constraints, patched__get_range_constraints) + # torch._refs._broadcast_shapes f__broadcast_shapes = torch._refs._broadcast_shapes torch._refs._broadcast_shapes = patched__broadcast_shapes @@ -358,7 +366,7 @@ def _patch_torch( ) ) torch._export.utils._check_input_constraints_for_graph = ( - lambda *args, **kwargs: patch__check_input_constraints_for_graph( + lambda *args, **kwargs: patched__check_input_constraints_for_graph( f__check_input_constraints_for_graph, *args, verbose=verbose, **kwargs ) ) @@ -410,6 +418,7 @@ def _patch_torch( f_shape_env__set_replacement, f_vmap, f__print_symbol, + f__get_range_constraints, ) @@ -435,6 +444,7 @@ def _unpatch_torch( f_shape_env__set_replacement: Optional[Callable], f_vmap: Optional[Callable], f__print_symbol: Optional[Callable], + f__get_range_constraints: Optional[Callable], ): import torch import torch.jit @@ -460,6 +470,7 @@ def _unpatch_torch( torch._prims.broadcast_in_dim = f_broadcast_in_dim torch._refs._maybe_broadcast = f__maybe_broadcast ShapeEnv._evaluate_expr = f_shape_env__evaluate_expr + torch.export._trace._get_range_constraints = f__get_range_constraints if verbose: print("[torch_export_patches] restored pytorch functions") @@ -1035,6 +1046,7 @@ def torch_export_patches( f_shape_env__set_replacement, f_vmap, f__print_Symbol, + f__get_range_constraints, ) = _patch_torch( verbose, patch_details, patch_torch, catch_constraints, stop_if_static ) @@ -1111,6 +1123,7 @@ def torch_export_patches( f_shape_env__set_replacement, f_vmap, f__print_Symbol, + f__get_range_constraints, ) if patch_transformers: diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py index 98cdabde..316a6dba 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py @@ -160,6 +160,9 @@ def patched_sdpa_attention_forward( # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool # is_causal=torch.tensor(query.shape[2] > 1) # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() attn_output = torch.cond( query.shape[2] > 1, # distinction between prefill and decoding steps lambda query, key, value: torch.nn.functional.scaled_dot_product_attention( diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index 30bb8521..8f0fbf75 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -7,6 +7,7 @@ from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union import sympy import torch +import torch.export._trace from torch._subclasses.fake_tensor import FakeTensorMode @@ -61,7 +62,7 @@ def _catch_produce_guards_and_solve_constraints( torch._dynamo.reset() -def patch__check_input_constraints_for_graph( +def patched__check_input_constraints_for_graph( previous_function: Callable, input_placeholders: list[torch.fx.Node], flat_args_with_path, @@ -133,6 +134,38 @@ def patched_infer_size(a, b): return tuple(expandedSizes) +def patched__get_range_constraints( + mod: torch.nn.Module, + export_artifact: torch.export._trace.ExportArtifact, + args, + kwargs, + dynamic_shapes, +): + """ + Patches ``torch.export._trace._get_range_constraints``. + See PR `#174593 whttps://github.com/pytorch/pytorch/pull/174593>`_. + """ + gm: torch.fx.GraphModule = export_artifact.aten.gm + export_graph_signature: torch.export.graph_signature.ExportGraphSignature = ( + export_artifact.aten.sig + ) + fake_mode: FakeTensorMode = export_artifact.fake_mode + num_lifted = next( + ( + i + for i, s in enumerate(export_graph_signature.input_specs) + if s.kind == torch.export.graph_signature.InputKind.USER_INPUT + ), + len(export_graph_signature.input_specs), + ) + combined_args = torch.export._trace._combine_args(mod, args, kwargs) + + range_constraints = torch._export.non_strict_utils.make_constraints( + fake_mode, gm, combined_args, dynamic_shapes, num_lifted + ) + return range_constraints + + def patched__broadcast_shapes(*_shapes): """Patches ``torch._refs._broadcast_shapes``.""" from functools import reduce From 460db5c12c2032cf6fe6e2339caedec69e68ef95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 10 Feb 2026 16:07:50 +0100 Subject: [PATCH 05/10] doc --- _doc/conf.py | 1 + ...xport_tiny_llm_attention_input_observer.py | 4 +++- _doc/technical/plot_histc.py | 2 +- onnx_diagnostic/helpers/dot_helper.py | 2 ++ onnx_diagnostic/helpers/helper.py | 9 +++++++++ onnx_diagnostic/investigate/input_observer.py | 19 ++++++++++++++----- .../patches/patch_torch.py | 2 +- 7 files changed, 31 insertions(+), 8 deletions(-) diff --git a/_doc/conf.py b/_doc/conf.py index 2fddac5e..86aa5461 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -141,6 +141,7 @@ def linkcode_resolve(domain, info): ("py:class", "torch.fx.proxy.TracerBase"), ("py:class", "torch.FloatTensor"), ("py:class", "torch.LongTensor"), + ("py:class", "torch.export._trace.ExportArtifact"), ("py:class", "torch.utils._pytree.Context"), ("py:class", "torch.utils._pytree.KeyEntry"), ("py:class", "torch.utils._pytree.TreeSpec"), diff --git a/_doc/final/plot_export_tiny_llm_attention_input_observer.py b/_doc/final/plot_export_tiny_llm_attention_input_observer.py index c12c32d3..68958c17 100644 --- a/_doc/final/plot_export_tiny_llm_attention_input_observer.py +++ b/_doc/final/plot_export_tiny_llm_attention_input_observer.py @@ -122,7 +122,9 @@ def generate_text( # %% # Let's measure the discrepancies. -data = observer.check_discrepancies(filename, progress_bar=True, atol=1e-2, include_io=True) +data = observer.check_discrepancies( + filename, progress_bar=True, atol=1e-2, include_io=True, skip_none=True +) df = pandas.DataFrame(data) df.to_excel("plot_export_tiny_llm_attention_input_observer.xlsx") print(df) diff --git a/_doc/technical/plot_histc.py b/_doc/technical/plot_histc.py index f04d0349..09791d05 100644 --- a/_doc/technical/plot_histc.py +++ b/_doc/technical/plot_histc.py @@ -193,7 +193,7 @@ def tune_threshold_histc( ) # %% -# This does not add up. Let's proove now :func:`torch.histc` is really confusing. +# This does not add up. Let's prove now :func:`torch.histc` is really confusing. # The following sum should be null but it is not. diff = torch.histc(tensor, hbins, hmin, hmax) - ( diff --git a/onnx_diagnostic/helpers/dot_helper.py b/onnx_diagnostic/helpers/dot_helper.py index 09d16f29..00c2634c 100644 --- a/onnx_diagnostic/helpers/dot_helper.py +++ b/onnx_diagnostic/helpers/dot_helper.py @@ -203,6 +203,8 @@ def _mkn(obj: object) -> int: if att.type == onnx.AttributeProto.GRAPH: unique |= get_hidden_inputs(att.g) for i in unique: + if i in tiny_inits: + continue edge = name_to_ids[i], _mkn(node) # type: ignore[assignment] if edge in done: continue diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index 69229bc2..334db3ca 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -1071,6 +1071,7 @@ def max_diff( _index: int = 0, allow_unique_tensor_with_list_of_one_element: bool = True, hist: Optional[Union[bool, List[float]]] = None, + skip_none: bool = False, ) -> Dict[str, Union[float, int, Tuple[Any, ...]]]: """ Returns the maximum discrepancy. @@ -1087,6 +1088,7 @@ def max_diff( :param allow_unique_tensor_with_list_of_one_element: allow a comparison between a single tensor and a list of one tensor :param hist: compute an histogram of the discrepancies + :param skip_none: skips none value :return: dictionary with many values * abs: max absolute error @@ -1112,6 +1114,7 @@ def max_diff( end=end, _index=_index, hist=hist, + skip_none=skip_none, ) _dkws = {**_dkws_, "flatten": flatten} _dkwsf = {**_dkws_, "flatten": False} @@ -1129,6 +1132,7 @@ def max_diff( debug_info=debug_info, allow_unique_tensor_with_list_of_one_element=False, hist=hist, + skip_none=skip_none, ) return max_diff( expected, @@ -1142,6 +1146,7 @@ def max_diff( _index=_index, allow_unique_tensor_with_list_of_one_element=False, hist=hist, + skip_none=skip_none, ) if expected.__class__.__name__ == "CausalLMOutputWithPast": @@ -1269,6 +1274,7 @@ def max_diff( _index=_index + ip, flatten=flatten, hist=hist, + skip_none=skip_none, ) am = max(am, d["abs"]) dn = max(dn, d["dnan"]) @@ -1793,6 +1799,9 @@ def max_diff( **_dkws, ) + if skip_none and (expected is None or got is None): + return {"abs": 0, "rel": 0, "dnan": 0, "n": 0, "sum": 0} + raise AssertionError( f"Not implemented with implemented with expected=" f"{string_type(expected)} ({type(expected)}), got={string_type(got)},\n" diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index f4d8a3f4..f8e1f8b8 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -286,22 +286,28 @@ class InputObserverInfo: and the arguments to send to :func:`torch.export.export`. Args: - signature_names: Names of the arguments of the method + signature_names: + Names of the arguments of the method the collector tensors come from. They are used if it becomes necessary to move positional arguments to named ones. They are used a second time because :func:`torch.export.export` cares about the order in kwargs and dynamic shapes, it needs to be the same in the ordered dictionaries `add_inputs` receive. - default_values: Default values defined by the signature of the function, + default_values: + Default values defined by the signature of the function, any value equal to that is ignore to simplify the export. - missing: If a named argument (in kwargs) is missing, + missing: + If a named argument (in kwargs) is missing, a default value will be taken in this dictionary, this is used when after the prefill step, an argument disappears (such as `pixel_values`) and another one is added (such as `past_key_values`). The values are only to infer dynamic shapes and arguments, not to run the model. - kwargs_name: Name of parameter **kwargs if it exists. + kwargs_name: + Name of parameter `**kwargs` if it exists. + + This is used by class :class:`InputObserver`. """ def __init__( @@ -910,6 +916,7 @@ def check_discrepancies( hist=(0.1, 0.01), progress_bar: bool = False, include_io: bool = True, + skip_none: bool = True, ) -> list[dict[str, str | int | float | bool]]: """Computes the discrepancies between the saved inputs and outputs with the saved onnx model. @@ -929,6 +936,8 @@ def check_discrepancies( include_io: Shows inputs/outputs shapes in the summary returned by this function. + skip_none: + Dooes not check discrepancies when an output is None. Returns: A list of dictionaries, ready to be consumed by a dataframe. @@ -982,7 +991,7 @@ def check_discrepancies( if isinstance(outputs, list) and isinstance(ort_outputs, list): while len(ort_outputs) > len(outputs) and ort_outputs[-1].numel() == 0: ort_outputs.pop() - diff = max_diff(outputs, ort_outputs, hist=lhist) # type: ignore[assignment] + diff = max_diff(outputs, ort_outputs, hist=lhist, skip_none=skip_none) # type: ignore[assignment] if "rep" in diff and isinstance(diff["rep"], dict): diff.update(diff["rep"]) del diff["rep"] diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index 8f0fbf75..eefa5265 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -143,7 +143,7 @@ def patched__get_range_constraints( ): """ Patches ``torch.export._trace._get_range_constraints``. - See PR `#174593 whttps://github.com/pytorch/pytorch/pull/174593>`_. + See PR `#174593 `_. """ gm: torch.fx.GraphModule = export_artifact.aten.gm export_graph_signature: torch.export.graph_signature.ExportGraphSignature = ( From f63ae3909367510b20abd9b6e8d298e5101dc613 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 10 Feb 2026 16:08:29 +0100 Subject: [PATCH 06/10] fix --- CHANGELOGS.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 789673d2..83179caa 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.9.1 +++++ +* :pr:`410`: add patch for `_get_range_constraints` * :pr:`409`: improves ModelBuilder wrapper * :pr:`408`: fix torch_deepcopy for empty DynamicCache and transformers==5.1.0, 5.2.0 (see https://github.com/huggingface/transformers/pull/43765/) From 0aa4ce151e4bb15cba5d179df5d066f71decdce9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 10 Feb 2026 17:28:55 +0100 Subject: [PATCH 07/10] fix patch --- ...st_tasks_zero_shot_image_classification.py | 2 +- .../test_patch_transformers.py | 229 +++++++++--------- onnx_diagnostic/ext_test_case.py | 15 +- .../patches/patch_torch.py | 17 ++ 4 files changed, 150 insertions(+), 113 deletions(-) diff --git a/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py b/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py index ef5337c0..98a44835 100644 --- a/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py +++ b/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py @@ -17,7 +17,7 @@ def test_zero_shot_image_classification(self): model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] expected = model(**inputs) model(**data["inputs2"]) - with torch_export_patches(patch_transformers=True, verbose=10): + with torch_export_patches(patch_torch=True, patch_transformers=True, verbose=10): ep = torch.export.export( model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False ) diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 0651090e..06a2e388 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -677,157 +677,164 @@ def _get_seqlen(cls) -> torch.Tensor: @requires_cuda() def test_plug_multi_head_attention_qwen25_packed_float16(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( - qwen_sdpa_attention_packed_versatile, + qwen_sdpa_attention_versatile as qwen_sdpa_attention_packed_versatile, ) - inputs = ( - torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"), - torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"), - torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"), - self._get_seqlen().to("cuda"), - ) + with self.set_env("QWEN25ATTENTION", "PACKED"): + inputs = ( + torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"), + torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"), + torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"), + self._get_seqlen().to("cuda"), + ) - results = qwen_sdpa_attention_packed_versatile.verify( - *inputs, scaling=0.5, num_heads=16 - ) - self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) - self.assertEqual(len(results.eager_outputs), len(results.diffs)) - self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01) - self.assertLess(results.diffs[0]["abs"], 0.01) + results = qwen_sdpa_attention_packed_versatile.verify( + *inputs, scaling=0.5, num_heads=16 + ) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01) + self.assertLess(results.diffs[0]["abs"], 0.01) - results = qwen_sdpa_attention_packed_versatile.verify( - *inputs, scaling=0.11180339887498948, num_heads=16 - ) - self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) - self.assertEqual(len(results.eager_outputs), len(results.diffs)) - self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01) - self.assertLess(results.diffs[0]["abs"], 0.01) + results = qwen_sdpa_attention_packed_versatile.verify( + *inputs, scaling=0.11180339887498948, num_heads=16 + ) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01) + self.assertLess(results.diffs[0]["abs"], 0.01) @requires_onnxruntime("1.25") @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") def test_plug_multi_head_attention_qwen25_loopmha_float16(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( - qwen_sdpa_attention_loopmha_versatile, + qwen_sdpa_attention_versatile as qwen_sdpa_attention_loopmha_versatile, ) - inputs = ( - torch.rand((1, 16, 1292, 80), dtype=torch.float16), - torch.rand((1, 16, 1292, 80), dtype=torch.float16), - torch.rand((1, 16, 1292, 80), dtype=torch.float16), - self._get_seqlen(), - ) + with self.set_env("QWEN25ATTENTION", "LOOPMHA"): + inputs = ( + torch.rand((1, 16, 1292, 80), dtype=torch.float16), + torch.rand((1, 16, 1292, 80), dtype=torch.float16), + torch.rand((1, 16, 1292, 80), dtype=torch.float16), + self._get_seqlen(), + ) - results = qwen_sdpa_attention_loopmha_versatile.verify( - *inputs, - scaling=0.5, - num_heads=16, - dump_onnx_model=self.get_dump_file( - "test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx" - ), - ) - self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) - self.assertEqual(len(results.eager_outputs), len(results.diffs)) - self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01) - self.assertLess(results.diffs[0]["abs"], 0.01) + results = qwen_sdpa_attention_loopmha_versatile.verify( + *inputs, + scaling=0.5, + num_heads=16, + dump_onnx_model=self.get_dump_file( + "test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx" + ), + ) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01) + self.assertLess(results.diffs[0]["abs"], 0.01) - results = qwen_sdpa_attention_loopmha_versatile.verify( - *inputs, scaling=0.11180339887498948, num_heads=16 - ) - self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) - self.assertEqual(len(results.eager_outputs), len(results.diffs)) - self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01) - self.assertLess(results.diffs[0]["abs"], 0.01) + results = qwen_sdpa_attention_loopmha_versatile.verify( + *inputs, scaling=0.11180339887498948, num_heads=16 + ) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01) + self.assertLess(results.diffs[0]["abs"], 0.01) @requires_onnxruntime("1.25") @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") def test_plug_multi_head_attention_qwen25_loopmha_float32(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( - qwen_sdpa_attention_loopmha_versatile, + qwen_sdpa_attention_versatile as qwen_sdpa_attention_loopmha_versatile, ) - inputs = ( - torch.rand((1, 16, 1292, 80), dtype=torch.float32), - torch.rand((1, 16, 1292, 80), dtype=torch.float32), - torch.rand((1, 16, 1292, 80), dtype=torch.float32), - self._get_seqlen(), - ) + with self.set_env("QWEN25ATTENTION", "LOOPMHA"): + inputs = ( + torch.rand((1, 16, 1292, 80), dtype=torch.float32), + torch.rand((1, 16, 1292, 80), dtype=torch.float32), + torch.rand((1, 16, 1292, 80), dtype=torch.float32), + self._get_seqlen(), + ) - results = qwen_sdpa_attention_loopmha_versatile.verify( - *inputs, - scaling=0.5, - num_heads=16, - dump_onnx_model=self.get_dump_file( - "test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx" - ), - ) - self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) - self.assertEqual(len(results.eager_outputs), len(results.diffs)) - self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) - self.assertLess(results.diffs[0]["abs"], 1e-5) + results = qwen_sdpa_attention_loopmha_versatile.verify( + *inputs, + scaling=0.5, + num_heads=16, + dump_onnx_model=self.get_dump_file( + "test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx" + ), + ) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) + self.assertLess(results.diffs[0]["abs"], 1e-5) - results = qwen_sdpa_attention_loopmha_versatile.verify( - *inputs, scaling=0.11180339887498948, num_heads=16 - ) - self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) - self.assertEqual(len(results.eager_outputs), len(results.diffs)) - self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) - self.assertLess(results.diffs[0]["abs"], 1e-5) + results = qwen_sdpa_attention_loopmha_versatile.verify( + *inputs, scaling=0.11180339887498948, num_heads=16 + ) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) + self.assertLess(results.diffs[0]["abs"], 1e-5) @requires_onnxruntime("1.25") @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") def test_plug_multi_head_attention_qwen25_loopa24_float16(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( - qwen_sdpa_attention_loopa24_versatile, + qwen_sdpa_attention_versatile as qwen_sdpa_attention_loopa24_versatile, ) - inputs = ( - torch.rand((1, 16, 1292, 80), dtype=torch.float16), - torch.rand((1, 16, 1292, 80), dtype=torch.float16), - torch.rand((1, 16, 1292, 80), dtype=torch.float16), - self._get_seqlen(), - ) + with self.set_env("QWEN25ATTENTION", "LOOO24"): + inputs = ( + torch.rand((1, 16, 1292, 80), dtype=torch.float16), + torch.rand((1, 16, 1292, 80), dtype=torch.float16), + torch.rand((1, 16, 1292, 80), dtype=torch.float16), + self._get_seqlen(), + ) - results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5) - self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) - self.assertEqual(len(results.eager_outputs), len(results.diffs)) - self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-2) - self.assertLess(results.diffs[0]["abs"], 1e-2) + results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-2) + self.assertLess(results.diffs[0]["abs"], 1e-2) - results = qwen_sdpa_attention_loopa24_versatile.verify( - *inputs, scaling=0.11180339887498948 - ) - self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) - self.assertEqual(len(results.eager_outputs), len(results.diffs)) - self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.005) - self.assertLess(results.diffs[0]["abs"], 0.005) + results = qwen_sdpa_attention_loopa24_versatile.verify( + *inputs, scaling=0.11180339887498948 + ) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray( + results.eager_outputs[0], results.onnx_outputs[0], atol=0.005 + ) + self.assertLess(results.diffs[0]["abs"], 0.005) @requires_onnxruntime("1.25") @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") def test_plug_multi_head_attention_qwen25_loopa24_float32(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( - qwen_sdpa_attention_loopa24_versatile, + qwen_sdpa_attention_versatile as qwen_sdpa_attention_loopa24_versatile, ) - inputs = ( - torch.rand((1, 16, 1292, 80), dtype=torch.float32), - torch.rand((1, 16, 1292, 80), dtype=torch.float32), - torch.rand((1, 16, 1292, 80), dtype=torch.float32), - self._get_seqlen(), - ) + with self.set_env("QWEN25ATTENTION", "LOOO24"): + inputs = ( + torch.rand((1, 16, 1292, 80), dtype=torch.float32), + torch.rand((1, 16, 1292, 80), dtype=torch.float32), + torch.rand((1, 16, 1292, 80), dtype=torch.float32), + self._get_seqlen(), + ) - results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5) - self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) - self.assertEqual(len(results.eager_outputs), len(results.diffs)) - self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) - self.assertLess(results.diffs[0]["abs"], 1e-5) + results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) + self.assertLess(results.diffs[0]["abs"], 1e-5) - results = qwen_sdpa_attention_loopa24_versatile.verify( - *inputs, scaling=0.11180339887498948 - ) - self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) - self.assertEqual(len(results.eager_outputs), len(results.diffs)) - self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) - self.assertLess(results.diffs[0]["abs"], 1e-5) + results = qwen_sdpa_attention_loopa24_versatile.verify( + *inputs, scaling=0.11180339887498948 + ) + self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs)) + self.assertEqual(len(results.eager_outputs), len(results.diffs)) + self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) + self.assertLess(results.diffs[0]["abs"], 1e-5) @unittest.skipIf(not patch_funnel, "Funnel not part of this transformers") def test_model_funnel(self): diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index a4707098..2dfb95bb 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -13,7 +13,7 @@ import sys import unittest import warnings -from contextlib import redirect_stderr, redirect_stdout +from contextlib import redirect_stderr, redirect_stdout, contextmanager from io import StringIO from timeit import Timer from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -1465,3 +1465,16 @@ def subloop(self, *args, verbose: int = 0): if verbose: print(f"[subloop] it={it!r}") yield it + + @contextmanager + def set_env(self, varname: str, value: str): + """ + Sets environment variable `varname` to `value` + and sets it back. + """ + old_value = os.environ.get(varname, None) + os.environ[varname] = value + try: + yield + finally: + os.environ[varname] = old_value or "" diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index eefa5265..0be0026c 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -158,8 +158,25 @@ def patched__get_range_constraints( ), len(export_graph_signature.input_specs), ) + combined_args = torch.export._trace._combine_args(mod, args, kwargs) + # _combine_args does not preserve the order. + if isinstance(combined_args, dict): + input_names = [ + s.arg.name + for s in export_graph_signature.input_specs + if s.kind == torch.export.graph_signature.InputKind.USER_INPUT + ] + new_args = {} + for k in input_names: + if k in combined_args: + new_args[k] = combined_args[k] + for k in combined_args: + if k not in new_args: + new_args[k] = combined_args[k] + combined_args = new_args + range_constraints = torch._export.non_strict_utils.make_constraints( fake_mode, gm, combined_args, dynamic_shapes, num_lifted ) From 50350929d632e99c1a9d15a9d4df6f97910d1e07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 10 Feb 2026 17:35:21 +0100 Subject: [PATCH 08/10] fix --- .../ut_investigate/test_input_observer.py | 9 +++++- .../test_patch_transformers.py | 2 ++ .../patches/patch_torch.py | 30 ++++++++++--------- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/_unittests/ut_investigate/test_input_observer.py b/_unittests/ut_investigate/test_input_observer.py index 3eacbdf1..9c802321 100644 --- a/_unittests/ut_investigate/test_input_observer.py +++ b/_unittests/ut_investigate/test_input_observer.py @@ -2,7 +2,12 @@ import unittest import pandas import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + requires_torch, + hide_stdout, + ignore_warnings, +) from onnx_diagnostic.investigate.input_observer import ( InputObserver, _infer_dynamic_dimensions, @@ -816,6 +821,8 @@ def forward(self, x=None, y=None): self.assertEqual(2, len(args)) self.assertEqual(len([v for v in args.values() if v is not None]), 2) + @hide_stdout() + @ignore_warnings(FutureWarning) def test_io_int_kwargs(self): class Model(torch.nn.Module): def forward(self, x=None, y=None, option=1): diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 06a2e388..e909e23a 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -12,6 +12,7 @@ requires_torch, ignore_warnings, has_onnxscript, + requires_onnxscript, ) from onnx_diagnostic.helpers.torch_helper import torch_deepcopy, fake_torchdynamo_exporting from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions @@ -352,6 +353,7 @@ def forward(self, query, key, value): self.assertEqualArray(expected, got) @requires_transformers("4.55") + @requires_onnxscript("0.6.2") @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") def test_qwen_apply_multimodal_rotary_pos_emb(self): apply_multimodal_rotary_pos_emb = ( diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index 0be0026c..c7371d3b 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -162,20 +162,22 @@ def patched__get_range_constraints( combined_args = torch.export._trace._combine_args(mod, args, kwargs) # _combine_args does not preserve the order. - if isinstance(combined_args, dict): - input_names = [ - s.arg.name - for s in export_graph_signature.input_specs - if s.kind == torch.export.graph_signature.InputKind.USER_INPUT - ] - new_args = {} - for k in input_names: - if k in combined_args: - new_args[k] = combined_args[k] - for k in combined_args: - if k not in new_args: - new_args[k] = combined_args[k] - combined_args = new_args + assert isinstance( + combined_args, dict + ), f"unexpected type {type(combined_args)} for 'combined_args'" + input_names = [ + s.arg.name + for s in export_graph_signature.input_specs + if s.kind == torch.export.graph_signature.InputKind.USER_INPUT + ] + new_args = {} + for k in input_names: + if k in combined_args: + new_args[k] = combined_args[k] + for k in combined_args: + if k not in new_args: + new_args[k] = combined_args[k] + combined_args = new_args range_constraints = torch._export.non_strict_utils.make_constraints( fake_mode, gm, combined_args, dynamic_shapes, num_lifted From 02eb10c05f2219469df90f3f52699e96b1b5ae38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 10 Feb 2026 18:03:03 +0100 Subject: [PATCH 09/10] fix --- onnx_diagnostic/investigate/input_observer.py | 2 ++ .../torch_export_patches/patches/patch_torch.py | 7 +------ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index f8e1f8b8..34264e95 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -677,6 +677,8 @@ def _post_process_for_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: # Nothing to do here. return kwargs to_be_moved = {k for k in kwargs if k not in self.signature_names} + if not to_be_moved: + return kwargs keywords = {k: v for k, v in kwargs.items() if k in to_be_moved} new_kwargs = {k: v for k, v in kwargs.items() if k not in to_be_moved} return {**new_kwargs, self.kwargs_name: keywords} diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index c7371d3b..79c6a96c 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -165,13 +165,8 @@ def patched__get_range_constraints( assert isinstance( combined_args, dict ), f"unexpected type {type(combined_args)} for 'combined_args'" - input_names = [ - s.arg.name - for s in export_graph_signature.input_specs - if s.kind == torch.export.graph_signature.InputKind.USER_INPUT - ] new_args = {} - for k in input_names: + for k in kwargs: if k in combined_args: new_args[k] = combined_args[k] for k in combined_args: From 5906a9053d6b921dc429dd0f2fc18634cf4067da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 10 Feb 2026 18:06:54 +0100 Subject: [PATCH 10/10] fix again --- .../patches/patch_torch.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index 79c6a96c..4890e04e 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -161,18 +161,24 @@ def patched__get_range_constraints( combined_args = torch.export._trace._combine_args(mod, args, kwargs) - # _combine_args does not preserve the order. + # This is because we trace based on the kwargs passed in from user + # not based on the signature. I feel it would be better to just enforce + # one ordering at the start of tracing to avoid confusions, but that is + # bigger refactor, so do this to unblock for now. assert isinstance( combined_args, dict ), f"unexpected type {type(combined_args)} for 'combined_args'" - new_args = {} - for k in kwargs: - if k in combined_args: - new_args[k] = combined_args[k] - for k in combined_args: - if k not in new_args: - new_args[k] = combined_args[k] - combined_args = new_args + + combined_args_traced_order = {} + for arg in kwargs: + if arg in combined_args: + combined_args_traced_order[arg] = combined_args[arg] + + for key in combined_args: + if key not in combined_args_traced_order: + combined_args_traced_order[key] = combined_args[key] + + combined_args = combined_args_traced_order range_constraints = torch._export.non_strict_utils.make_constraints( fake_mode, gm, combined_args, dynamic_shapes, num_lifted