diff --git a/src/xtc/graphs/xtc/builder.py b/src/xtc/graphs/xtc/builder.py index 1f8580c1b..b66952dd7 100644 --- a/src/xtc/graphs/xtc/builder.py +++ b/src/xtc/graphs/xtc/builder.py @@ -6,6 +6,10 @@ from .graph import XTCGraph from .context import XTCGraphContext +from .expr import XTCTensorExpr +from .operators import XTCOperator +from . import op_factory +from yaml import safe_load class graph_builder: @@ -25,3 +29,51 @@ def __exit__(self, *_: Any) -> None: def graph(self) -> XTCGraph: assert self._graph is not None, "can't get graph inside builder context" return self._graph + + @classmethod + def from_dict(cls, graph_dict: dict[str, Any]) -> Any: + def tuplify(obj: Any) -> Any: + if isinstance(obj, dict): + return {k: tuplify(v) for k, v in obj.items()} + elif isinstance(obj, list): + return tuple(tuplify(v) for v in obj) + else: + return obj + + expr_uid_map = {} + if "name" in graph_dict: + XTCGraphContext.name(graph_dict["name"]) + + for inp in graph_dict["inputs"]: + expr_uid_map[inp["uid"]] = XTCTensorExpr.from_dict(inp["expr"]) + + for node in graph_dict["nodes"]: + expr = node["expr"] + args = [expr_uid_map.get(arg) for arg in expr["args"]] + if "name" in node: + args.append(node["name"]) + if not hasattr(op_factory, expr["op"]["name"]): + version_mismatch = ( + "version mismatch detected, " + if XTCOperator.version_string() != graph_dict["ops_version"] + else "" + ) + raise ValueError( + version_mismatch + + f"serialized op {expr['op']['name']} is not implemented." + ) + op_func = getattr(op_factory, expr["op"]["name"]) + expr_uid_map[node["uid"]] = op_func(*args, **tuplify(expr["op"]["attrs"])) + + outputs = [expr_uid_map[out["uid"]] for out in graph_dict["outputs"]] + XTCGraphContext.outputs(*outputs) + return graph_dict + + @classmethod + def loads(cls, yaml_str: str) -> None: + cls.from_dict(safe_load(yaml_str)) + + @classmethod + def load(cls, file_name: str) -> None: + with open(file_name, "r") as f: + cls.from_dict(safe_load(f)) diff --git a/src/xtc/graphs/xtc/context.py b/src/xtc/graphs/xtc/context.py index e7e263955..47be81f1e 100644 --- a/src/xtc/graphs/xtc/context.py +++ b/src/xtc/graphs/xtc/context.py @@ -119,5 +119,8 @@ def outputs(self, *outs: XTCExpr) -> None: def inputs(self, *inps: XTCExpr) -> None: return self.current.add_inputs(*inps) + def name(self, name_str: str) -> None: + self.current._graph_kwargs["name"] = name_str + XTCGraphContext = XTCGraphScopes() diff --git a/src/xtc/graphs/xtc/data.py b/src/xtc/graphs/xtc/data.py index 2451de8a9..0e763d08a 100644 --- a/src/xtc/graphs/xtc/data.py +++ b/src/xtc/graphs/xtc/data.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024-2026 The XTC Project Authors # -from typing_extensions import override -from typing import cast, Any +from typing_extensions import override, Self +from typing import cast, Any, Iterable import functools import operator import numpy as np @@ -108,6 +108,23 @@ def __eq__(self, other: object) -> bool: return NotImplemented return self.dtype == other.dtype and self.shape == other.shape + @override + def to_dict(self) -> dict[str, Any]: + tensor_dict: dict[str, Any] = {} + if self.shape: + tensor_dict["shape"] = list(cast(Iterable[int], self.shape)) + if self.dtype: + tensor_dict["dtype"] = self.dtype + return tensor_dict + + @override + @classmethod + def from_dict(cls, tensor_dict: dict[str, Any]) -> Self: + return cls( + shape=tuple(tensor_dict["shape"]) if "shape" in tensor_dict else None, + dtype=tensor_dict["dtype"] if "dtype" in tensor_dict else None, + ) + class XTCConstantTensorType(XTCTensorType, ConstantTensorType): def __init__(self, shape: ConstantShapeType, dtype: ConstantDataType): diff --git a/src/xtc/graphs/xtc/expr.py b/src/xtc/graphs/xtc/expr.py index e4ce0accd..3a58bc928 100644 --- a/src/xtc/graphs/xtc/expr.py +++ b/src/xtc/graphs/xtc/expr.py @@ -4,7 +4,7 @@ # from abc import ABC, abstractmethod from collections.abc import Sequence -from typing_extensions import override +from typing_extensions import override, Self from typing import Any, TypeAlias import threading @@ -94,6 +94,10 @@ def uid(self) -> str: def __str__(self) -> str: return f"{self.uid} = ?" + @abstractmethod + def to_dict(self) -> dict[str, Any]: + return {"uid": self.uid} + class XTCValueExpr(XTCExpr): @property @@ -180,6 +184,15 @@ def __str__(self) -> str: args = ", ".join([arg.uid for arg in self.args]) return f"{self.uid} = {self.op_name}({args})" + @override + def to_dict(self) -> dict[str, Any]: + return {"type": self.type.to_dict()} + + @classmethod + def from_dict(cls, op_dict: dict[str, Any]) -> Self: + type = XTCTensorType.from_dict(op_dict["type"]) + return cls(tensor=type) + class XTCOpExpr(XTCExpr): def __init__(self, op: XTCOperator, args: ArgumentsType) -> None: @@ -220,6 +233,10 @@ def __str__(self) -> str: args = ", ".join(params) return f"{self.uid} = {self.op_name}({args})" + @override + def to_dict(self) -> dict[str, Any]: + return {"op": self._op.to_dict(), "args": [a.uid for a in self.args]} + class XTCMatmulExpr(XTCOpExpr): def __init__(self, x: XTCExpr, y: XTCExpr, **attrs: Any) -> None: diff --git a/src/xtc/graphs/xtc/graph.py b/src/xtc/graphs/xtc/graph.py index b8c54d056..8ce42170c 100644 --- a/src/xtc/graphs/xtc/graph.py +++ b/src/xtc/graphs/xtc/graph.py @@ -4,7 +4,7 @@ # from typing_extensions import override from collections.abc import Sequence, Mapping -from typing import TypeAlias, cast +from typing import TypeAlias, cast, Any from xtc.itf.graph import Graph from xtc.itf.data import TensorType, Tensor @@ -12,6 +12,8 @@ from .node import XTCNode from .utils import XTCGraphUtils from .data import XTCTensor, XTCTensorType +from .operators import XTCOperator +from yaml import dump as yaml_dump, safe_dump __all__ = [ "XTCGraph", @@ -152,3 +154,31 @@ def __str__(self) -> str: else: graph_str += " nodes: {}\n" return graph_str + + def to_dict(self) -> dict[str, Any]: + graph_dict: dict[str, Any] = {} + if self._name: + graph_dict["name"] = self._name + graph_dict["inputs"] = [i.to_dict() for i in self._inputs] + graph_dict["outputs"] = [{"uid": o.to_dict()["uid"]} for o in self._outputs] + graph_dict["nodes"] = [n.to_dict() for n in self._nodes] + graph_dict["ops_version"] = XTCOperator.version_string() + lowest_uid = int(graph_dict["inputs"][0]["uid"][1:]) + + def compact_uids(obj: Any) -> Any: + if isinstance(obj, dict): + obj = {k: compact_uids(v) for k, v in obj.items()} + elif isinstance(obj, list): + obj = [compact_uids(v) for v in obj] + elif isinstance(obj, str) and obj[0] == "%": + obj = f"%{int(obj[1:]) - lowest_uid}" + return obj + + return compact_uids(graph_dict) + + def dumps(self) -> str: + return safe_dump(self.to_dict()) + + def dump(self, file_name: str) -> None: + with open(file_name, "w") as f: + yaml_dump(self.to_dict(), f, sort_keys=False) diff --git a/src/xtc/graphs/xtc/node.py b/src/xtc/graphs/xtc/node.py index 97491171a..c2ab160fa 100644 --- a/src/xtc/graphs/xtc/node.py +++ b/src/xtc/graphs/xtc/node.py @@ -4,7 +4,7 @@ # from typing_extensions import override from collections.abc import Sequence -from typing import cast +from typing import cast, Any from xtc.itf.graph import Node from xtc.itf.operator import Operator @@ -156,3 +156,12 @@ def __str__(self) -> str: if self.inputs_types is not None and self.outputs_types is not None: type_str = f" : {self.inputs_types} -> {self.outputs_types}" return str(self._expr).split("=", 1)[1].strip() + attrs_str + type_str + + @override + def to_dict(self) -> dict[str, Any]: + node_dict: dict[str, Any] = {} + if self.uid != self.name: + node_dict["name"] = self.name + node_dict["uid"] = self.uid + node_dict["expr"] = self._expr.to_dict() + return node_dict diff --git a/src/xtc/graphs/xtc/operators.py b/src/xtc/graphs/xtc/operators.py index 1f65bd17e..d157db596 100644 --- a/src/xtc/graphs/xtc/operators.py +++ b/src/xtc/graphs/xtc/operators.py @@ -9,6 +9,7 @@ import functools import operator import numpy as np +import hashlib from xtc.itf.operator import Operator from xtc.itf.data import Tensor, TensorType @@ -86,6 +87,26 @@ def _get_operation( outs_maps=outs_maps, ) + @override + def to_dict(self) -> dict[str, Any]: + def listify(obj: Any) -> Any: + if isinstance(obj, dict): + return {k: listify(v) for k, v in obj.items()} + elif isinstance(obj, tuple): + return [listify(v) for v in obj] + else: + return obj + + op_dict: dict[str, Any] = {"name": self._name} + op_dict["attrs"] = {k: listify(v) for k, v in self._attrs.__dict__.items()} + return op_dict + + @classmethod + def version_string(cls) -> str: + names = sorted(f"{c.__module__}.{c.__qualname__}" for c in cls.__subclasses__()) + data = "|".join(names).encode() + return hashlib.sha256(data).hexdigest()[:16] + class XTCOperTensor(XTCOperator): def __init__(self) -> None: diff --git a/src/xtc/itf/data/tensor.py b/src/xtc/itf/data/tensor.py index 9695a0fd8..5144c0033 100644 --- a/src/xtc/itf/data/tensor.py +++ b/src/xtc/itf/data/tensor.py @@ -4,7 +4,7 @@ # from abc import ABC, abstractmethod from typing import Any, TypeAlias -from typing_extensions import override +from typing_extensions import override, Self import numpy.typing @@ -53,6 +53,13 @@ def ndim(self) -> int: """ ... + @abstractmethod + def to_dict(self) -> dict[str, Any]: ... + + @classmethod + @abstractmethod + def from_dict(cls, tensor_dict: dict[str, Any]) -> Self: ... + class ConstantTensorType(TensorType): @property diff --git a/src/xtc/itf/graph/node.py b/src/xtc/itf/graph/node.py index 32c958c48..4bd31abd4 100644 --- a/src/xtc/itf/graph/node.py +++ b/src/xtc/itf/graph/node.py @@ -162,3 +162,6 @@ def forward(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]: List of output tensors """ ... + + @abstractmethod + def to_dict(self) -> dict[str, str | Sequence[TensorType]]: ... diff --git a/src/xtc/itf/operator/operator.py b/src/xtc/itf/operator/operator.py index 7ce87c8e6..e81ce31e4 100644 --- a/src/xtc/itf/operator/operator.py +++ b/src/xtc/itf/operator/operator.py @@ -3,6 +3,7 @@ # Copyright (c) 2024-2026 The XTC Project Authors # from abc import ABC, abstractmethod +from typing import Any from collections.abc import Sequence from ..data.tensor import Tensor, TensorType @@ -48,3 +49,6 @@ def forward(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]: List of output tensors """ ... + + @abstractmethod + def to_dict(self) -> dict[str, Any]: ... diff --git a/tests/pytest/graphs/test_graph_serialization.py b/tests/pytest/graphs/test_graph_serialization.py new file mode 100644 index 000000000..3aecd687a --- /dev/null +++ b/tests/pytest/graphs/test_graph_serialization.py @@ -0,0 +1,66 @@ +import tempfile +import xtc.graphs.xtc.op as O + +def test_matmul_relu_to_from_dict(): + I, J, K, dtype = 4, 32, 512, "float32" + a = O.tensor((I, K), dtype, name="A") + b = O.tensor((K, J), dtype, name="B") + + with O.graph(name="matmul_relu") as gb: + m = O.matmul(a, b, name="matmul") + O.relu(m, name="relu") + + graph_dict = gb.graph.to_dict() + with O.graph() as gb2: + gb2.from_dict(graph_dict) + assert graph_dict != {} + assert graph_dict == gb2.graph.to_dict() + + +def test_conv2d_pad_sdump_sload(): + N, H, W, F, R, S, C, SH, SW, dtype = 1, 8, 8, 16, 5, 5, 3, 2, 2, "float32" + a = O.tensor((N, H, W, C), dtype, name="I") + b = O.tensor((R, S, C, F), dtype, name="W") + + with O.graph(name="pad_conv2d_nhwc_mini") as gb: + p = O.pad(a, padding={1: (2), 2: (2, 2)}, name="pad") + O.conv2d(p, b, stride=(SH, SW), name="conv") + + graph_str = gb.graph.dumps() + with O.graph(name="matmul_relu") as gb2: + gb2.loads(graph_str) + assert graph_str != "" + assert graph_str == gb2.graph.dumps() + + with tempfile.NamedTemporaryFile(mode="w+", delete=True) as f: + gb.graph.dump(f.name) + with O.graph() as gb3: + gb3.load(f.name) + assert gb.graph.to_dict() == gb3.graph.to_dict() + +def test_mlp_fc_custom_output(): + img = O.tensor() + w1 = O.tensor() + w2 = O.tensor() + w3 = O.tensor() + w4 = O.tensor() + fc = lambda i, w, nout: O.matmul(O.reshape(i, shape=(1, -1)), O.reshape(w, shape=(-1, nout))) + # Multi Layer Perceptron with 3 relu(fc) + 1 fc + with O.graph(name="mlp4") as gb: + with O.graph(name="l1"): + l1 = O.relu(fc(img, w1, 512)) + with O.graph(name="l2"): + l2 = O.relu(fc(l1, w2, 256)) + with O.graph(name="l3"): + l3 = O.relu(fc(l2, w3, 128)) + with O.graph(name="l4"): + l4 = fc(l3, w4, 10) + O.reshape(l4, shape=(-1,)) + O.outputs(l1) + + graph_dict = gb.graph.to_dict() + with O.graph() as gb2: + gb2.from_dict(graph_dict) + assert graph_dict != {} + assert graph_dict == gb2.graph.to_dict() +