Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ Change Logs
0.9.2
+++++

* pr:`412`: patches for ViTModel (through rewriting)
* :pr:`413`: fix InputObserver in the generic case
* :pr:`412`: patches for ViTModel (through rewriting)

0.9.1
+++++
Expand Down
3 changes: 2 additions & 1 deletion _unittests/ut_export/test_cf_simple_loop_for.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
from typing import Tuple
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, requires_transformers
from onnx_diagnostic.export.control_flow_onnx import (
enable_code_export_control_flow,
)
Expand Down Expand Up @@ -332,6 +332,7 @@ def forward(self, n_iter, x):
self.assertEqualArray(model(n, -x), ep.module()(n, -x))

@requires_torch("2.9.99")
@requires_transformers("4.50")
def test_simple_loop_for_phi4(self):
_IMAGE_SPECIAL_TOKEN_ID = 200010
vocab_size = 200064
Expand Down
74 changes: 74 additions & 0 deletions _unittests/ut_investigate/test_input_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,80 @@ def forward(self, x, **kwargs):
)
torch.export.export(model, (), kwargs=args, dynamic_shapes=ds)

def test_io_captured_kwargs_kwargs_with_args(self):
class Model(torch.nn.Module):
def forward(self, a, *args, **kwargs):
return a - args[0] * args[1] + kwargs["x"] - kwargs["y"]

inputs = [
(
(torch.randn((5, 6)), torch.randn((5, 6)), torch.randn((5, 6))),
dict(x=torch.randn((5, 6)), y=torch.randn((1, 6))),
),
(
(torch.randn((7, 7)), torch.randn((7, 7)), torch.randn((7, 7))),
dict(x=torch.randn((7, 7)), y=torch.randn((1, 7))),
),
]

model = Model()
expected = [model(*args, **kwargs) for args, kwargs in inputs]
observer = InputObserver()
with observer(model):
for args, kwargs in inputs:
model(*args, **kwargs)
self.assertEqual(len(observer.info), 2)
for i in range(2):
self.assertEqual(len(observer.info.flat_outputs[i]), 1)
torch.testing.assert_close(expected[i], observer.info.flat_outputs[i][0])

cst = torch.export.Dim.DYNAMIC
ds = observer.infer_dynamic_shapes()
self.assertEqual(
{
"a": {0: cst, 1: cst},
"args": ({0: cst, 1: cst}, {0: cst, 1: cst}),
"kwargs": {"x": {0: cst, 1: cst}, "y": {1: cst}},
},
ds,
)

dynamic_shapes = torch.export.AdditionalInputs()
for args, kwargs in inputs:
dynamic_shapes.add(args, kwargs)
dss = dynamic_shapes.dynamic_shapes(model, *inputs[0])
self.assertEqual(
{
"a": (cst, cst),
"args": ((cst, cst), (cst, cst)),
"kwargs": {"x": (cst, cst), "y": (None, cst)},
},
dss,
)

with self.assertRaises(RuntimeError):
observer.infer_arguments()

args, kwargs = observer.infer_arguments(as_args_kwargs=True)
self.assertIsInstance(kwargs, dict)
self.assertEqual(["x", "y"], list(kwargs))
self.assertIsInstance(args, tuple)
self.assertEqual(len(args), 3)

# _get_range_constraints
with torch_export_patches(patch_torch=True):
torch.export.export(
model,
args,
kwargs=kwargs,
dynamic_shapes={
"a": {0: cst, 1: cst},
"args": ({0: cst, 1: cst}, {0: cst, 1: cst}),
"kwargs": {"x": {0: cst, 1: cst}, "y": {1: cst}},
},
)
torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=ds)


if __name__ == "__main__":
unittest.main(verbosity=2)
35 changes: 34 additions & 1 deletion _unittests/ut_torch_export_patches/test_patch_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def forward(self, x, **kwargs):
inputs = [n.name for n in ep.graph.nodes if n.op == "placeholder"]
self.assertEqual(["x", "y"], inputs)

def test_mixed_named_and_unnamed_kwargs_2(self):
def test_mixed_named_and_unnamed_kwargs_with_args(self):
class Model(torch.nn.Module):
def forward(self, a, x, **kwargs):
return a - x + kwargs["y"] - kwargs["z"]
Expand Down Expand Up @@ -716,6 +716,39 @@ def forward(self, a, x, **kwargs):
inputs = [n.name for n in ep.graph.nodes if n.op == "placeholder"]
self.assertEqual(["a", "x", "y", "z"], inputs)

def test_mixed_named_and_unnamed_kwargs_with_generic_args(self):
class Model(torch.nn.Module):
def forward(self, a, *args, **kwargs):
return a - args[0] * args[1] + kwargs["y"] - kwargs["z"]

