diff --git a/test/modules/op/to.py b/test/modules/op/to.py index 06cad3d0..4d6fd1ba 100644 --- a/test/modules/op/to.py +++ b/test/modules/op/to.py @@ -85,3 +85,14 @@ def forward(self, x): def get_example_inputs(self): return (torch.randn(1, 3),), {} + + +class SimpleToForCast(TestModuleBase): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.to(torch.int32) + + def get_example_inputs(self): + return (torch.randn(1, 3),), {} diff --git a/test/modules/op/top_k.py b/test/modules/op/top_k.py new file mode 100644 index 00000000..255d0965 --- /dev/null +++ b/test/modules/op/top_k.py @@ -0,0 +1,34 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from test.modules.base import TestModuleBase +from test.utils.tag import use_onert + +# luci-interpreter doesn't support TopK operator yet +@use_onert +class SimpleTopK(TestModuleBase): + def __init__(self): + super().__init__() + + def forward(self, x): + values, indices = torch.topk(x, 2) + return values, indices + + def get_example_inputs(self): + batch_size = 1 + seq_len = 63 + num_experts = 8 + return (torch.randn(batch_size * seq_len, num_experts),), {} diff --git a/tico/passes/legalize_predefined_layout_operators.py b/tico/passes/legalize_predefined_layout_operators.py index 5e9bd690..e8e00f97 100644 --- a/tico/passes/legalize_predefined_layout_operators.py +++ b/tico/passes/legalize_predefined_layout_operators.py @@ -17,6 +17,8 @@ if TYPE_CHECKING: import torch.fx +from operator import getitem + import torch from torch.export import ExportedProgram @@ -26,7 +28,7 @@ from tico.utils.graph import create_node from tico.utils.passes import PassBase, PassResult from tico.utils.trace_decorators import trace_graph_diff_on_pass -from tico.utils.utils import is_target_node +from tico.utils.utils import is_target_node, set_new_meta_val from tico.utils.validate_args_kwargs import ( AvgPool2dArgs, Conv2DArgs, @@ -35,6 +37,7 @@ DequantizePerTensorArgs, InstanceNormArgs, MaxPool2dWithIndicesArgs, + TopKArgs, ) @@ -434,6 +437,49 @@ def legalize_avg_pool2d(self, exported_program, node) -> bool: modified = True return modified + def legalize_top_k(self, exported_program, node) -> bool: + logger = logging.getLogger(__name__) + modified = False + + graph_module = exported_program.graph_module + graph = graph_module.graph + + args = TopKArgs(*node.args, **node.kwargs) # type: ignore[arg-type] + input_ = args.input + k = args.k + dim = args.dim + + if not (dim == -1 or dim == len(extract_shape(input_)) - 1): + raise NotYetSupportedError("Only support dim = -1 (last dimension)") + + with graph.inserting_after(input_): + circle_topk = create_node( + graph, + torch.ops.circle_custom.top_k, + args=(input_, k), + origin=input_, + ) + + with graph.inserting_after(circle_topk): + topk_values = create_node(graph, getitem, args=(circle_topk, 0)) + topk_indices = create_node(graph, getitem, args=(circle_topk, 1)) + with graph.inserting_after(topk_indices): + topk_indices_int64 = create_node( + graph, + torch.ops.aten._to_copy.default, + args=(topk_indices,), + kwargs={"dtype": torch.int64}, + ) + + get_item, get_item_1 = node.users.keys() + get_item.replace_all_uses_with(topk_values, propagate_meta=True) + get_item_1.replace_all_uses_with(topk_indices_int64, propagate_meta=True) + + logger.debug(f"{node.name} is replaced with {circle_topk.name}") + modified = True + + return modified + def call(self, exported_program: ExportedProgram) -> PassResult: target_to_legalize_func = { torch.ops.aten.conv2d.default: self.legalize_conv2d, @@ -442,6 +488,7 @@ def call(self, exported_program: ExportedProgram) -> PassResult: torch.ops.aten.max_pool2d_with_indices.default: self.legalize_max_pool2d_with_indices, torch.ops.aten.avg_pool2d.default: self.legalize_avg_pool2d, torch.ops.aten.instance_norm.default: self.legalize_instance_norm, + torch.ops.aten.topk.default: self.legalize_top_k, } graph_module = exported_program.graph_module diff --git a/tico/serialize/circle_serializer.py b/tico/serialize/circle_serializer.py index 5dd697dc..e1f8057a 100644 --- a/tico/serialize/circle_serializer.py +++ b/tico/serialize/circle_serializer.py @@ -32,6 +32,8 @@ multiple_output_ops = [ torch.ops.aten.split_with_sizes.default, torch.ops.aten.max.dim, + torch.ops.aten.topk.default, + torch.ops.circle_custom.top_k, ] diff --git a/tico/serialize/operators/op_to_copy.py b/tico/serialize/operators/op_to_copy.py index e7561d1b..c2c958cf 100644 --- a/tico/serialize/operators/op_to_copy.py +++ b/tico/serialize/operators/op_to_copy.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, TYPE_CHECKING +from typing import Dict, List, TYPE_CHECKING, Union if TYPE_CHECKING: import torch._ops @@ -20,6 +20,8 @@ import torch from circle_schema import circle +from tico.passes import ops + from tico.serialize.circle_mapping import ( extract_circle_dtype, extract_torch_dtype, @@ -29,12 +31,12 @@ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor from tico.serialize.operators.utils import create_builtin_operator, get_op_index from tico.utils.errors import NotYetSupportedError -from tico.utils.validate_args_kwargs import ToCopyArgs +from tico.utils.validate_args_kwargs import ToCopyArgs, ToDtypeArgs, ToDtypeLayoutArgs @register_node_visitor class ToCopyVisitor(NodeVisitor): - target: List[torch._ops.OpOverload] = [torch.ops.aten._to_copy.default] + target: List[torch._ops.OpOverload] = ops.aten.to_copy def __init__(self, op_codes: Dict[OpCode, int], graph): super().__init__(op_codes, graph) @@ -60,42 +62,55 @@ def define_cast_node( return operator + def parse_args(self, op: torch._ops.OpOverload, args, kwargs): + ret: Union[ToCopyArgs, ToDtypeArgs, ToDtypeLayoutArgs] + if op is torch.ops.aten._to_copy.default: + ret = ToCopyArgs(*args, **kwargs) + elif op is torch.ops.aten.to.dtype: + ret = ToDtypeArgs(*args, **kwargs) + elif op is torch.ops.aten.to.dtype_layout: + ret = ToDtypeLayoutArgs(*args, **kwargs) + else: + raise NotImplementedError(f"Unsupported to_copy/to operator: {op}") + + return ret + def define_node( self, node: torch.fx.Node, ) -> circle.Operator.OperatorT: - supported_kwargs = ["dtype", "device", "layout"] - if not all(k in supported_kwargs for k in node.kwargs): - unsupported_node_kargs = list(node.kwargs.keys()) - for supported_key in supported_kwargs: - if supported_key in node.kwargs: - unsupported_node_kargs.remove(supported_key) - raise NotYetSupportedError( - f"Support only {supported_kwargs} kwargs now. Do not support {unsupported_node_kargs}" - ) - - args = ToCopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type, call-arg] + args = ToCopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type] input = args.input dtype = args.dtype + layout = args.layout + # device is meaningless in circle + + pin_memory = args.pin_memory + non_blocking = args.non_blocking + memory_format = args.memory_format + + if pin_memory is not None: + raise NotYetSupportedError("Do not support pin_memory yet") + if non_blocking is True: + raise NotYetSupportedError("Do not support non_blocking yet") + if memory_format is not None: + raise NotYetSupportedError("Do not support memory_format yet") input_meta = input.meta["val"] # https://pytorch.org/docs/stable/tensor_attributes.html#torch-layout # layout is two types: torch.strided(dense Tensors), torch.sparse_coo(sparse COO Tensors) if "layout" in input.kwargs and input.kwargs["layout"] != input_meta: raise NotYetSupportedError( - f"Only support when node and its input have same layout: (input layout: {input_meta}), (node layout: {node.kwargs['layout']})." + f"Only support when node and its input have same layout: (input layout: {input_meta}), (node layout: {layout})." ) - if dtype is not None: - target_type = node.kwargs["dtype"] - else: - # device and layout are meaningless - target_type = extract_torch_dtype(node) - assert isinstance(target_type, torch.dtype), type(target_type) + if dtype is None: + dtype = extract_torch_dtype(node) + assert isinstance(dtype, torch.dtype), type(dtype) # define cast node in_type: int = extract_circle_dtype(input) - out_type: int = to_circle_dtype(target_type) + out_type: int = to_circle_dtype(dtype) inputs = [input] outputs = [node] operator = self.define_cast_node(inputs, outputs, in_type, out_type) diff --git a/tico/serialize/operators/op_topk.py b/tico/serialize/operators/op_topk.py new file mode 100644 index 00000000..d966423f --- /dev/null +++ b/tico/serialize/operators/op_topk.py @@ -0,0 +1,79 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, TYPE_CHECKING + +if TYPE_CHECKING: + import torch.fx +import torch +from circle_schema import circle + +from tico.serialize.circle_graph import CircleSubgraph +from tico.serialize.circle_mapping import ( + circle_legalize_dtype_to, + extract_circle_shape, + extract_shape, + extract_torch_dtype, +) +from tico.serialize.operators.hashable_opcode import OpCode +from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor +from tico.serialize.operators.utils import create_builtin_operator, get_op_index +from tico.utils.validate_args_kwargs import TopKArgs + + +@register_node_visitor +class TopkVisitor(NodeVisitor): + """ """ + + target: List[torch._ops.OpOverload] = [ + torch.ops.circle_custom.top_k, + ] + + def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): + super().__init__(op_codes, graph) + + def define_topk_node( + self, inputs: List, outputs: List + ) -> circle.Operator.OperatorT: + op_index = get_op_index( + circle.BuiltinOperator.BuiltinOperator.TOPK_V2, self._op_codes + ) + + operator = create_builtin_operator(self.graph, op_index, inputs, outputs) + + operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.TopKV2Options + option = circle.TopKV2Options.TopKV2OptionsT() + operator.builtinOptions = option + + return operator + + def define_node( + self, + node: torch.fx.Node, + ) -> circle.Operator.OperatorT: + args = TopKArgs(*node.args, **node.kwargs) # type: ignore[arg-type] + input = args.input + k = args.k + + input_shape = extract_circle_shape(input) + k_i32 = circle_legalize_dtype_to(k, dtype=torch.int32) + assert args.dim == -1 or args.dim == len(input_shape) - 1 + + inputs = [input, k_i32] + + outputs = [i for i in node.users.keys()] + + topk_node: circle.Operator.OperatorT = self.define_topk_node(inputs, outputs) + + return topk_node diff --git a/tico/utils/graph.py b/tico/utils/graph.py index f0905334..d5d30d36 100644 --- a/tico/utils/graph.py +++ b/tico/utils/graph.py @@ -16,10 +16,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING +from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING if TYPE_CHECKING: import torch.fx +from operator import getitem + import torch from torch.export import ExportedProgram from torch.export.exported_program import InputKind, InputSpec, TensorArgument @@ -238,7 +240,7 @@ def get_module_name_chain(node: Optional[torch.fx.Node]) -> str: def create_node( graph: torch.fx.Graph, - target: torch._ops.OpOverload, + target: Callable, args: Optional[Tuple[Any, ...]] = None, kwargs: Optional[Dict[str, Any]] = None, *, @@ -252,7 +254,7 @@ def create_node( graph : torch.fx.Graph The graph that will own the newly-created node. - target : torch._ops.OpOverload + target : Callable The op to call (e.g. `torch.add` or "call_function" target). args : Tuple[Any, ...], optional diff --git a/tico/utils/register_custom_op.py b/tico/utils/register_custom_op.py index 48372e0c..abceb0e9 100644 --- a/tico/utils/register_custom_op.py +++ b/tico/utils/register_custom_op.py @@ -19,6 +19,7 @@ from torch.library import custom_op, register_fake from tico.utils.mx.mx_ops import _quantize_mx +from tico.utils.validate_args_kwargs import TopKArgs # Note that an operator assumes input tensor has NHWC format. def CircleResizeNearestNeighbor(): @@ -662,6 +663,50 @@ def _( return input.new_empty(input.size()) +def CircleTopK(): + @custom_op( + "circle_custom::top_k", + mutates_args=(), + schema="(Tensor input, int k) -> (Tensor, Tensor)", + ) + def top_k( + input: torch.Tensor, + k: int, + dim: int = -1, + largest: bool = True, + sorted: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert (dim == -1) or (dim == len(input.size()) - 1) + assert largest is True + assert sorted is True + + topk_out_0, topk_out_1 = torch.ops.aten.topk.default(input, k, dim) + topk_out_1_int32 = torch.ops.aten.to.dtype(topk_out_1, dtype=torch.int32) + + return ( + topk_out_0, + topk_out_1_int32, + ) + + @register_fake("circle_custom::top_k") + def _( + input: FakeTensor, + k: int, + dim: int = -1, + largest: bool = True, + sorted: bool = True, + ) -> tuple[FakeTensor, FakeTensor]: + assert (dim == -1) or (dim == len(input.size()) - 1) + assert largest is True + assert sorted is True + topk_out0, topk_out1 = torch.ops.aten.topk.default(input, k, dim) + + return ( + topk_out0, + topk_out1.new_empty(size=topk_out1.size(), dtype=torch.int32), + ) + + def CircleQuantizeMX(): # This operator conducts fake-quantization of microscaling # NOTE Why using "quantize"_mx not "fake_quantize"_mx? @@ -738,3 +783,4 @@ def RegisterOps(): CircleInstanceNorm() CircleQuantizeMX() CircleRMSNorm() + CircleTopK() diff --git a/tico/utils/validate_args_kwargs.py b/tico/utils/validate_args_kwargs.py index f6badfff..a2567da6 100644 --- a/tico/utils/validate_args_kwargs.py +++ b/tico/utils/validate_args_kwargs.py @@ -1174,6 +1174,24 @@ class ToDtypeLayoutArgs: memory_format: Optional[torch.memory_format] = None +@enforce_type +@dataclass +class TopKArgs: + """ + topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) + """ + + input: torch.fx.Node + k: int + dim: int = -1 + largest: bool = True + sorted: bool = True + + def __post_init__(self): + assert self.largest is True, "Only support largest=True" + assert self.sorted is True, "Only support sorted=True" + + @enforce_type @dataclass class UnSqueezeArgs: