Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.data
*.pyc
*.pyd
*.dylib
Expand Down
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.9.1
+++++

* :pr:`410`: add patch for `_get_range_constraints`
* :pr:`409`: improves ModelBuilder wrapper
* :pr:`408`: fix torch_deepcopy for empty DynamicCache and transformers==5.1.0, 5.2.0 (see https://github.com/huggingface/transformers/pull/43765/)

Expand Down
3 changes: 2 additions & 1 deletion _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def linkcode_resolve(domain, info):
("py:class", "torch.fx.proxy.TracerBase"),
("py:class", "torch.FloatTensor"),
("py:class", "torch.LongTensor"),
("py:class", "torch.export._trace.ExportArtifact"),
("py:class", "torch.utils._pytree.Context"),
("py:class", "torch.utils._pytree.KeyEntry"),
("py:class", "torch.utils._pytree.TreeSpec"),
Expand Down Expand Up @@ -211,7 +212,7 @@ def linkcode_resolve(domain, info):

if int(os.environ.get("UNITTEST_GOING", "0")):
sphinx_gallery_conf["ignore_pattern"] = (
".*((tiny_llm)|(dort)|(draft_mode)|(hub_codellama.py)|(whisper)).*"
".*((tiny_llm)|(dort)|(draft_mode)|(hub_codellama.py)|(whisper)|(optimind)).*"
)
elif pv.Version(torch.__version__) < pv.Version("2.8"):
sphinx_gallery_conf["ignore_pattern"] = ".*((_oe_)|(dort)|(draft_mode)).*"
Expand Down
136 changes: 136 additions & 0 deletions _doc/final/plot_export_optimind_input_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""
.. _l-plot-optimind-export-input-observer:

Export OptiMind-SFT with InputObserver
======================================

This reuses the recipe introduced by example :ref:`l-plot-tiny-llm-export-input-observer`
for model `microsoft/OptiMind-SFT <https://huggingface.co/microsoft/OptiMind-SFT>`_.
We only export class ``GptOssExperts``.

Let's create a random model
+++++++++++++++++++++++++++
"""

import pandas
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from onnx_diagnostic import doc
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import (
register_additional_serialization_functions,
torch_export_patches,
)
from onnx_diagnostic.investigate.input_observer import InputObserver

device = "cuda"
model_id = "microsoft/OptiMind-SFT"
print(f"get tokenizer {model_id!r}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
print(f"get config {model_id!r}")
config = AutoConfig.from_pretrained(model_id)
config.num_hidden_layers = 2
config.layer_types = config.layer_types[:2]
print(f"create model from config for {model_id!r}")
model = AutoModelForCausalLM.from_config(config)
print(f"the model is created with {len(list(model.named_modules()))} subdmodules.")
model = model.to(device)

# %%
# We need to only export class GptOssExperts
# ++++++++++++++++++++++++++++++++++++++++++


export_module = None
for _name, sub in model.named_modules():
if sub.__class__.__name__ == "GptOssExperts":
export_module = sub

assert export_module is not None, (
f"Unable to find a submodule from class GptOssExperts in "
f"{set(sub.__class__.__name__ for _, sub in model.named_modules())}"
)

# %%
# Let's run the model and capture inputs and outputs


def generate_text(
prompt,
model,
tokenizer,
max_length=50,
temperature=0.01,
top_k=50,
top_p=0.95,
do_sample=True,
):
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)

outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
)

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text


prompt = "Continue: it rains, what should I do?"
observer = InputObserver()
with (
register_additional_serialization_functions(patch_transformers=True),
observer(export_module),
):
generate_text(prompt, model, tokenizer)


# %%
# Export
# ++++++
#
# First, what was inferred.

args = observer.infer_arguments()
dynamic_shapes = observer.infer_dynamic_shapes()
print(f"args={string_type(args, with_shape=True, with_device=True)}")
print(f"dynamic_shapes={dynamic_shapes}")

# %%
# Next, the export.


filename = "plot_export_optimind_experts_input_observer.onnx"
with torch_export_patches(patch_transformers=True):
to_onnx(
export_module,
args=args,
filename=filename,
dynamic_shapes=dynamic_shapes,
exporter="custom",
verbose=1,
)

# %%
# Let's measure the discrepancies.
data = observer.check_discrepancies(filename, progress_bar=True, atol=1e-2, include_io=True)
df = pandas.DataFrame(data)
df.to_excel("plot_export_optimind_input_observer.xlsx")
print(df)

# %%
# Let's show the errors.
for row in data:
if not row["SUCCESS"] and "error" in row:
print(row["error"])


# %%
doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400)
140 changes: 140 additions & 0 deletions _doc/final/plot_export_tiny_llm_attention_input_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""
.. _l-plot-tiny-llm-attention-export-input-observer:

Export attention from arnir0/Tiny-LLM with InputObserver
========================================================

This shows how to only export attention from model
`arnir0/Tiny-LLM <https://huggingface.co/arnir0/Tiny-LLM>`_.
It uses what was shown in example
:ref:`l-plot-tiny-llm-export-input-observer`.

Let's create a random model
+++++++++++++++++++++++++++
"""

import pandas
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from onnx_diagnostic import doc
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import (
register_additional_serialization_functions,
torch_export_patches,
)
from onnx_diagnostic.investigate.input_observer import InputObserver

device = "cuda"
model_id = "arnir0/Tiny-LLM"
print(f"get tokenizer {model_id!r}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
print(f"get config {model_id!r}")
config = AutoConfig.from_pretrained(model_id)
print(f"create model from config for {model_id!r}")
model = AutoModelForCausalLM.from_config(config)
print(f"the model is created with {len(list(model.named_modules()))} subdmodules.")
model = model.to(device).to(torch.float16)

# %%
# We need to only export class LlamaAttention
# +++++++++++++++++++++++++++++++++++++++++++


export_module = None
for _name, sub in model.named_modules():
if sub.__class__.__name__ == "LlamaAttention":
export_module = sub

assert export_module is not None, (
f"Unable to find a submodule from class LlamaAttention in "
f"{set(sub.__class__.__name__ for _, sub in model.named_modules())}"
)

# %%
# Let's run the model and capture the inputs and outputs of the attention part.


def generate_text(
prompt,
model,
tokenizer,
max_length=50,
temperature=0.01,
top_k=50,
top_p=0.95,
do_sample=True,
):
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)

outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
)

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text