args = (torch.randn((5, 6)), torch.randn((5, 6)), torch.randn((5, 6)))
kwargs = dict(y=torch.randn((1, 6)), z=torch.randn((1, 6)) + 10)
model = Model()
expected = model(*args, **kwargs)
with torch_export_patches(patch_torch=True):
ep = torch.export.export(
model,
args,
kwargs=kwargs,
dynamic_shapes={
"a": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},
"args": (
{0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},
{0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},
),
"kwargs": {
"y": {1: torch.export.Dim.DYNAMIC},
"z": {1: torch.export.Dim.DYNAMIC},
},
},
)
# ep.module()(**kwargs): raises NameError: name 'L' is not defined
self.assertEqualArray(
args[0] - args[1] * args[2] + kwargs["y"] - kwargs["z"], expected
)
inputs = [n.name for n in ep.graph.nodes if n.op == "placeholder"]
self.assertEqual(["a", "args_0", "args_1", "y", "z"], inputs)


if __name__ == "__main__":
unittest.main(verbosity=2)
120 changes: 95 additions & 25 deletions onnx_diagnostic/investigate/input_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ class InputObserverInfo:
is added (such as `past_key_values`).
The values are only to infer dynamic shapes and arguments,
not to run the model.
args_name_and_position:
Name of parameter `*args` and its position if it exists.
kwargs_name:
Name of parameter `**kwargs` if it exists.

