From 2556170c4c65c810b0bc739a8c4c51f4830b9048 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 11 Feb 2026 15:32:56 +0100 Subject: [PATCH 1/5] Patches for ViTModel --- _doc/patches.rst | 10 +-- _doc/status/patches_coverage.rst | 8 +- .../test_patch_module.py | 41 ++++++++++- .../test_patch_rewriting.py | 16 ++++ .../onnx_export_errors.py | 3 +- .../torch_export_patches/patch_module.py | 73 ++++++++++++++++++- .../patch_module_helper.py | 39 ++++++---- 7 files changed, 160 insertions(+), 30 deletions(-) diff --git a/_doc/patches.rst b/_doc/patches.rst index a397d11c..5e2b75b9 100644 --- a/_doc/patches.rst +++ b/_doc/patches.rst @@ -124,7 +124,7 @@ Here is the list of supported caches: .. _l-control-flow-rewriting: -Control flow rewriting +Control Flow Rewriting ====================== This is an attempt to automatically rewrite control flow using :mod:`ast`. @@ -217,17 +217,17 @@ The locations where it has to be done: import pprint from onnx_diagnostic.torch_export_patches.patch_module_helper import ( - known_transformers_rewritings_clamp_float16, + known_transformers_rewritings, ) - pprint.pprint(known_transformers_rewritings_clamp_float16()) + pprint.pprint(known_transformers_rewritings()) .. runpython:: :showcode: import pprint from onnx_diagnostic.torch_export_patches.patch_module_helper import ( - _rewrite_forward_clamp_float16, + _rewrite_forward, ) - pprint.pprint(_rewrite_forward_clamp_float16()) \ No newline at end of file + pprint.pprint(_rewrite_forward()) diff --git a/_doc/status/patches_coverage.rst b/_doc/status/patches_coverage.rst index c2cdc37f..9d48dbef 100644 --- a/_doc/status/patches_coverage.rst +++ b/_doc/status/patches_coverage.rst @@ -49,17 +49,17 @@ The number of fixes if much less than the number of classes to fix. import pprint from onnx_diagnostic.torch_export_patches.patch_module_helper import ( - known_transformers_rewritings_clamp_float16, + known_transformers_rewritings, ) - pprint.pprint(known_transformers_rewritings_clamp_float16()) + pprint.pprint(known_transformers_rewritings()) .. runpython:: :showcode: import pprint from onnx_diagnostic.torch_export_patches.patch_module_helper import ( - _rewrite_forward_clamp_float16, + _rewrite_forward, ) - pprint.pprint(_rewrite_forward_clamp_float16()) \ No newline at end of file + pprint.pprint(_rewrite_forward()) \ No newline at end of file diff --git a/_unittests/ut_torch_export_patches/test_patch_module.py b/_unittests/ut_torch_export_patches/test_patch_module.py index e38d6678..8b8c1ca4 100644 --- a/_unittests/ut_torch_export_patches/test_patch_module.py +++ b/_unittests/ut_torch_export_patches/test_patch_module.py @@ -529,7 +529,6 @@ def test__find_loop_vars(self): @requires_torch("2.8") def test_rewrite_loop(self): - class Model(torch.nn.Module): def forward(self, x, y): z = torch.empty((x.shape[0], y.shape[0])) @@ -715,6 +714,46 @@ def forward(self, x): got = ep.module()(x) self.assertEqualArray(expected, got) + def test_test_raise(self): + class Model(torch.nn.Module): + def forward(self, x, y): + if x.shape[0] != y.shape[0]: + raise ValueError(f"Wrong shape {x.shape=} and {y.shape=}") + return x + y + + x, y = torch.rand((3, 4)), torch.rand((3, 4)) + expected, expected_ = Model()(x, y), Model()(-x, y) + + rewritten = transform_method(Model.forward) + self.assertIn("torch._check(", rewritten.code) + Model.forward = rewritten.func + self.assertEqualAny(expected, Model()(x, y)) + self.assertEqualAny(expected_, Model()(-x, y)) + + DYN = torch.export.Dim.DYNAMIC + ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) + ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds) + self.assertEqualAny(expected, ep.module()(x, y)) + self.assertEqualAny(expected_, ep.module()(-x, y)) + + @hide_stdout() + def test_test_raise_rewrite(self): + class Model(torch.nn.Module): + def forward(self, x, y): + if x.shape[0] != y.shape[0]: + raise ValueError(f"Wrong shape {x.shape=} and {y.shape=}") + return x + y + + model = Model() + x, y = torch.rand((4, 5)), torch.rand((4, 5)) + expected = model(x, y) + DYN = torch.export.Dim.DYNAMIC + ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) + with torch_export_rewrite(rewrite=[(Model, "forward")], verbose=1): + ep = torch.export.export(model, (x, y), dynamic_shapes=ds) + got = ep.module()(x, y) + self.assertEqualArray(expected, got) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_rewriting.py b/_unittests/ut_torch_export_patches/test_patch_rewriting.py index 4573180d..81fbba11 100644 --- a/_unittests/ut_torch_export_patches/test_patch_rewriting.py +++ b/_unittests/ut_torch_export_patches/test_patch_rewriting.py @@ -4,6 +4,7 @@ rewrite_loop_for_square_mask, ) from onnx_diagnostic.torch_export_patches.patch_module_helper import code_needing_rewriting +from onnx_diagnostic.torch_export_patches.patch_module import transform_method class TestPatchRewriting(ExtTestCase): @@ -38,6 +39,21 @@ def test_code_needing_rewriting(self): res = code_needing_rewriting("BartModel") self.assertEqual(len(res), 2) + def test_code_needing_rewriting_vit_patch_embedding(self): + res = code_needing_rewriting("ViTPatchEmbeddings") + self.assertEqual(len(res), 1) + + def test_code_needing_rewriting_vit_class(self): + import transformers + + res = code_needing_rewriting(transformers.models.vit.modeling_vit.ViTModel) + self.assertEqual(len(res), 1) + + def test_rewriting_vit_patch_embedding(self): + import transformers + + transform_method(transformers.models.vit.modeling_vit.ViTPatchEmbeddings.forward) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 9f555771..f47bd068 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -899,7 +899,8 @@ def torch_export_patches( before being exported if the execution path depends on the inputs, this is done by function :func:`transform_method `, - its documentation provides possible values + its documentation provides possible values, if `rewrite==True`, then + all known methods to rewrite are added. :param dump_rewriting: dumps rewriting information in file beginning with that prefix, this only applied on the automated rewritings :param patch_details: if specified, this class is used to stored every applied rewriting. diff --git a/onnx_diagnostic/torch_export_patches/patch_module.py b/onnx_diagnostic/torch_export_patches/patch_module.py index d45bf2ef..82dd4cec 100644 --- a/onnx_diagnostic/torch_export_patches/patch_module.py +++ b/onnx_diagnostic/torch_export_patches/patch_module.py @@ -148,7 +148,14 @@ def visit_FunctionDef(self, node): # Capture argument names for branch functions old_args = self.current_func_args self.current_func_args = [arg.arg for arg in node.args.args] - node.body = [self.visit(n) for n in node.body] + new_body = [] + for n in node.body: + visited = self.visit(n) + if isinstance(visited, list): + new_body.extend(visited) + else: + new_body.append(visited) + node.body = new_body self.current_func_args = old_args return node @@ -332,6 +339,36 @@ def _make_targets(self, node, then_assigns, else_assigns): tgt = d[0] if len(d) == 1 else ast.Tuple(d, ctx=ast.Load()) return tgt, tgt_mapping + def _rewrite_if_raise(self, cond_node, raise_node, known_local_variables=None): + assert known_local_variables is not None, "known_local_variables cannot be None" + + expr = ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="torch", ctx=ast.Load()), + attr="_check", + ctx=ast.Load(), + ), + args=[ + ast.UnaryOp(op=ast.Not(), operand=cond_node), + ast.Lambda( + args=ast.arguments( + posonlyargs=[], + args=[], + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[], + ), + body=raise_node.exc.args[0], + ), + ], + keywords=[], + ) + ) + return expr + def visit_If(self, node): if not self.filter_node(node): return [node] @@ -357,6 +394,33 @@ def visit_If(self, node): self._check(self.current_func_args is not None, node, "current_func_args is None") self.counter_test += 1 + if not has_then_return and not node.orelse: + then_raise = [n for n in node.body if isinstance(n, ast.Raise)] + if then_raise: + self._check( + len(then_raise) == 1 == len(node.body), node, "More than a simple raise." + ) + check_node = self._rewrite_if_raise( + node.test, then_raise[0], known_local_variables=known_local_variables + ) + ast.copy_location(check_node, node) + ast.fix_missing_locations(check_node) + return [self.post_rewriter(check_node)] + if ( + len(node.body) == 1 + and isinstance(node.body[0], ast.Expr) + and isinstance(node.body[0].value, ast.Call) + and isinstance(node.body[0].value.func, ast.Attribute) + and isinstance(node.body[0].value.func.value, ast.Name) + and node.body[0].value.func.value.id == "torch" + and node.body[0].value.func.attr == "_check" + ): + # We assume there is nothing to do, + # and this was rewritting by _rewrite_if_raise. + # Or maybe we can include that into the check itself. + return [node] + # Otherwise it is sompething else. + if not has_then_return: # Case 1: simple assignment in both branches then_assigns = [n for n in node.body if isinstance(n, ast.Assign)] @@ -861,11 +925,14 @@ def forward(self, x, y): mod = compile(new_tree, filename="", mode="exec") except TypeError as e: if 'required field "lineno" missing from stmt' in str(e): - # Could not find a way to avoid compilng a string. + # Could not find a way to avoid compiling a string. # The error message still pops up without indicating which node is not # properly set. code = ast.unparse(new_tree) - mod = compile(code, filename="", mode="exec") + try: + mod = compile(code, filename="", mode="exec") + except IndentationError as ee: + raise RuntimeError(f"Unable to compile\n{code}") from ee else: kws = dict(include_attributes=True, annotate_fields=True, indent=4) raise RuntimeError( diff --git a/onnx_diagnostic/torch_export_patches/patch_module_helper.py b/onnx_diagnostic/torch_export_patches/patch_module_helper.py index fe8c8b14..417360b3 100644 --- a/onnx_diagnostic/torch_export_patches/patch_module_helper.py +++ b/onnx_diagnostic/torch_export_patches/patch_module_helper.py @@ -1,6 +1,6 @@ import ast import functools -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union class OrToBitOrTransformer(ast.NodeTransformer): @@ -21,8 +21,7 @@ def ast_or_into_bitor(node: "ast.Node") -> "ast.Node": @functools.lru_cache -def _rewrite_forward_clamp_float16() -> Dict[str, List[type]]: - +def _rewrite_forward() -> Dict[str, List[type]]: import transformers _known = { @@ -48,6 +47,9 @@ def _rewrite_forward_clamp_float16() -> Dict[str, List[type]]: "NllbMoeEncoderLayer": [ transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeEncoderLayer ], + "PatchEmbeddings": [ + transformers.models.vit.modeling_vit.ViTPatchEmbeddings, + ], "TimeSeriesTransformerEncoderLayer": [ transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerEncoderLayer ], @@ -56,11 +58,11 @@ def _rewrite_forward_clamp_float16() -> Dict[str, List[type]]: @functools.lru_cache -def known_transformers_rewritings_clamp_float16() -> Dict[str, str]: +def known_transformers_rewritings() -> Dict[str, str]: """ This functions returns the list of known classes to be rewritten. in :epkg:`transformers`. Each class is mapped to an alias, - this alias is then given to :func:`rewritings_transformers_clamp_float16` + this alias is then given to :func:`rewritings_transformers` to rewrite the encoder layers because of a specific control flow. .. runpython:: @@ -68,10 +70,10 @@ def known_transformers_rewritings_clamp_float16() -> Dict[str, str]: import pprint from onnx_diagnostic.torch_export_patches.patch_module_helper import ( - known_transformers_rewritings_clamp_float16, + known_transformers_rewritings, ) - pprint.pprint(known_transformers_rewritings_clamp_float16()) + pprint.pprint(known_transformers_rewritings()) """ _alias = { "AutoformerEncoder": "AutoformerEncoderLayer", @@ -108,11 +110,14 @@ def known_transformers_rewritings_clamp_float16() -> Dict[str, str]: "PLBartForConditionalGeneration": "BartEncoderLayer", "TimeSeriesTransformerEncoderLayer": "TimeSeriesTransformerEncoderLayer", "TimeSeriesTransformerForPrediction": "TimeSeriesTransformerEncoderLayer", + "ViTPatchEmbeddings": "PatchEmbeddings", + "ViTForImageClassification": "PatchEmbeddings", + "ViTModel": "PatchEmbeddings", } return _alias -def rewritings_transformers_clamp_float16(cls_name) -> List[type]: +def rewritings_transformers(cls_name) -> List[type]: """ Rewrites known control flows equal to this: @@ -132,15 +137,15 @@ def rewritings_transformers_clamp_float16(cls_name) -> List[type]: import pprint from onnx_diagnostic.torch_export_patches.patch_module_helper import ( - _rewrite_forward_clamp_float16, + _rewrite_forward, ) - pprint.pprint(_rewrite_forward_clamp_float16()) + pprint.pprint(_rewrite_forward()) - Function `_rewrite_forward_clamp_float16` collects + Function `_rewrite_forward` collects all model classes using those layers. """ - _known = _rewrite_forward_clamp_float16() + _known = _rewrite_forward() assert cls_name in _known, f"cls_name={cls_name!r} unknown in {sorted(_known)}." @@ -159,10 +164,10 @@ def _add(f): return [_add(cls.forward) for cls in _known[cls_name]] -def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]: +def code_needing_rewriting(cls_name: Union[type, str]) -> Optional[List[Any]]: """ Returns a known list of classes mapped to a known rewritings - because of control flow. See :func:`known_transformers_rewritings_clamp_float16`. + because of control flow. See :func:`known_transformers_rewritings`. :param cls_name: name of the class :return: a list of rewriting @@ -177,8 +182,10 @@ def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]: pprint.pprint(code_needing_rewriting("BartForConditionalGeneration")) """ - aliases = known_transformers_rewritings_clamp_float16() + if not isinstance(cls_name, str): + cls_name = cls_name.__name__ + aliases = known_transformers_rewritings() if cls_name in aliases: alias = aliases[cls_name] - return rewritings_transformers_clamp_float16(alias) + return rewritings_transformers(alias) return None From 47f844b078c0d9af2f09d512802d20c6ca4c9b4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 11 Feb 2026 16:49:12 +0100 Subject: [PATCH 2/5] fix --- CHANGELOGS.rst | 2 + .../torch_export_patches/patch_module.py | 61 ++++++++++++++++--- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index e0f845c2..d1535b44 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,8 @@ Change Logs 0.9.2 +++++ +* pr:`412`: patches for ViTModel (through rewriting) + 0.9.1 +++++ diff --git a/onnx_diagnostic/torch_export_patches/patch_module.py b/onnx_diagnostic/torch_export_patches/patch_module.py index 82dd4cec..80a744d3 100644 --- a/onnx_diagnostic/torch_export_patches/patch_module.py +++ b/onnx_diagnostic/torch_export_patches/patch_module.py @@ -144,6 +144,39 @@ def visit_Name(self, node): self.local_variables.add(node.id) return node + def is_expression_context(self, node): + if not hasattr(node, "_parent") or node._parent is None: + return False + parent = node._parent + # Common expression contexts: + if isinstance( + parent, + ( + ast.BinOp, + ast.UnaryOp, + ast.BoolOp, + ast.Call, + ast.Subscript, + ast.Compare, + ast.Return, + ast.Expr, + ast.If, + ast.While, + ), + ): + return True + # RHS of assignment: parent is Assign and node is in value + if isinstance(parent, ast.Assign) and node in ast.walk(parent.value): + return True + return False + + def _attach_parents(self, node, parent=None): + node._parent = parent + if parent and not hasattr(node, "lineno"): + node.lineno = parent.lineno + for child in ast.iter_child_nodes(node): + self._attach_parents(child, node) + def visit_FunctionDef(self, node): # Capture argument names for branch functions old_args = self.current_func_args @@ -152,8 +185,11 @@ def visit_FunctionDef(self, node): for n in node.body: visited = self.visit(n) if isinstance(visited, list): + for n in visited: + self._attach_parents(n, node) new_body.extend(visited) else: + self._attach_parents(visited, node) new_body.append(visited) node.body = new_body self.current_func_args = old_args @@ -455,10 +491,10 @@ def visit_If(self, node): f"Inconsistencies between n_returned_values={n_returned_values}, " f"dropped={dropped}, tgt.elts={tgt.elts}, tgt_elts={tgt_elts}" ) - tgt = ast.Tuple(tgt_elts, ctx=ast.Store()) + tgt = ast.Tuple(list(tgt_elts), ctx=ast.Store()) added = {tgt.id} if isinstance(tgt, ast.Name) else set(t.id for t in tgt.elts) - assign = ast.Assign(targets=[tgt], value=call) + assign = ast.Assign(targets=[tgt], value=call, ctx=ast.Store()) ast.copy_location(assign, node) ast.fix_missing_locations(assign) self.local_variables = known_local_variables | added @@ -631,7 +667,7 @@ def visit_For(self, node): ), ], decorator_list=[], - ctx=ast.Store(), + ctx=ast.Load(), ) # final rewriting @@ -654,7 +690,7 @@ def visit_For(self, node): args=[ ast.Name(id=func_name, ctx=ast.Load()), ast.List( - elts=[ast.Name(id=v, ctx=ast.Load()) for v in init_vars], ctx=ast.Store() + elts=[ast.Name(id=v, ctx=ast.Load()) for v in init_vars], ctx=ast.Load() ), ast.List( elts=[ @@ -700,13 +736,13 @@ def visit_For(self, node): ], *[ast.Name(id=v, ctx=ast.Load()) for v in scan_vars], ], - ctx=ast.Store(), + ctx=ast.Load(), ), ast.List( elts=[ ast.Name(id=v, ctx=ast.Load()) for v in [*scan_shape_vars, *input_vars] ], - ctx=ast.Store(), + ctx=ast.Load(), ), ], keywords=[], @@ -923,8 +959,12 @@ def forward(self, x, y): ) try: mod = compile(new_tree, filename="", mode="exec") - except TypeError as e: - if 'required field "lineno" missing from stmt' in str(e): + except (TypeError, ValueError) as e: + se = str(e) + if ( + 'required field "lineno" missing from' in se + or "expression must have Load context but has Store instead" in se + ): # Could not find a way to avoid compiling a string. # The error message still pops up without indicating which node is not # properly set. @@ -932,14 +972,15 @@ def forward(self, x, y): try: mod = compile(code, filename="", mode="exec") except IndentationError as ee: - raise RuntimeError(f"Unable to compile\n{code}") from ee + raise RuntimeError(f"Unable to compile due to {ee} (and {e})\n{code}") from ee else: kws = dict(include_attributes=True, annotate_fields=True, indent=4) raise RuntimeError( - f"Unable to compile code\n--CODE--\n" + f"Unable to compile code due to {e}\n--CODE--\n" f"{ast.unparse(new_tree)}\n--TREE--\n" f"{ast.dump(new_tree, **kws)}" ) from e + namespace: Dict[str, type] = {} globs = func.__globals__.copy() exec(mod, globs, namespace) From d8337dd6c1261f3b66658028b384a6c217bc13aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 11 Feb 2026 17:47:53 +0100 Subject: [PATCH 3/5] no issue --- onnx_diagnostic/torch_export_patches/patch_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_export_patches/patch_module.py b/onnx_diagnostic/torch_export_patches/patch_module.py index 80a744d3..9dbdc98f 100644 --- a/onnx_diagnostic/torch_export_patches/patch_module.py +++ b/onnx_diagnostic/torch_export_patches/patch_module.py @@ -452,7 +452,7 @@ def visit_If(self, node): and node.body[0].value.func.attr == "_check" ): # We assume there is nothing to do, - # and this was rewritting by _rewrite_if_raise. + # and this was rewritten by _rewrite_if_raise. # Or maybe we can include that into the check itself. return [node] # Otherwise it is sompething else. From f4f518b84a52837b1c907b9c536e28c486dc4def Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 11 Feb 2026 23:48:12 +0100 Subject: [PATCH 4/5] fix --- .../test_patch_torch.py | 28 +++++++++++ .../patches/patch_torch.py | 48 ++++++++++++------- 2 files changed, 59 insertions(+), 17 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index 87838bf8..81d24c63 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -688,6 +688,34 @@ def forward(self, x, **kwargs): inputs = [n.name for n in ep.graph.nodes if n.op == "placeholder"] self.assertEqual(["x", "y"], inputs) + def test_mixed_named_and_unnamed_kwargs_2(self): + class Model(torch.nn.Module): + def forward(self, a, x, **kwargs): + return a - x + kwargs["y"] - kwargs["z"] + + args = (torch.randn((5, 6)),) + kwargs = dict(x=torch.randn((5, 6)), y=torch.randn((1, 6)), z=torch.randn((1, 6)) + 10) + model = Model() + expected = model(*args, **kwargs) + with torch_export_patches(patch_torch=True): + ep = torch.export.export( + model, + args, + kwargs=kwargs, + dynamic_shapes={ + "a": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC}, + "x": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC}, + "kwargs": { + "y": {1: torch.export.Dim.DYNAMIC}, + "z": {1: torch.export.Dim.DYNAMIC}, + }, + }, + ) + # ep.module()(**kwargs): raises NameError: name 'L' is not defined + self.assertEqualArray(args[0] - kwargs["x"] + kwargs["y"] - kwargs["z"], expected) + inputs = [n.name for n in ep.graph.nodes if n.op == "placeholder"] + self.assertEqual(["a", "x", "y", "z"], inputs) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index 4890e04e..b5f02c2c 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -134,6 +134,34 @@ def patched_infer_size(a, b): return tuple(expandedSizes) +def _combine_args(f, args, kwargs, preserve_order: bool = False) -> dict[str, Any]: + # combine args and kwargs following the signature of f, as it happens + # in the body of f when called with *args, **kwargs + # the exporter needs to preserve the original order of the arguments + # to match the dynamic shapes. + if isinstance(f, torch.export.ExportedProgram): + f = f.module() + + signature = ( + inspect.signature(f.forward) + if isinstance(f, torch.nn.Module) + else inspect.signature(f) + ) + kwargs = kwargs if kwargs is not None else {} + combined_args = signature.bind(*args, **kwargs).arguments + if not preserve_order: + return combined_args + + combined_args_traced_order = dict(zip(signature.parameters, args)) + for arg in kwargs: + if arg in combined_args: + combined_args_traced_order[arg] = combined_args[arg] + for key in combined_args: + if key not in combined_args_traced_order: + combined_args_traced_order[key] = combined_args[key] + return combined_args_traced_order + + def patched__get_range_constraints( mod: torch.nn.Module, export_artifact: torch.export._trace.ExportArtifact, @@ -159,26 +187,12 @@ def patched__get_range_constraints( len(export_graph_signature.input_specs), ) - combined_args = torch.export._trace._combine_args(mod, args, kwargs) - - # This is because we trace based on the kwargs passed in from user + # preserve_order=True: + # this is because we trace based on the kwargs passed in from user # not based on the signature. I feel it would be better to just enforce # one ordering at the start of tracing to avoid confusions, but that is # bigger refactor, so do this to unblock for now. - assert isinstance( - combined_args, dict - ), f"unexpected type {type(combined_args)} for 'combined_args'" - - combined_args_traced_order = {} - for arg in kwargs: - if arg in combined_args: - combined_args_traced_order[arg] = combined_args[arg] - - for key in combined_args: - if key not in combined_args_traced_order: - combined_args_traced_order[key] = combined_args[key] - - combined_args = combined_args_traced_order + combined_args = _combine_args(mod, args, kwargs, preserve_order=True) range_constraints = torch._export.non_strict_utils.make_constraints( fake_mode, gm, combined_args, dynamic_shapes, num_lifted From 0b2d41f1e49408dd4bfac6da4042d7d3966addeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 12 Feb 2026 01:23:21 +0100 Subject: [PATCH 5/5] fix model builder --- onnx_diagnostic/helpers/model_builder_helper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnx_diagnostic/helpers/model_builder_helper.py b/onnx_diagnostic/helpers/model_builder_helper.py index fa6f4186..6a383e9b 100644 --- a/onnx_diagnostic/helpers/model_builder_helper.py +++ b/onnx_diagnostic/helpers/model_builder_helper.py @@ -40,6 +40,7 @@ def download_model_builder_to_cache( "gemma.py", "gptoss.py", "granite.py", + "internlm.py", "llama.py", "mistral.py", "nemotron.py",