prompt = "Continue: it rains, what should I do?"
observer = InputObserver()
with (
register_additional_serialization_functions(patch_transformers=True),
observer(export_module),
):
generate_text(prompt, model, tokenizer)


# %%
# Export
# ++++++
#
# First, what was inferred.

kwargs = observer.infer_arguments()
dynamic_shapes = observer.infer_dynamic_shapes()
print("attention type:", type(export_module))
print(f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}")
print(f"dynamic_shapes={dynamic_shapes}")

# %%
# Next, the export.


filename = "plot_export_tiny_llm_attention_input_observer.onnx"
with torch_export_patches(patch_torch=True, patch_transformers=True):
to_onnx(
export_module,
args=(),
kwargs=kwargs,
filename=filename,
dynamic_shapes=dynamic_shapes,
exporter="custom",
verbose=1,
)

# %%
# Let's measure the discrepancies.
data = observer.check_discrepancies(
filename, progress_bar=True, atol=1e-2, include_io=True, skip_none=True
)
df = pandas.DataFrame(data)
df.to_excel("plot_export_tiny_llm_attention_input_observer.xlsx")
print(df)

# %%
# Let's show the errors.
for row in data:
if not row["SUCCESS"] and "error" in row:
print(row["error"])


# %%
doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400)
Loading
Loading