Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Change Logs
0.9.2
+++++

* pr:`412`: patches for ViTModel (through rewriting)

0.9.1
+++++

Expand Down
10 changes: 5 additions & 5 deletions _doc/patches.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ Here is the list of supported caches:

.. _l-control-flow-rewriting:

Control flow rewriting
Control Flow Rewriting
======================

This is an attempt to automatically rewrite control flow using :mod:`ast`.
Expand Down Expand Up @@ -217,17 +217,17 @@ The locations where it has to be done:

import pprint
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
known_transformers_rewritings_clamp_float16,
known_transformers_rewritings,
)

pprint.pprint(known_transformers_rewritings_clamp_float16())
pprint.pprint(known_transformers_rewritings())

.. runpython::
:showcode:

import pprint
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
_rewrite_forward_clamp_float16,
_rewrite_forward,
)

pprint.pprint(_rewrite_forward_clamp_float16())
pprint.pprint(_rewrite_forward())
8 changes: 4 additions & 4 deletions _doc/status/patches_coverage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ The number of fixes if much less than the number of classes to fix.

import pprint
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
known_transformers_rewritings_clamp_float16,
known_transformers_rewritings,
)

pprint.pprint(known_transformers_rewritings_clamp_float16())
pprint.pprint(known_transformers_rewritings())

.. runpython::
:showcode:

import pprint
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
_rewrite_forward_clamp_float16,
_rewrite_forward,
)

pprint.pprint(_rewrite_forward_clamp_float16())
pprint.pprint(_rewrite_forward())
41 changes: 40 additions & 1 deletion _unittests/ut_torch_export_patches/test_patch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,6 @@ def test__find_loop_vars(self):

@requires_torch("2.8")
def test_rewrite_loop(self):

class Model(torch.nn.Module):
def forward(self, x, y):
z = torch.empty((x.shape[0], y.shape[0]))
Expand Down Expand Up @@ -715,6 +714,46 @@ def forward(self, x):
got = ep.module()(x)
self.assertEqualArray(expected, got)

def test_test_raise(self):
class Model(torch.nn.Module):
def forward(self, x, y):
if x.shape[0] != y.shape[0]:
raise ValueError(f"Wrong shape {x.shape=} and {y.shape=}")
return x + y

x, y = torch.rand((3, 4)), torch.rand((3, 4))
expected, expected_ = Model()(x, y), Model()(-x, y)

rewritten = transform_method(Model.forward)
self.assertIn("torch._check(", rewritten.code)
Model.forward = rewritten.func
self.assertEqualAny(expected, Model()(x, y))
self.assertEqualAny(expected_, Model()(-x, y))

DYN = torch.export.Dim.DYNAMIC
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
self.assertEqualAny(expected, ep.module()(x, y))
self.assertEqualAny(expected_, ep.module()(-x, y))

@hide_stdout()
def test_test_raise_rewrite(self):
class Model(torch.nn.Module):
def forward(self, x, y):
if x.shape[0] != y.shape[0]:
raise ValueError(f"Wrong shape {x.shape=} and {y.shape=}")
return x + y

model = Model()
x, y = torch.rand((4, 5)), torch.rand((4, 5))
expected = model(x, y)
DYN = torch.export.Dim.DYNAMIC
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
with torch_export_rewrite(rewrite=[(Model, "forward")], verbose=1):
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
got = ep.module()(x, y)
self.assertEqualArray(expected, got)


if __name__ == "__main__":
unittest.main(verbosity=2)
16 changes: 16 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
rewrite_loop_for_square_mask,
)
from onnx_diagnostic.torch_export_patches.patch_module_helper import code_needing_rewriting
from onnx_diagnostic.torch_export_patches.patch_module import transform_method


class TestPatchRewriting(ExtTestCase):
Expand Down Expand Up @@ -38,6 +39,21 @@ def test_code_needing_rewriting(self):
res = code_needing_rewriting("BartModel")
self.assertEqual(len(res), 2)

def test_code_needing_rewriting_vit_patch_embedding(self):
res = code_needing_rewriting("ViTPatchEmbeddings")
self.assertEqual(len(res), 1)

def test_code_needing_rewriting_vit_class(self):
import transformers

res = code_needing_rewriting(transformers.models.vit.modeling_vit.ViTModel)
self.assertEqual(len(res), 1)

def test_rewriting_vit_patch_embedding(self):
import transformers

transform_method(transformers.models.vit.modeling_vit.ViTPatchEmbeddings.forward)


if __name__ == "__main__":
unittest.main(verbosity=2)
28 changes: 28 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,34 @@ def forward(self, x, **kwargs):
inputs = [n.name for n in ep.graph.nodes if n.op == "placeholder"]
self.assertEqual(["x", "y"], inputs)

def test_mixed_named_and_unnamed_kwargs_2(self):
class Model(torch.nn.Module):
def forward(self, a, x, **kwargs):
return a - x + kwargs["y"] - kwargs["z"]

args = (torch.randn((5, 6)),)
kwargs = dict(x=torch.randn((5, 6)), y=torch.randn((1, 6)), z=torch.randn((1, 6)) + 10)
model = Model()
expected = model(*args, **kwargs)
with torch_export_patches(patch_torch=True):
ep = torch.export.export(
model,
args,
kwargs=kwargs,
dynamic_shapes={
"a": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},
"x": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},
"kwargs": {
"y": {1: torch.export.Dim.DYNAMIC},
"z": {1: torch.export.Dim.DYNAMIC},
},
},
)
# ep.module()(**kwargs): raises NameError: name 'L' is not defined
self.assertEqualArray(args[0] - kwargs["x"] + kwargs["y"] - kwargs["z"], expected)
inputs = [n.name for n in ep.graph.nodes if n.op == "placeholder"]
self.assertEqual(["a", "x", "y", "z"], inputs)


if __name__ == "__main__":
unittest.main(verbosity=2)
1 change: 1 addition & 0 deletions onnx_diagnostic/helpers/model_builder_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def download_model_builder_to_cache(
"gemma.py",
"gptoss.py",
"granite.py",
"internlm.py",
"llama.py",
"mistral.py",
"nemotron.py",
Expand Down
3 changes: 2 additions & 1 deletion onnx_diagnostic/torch_export_patches/onnx_export_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,8 @@ def torch_export_patches(
before being exported if the execution path depends on the inputs,
this is done by function :func:`transform_method
<onnx_diagnostic.torch_export_patches.patch_module.transform_method>`,
its documentation provides possible values
its documentation provides possible values, if `rewrite==True`, then
all known methods to rewrite are added.
:param dump_rewriting: dumps rewriting information in file beginning with that prefix,
this only applied on the automated rewritings
:param patch_details: if specified, this class is used to stored every applied rewriting.
Expand Down
Loading
Loading