diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index d1535b44..9c1150b1 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,7 +4,8 @@ Change Logs 0.9.2 +++++ -* pr:`412`: patches for ViTModel (through rewriting) +* :pr:`413`: fix InputObserver in the generic case +* :pr:`412`: patches for ViTModel (through rewriting) 0.9.1 +++++ diff --git a/_unittests/ut_export/test_cf_simple_loop_for.py b/_unittests/ut_export/test_cf_simple_loop_for.py index eb99cc58..fee638ee 100644 --- a/_unittests/ut_export/test_cf_simple_loop_for.py +++ b/_unittests/ut_export/test_cf_simple_loop_for.py @@ -1,7 +1,7 @@ import unittest from typing import Tuple import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, requires_transformers from onnx_diagnostic.export.control_flow_onnx import ( enable_code_export_control_flow, ) @@ -332,6 +332,7 @@ def forward(self, n_iter, x): self.assertEqualArray(model(n, -x), ep.module()(n, -x)) @requires_torch("2.9.99") + @requires_transformers("4.50") def test_simple_loop_for_phi4(self): _IMAGE_SPECIAL_TOKEN_ID = 200010 vocab_size = 200064 diff --git a/_unittests/ut_investigate/test_input_observer.py b/_unittests/ut_investigate/test_input_observer.py index 9c802321..6383c7c5 100644 --- a/_unittests/ut_investigate/test_input_observer.py +++ b/_unittests/ut_investigate/test_input_observer.py @@ -993,6 +993,80 @@ def forward(self, x, **kwargs): ) torch.export.export(model, (), kwargs=args, dynamic_shapes=ds) + def test_io_captured_kwargs_kwargs_with_args(self): + class Model(torch.nn.Module): + def forward(self, a, *args, **kwargs): + return a - args[0] * args[1] + kwargs["x"] - kwargs["y"] + + inputs = [ + ( + (torch.randn((5, 6)), torch.randn((5, 6)), torch.randn((5, 6))), + dict(x=torch.randn((5, 6)), y=torch.randn((1, 6))), + ), + ( + (torch.randn((7, 7)), torch.randn((7, 7)), torch.randn((7, 7))), + dict(x=torch.randn((7, 7)), y=torch.randn((1, 7))), + ), + ] + + model = Model() + expected = [model(*args, **kwargs) for args, kwargs in inputs] + observer = InputObserver() + with observer(model): + for args, kwargs in inputs: + model(*args, **kwargs) + self.assertEqual(len(observer.info), 2) + for i in range(2): + self.assertEqual(len(observer.info.flat_outputs[i]), 1) + torch.testing.assert_close(expected[i], observer.info.flat_outputs[i][0]) + + cst = torch.export.Dim.DYNAMIC + ds = observer.infer_dynamic_shapes() + self.assertEqual( + { + "a": {0: cst, 1: cst}, + "args": ({0: cst, 1: cst}, {0: cst, 1: cst}), + "kwargs": {"x": {0: cst, 1: cst}, "y": {1: cst}}, + }, + ds, + ) + + dynamic_shapes = torch.export.AdditionalInputs() + for args, kwargs in inputs: + dynamic_shapes.add(args, kwargs) + dss = dynamic_shapes.dynamic_shapes(model, *inputs[0]) + self.assertEqual( + { + "a": (cst, cst), + "args": ((cst, cst), (cst, cst)), + "kwargs": {"x": (cst, cst), "y": (None, cst)}, + }, + dss, + ) + + with self.assertRaises(RuntimeError): + observer.infer_arguments() + + args, kwargs = observer.infer_arguments(as_args_kwargs=True) + self.assertIsInstance(kwargs, dict) + self.assertEqual(["x", "y"], list(kwargs)) + self.assertIsInstance(args, tuple) + self.assertEqual(len(args), 3) + + # _get_range_constraints + with torch_export_patches(patch_torch=True): + torch.export.export( + model, + args, + kwargs=kwargs, + dynamic_shapes={ + "a": {0: cst, 1: cst}, + "args": ({0: cst, 1: cst}, {0: cst, 1: cst}), + "kwargs": {"x": {0: cst, 1: cst}, "y": {1: cst}}, + }, + ) + torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=ds) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index 81d24c63..33905bfa 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -688,7 +688,7 @@ def forward(self, x, **kwargs): inputs = [n.name for n in ep.graph.nodes if n.op == "placeholder"] self.assertEqual(["x", "y"], inputs) - def test_mixed_named_and_unnamed_kwargs_2(self): + def test_mixed_named_and_unnamed_kwargs_with_args(self): class Model(torch.nn.Module): def forward(self, a, x, **kwargs): return a - x + kwargs["y"] - kwargs["z"] @@ -716,6 +716,39 @@ def forward(self, a, x, **kwargs): inputs = [n.name for n in ep.graph.nodes if n.op == "placeholder"] self.assertEqual(["a", "x", "y", "z"], inputs) + def test_mixed_named_and_unnamed_kwargs_with_generic_args(self): + class Model(torch.nn.Module): + def forward(self, a, *args, **kwargs): + return a - args[0] * args[1] + kwargs["y"] - kwargs["z"] + + args = (torch.randn((5, 6)), torch.randn((5, 6)), torch.randn((5, 6))) + kwargs = dict(y=torch.randn((1, 6)), z=torch.randn((1, 6)) + 10) + model = Model() + expected = model(*args, **kwargs) + with torch_export_patches(patch_torch=True): + ep = torch.export.export( + model, + args, + kwargs=kwargs, + dynamic_shapes={ + "a": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC}, + "args": ( + {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC}, + {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC}, + ), + "kwargs": { + "y": {1: torch.export.Dim.DYNAMIC}, + "z": {1: torch.export.Dim.DYNAMIC}, + }, + }, + ) + # ep.module()(**kwargs): raises NameError: name 'L' is not defined + self.assertEqualArray( + args[0] - args[1] * args[2] + kwargs["y"] - kwargs["z"], expected + ) + inputs = [n.name for n in ep.graph.nodes if n.op == "placeholder"] + self.assertEqual(["a", "args_0", "args_1", "y", "z"], inputs) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index 34264e95..2fb32d71 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -304,6 +304,8 @@ class InputObserverInfo: is added (such as `past_key_values`). The values are only to infer dynamic shapes and arguments, not to run the model. + args_name_and_position: + Name of parameter `*args` and its position if it exists. kwargs_name: Name of parameter `**kwargs` if it exists. @@ -315,6 +317,7 @@ def __init__( signature_names: list[str], default_values: dict[str, int | bool | str | float], missing: dict[str, Any], + args_name_and_position: tuple[str, int] | None, kwargs_name: str | None, ): self.default_values = default_values @@ -323,6 +326,7 @@ def __init__( self.outputs_specs: list[torch.utils._pytree.PyTreeSpec] = [] self.flat_outputs: list[list[torch.Tensor | None]] = [] self.latencies: list[float] = [] + self.args_name_and_position = args_name_and_position self.kwargs_name = kwargs_name self.signature_names = signature_names self._best_candidate: InputCandidate | None = None @@ -491,11 +495,17 @@ def _set_batch_dimension_for_flat_index(index): flat_dynamic_shapes = [dict.fromkeys(dims, cst) for dims in dynamic_shapes] if return_flat: return tuple(flat_dynamic_shapes) + + # Let's regroup. if len(flat_dynamic_shapes) == len(self._best_candidate.args) + len( self._best_candidate.kwargs ): # It means forward method is called with tensors only. - if not self._best_candidate.kwargs and not self._best_candidate.cst_kwargs: + if ( + not self._best_candidate.kwargs + and not self._best_candidate.cst_kwargs + and not self.args_name_and_position + ): # only positional arguments return tuple(flat_dynamic_shapes) if not self._best_candidate.args: @@ -504,14 +514,32 @@ def _set_batch_dimension_for_flat_index(index): return self._post_process_for_kwargs( {**ds, **dict.fromkeys(self._best_candidate.cst_kwargs, None)} ) + if not self.args_name_and_position: + # positional arguments needs to be moved to the named arguments + n_args = len(self._best_candidate.args) + pos_names = self.signature_names[:n_args] + return self._post_process_for_kwargs( + { + **dict(zip(pos_names, flat_dynamic_shapes[:n_args])), + **dict( + zip( + list(self._best_candidate.kwargs), flat_dynamic_shapes[n_args:] + ) + ), + **dict.fromkeys(self._best_candidate.cst_kwargs, None), + } + ) # positional arguments needs to be moved to the named arguments - n_args = len(self._best_candidate.args) + n_args = min(len(self._best_candidate.args), self.args_name_and_position[1]) + i_kwargs = max(len(self._best_candidate.args), self.args_name_and_position[1]) + var_pos = self.args_name_and_position[0] pos_names = self.signature_names[:n_args] return self._post_process_for_kwargs( { **dict(zip(pos_names, flat_dynamic_shapes[:n_args])), + var_pos: tuple(flat_dynamic_shapes[n_args:i_kwargs]), **dict( - zip(list(self._best_candidate.kwargs), flat_dynamic_shapes[n_args:]) + zip(list(self._best_candidate.kwargs), flat_dynamic_shapes[i_kwargs:]) ), **dict.fromkeys(self._best_candidate.cst_kwargs, None), } @@ -552,16 +580,38 @@ def change_function(t): ) if self._best_candidate.cst_kwargs: ds_kwargs = {**ds_kwargs, **dict.fromkeys(self._best_candidate.cst_kwargs, None)} - if not ds_kwargs: + if not ds_kwargs and not self.args_name_and_position: return tuple(ds_args) if not ds_args: return self._post_process_for_kwargs(ds_kwargs) - pos_names = self.signature_names[: len(ds_args)] - return self._post_process_for_kwargs({**dict(zip(pos_names, ds_args)), **ds_kwargs}) + + if not self.args_name_and_position: + pos_names = self.signature_names[: len(ds_args)] + return self._post_process_for_kwargs( + {**dict(zip(pos_names, ds_args)), **ds_kwargs} + ) + + n_args = min(len(ds_args), self.args_name_and_position[1]) + pos_names = self.signature_names[:n_args] + return self._post_process_for_kwargs( + { + **dict(zip(pos_names, ds_args[:n_args])), + self.args_name_and_position[0]: tuple(ds_args[n_args:]), + **ds_kwargs, + } + ) def infer_arguments( - self, index_or_candidate: InputCandidate | int | None = None, flat: bool = False - ) -> list[torch.Tensor] | tuple[torch.Tensor, ...] | dict[str, torch.Tensor]: + self, + index_or_candidate: InputCandidate | int | None = None, + flat: bool = False, + as_args_kwargs: bool = False, + ) -> ( + list[torch.Tensor] + | tuple[torch.Tensor, ...] + | dict[str, torch.Tensor] + | tuple[list[torch.Tensor] | tuple[torch.Tensor, ...], dict[str, torch.Tensor]] + ): """Infers arguments based on the collected tensors.""" # This is already checked by _build_inputs_completed_with_none_values # but this is not always well captured by tools checking types. @@ -649,6 +699,7 @@ def infer_arguments( assert all(t is not None for t in aligned_flat_list) # pyrefly: ignore[bad-return] return aligned_flat_list + # type checking assert candidate is not None assert candidate.aligned_spec is not None @@ -658,13 +709,26 @@ def infer_arguments( if self._best_candidate.cst_kwargs: kwargs = {**kwargs, **self._best_candidate.cst_kwargs} - if not kwargs: - return args - if not args: - return kwargs - # We need to move args to kwargs - pos_names = self.signature_names[: len(args)] - return {**dict(zip(pos_names, args)), **kwargs} + if not as_args_kwargs: + if not kwargs: + return args + if not args: + return kwargs + + # We need to move args to kwargs + if self.args_name_and_position: + raise RuntimeError( + "Cannot return arguments " + "as a single tuple or a single dictionary " + "because of '*args' in the function signature. " + "You need to set `as_args_kwargs=True`." + ) + n_args = len(args) + pos_names = self.signature_names[:n_args] + return {**dict(zip(pos_names, args[:n_args])), **kwargs} + + # Generic case. + return tuple(args), kwargs def _post_process_for_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: """ @@ -783,15 +847,10 @@ def __call__( if sig.parameters[p].kind == inspect.Parameter.VAR_KEYWORD ] args_names = [ - p - for p in sig.parameters + (p, i) + for (i, p) in enumerate(sig.parameters) if sig.parameters[p].kind == inspect.Parameter.VAR_POSITIONAL ] - if args_names: - raise RuntimeError( - f"Inference is not implemented " - f"when the signature includes '*{args_names[0]}'." - ) self.info = InputObserverInfo( signature_names=list(sig.parameters), default_values={ @@ -801,6 +860,7 @@ def __call__( and isinstance(p.default, (int, bool, str, float)) }, missing=self.missing, + args_name_and_position=args_names[0] if args_names else None, kwargs_name=kwargs_names[0] if kwargs_names else None, ) n_already_stored = len(self.info) @@ -852,7 +912,13 @@ def infer_arguments( self, index_or_args_or_kwargs: tuple[Any] | dict[str, Any] | int | None = None, flat: bool = False, - ) -> list[torch.Tensor] | tuple[torch.Tensor, ...] | dict[str, torch.Tensor]: + as_args_kwargs: bool = False, + ) -> ( + list[torch.Tensor] + | tuple[torch.Tensor, ...] + | dict[str, torch.Tensor] + | tuple[list[torch.Tensor] | tuple[torch.Tensor, ...], dict[str, torch.Tensor]] + ): """Infers arguments based on the collected tensors. Args: @@ -865,7 +931,9 @@ def infer_arguments( flat: If True, it returns a flattened list of tensors, if False, it returns a tuple or a dictionary preserving the nested structures. - + as_args_kwargs: If True, the method always returns `(args, kwargs)`, + otherwise, it returns either a tuple (only args) or a dictionary + (only kwargs) or raises an exception if it cannot do so. Returns: Inferred arguments, every optional tensor is replaced by a empty tensor. """ @@ -908,7 +976,9 @@ def infer_arguments( self.info._captured_inputs, self.info.signature_names, ) - return self.info.infer_arguments(index_or_candidate=index_or_candidate, flat=flat) + return self.info.infer_arguments( + index_or_candidate=index_or_candidate, flat=flat, as_args_kwargs=as_args_kwargs + ) def check_discrepancies( self, diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index b5f02c2c..22297493 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -152,7 +152,25 @@ def _combine_args(f, args, kwargs, preserve_order: bool = False) -> dict[str, An if not preserve_order: return combined_args - combined_args_traced_order = dict(zip(signature.parameters, args)) + var_position_parameters = [ + name + for name, p in signature.parameters.items() + if p.kind == inspect.Parameter.VAR_POSITIONAL + ] + if var_position_parameters: + n_positional_only = max( + [ + i + for i, p in enumerate(signature.parameters.values()) + if p.kind == inspect.Parameter.VAR_POSITIONAL + ] + ) + combined_args_traced_order = dict(zip(signature.parameters, args[:n_positional_only])) + combined_args_traced_order[var_position_parameters[0]] = tuple( + args[n_positional_only:] + ) + else: + combined_args_traced_order = dict(zip(signature.parameters, args)) for arg in kwargs: if arg in combined_args: combined_args_traced_order[arg] = combined_args[arg]