Expand All @@ -315,6 +317,7 @@ def __init__(
signature_names: list[str],
default_values: dict[str, int | bool | str | float],
missing: dict[str, Any],
args_name_and_position: tuple[str, int] | None,
kwargs_name: str | None,
):
self.default_values = default_values
Expand All @@ -323,6 +326,7 @@ def __init__(
self.outputs_specs: list[torch.utils._pytree.PyTreeSpec] = []
self.flat_outputs: list[list[torch.Tensor | None]] = []
self.latencies: list[float] = []
self.args_name_and_position = args_name_and_position
self.kwargs_name = kwargs_name
self.signature_names = signature_names
self._best_candidate: InputCandidate | None = None
Expand Down Expand Up @@ -491,11 +495,17 @@ def _set_batch_dimension_for_flat_index(index):
flat_dynamic_shapes = [dict.fromkeys(dims, cst) for dims in dynamic_shapes]
if return_flat:
return tuple(flat_dynamic_shapes)

# Let's regroup.
if len(flat_dynamic_shapes) == len(self._best_candidate.args) + len(
self._best_candidate.kwargs
):
# It means forward method is called with tensors only.
if not self._best_candidate.kwargs and not self._best_candidate.cst_kwargs:
if (
not self._best_candidate.kwargs
and not self._best_candidate.cst_kwargs
and not self.args_name_and_position
):
# only positional arguments
return tuple(flat_dynamic_shapes)
if not self._best_candidate.args:
Expand All @@ -504,14 +514,32 @@ def _set_batch_dimension_for_flat_index(index):
return self._post_process_for_kwargs(
{**ds, **dict.fromkeys(self._best_candidate.cst_kwargs, None)}
)
if not self.args_name_and_position:
# positional arguments needs to be moved to the named arguments
n_args = len(self._best_candidate.args)
pos_names = self.signature_names[:n_args]
return self._post_process_for_kwargs(
{
**dict(zip(pos_names, flat_dynamic_shapes[:n_args])),
**dict(
zip(
list(self._best_candidate.kwargs), flat_dynamic_shapes[n_args:]
)
),
**dict.fromkeys(self._best_candidate.cst_kwargs, None),
}
)
# positional arguments needs to be moved to the named arguments
n_args = len(self._best_candidate.args)
n_args = min(len(self._best_candidate.args), self.args_name_and_position[1])
i_kwargs = max(len(self._best_candidate.args), self.args_name_and_position[1])
var_pos = self.args_name_and_position[0]
pos_names = self.signature_names[:n_args]
return self._post_process_for_kwargs(
{
**dict(zip(pos_names, flat_dynamic_shapes[:n_args])),
var_pos: tuple(flat_dynamic_shapes[n_args:i_kwargs]),
**dict(
zip(list(self._best_candidate.kwargs), flat_dynamic_shapes[n_args:])
zip(list(self._best_candidate.kwargs), flat_dynamic_shapes[i_kwargs:])
),
**dict.fromkeys(self._best_candidate.cst_kwargs, None),
}
Expand Down Expand Up @@ -552,16 +580,38 @@ def change_function(t):
)
if self._best_candidate.cst_kwargs:
ds_kwargs = {**ds_kwargs, **dict.fromkeys(self._best_candidate.cst_kwargs, None)}
if not ds_kwargs:
if not ds_kwargs and not self.args_name_and_position:
return tuple(ds_args)
if not ds_args:
return self._post_process_for_kwargs(ds_kwargs)
pos_names = self.signature_names[: len(ds_args)]
return self._post_process_for_kwargs({**dict(zip(pos_names, ds_args)), **ds_kwargs})

if not self.args_name_and_position:
pos_names = self.signature_names[: len(ds_args)]
return self._post_process_for_kwargs(
{**dict(zip(pos_names, ds_args)), **ds_kwargs}
)

n_args = min(len(ds_args), self.args_name_and_position[1])
pos_names = self.signature_names[:n_args]
return self._post_process_for_kwargs(
{
**dict(zip(pos_names, ds_args[:n_args])),
self.args_name_and_position[0]: tuple(ds_args[n_args:]),
**ds_kwargs,
}
)

def infer_arguments(
self, index_or_candidate: InputCandidate | int | None = None, flat: bool = False
) -> list[torch.Tensor] | tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
self,
index_or_candidate: InputCandidate | int | None = None,
flat: bool = False,
as_args_kwargs: bool = False,
) -> (
list[torch.Tensor]
| tuple[torch.Tensor, ...]
| dict[str, torch.Tensor]
| tuple[list[torch.Tensor] | tuple[torch.Tensor, ...], dict[str, torch.Tensor]]
):
"""Infers arguments based on the collected tensors."""
# This is already checked by _build_inputs_completed_with_none_values
# but this is not always well captured by tools checking types.
Expand Down Expand Up @@ -649,6 +699,7 @@ def infer_arguments(
assert all(t is not None for t in aligned_flat_list)
# pyrefly: ignore[bad-return]
return aligned_flat_list

# type checking
assert candidate is not None
assert candidate.aligned_spec is not None
Expand All @@ -658,13 +709,26 @@ def infer_arguments(
if self._best_candidate.cst_kwargs:
kwargs = {**kwargs, **self._best_candidate.cst_kwargs}

if not kwargs:
return args
if not args:
return kwargs
# We need to move args to kwargs
pos_names = self.signature_names[: len(args)]
return {**dict(zip(pos_names, args)), **kwargs}
if not as_args_kwargs:
if not kwargs:
return args
if not args:
return kwargs

# We need to move args to kwargs
if self.args_name_and_position:
raise RuntimeError(
"Cannot return arguments "
"as a single tuple or a single dictionary "
"because of '*args' in the function signature. "
"You need to set `as_args_kwargs=True`."
)
n_args = len(args)
pos_names = self.signature_names[:n_args]
return {**dict(zip(pos_names, args[:n_args])), **kwargs}

# Generic case.
return tuple(args), kwargs

def _post_process_for_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -783,15 +847,10 @@ def __call__(
if sig.parameters[p].kind == inspect.Parameter.VAR_KEYWORD
]
args_names = [
p
for p in sig.parameters
(p, i)
for (i, p) in enumerate(sig.parameters)
if sig.parameters[p].kind == inspect.Parameter.VAR_POSITIONAL
]
if args_names:
raise RuntimeError(
f"Inference is not implemented "
f"when the signature includes '*{args_names[0]}'."
)
self.info = InputObserverInfo(
signature_names=list(sig.parameters),
default_values={
Expand All @@ -801,6 +860,7 @@ def __call__(
and isinstance(p.default, (int, bool, str, float))
},
missing=self.missing,
args_name_and_position=args_names[0] if args_names else None,
kwargs_name=kwargs_names[0] if kwargs_names else None,
)
n_already_stored = len(self.info)
Expand Down Expand Up @@ -852,7 +912,13 @@ def infer_arguments(
self,
index_or_args_or_kwargs: tuple[Any] | dict[str, Any] | int | None = None,
flat: bool = False,
) -> list[torch.Tensor] | tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
as_args_kwargs: bool = False,
) -> (
list[torch.Tensor]
| tuple[torch.Tensor, ...]
| dict[str, torch.Tensor]
| tuple[list[torch.Tensor] | tuple[torch.Tensor, ...], dict[str, torch.Tensor]]
):
"""Infers arguments based on the collected tensors.

Args:
Expand All @@ -865,7 +931,9 @@ def infer_arguments(
flat: If True, it returns a flattened list of tensors,
if False, it returns a tuple or a dictionary preserving
the nested structures.

as_args_kwargs: If True, the method always returns `(args, kwargs)`,
otherwise, it returns either a tuple (only args) or a dictionary
(only kwargs) or raises an exception if it cannot do so.
Returns:
Inferred arguments, every optional tensor is replaced by a empty tensor.
"""
Expand Down Expand Up @@ -908,7 +976,9 @@ def infer_arguments(
self.info._captured_inputs,
self.info.signature_names,
)
return self.info.infer_arguments(index_or_candidate=index_or_candidate, flat=flat)
return self.info.infer_arguments(
index_or_candidate=index_or_candidate, flat=flat, as_args_kwargs=as_args_kwargs
)

def check_discrepancies(
self,
Expand Down
Loading
Loading