From 73c143f32dfafc35983c295b9a12ccdf0931ccd1 Mon Sep 17 00:00:00 2001 From: Liam Semeria Date: Mon, 2 Mar 2026 13:55:59 +0100 Subject: [PATCH 01/10] to_dict, from_dict recursive way --- src/xtc/graphs/xtc/builder.py | 40 +++++++++++++++++++++++++++++++- src/xtc/graphs/xtc/data.py | 13 ++++++++++- src/xtc/graphs/xtc/expr.py | 29 ++++++++++++++++++++++- src/xtc/graphs/xtc/graph.py | 23 +++++++++++++++++- src/xtc/graphs/xtc/node.py | 21 +++++++++++++++-- src/xtc/graphs/xtc/operators.py | 19 ++++++++++++++- src/xtc/itf/data/tensor.py | 11 ++++++++- src/xtc/itf/graph/node.py | 11 +++++++++ src/xtc/itf/operator/operator.py | 10 ++++++++ 9 files changed, 169 insertions(+), 8 deletions(-) diff --git a/src/xtc/graphs/xtc/builder.py b/src/xtc/graphs/xtc/builder.py index 1f8580c1b..ae1aa5976 100644 --- a/src/xtc/graphs/xtc/builder.py +++ b/src/xtc/graphs/xtc/builder.py @@ -6,7 +6,8 @@ from .graph import XTCGraph from .context import XTCGraphContext - +from .expr import XTCTensorExpr +from . import op_factory class graph_builder: def __init__(self, **graph_kwargs: Any) -> None: @@ -25,3 +26,40 @@ def __exit__(self, *_: Any) -> None: def graph(self) -> XTCGraph: assert self._graph is not None, "can't get graph inside builder context" return self._graph + + @classmethod + def from_dict(self, graph_dict: dict[str, Any]) -> Any: + XTCGraphContext.push() + # TODO: scan for names, store in context _names + # start at deepest expr than use factory to make OpExprs + def build(obj: Any) -> Any: + if isinstance(obj, dict): + if "op" in obj: + args = build(obj["args"]) + func_name: str = obj["op"]["name"] + op_func = getattr(op_factory, func_name) + print(f"calling {func_name}") + obj = op_func(*args, **obj["op"]["attrs"]) + return obj + elif "idx" in obj: + obj = XTCTensorExpr.from_dict(obj) + return obj + else: + print(obj) + return None + elif isinstance(obj, list): + obj = [build(o) for o in obj] + return obj + else: + print(f"wtf is {obj}") + return None + + + for out in graph_dict["outputs"]: + build(out["expr"]) + scope = XTCGraphContext.pop() + print(scope.graph.inputs) + for inp in scope.graph._inputs: + print(inp.uid) + print(scope.graph) + return graph_dict diff --git a/src/xtc/graphs/xtc/data.py b/src/xtc/graphs/xtc/data.py index 2451de8a9..825e87470 100644 --- a/src/xtc/graphs/xtc/data.py +++ b/src/xtc/graphs/xtc/data.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024-2026 The XTC Project Authors # -from typing_extensions import override +from typing_extensions import override, Self from typing import cast, Any import functools import operator @@ -108,6 +108,17 @@ def __eq__(self, other: object) -> bool: return NotImplemented return self.dtype == other.dtype and self.shape == other.shape + @override + def to_dict(self) -> dict[str, Any]: + return { + "shape": self.shape, + "dtype": self.dtype, + } + + @override + @classmethod + def from_dict(cls, tensor_dict: dict[str, Any]) -> Self: + return cls(shape = tensor_dict["shape"], dtype = tensor_dict["dtype"]) class XTCConstantTensorType(XTCTensorType, ConstantTensorType): def __init__(self, shape: ConstantShapeType, dtype: ConstantDataType): diff --git a/src/xtc/graphs/xtc/expr.py b/src/xtc/graphs/xtc/expr.py index e4ce0accd..8d24e0633 100644 --- a/src/xtc/graphs/xtc/expr.py +++ b/src/xtc/graphs/xtc/expr.py @@ -4,7 +4,7 @@ # from abc import ABC, abstractmethod from collections.abc import Sequence -from typing_extensions import override +from typing_extensions import override, Self from typing import Any, TypeAlias import threading @@ -94,6 +94,15 @@ def uid(self) -> str: def __str__(self) -> str: return f"{self.uid} = ?" + @abstractmethod + def to_dict(self) -> dict[str, Any]: + return {"idx" : self._idx} + + @classmethod + @abstractmethod + def from_dict(cls, node_dict: dict[str, Any]) -> Self: + ... + class XTCValueExpr(XTCExpr): @property @@ -180,6 +189,15 @@ def __str__(self) -> str: args = ", ".join([arg.uid for arg in self.args]) return f"{self.uid} = {self.op_name}({args})" + @override + def to_dict(self) -> dict[str, Any]: + return {"idx": self._idx, "type": self.type.to_dict()} + + @classmethod + @override + def from_dict(cls, node_dict: dict[str, Any]) -> Self: + type = XTCTensorType.from_dict(node_dict["type"]) + return cls(tensor=type) class XTCOpExpr(XTCExpr): def __init__(self, op: XTCOperator, args: ArgumentsType) -> None: @@ -220,6 +238,15 @@ def __str__(self) -> str: args = ", ".join(params) return f"{self.uid} = {self.op_name}({args})" + @override + def to_dict(self) -> dict[str, Any]: + return {"idx": self._idx, "op": self._op.to_dict(), "args": [a.to_dict() for a in self.args]} + + @override + @classmethod + def from_dict(cls, node_dict: dict[str, Any]) -> Self: + # should not be called on just XTCOpExpr + return cls(*node_dict["args"], **node_dict["attrs"]) class XTCMatmulExpr(XTCOpExpr): def __init__(self, x: XTCExpr, y: XTCExpr, **attrs: Any) -> None: diff --git a/src/xtc/graphs/xtc/graph.py b/src/xtc/graphs/xtc/graph.py index b8c54d056..35479d747 100644 --- a/src/xtc/graphs/xtc/graph.py +++ b/src/xtc/graphs/xtc/graph.py @@ -4,7 +4,7 @@ # from typing_extensions import override from collections.abc import Sequence, Mapping -from typing import TypeAlias, cast +from typing import TypeAlias, cast, Any from xtc.itf.graph import Graph from xtc.itf.data import TensorType, Tensor @@ -152,3 +152,24 @@ def __str__(self) -> str: else: graph_str += " nodes: {}\n" return graph_str + +# TODO: to_dict(), from_dict() + + def to_dict(self) -> dict[str, Any]: + # fold nodes into just uids + for n in self._nodes: + node_dict = n.to_dict() + print(node_dict["expr"]) + return { + #"inputs": [i.to_dict() for i in self._inputs], + "outputs": [o.to_dict() for o in self._outputs], + #"nodes": [n.to_dict() for n in self._nodes] + } + def from_dict(self, graph_dict: dict[str, Any]): + # create inputs exprs + for ins in graph_dict["inputs"]: + self._inputs.append(ins.from_dict()) + # set inputs_types here + pass +# TODO: sdump(), sload() +# TODO: dump(), load() (yaml) diff --git a/src/xtc/graphs/xtc/node.py b/src/xtc/graphs/xtc/node.py index 97491171a..7db96a6aa 100644 --- a/src/xtc/graphs/xtc/node.py +++ b/src/xtc/graphs/xtc/node.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024-2026 The XTC Project Authors # -from typing_extensions import override +from typing_extensions import override, Self from collections.abc import Sequence -from typing import cast +from typing import cast, Any from xtc.itf.graph import Node from xtc.itf.operator import Operator @@ -156,3 +156,20 @@ def __str__(self) -> str: if self.inputs_types is not None and self.outputs_types is not None: type_str = f" : {self.inputs_types} -> {self.outputs_types}" return str(self._expr).split("=", 1)[1].strip() + attrs_str + type_str + + @override + def to_dict(self) -> dict[str, str | Sequence[TensorType]]: + node_dict = {} + if self.inputs_types and False: + node_dict["input_types"] = [t.to_dict() for t in self.inputs_types] + if self.outputs_types and False: + node_dict["output_types"] = [t.to_dict() for t in self.outputs_types] + if self.uid != self.name and False: + node_dict["name"] = self.name + node_dict["expr"] = self._expr.to_dict() + return node_dict + + @override + @classmethod + def from_dict(cls, node_dict: dict[str, Any]) -> Self: + return cls(node_dict["expr"]) diff --git a/src/xtc/graphs/xtc/operators.py b/src/xtc/graphs/xtc/operators.py index 1f65bd17e..b7dc391f3 100644 --- a/src/xtc/graphs/xtc/operators.py +++ b/src/xtc/graphs/xtc/operators.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024-2026 The XTC Project Authors # -from typing_extensions import override +from typing_extensions import override, Self from typing import TypeAlias, cast, Any from types import SimpleNamespace as NS from collections.abc import Sequence, Mapping @@ -86,6 +86,23 @@ def _get_operation( outs_maps=outs_maps, ) + @override + def to_dict(self) -> dict[str, Any]: + op_dict = {"name" : self._name} + #def get_attr(obj: Any): + # if isinstance(obj, dict): + # return {k: get_attr(v) for k, v in obj.items()} + # elif isinstance(obj, (list, tuple, set)): + # return [get_attr(v) for v in obj] + # else: + # return obj + op_dict["attrs"] = {k: v for k,v in self._attrs.__dict__.items()} + return op_dict + + @classmethod + @override + def from_dict(cls, op_dict: dict[str, Any]) -> Self: + return cls() class XTCOperTensor(XTCOperator): def __init__(self) -> None: diff --git a/src/xtc/itf/data/tensor.py b/src/xtc/itf/data/tensor.py index 9695a0fd8..8cb244b4e 100644 --- a/src/xtc/itf/data/tensor.py +++ b/src/xtc/itf/data/tensor.py @@ -4,7 +4,7 @@ # from abc import ABC, abstractmethod from typing import Any, TypeAlias -from typing_extensions import override +from typing_extensions import override, Self import numpy.typing @@ -53,6 +53,15 @@ def ndim(self) -> int: """ ... + @abstractmethod + def to_dict(self) -> dict[str, Any]: + ... + + @classmethod + @abstractmethod + def from_dict(cls, tensor_dict :dict[str, Any]) -> Self: + ... + class ConstantTensorType(TensorType): @property diff --git a/src/xtc/itf/graph/node.py b/src/xtc/itf/graph/node.py index 32c958c48..ad4f1f0b5 100644 --- a/src/xtc/itf/graph/node.py +++ b/src/xtc/itf/graph/node.py @@ -3,6 +3,8 @@ # Copyright (c) 2024-2026 The XTC Project Authors # from abc import ABC, abstractmethod +from typing import Any +from typing_extensions import Self from collections.abc import Sequence from ..operator.operator import Operator from ..data import TensorType, Tensor @@ -162,3 +164,12 @@ def forward(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]: List of output tensors """ ... + + @abstractmethod + def to_dict(self) -> dict[str, str | Sequence[TensorType]]: + ... + + @classmethod + @abstractmethod + def from_dict(cls, node_dict: dict[str, Any]) -> Self: + ... diff --git a/src/xtc/itf/operator/operator.py b/src/xtc/itf/operator/operator.py index 7ce87c8e6..5bbb2dc3a 100644 --- a/src/xtc/itf/operator/operator.py +++ b/src/xtc/itf/operator/operator.py @@ -3,6 +3,8 @@ # Copyright (c) 2024-2026 The XTC Project Authors # from abc import ABC, abstractmethod +from typing import Any +from typing_extensions import Self from collections.abc import Sequence from ..data.tensor import Tensor, TensorType @@ -48,3 +50,11 @@ def forward(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]: List of output tensors """ ... + + @abstractmethod + def to_dict(self) -> dict[str, Any]: + ... + @classmethod + @abstractmethod + def from_dict(cls, op_dict: dict[str, Any]) -> Self: + ... From d71cd3147f193929a8a1e791a31338166da7011b Mon Sep 17 00:00:00 2001 From: Liam Semeria Date: Mon, 2 Mar 2026 15:43:18 +0100 Subject: [PATCH 02/10] serealization: to and from dict --- src/xtc/graphs/xtc/builder.py | 45 ++++++++++++----------------------- src/xtc/graphs/xtc/expr.py | 6 ++--- src/xtc/graphs/xtc/graph.py | 19 ++++----------- src/xtc/graphs/xtc/node.py | 9 +++---- 4 files changed, 25 insertions(+), 54 deletions(-) diff --git a/src/xtc/graphs/xtc/builder.py b/src/xtc/graphs/xtc/builder.py index ae1aa5976..d926c2f56 100644 --- a/src/xtc/graphs/xtc/builder.py +++ b/src/xtc/graphs/xtc/builder.py @@ -4,6 +4,8 @@ # from typing import Any +from xtc.graphs.xtc.node import XTCNode + from .graph import XTCGraph from .context import XTCGraphContext from .expr import XTCTensorExpr @@ -28,38 +30,21 @@ def graph(self) -> XTCGraph: return self._graph @classmethod - def from_dict(self, graph_dict: dict[str, Any]) -> Any: + def from_dict(cls, graph_dict: dict[str, Any]) -> Any: XTCGraphContext.push() - # TODO: scan for names, store in context _names - # start at deepest expr than use factory to make OpExprs - def build(obj: Any) -> Any: - if isinstance(obj, dict): - if "op" in obj: - args = build(obj["args"]) - func_name: str = obj["op"]["name"] - op_func = getattr(op_factory, func_name) - print(f"calling {func_name}") - obj = op_func(*args, **obj["op"]["attrs"]) - return obj - elif "idx" in obj: - obj = XTCTensorExpr.from_dict(obj) - return obj - else: - print(obj) - return None - elif isinstance(obj, list): - obj = [build(o) for o in obj] - return obj - else: - print(f"wtf is {obj}") - return None - - for out in graph_dict["outputs"]: - build(out["expr"]) + expr_uid_map = {} + for inp in graph_dict["inputs"]: + expr_uid_map[inp["uid"]] = XTCTensorExpr.from_dict(inp["expr"]) + + for node in graph_dict["nodes"]: + expr = node["expr"] + args = [expr_uid_map.get(arg) for arg in expr["args"]] + if "name" in node: + args.append(node["name"]) + op_func = getattr(op_factory, expr["op"]["name"]) + expr_uid_map[node["uid"]] = op_func(*args, **expr["op"]["attrs"]) + scope = XTCGraphContext.pop() - print(scope.graph.inputs) - for inp in scope.graph._inputs: - print(inp.uid) print(scope.graph) return graph_dict diff --git a/src/xtc/graphs/xtc/expr.py b/src/xtc/graphs/xtc/expr.py index 8d24e0633..e8f2c06a8 100644 --- a/src/xtc/graphs/xtc/expr.py +++ b/src/xtc/graphs/xtc/expr.py @@ -96,7 +96,7 @@ def __str__(self) -> str: @abstractmethod def to_dict(self) -> dict[str, Any]: - return {"idx" : self._idx} + return {"uid" : self.uid} @classmethod @abstractmethod @@ -191,7 +191,7 @@ def __str__(self) -> str: @override def to_dict(self) -> dict[str, Any]: - return {"idx": self._idx, "type": self.type.to_dict()} + return {"type": self.type.to_dict()} @classmethod @override @@ -240,7 +240,7 @@ def __str__(self) -> str: @override def to_dict(self) -> dict[str, Any]: - return {"idx": self._idx, "op": self._op.to_dict(), "args": [a.to_dict() for a in self.args]} + return {"op": self._op.to_dict(), "args": [a.uid for a in self.args]} @override @classmethod diff --git a/src/xtc/graphs/xtc/graph.py b/src/xtc/graphs/xtc/graph.py index 35479d747..3f29f3825 100644 --- a/src/xtc/graphs/xtc/graph.py +++ b/src/xtc/graphs/xtc/graph.py @@ -153,23 +153,12 @@ def __str__(self) -> str: graph_str += " nodes: {}\n" return graph_str -# TODO: to_dict(), from_dict() - def to_dict(self) -> dict[str, Any]: - # fold nodes into just uids - for n in self._nodes: - node_dict = n.to_dict() - print(node_dict["expr"]) return { - #"inputs": [i.to_dict() for i in self._inputs], - "outputs": [o.to_dict() for o in self._outputs], - #"nodes": [n.to_dict() for n in self._nodes] + "inputs": [i.to_dict() for i in self._inputs], + #"outputs": [o.to_dict() for o in self._outputs], + "nodes": [n.to_dict() for n in self._nodes] } - def from_dict(self, graph_dict: dict[str, Any]): - # create inputs exprs - for ins in graph_dict["inputs"]: - self._inputs.append(ins.from_dict()) - # set inputs_types here - pass + # TODO: sdump(), sload() # TODO: dump(), load() (yaml) diff --git a/src/xtc/graphs/xtc/node.py b/src/xtc/graphs/xtc/node.py index 7db96a6aa..bf1a4dd4e 100644 --- a/src/xtc/graphs/xtc/node.py +++ b/src/xtc/graphs/xtc/node.py @@ -160,16 +160,13 @@ def __str__(self) -> str: @override def to_dict(self) -> dict[str, str | Sequence[TensorType]]: node_dict = {} - if self.inputs_types and False: - node_dict["input_types"] = [t.to_dict() for t in self.inputs_types] - if self.outputs_types and False: - node_dict["output_types"] = [t.to_dict() for t in self.outputs_types] - if self.uid != self.name and False: + if self.uid != self.name: node_dict["name"] = self.name + node_dict["uid"] = self.uid node_dict["expr"] = self._expr.to_dict() return node_dict @override @classmethod def from_dict(cls, node_dict: dict[str, Any]) -> Self: - return cls(node_dict["expr"]) + ... From 60ee0d65bc4250d37a8d582b837a95cea0e626e0 Mon Sep 17 00:00:00 2001 From: Liam Semeria Date: Mon, 2 Mar 2026 17:03:01 +0100 Subject: [PATCH 03/10] serialization: optional tensor info, reformatting --- src/xtc/graphs/xtc/builder.py | 9 +++++---- src/xtc/graphs/xtc/data.py | 18 ++++++++++++------ src/xtc/graphs/xtc/expr.py | 7 ++++--- src/xtc/graphs/xtc/graph.py | 5 +++-- src/xtc/graphs/xtc/node.py | 5 ++--- src/xtc/graphs/xtc/operators.py | 9 +++++---- src/xtc/itf/data/tensor.py | 6 ++---- src/xtc/itf/graph/node.py | 6 ++---- src/xtc/itf/operator/operator.py | 8 +++----- 9 files changed, 38 insertions(+), 35 deletions(-) diff --git a/src/xtc/graphs/xtc/builder.py b/src/xtc/graphs/xtc/builder.py index d926c2f56..5572aca9c 100644 --- a/src/xtc/graphs/xtc/builder.py +++ b/src/xtc/graphs/xtc/builder.py @@ -11,6 +11,7 @@ from .expr import XTCTensorExpr from . import op_factory + class graph_builder: def __init__(self, **graph_kwargs: Any) -> None: self._graph_kwargs = graph_kwargs @@ -32,19 +33,19 @@ def graph(self) -> XTCGraph: @classmethod def from_dict(cls, graph_dict: dict[str, Any]) -> Any: XTCGraphContext.push() - + expr_uid_map = {} for inp in graph_dict["inputs"]: expr_uid_map[inp["uid"]] = XTCTensorExpr.from_dict(inp["expr"]) - + for node in graph_dict["nodes"]: expr = node["expr"] args = [expr_uid_map.get(arg) for arg in expr["args"]] - if "name" in node: + if "name" in node: args.append(node["name"]) op_func = getattr(op_factory, expr["op"]["name"]) expr_uid_map[node["uid"]] = op_func(*args, **expr["op"]["attrs"]) - + scope = XTCGraphContext.pop() print(scope.graph) return graph_dict diff --git a/src/xtc/graphs/xtc/data.py b/src/xtc/graphs/xtc/data.py index 825e87470..2fbf37399 100644 --- a/src/xtc/graphs/xtc/data.py +++ b/src/xtc/graphs/xtc/data.py @@ -3,7 +3,7 @@ # Copyright (c) 2024-2026 The XTC Project Authors # from typing_extensions import override, Self -from typing import cast, Any +from typing import cast, Any, Iterable import functools import operator import numpy as np @@ -110,15 +110,21 @@ def __eq__(self, other: object) -> bool: @override def to_dict(self) -> dict[str, Any]: - return { - "shape": self.shape, - "dtype": self.dtype, - } + tensor_dict = {} + if self.shape: + tensor_dict["shape"] = list(cast(Iterable[int], self.shape)) + if self.dtype: + tensor_dict["dtype"] = self.dtype + return tensor_dict @override @classmethod def from_dict(cls, tensor_dict: dict[str, Any]) -> Self: - return cls(shape = tensor_dict["shape"], dtype = tensor_dict["dtype"]) + return cls( + shape=tuple(tensor_dict["shape"]) if "shape" in tensor_dict else None, + dtype=tensor_dict["dtype"] if "dtype" in tensor_dict else None, + ) + class XTCConstantTensorType(XTCTensorType, ConstantTensorType): def __init__(self, shape: ConstantShapeType, dtype: ConstantDataType): diff --git a/src/xtc/graphs/xtc/expr.py b/src/xtc/graphs/xtc/expr.py index e8f2c06a8..537867ea3 100644 --- a/src/xtc/graphs/xtc/expr.py +++ b/src/xtc/graphs/xtc/expr.py @@ -96,12 +96,11 @@ def __str__(self) -> str: @abstractmethod def to_dict(self) -> dict[str, Any]: - return {"uid" : self.uid} + return {"uid": self.uid} @classmethod @abstractmethod - def from_dict(cls, node_dict: dict[str, Any]) -> Self: - ... + def from_dict(cls, node_dict: dict[str, Any]) -> Self: ... class XTCValueExpr(XTCExpr): @@ -199,6 +198,7 @@ def from_dict(cls, node_dict: dict[str, Any]) -> Self: type = XTCTensorType.from_dict(node_dict["type"]) return cls(tensor=type) + class XTCOpExpr(XTCExpr): def __init__(self, op: XTCOperator, args: ArgumentsType) -> None: super().__init__() @@ -248,6 +248,7 @@ def from_dict(cls, node_dict: dict[str, Any]) -> Self: # should not be called on just XTCOpExpr return cls(*node_dict["args"], **node_dict["attrs"]) + class XTCMatmulExpr(XTCOpExpr): def __init__(self, x: XTCExpr, y: XTCExpr, **attrs: Any) -> None: super().__init__( diff --git a/src/xtc/graphs/xtc/graph.py b/src/xtc/graphs/xtc/graph.py index 3f29f3825..d1053af4a 100644 --- a/src/xtc/graphs/xtc/graph.py +++ b/src/xtc/graphs/xtc/graph.py @@ -156,9 +156,10 @@ def __str__(self) -> str: def to_dict(self) -> dict[str, Any]: return { "inputs": [i.to_dict() for i in self._inputs], - #"outputs": [o.to_dict() for o in self._outputs], - "nodes": [n.to_dict() for n in self._nodes] + # "outputs": [o.to_dict() for o in self._outputs], + "nodes": [n.to_dict() for n in self._nodes], } + # TODO: sdump(), sload() # TODO: dump(), load() (yaml) diff --git a/src/xtc/graphs/xtc/node.py b/src/xtc/graphs/xtc/node.py index bf1a4dd4e..64d923b26 100644 --- a/src/xtc/graphs/xtc/node.py +++ b/src/xtc/graphs/xtc/node.py @@ -165,8 +165,7 @@ def to_dict(self) -> dict[str, str | Sequence[TensorType]]: node_dict["uid"] = self.uid node_dict["expr"] = self._expr.to_dict() return node_dict - + @override @classmethod - def from_dict(cls, node_dict: dict[str, Any]) -> Self: - ... + def from_dict(cls, node_dict: dict[str, Any]) -> Self: ... diff --git a/src/xtc/graphs/xtc/operators.py b/src/xtc/graphs/xtc/operators.py index b7dc391f3..c77b883d7 100644 --- a/src/xtc/graphs/xtc/operators.py +++ b/src/xtc/graphs/xtc/operators.py @@ -88,21 +88,22 @@ def _get_operation( @override def to_dict(self) -> dict[str, Any]: - op_dict = {"name" : self._name} - #def get_attr(obj: Any): + op_dict = {"name": self._name} + # def get_attr(obj: Any): # if isinstance(obj, dict): # return {k: get_attr(v) for k, v in obj.items()} # elif isinstance(obj, (list, tuple, set)): # return [get_attr(v) for v in obj] # else: # return obj - op_dict["attrs"] = {k: v for k,v in self._attrs.__dict__.items()} + op_dict["attrs"] = {k: v for k, v in self._attrs.__dict__.items()} return op_dict @classmethod @override def from_dict(cls, op_dict: dict[str, Any]) -> Self: - return cls() + return cls() + class XTCOperTensor(XTCOperator): def __init__(self) -> None: diff --git a/src/xtc/itf/data/tensor.py b/src/xtc/itf/data/tensor.py index 8cb244b4e..5144c0033 100644 --- a/src/xtc/itf/data/tensor.py +++ b/src/xtc/itf/data/tensor.py @@ -54,13 +54,11 @@ def ndim(self) -> int: ... @abstractmethod - def to_dict(self) -> dict[str, Any]: - ... + def to_dict(self) -> dict[str, Any]: ... @classmethod @abstractmethod - def from_dict(cls, tensor_dict :dict[str, Any]) -> Self: - ... + def from_dict(cls, tensor_dict: dict[str, Any]) -> Self: ... class ConstantTensorType(TensorType): diff --git a/src/xtc/itf/graph/node.py b/src/xtc/itf/graph/node.py index ad4f1f0b5..20c1d17d4 100644 --- a/src/xtc/itf/graph/node.py +++ b/src/xtc/itf/graph/node.py @@ -166,10 +166,8 @@ def forward(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]: ... @abstractmethod - def to_dict(self) -> dict[str, str | Sequence[TensorType]]: - ... + def to_dict(self) -> dict[str, str | Sequence[TensorType]]: ... @classmethod @abstractmethod - def from_dict(cls, node_dict: dict[str, Any]) -> Self: - ... + def from_dict(cls, node_dict: dict[str, Any]) -> Self: ... diff --git a/src/xtc/itf/operator/operator.py b/src/xtc/itf/operator/operator.py index 5bbb2dc3a..89875b1fc 100644 --- a/src/xtc/itf/operator/operator.py +++ b/src/xtc/itf/operator/operator.py @@ -50,11 +50,9 @@ def forward(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]: List of output tensors """ ... - + @abstractmethod - def to_dict(self) -> dict[str, Any]: - ... + def to_dict(self) -> dict[str, Any]: ... @classmethod @abstractmethod - def from_dict(cls, op_dict: dict[str, Any]) -> Self: - ... + def from_dict(cls, op_dict: dict[str, Any]) -> Self: ... From 01bf79f711e1ab204205d9c852029e6ff128b5bc Mon Sep 17 00:00:00 2001 From: Liam Semeria Date: Tue, 3 Mar 2026 15:47:47 +0100 Subject: [PATCH 04/10] serialization: yaml and s load/dump, op deserealization scraps --- src/xtc/graphs/xtc/builder.py | 17 +++++++++++++-- src/xtc/graphs/xtc/context.py | 3 +++ src/xtc/graphs/xtc/data.py | 2 +- src/xtc/graphs/xtc/expr.py | 2 +- src/xtc/graphs/xtc/graph.py | 23 ++++++++++++-------- src/xtc/graphs/xtc/operators.py | 38 ++++++++++++++++++++++++--------- 6 files changed, 62 insertions(+), 23 deletions(-) diff --git a/src/xtc/graphs/xtc/builder.py b/src/xtc/graphs/xtc/builder.py index 5572aca9c..ebb7fa483 100644 --- a/src/xtc/graphs/xtc/builder.py +++ b/src/xtc/graphs/xtc/builder.py @@ -4,12 +4,12 @@ # from typing import Any -from xtc.graphs.xtc.node import XTCNode - from .graph import XTCGraph from .context import XTCGraphContext from .expr import XTCTensorExpr from . import op_factory +from ast import literal_eval +from yaml import safe_load class graph_builder: @@ -35,6 +35,9 @@ def from_dict(cls, graph_dict: dict[str, Any]) -> Any: XTCGraphContext.push() expr_uid_map = {} + if "name" in graph_dict: + XTCGraphContext.name(graph_dict["name"]) + for inp in graph_dict["inputs"]: expr_uid_map[inp["uid"]] = XTCTensorExpr.from_dict(inp["expr"]) @@ -44,8 +47,18 @@ def from_dict(cls, graph_dict: dict[str, Any]) -> Any: if "name" in node: args.append(node["name"]) op_func = getattr(op_factory, expr["op"]["name"]) + # TODO: doesnt handle tuple conversion, (in operators.py) expr_uid_map[node["uid"]] = op_func(*args, **expr["op"]["attrs"]) scope = XTCGraphContext.pop() print(scope.graph) return graph_dict + + @classmethod + def loads(cls, dict_str: str) -> None: + cls.from_dict(literal_eval(dict_str)) + + @classmethod + def load(cls, file_name: str) -> None: + with open(file_name, "r") as f: + cls.from_dict(safe_load(f)) diff --git a/src/xtc/graphs/xtc/context.py b/src/xtc/graphs/xtc/context.py index e7e263955..47be81f1e 100644 --- a/src/xtc/graphs/xtc/context.py +++ b/src/xtc/graphs/xtc/context.py @@ -119,5 +119,8 @@ def outputs(self, *outs: XTCExpr) -> None: def inputs(self, *inps: XTCExpr) -> None: return self.current.add_inputs(*inps) + def name(self, name_str: str) -> None: + self.current._graph_kwargs["name"] = name_str + XTCGraphContext = XTCGraphScopes() diff --git a/src/xtc/graphs/xtc/data.py b/src/xtc/graphs/xtc/data.py index 2fbf37399..0e763d08a 100644 --- a/src/xtc/graphs/xtc/data.py +++ b/src/xtc/graphs/xtc/data.py @@ -110,7 +110,7 @@ def __eq__(self, other: object) -> bool: @override def to_dict(self) -> dict[str, Any]: - tensor_dict = {} + tensor_dict: dict[str, Any] = {} if self.shape: tensor_dict["shape"] = list(cast(Iterable[int], self.shape)) if self.dtype: diff --git a/src/xtc/graphs/xtc/expr.py b/src/xtc/graphs/xtc/expr.py index 537867ea3..d3b9de3f6 100644 --- a/src/xtc/graphs/xtc/expr.py +++ b/src/xtc/graphs/xtc/expr.py @@ -246,7 +246,7 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, node_dict: dict[str, Any]) -> Self: # should not be called on just XTCOpExpr - return cls(*node_dict["args"], **node_dict["attrs"]) + return cls(XTCOperator.from_dict(node_dict["op"]), *node_dict["args"]) class XTCMatmulExpr(XTCOpExpr): diff --git a/src/xtc/graphs/xtc/graph.py b/src/xtc/graphs/xtc/graph.py index d1053af4a..e0ed55e9c 100644 --- a/src/xtc/graphs/xtc/graph.py +++ b/src/xtc/graphs/xtc/graph.py @@ -12,6 +12,7 @@ from .node import XTCNode from .utils import XTCGraphUtils from .data import XTCTensor, XTCTensorType +from yaml import dump as yaml_dump __all__ = [ "XTCGraph", @@ -154,12 +155,16 @@ def __str__(self) -> str: return graph_str def to_dict(self) -> dict[str, Any]: - return { - "inputs": [i.to_dict() for i in self._inputs], - # "outputs": [o.to_dict() for o in self._outputs], - "nodes": [n.to_dict() for n in self._nodes], - } - - -# TODO: sdump(), sload() -# TODO: dump(), load() (yaml) + graph_dict: dict[str, Any] = {} + if self._name: + graph_dict["name"] = self._name + graph_dict["inputs"] = [i.to_dict() for i in self._inputs] + graph_dict["nodes"] = [n.to_dict() for n in self._nodes] + return graph_dict + + def dumps(self) -> str: + return str(self.to_dict()) + + def dump(self, file_name: str): + with open(file_name, "w") as f: + yaml_dump(self.to_dict(), f, sort_keys=False) diff --git a/src/xtc/graphs/xtc/operators.py b/src/xtc/graphs/xtc/operators.py index c77b883d7..90b6e0fb8 100644 --- a/src/xtc/graphs/xtc/operators.py +++ b/src/xtc/graphs/xtc/operators.py @@ -30,10 +30,17 @@ class XTCOperator(Operator): + #_registry = {} + def __init__(self, name: str, **attrs: XTCOperatorAttr) -> None: self._name = name self._attrs = NS(**attrs) + #def __init_subclass__(cls) -> None: + # # TODO: add _name to each op for this to work + # XTCOperator._registry[cls.name] = cls + # return super().__init_subclass__() + @property @override def name(self) -> str: @@ -88,21 +95,32 @@ def _get_operation( @override def to_dict(self) -> dict[str, Any]: - op_dict = {"name": self._name} - # def get_attr(obj: Any): - # if isinstance(obj, dict): - # return {k: get_attr(v) for k, v in obj.items()} - # elif isinstance(obj, (list, tuple, set)): - # return [get_attr(v) for v in obj] - # else: - # return obj - op_dict["attrs"] = {k: v for k, v in self._attrs.__dict__.items()} + def listify(obj: Any): + if isinstance(obj, dict): + return {k: listify(v) for k, v in obj.items()} + elif isinstance(obj, tuple): + return [listify(v) for v in obj] + else: + return obj + + op_dict: dict[str, Any] = {"name": self._name} + op_dict["attrs"] = {k: listify(v) for k, v in self._attrs.__dict__.items()} return op_dict @classmethod @override def from_dict(cls, op_dict: dict[str, Any]) -> Self: - return cls() + ... + # ops with list attr should override this +# def tuplify(obj: Any): +# if isinstance(obj, dict): +# return {k: tuplify(v) for k, v in obj.items()} +# elif isinstance(obj, list): +# return tuple(tuplify(v) for v in obj) +# else: +# return obj +# + #return cls._registry[op_dict["name"]](**tuplify(op_dict["attrs"])) class XTCOperTensor(XTCOperator): From de98f8144f66bb76263070519c0bb61bfde97567 Mon Sep 17 00:00:00 2001 From: Liam Semeria Date: Wed, 11 Mar 2026 14:29:02 +0100 Subject: [PATCH 05/10] serialization: added temp example --- serialization_temp_example.py | 25 +++++++++++++++++++++++++ src/xtc/graphs/xtc/builder.py | 5 +---- src/xtc/graphs/xtc/graph.py | 2 +- 3 files changed, 27 insertions(+), 5 deletions(-) create mode 100644 serialization_temp_example.py diff --git a/serialization_temp_example.py b/serialization_temp_example.py new file mode 100644 index 000000000..16155707c --- /dev/null +++ b/serialization_temp_example.py @@ -0,0 +1,25 @@ +from pathlib import Path +import xtc.graphs.xtc.op as O + +if not Path("output.yaml").exists(): + I, J, K, dtype = 4, 32, 512, "float32" + a = O.tensor((I, K), dtype, name="A") + b = O.tensor((K, J), dtype, name="B") + c = O.tensor((J, I), dtype, name="C") + + with O.graph(name = "matmul_relu") as gb: + m = O.matmul(a, b, name="M") + q = O.relu(m, threshold=.1) + r = O.relu(m, threshold=.1) + k = O.matmul(c, r, name="K") + + graph = gb.graph + print(graph) + graph.dump("output.yaml") + +print("loading from yaml....") + +with O.graph(name = "matmul_relu") as gb2: + gb2.load("output.yaml") + +print(gb2.graph) diff --git a/src/xtc/graphs/xtc/builder.py b/src/xtc/graphs/xtc/builder.py index ebb7fa483..65ea56986 100644 --- a/src/xtc/graphs/xtc/builder.py +++ b/src/xtc/graphs/xtc/builder.py @@ -32,7 +32,6 @@ def graph(self) -> XTCGraph: @classmethod def from_dict(cls, graph_dict: dict[str, Any]) -> Any: - XTCGraphContext.push() expr_uid_map = {} if "name" in graph_dict: @@ -47,11 +46,9 @@ def from_dict(cls, graph_dict: dict[str, Any]) -> Any: if "name" in node: args.append(node["name"]) op_func = getattr(op_factory, expr["op"]["name"]) - # TODO: doesnt handle tuple conversion, (in operators.py) + # TODO: doesnt handle tuple conversion, (still in operators.py) expr_uid_map[node["uid"]] = op_func(*args, **expr["op"]["attrs"]) - scope = XTCGraphContext.pop() - print(scope.graph) return graph_dict @classmethod diff --git a/src/xtc/graphs/xtc/graph.py b/src/xtc/graphs/xtc/graph.py index e0ed55e9c..f3f0e076c 100644 --- a/src/xtc/graphs/xtc/graph.py +++ b/src/xtc/graphs/xtc/graph.py @@ -165,6 +165,6 @@ def to_dict(self) -> dict[str, Any]: def dumps(self) -> str: return str(self.to_dict()) - def dump(self, file_name: str): + def dump(self, file_name: str) -> None: with open(file_name, "w") as f: yaml_dump(self.to_dict(), f, sort_keys=False) From 1621c3ba99aff524caed5f4ea47bf57cd2b65d47 Mon Sep 17 00:00:00 2001 From: Liam Semeria Date: Wed, 11 Mar 2026 14:51:07 +0100 Subject: [PATCH 06/10] serialization: format, type checking --- serialization_temp_example.py | 12 ++++++------ src/xtc/graphs/xtc/builder.py | 1 - src/xtc/graphs/xtc/expr.py | 15 ++------------- src/xtc/graphs/xtc/node.py | 10 +++------- src/xtc/graphs/xtc/operators.py | 17 ++++++++--------- src/xtc/itf/graph/node.py | 6 ------ src/xtc/itf/operator/operator.py | 4 ---- 7 files changed, 19 insertions(+), 46 deletions(-) diff --git a/serialization_temp_example.py b/serialization_temp_example.py index 16155707c..3207b91de 100644 --- a/serialization_temp_example.py +++ b/serialization_temp_example.py @@ -7,10 +7,10 @@ b = O.tensor((K, J), dtype, name="B") c = O.tensor((J, I), dtype, name="C") - with O.graph(name = "matmul_relu") as gb: + with O.graph(name="matmul_relu") as gb: m = O.matmul(a, b, name="M") - q = O.relu(m, threshold=.1) - r = O.relu(m, threshold=.1) + q = O.relu(m, threshold=0.1) + r = O.relu(m, threshold=0.1) k = O.matmul(c, r, name="K") graph = gb.graph @@ -19,7 +19,7 @@ print("loading from yaml....") -with O.graph(name = "matmul_relu") as gb2: - gb2.load("output.yaml") - +with O.graph(name="matmul_relu") as gb2: + gb2.load("output.yaml") + print(gb2.graph) diff --git a/src/xtc/graphs/xtc/builder.py b/src/xtc/graphs/xtc/builder.py index 65ea56986..607a5a995 100644 --- a/src/xtc/graphs/xtc/builder.py +++ b/src/xtc/graphs/xtc/builder.py @@ -32,7 +32,6 @@ def graph(self) -> XTCGraph: @classmethod def from_dict(cls, graph_dict: dict[str, Any]) -> Any: - expr_uid_map = {} if "name" in graph_dict: XTCGraphContext.name(graph_dict["name"]) diff --git a/src/xtc/graphs/xtc/expr.py b/src/xtc/graphs/xtc/expr.py index d3b9de3f6..3a58bc928 100644 --- a/src/xtc/graphs/xtc/expr.py +++ b/src/xtc/graphs/xtc/expr.py @@ -98,10 +98,6 @@ def __str__(self) -> str: def to_dict(self) -> dict[str, Any]: return {"uid": self.uid} - @classmethod - @abstractmethod - def from_dict(cls, node_dict: dict[str, Any]) -> Self: ... - class XTCValueExpr(XTCExpr): @property @@ -193,9 +189,8 @@ def to_dict(self) -> dict[str, Any]: return {"type": self.type.to_dict()} @classmethod - @override - def from_dict(cls, node_dict: dict[str, Any]) -> Self: - type = XTCTensorType.from_dict(node_dict["type"]) + def from_dict(cls, op_dict: dict[str, Any]) -> Self: + type = XTCTensorType.from_dict(op_dict["type"]) return cls(tensor=type) @@ -242,12 +237,6 @@ def __str__(self) -> str: def to_dict(self) -> dict[str, Any]: return {"op": self._op.to_dict(), "args": [a.uid for a in self.args]} - @override - @classmethod - def from_dict(cls, node_dict: dict[str, Any]) -> Self: - # should not be called on just XTCOpExpr - return cls(XTCOperator.from_dict(node_dict["op"]), *node_dict["args"]) - class XTCMatmulExpr(XTCOpExpr): def __init__(self, x: XTCExpr, y: XTCExpr, **attrs: Any) -> None: diff --git a/src/xtc/graphs/xtc/node.py b/src/xtc/graphs/xtc/node.py index 64d923b26..c2ab160fa 100644 --- a/src/xtc/graphs/xtc/node.py +++ b/src/xtc/graphs/xtc/node.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024-2026 The XTC Project Authors # -from typing_extensions import override, Self +from typing_extensions import override from collections.abc import Sequence from typing import cast, Any @@ -158,14 +158,10 @@ def __str__(self) -> str: return str(self._expr).split("=", 1)[1].strip() + attrs_str + type_str @override - def to_dict(self) -> dict[str, str | Sequence[TensorType]]: - node_dict = {} + def to_dict(self) -> dict[str, Any]: + node_dict: dict[str, Any] = {} if self.uid != self.name: node_dict["name"] = self.name node_dict["uid"] = self.uid node_dict["expr"] = self._expr.to_dict() return node_dict - - @override - @classmethod - def from_dict(cls, node_dict: dict[str, Any]) -> Self: ... diff --git a/src/xtc/graphs/xtc/operators.py b/src/xtc/graphs/xtc/operators.py index 90b6e0fb8..d7206b03f 100644 --- a/src/xtc/graphs/xtc/operators.py +++ b/src/xtc/graphs/xtc/operators.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024-2026 The XTC Project Authors # -from typing_extensions import override, Self +from typing_extensions import override from typing import TypeAlias, cast, Any from types import SimpleNamespace as NS from collections.abc import Sequence, Mapping @@ -30,13 +30,13 @@ class XTCOperator(Operator): - #_registry = {} + # _registry = {} def __init__(self, name: str, **attrs: XTCOperatorAttr) -> None: self._name = name self._attrs = NS(**attrs) - #def __init_subclass__(cls) -> None: + # def __init_subclass__(cls) -> None: # # TODO: add _name to each op for this to work # XTCOperator._registry[cls.name] = cls # return super().__init_subclass__() @@ -107,11 +107,10 @@ def listify(obj: Any): op_dict["attrs"] = {k: listify(v) for k, v in self._attrs.__dict__.items()} return op_dict - @classmethod - @override - def from_dict(cls, op_dict: dict[str, Any]) -> Self: - ... - # ops with list attr should override this + +# tuplify logic to be moved to builder + + # def tuplify(obj: Any): # if isinstance(obj, dict): # return {k: tuplify(v) for k, v in obj.items()} @@ -120,7 +119,7 @@ def from_dict(cls, op_dict: dict[str, Any]) -> Self: # else: # return obj # - #return cls._registry[op_dict["name"]](**tuplify(op_dict["attrs"])) +# return cls._registry[op_dict["name"]](**tuplify(op_dict["attrs"])) class XTCOperTensor(XTCOperator): diff --git a/src/xtc/itf/graph/node.py b/src/xtc/itf/graph/node.py index 20c1d17d4..4bd31abd4 100644 --- a/src/xtc/itf/graph/node.py +++ b/src/xtc/itf/graph/node.py @@ -3,8 +3,6 @@ # Copyright (c) 2024-2026 The XTC Project Authors # from abc import ABC, abstractmethod -from typing import Any -from typing_extensions import Self from collections.abc import Sequence from ..operator.operator import Operator from ..data import TensorType, Tensor @@ -167,7 +165,3 @@ def forward(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]: @abstractmethod def to_dict(self) -> dict[str, str | Sequence[TensorType]]: ... - - @classmethod - @abstractmethod - def from_dict(cls, node_dict: dict[str, Any]) -> Self: ... diff --git a/src/xtc/itf/operator/operator.py b/src/xtc/itf/operator/operator.py index 89875b1fc..e81ce31e4 100644 --- a/src/xtc/itf/operator/operator.py +++ b/src/xtc/itf/operator/operator.py @@ -4,7 +4,6 @@ # from abc import ABC, abstractmethod from typing import Any -from typing_extensions import Self from collections.abc import Sequence from ..data.tensor import Tensor, TensorType @@ -53,6 +52,3 @@ def forward(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]: @abstractmethod def to_dict(self) -> dict[str, Any]: ... - @classmethod - @abstractmethod - def from_dict(cls, op_dict: dict[str, Any]) -> Self: ... From 6e4904264e8a5080ef17a43b9a4734cd6af16c3b Mon Sep 17 00:00:00 2001 From: Liam Semeria Date: Wed, 11 Mar 2026 15:41:09 +0100 Subject: [PATCH 07/10] serialization: multiple and custom output exprs --- serialization_temp_example.py | 3 ++- src/xtc/graphs/xtc/builder.py | 3 +++ src/xtc/graphs/xtc/graph.py | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/serialization_temp_example.py b/serialization_temp_example.py index 3207b91de..317d28761 100644 --- a/serialization_temp_example.py +++ b/serialization_temp_example.py @@ -1,5 +1,6 @@ from pathlib import Path import xtc.graphs.xtc.op as O +from xtc.graphs.xtc.context import XTCGraphContext as C if not Path("output.yaml").exists(): I, J, K, dtype = 4, 32, 512, "float32" @@ -9,9 +10,9 @@ with O.graph(name="matmul_relu") as gb: m = O.matmul(a, b, name="M") - q = O.relu(m, threshold=0.1) r = O.relu(m, threshold=0.1) k = O.matmul(c, r, name="K") + C.outputs(m) graph = gb.graph print(graph) diff --git a/src/xtc/graphs/xtc/builder.py b/src/xtc/graphs/xtc/builder.py index 607a5a995..379af896f 100644 --- a/src/xtc/graphs/xtc/builder.py +++ b/src/xtc/graphs/xtc/builder.py @@ -48,6 +48,9 @@ def from_dict(cls, graph_dict: dict[str, Any]) -> Any: # TODO: doesnt handle tuple conversion, (still in operators.py) expr_uid_map[node["uid"]] = op_func(*args, **expr["op"]["attrs"]) + outputs = [expr_uid_map[out["uid"]] for out in graph_dict["outputs"]] + XTCGraphContext.outputs(*outputs) + return graph_dict @classmethod diff --git a/src/xtc/graphs/xtc/graph.py b/src/xtc/graphs/xtc/graph.py index f3f0e076c..37e1137b9 100644 --- a/src/xtc/graphs/xtc/graph.py +++ b/src/xtc/graphs/xtc/graph.py @@ -159,6 +159,7 @@ def to_dict(self) -> dict[str, Any]: if self._name: graph_dict["name"] = self._name graph_dict["inputs"] = [i.to_dict() for i in self._inputs] + graph_dict["outputs"] = [{"uid": o.to_dict()["uid"]} for o in self._outputs] graph_dict["nodes"] = [n.to_dict() for n in self._nodes] return graph_dict From 3386251c55c8e8d89a5b6ea7ebbdeb3cf26e8f09 Mon Sep 17 00:00:00 2001 From: Liam Semeria Date: Wed, 11 Mar 2026 16:26:55 +0100 Subject: [PATCH 08/10] serialization: support for attrs with tuples --- src/xtc/graphs/xtc/builder.py | 11 +++++++++-- src/xtc/graphs/xtc/operators.py | 23 +---------------------- 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/src/xtc/graphs/xtc/builder.py b/src/xtc/graphs/xtc/builder.py index 379af896f..3fb940bb1 100644 --- a/src/xtc/graphs/xtc/builder.py +++ b/src/xtc/graphs/xtc/builder.py @@ -32,6 +32,14 @@ def graph(self) -> XTCGraph: @classmethod def from_dict(cls, graph_dict: dict[str, Any]) -> Any: + def tuplify(obj: Any) -> Any: + if isinstance(obj, dict): + return {k: tuplify(v) for k, v in obj.items()} + elif isinstance(obj, list): + return tuple(tuplify(v) for v in obj) + else: + return obj + expr_uid_map = {} if "name" in graph_dict: XTCGraphContext.name(graph_dict["name"]) @@ -45,8 +53,7 @@ def from_dict(cls, graph_dict: dict[str, Any]) -> Any: if "name" in node: args.append(node["name"]) op_func = getattr(op_factory, expr["op"]["name"]) - # TODO: doesnt handle tuple conversion, (still in operators.py) - expr_uid_map[node["uid"]] = op_func(*args, **expr["op"]["attrs"]) + expr_uid_map[node["uid"]] = op_func(*args, **tuplify(expr["op"]["attrs"])) outputs = [expr_uid_map[out["uid"]] for out in graph_dict["outputs"]] XTCGraphContext.outputs(*outputs) diff --git a/src/xtc/graphs/xtc/operators.py b/src/xtc/graphs/xtc/operators.py index d7206b03f..d06c02aa6 100644 --- a/src/xtc/graphs/xtc/operators.py +++ b/src/xtc/graphs/xtc/operators.py @@ -30,17 +30,10 @@ class XTCOperator(Operator): - # _registry = {} - def __init__(self, name: str, **attrs: XTCOperatorAttr) -> None: self._name = name self._attrs = NS(**attrs) - # def __init_subclass__(cls) -> None: - # # TODO: add _name to each op for this to work - # XTCOperator._registry[cls.name] = cls - # return super().__init_subclass__() - @property @override def name(self) -> str: @@ -95,7 +88,7 @@ def _get_operation( @override def to_dict(self) -> dict[str, Any]: - def listify(obj: Any): + def listify(obj: Any) -> Any: if isinstance(obj, dict): return {k: listify(v) for k, v in obj.items()} elif isinstance(obj, tuple): @@ -108,20 +101,6 @@ def listify(obj: Any): return op_dict -# tuplify logic to be moved to builder - - -# def tuplify(obj: Any): -# if isinstance(obj, dict): -# return {k: tuplify(v) for k, v in obj.items()} -# elif isinstance(obj, list): -# return tuple(tuplify(v) for v in obj) -# else: -# return obj -# -# return cls._registry[op_dict["name"]](**tuplify(op_dict["attrs"])) - - class XTCOperTensor(XTCOperator): def __init__(self) -> None: super().__init__("tensor") From 8f807a167c782e52eaf3b0ebf91c277455343eed Mon Sep 17 00:00:00 2001 From: Liam Semeria Date: Thu, 12 Mar 2026 14:10:22 +0100 Subject: [PATCH 09/10] serialization: version str, deterministic uids, sdump/sload --- src/xtc/graphs/xtc/builder.py | 10 ++++++---- src/xtc/graphs/xtc/graph.py | 19 ++++++++++++++++--- src/xtc/graphs/xtc/operators.py | 7 +++++++ 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/src/xtc/graphs/xtc/builder.py b/src/xtc/graphs/xtc/builder.py index 3fb940bb1..cc39cf581 100644 --- a/src/xtc/graphs/xtc/builder.py +++ b/src/xtc/graphs/xtc/builder.py @@ -8,7 +8,6 @@ from .context import XTCGraphContext from .expr import XTCTensorExpr from . import op_factory -from ast import literal_eval from yaml import safe_load @@ -52,17 +51,20 @@ def tuplify(obj: Any) -> Any: args = [expr_uid_map.get(arg) for arg in expr["args"]] if "name" in node: args.append(node["name"]) + if not hasattr(op_factory, expr["op"]["name"]): + raise ValueError( + f"serialized op {expr['op']['name']} is not implemented!" + ) op_func = getattr(op_factory, expr["op"]["name"]) expr_uid_map[node["uid"]] = op_func(*args, **tuplify(expr["op"]["attrs"])) outputs = [expr_uid_map[out["uid"]] for out in graph_dict["outputs"]] XTCGraphContext.outputs(*outputs) - return graph_dict @classmethod - def loads(cls, dict_str: str) -> None: - cls.from_dict(literal_eval(dict_str)) + def loads(cls, yaml_str: str) -> None: + cls.from_dict(safe_load(yaml_str)) @classmethod def load(cls, file_name: str) -> None: diff --git a/src/xtc/graphs/xtc/graph.py b/src/xtc/graphs/xtc/graph.py index 37e1137b9..8ce42170c 100644 --- a/src/xtc/graphs/xtc/graph.py +++ b/src/xtc/graphs/xtc/graph.py @@ -12,7 +12,8 @@ from .node import XTCNode from .utils import XTCGraphUtils from .data import XTCTensor, XTCTensorType -from yaml import dump as yaml_dump +from .operators import XTCOperator +from yaml import dump as yaml_dump, safe_dump __all__ = [ "XTCGraph", @@ -161,10 +162,22 @@ def to_dict(self) -> dict[str, Any]: graph_dict["inputs"] = [i.to_dict() for i in self._inputs] graph_dict["outputs"] = [{"uid": o.to_dict()["uid"]} for o in self._outputs] graph_dict["nodes"] = [n.to_dict() for n in self._nodes] - return graph_dict + graph_dict["ops_version"] = XTCOperator.version_string() + lowest_uid = int(graph_dict["inputs"][0]["uid"][1:]) + + def compact_uids(obj: Any) -> Any: + if isinstance(obj, dict): + obj = {k: compact_uids(v) for k, v in obj.items()} + elif isinstance(obj, list): + obj = [compact_uids(v) for v in obj] + elif isinstance(obj, str) and obj[0] == "%": + obj = f"%{int(obj[1:]) - lowest_uid}" + return obj + + return compact_uids(graph_dict) def dumps(self) -> str: - return str(self.to_dict()) + return safe_dump(self.to_dict()) def dump(self, file_name: str) -> None: with open(file_name, "w") as f: diff --git a/src/xtc/graphs/xtc/operators.py b/src/xtc/graphs/xtc/operators.py index d06c02aa6..d157db596 100644 --- a/src/xtc/graphs/xtc/operators.py +++ b/src/xtc/graphs/xtc/operators.py @@ -9,6 +9,7 @@ import functools import operator import numpy as np +import hashlib from xtc.itf.operator import Operator from xtc.itf.data import Tensor, TensorType @@ -100,6 +101,12 @@ def listify(obj: Any) -> Any: op_dict["attrs"] = {k: listify(v) for k, v in self._attrs.__dict__.items()} return op_dict + @classmethod + def version_string(cls) -> str: + names = sorted(f"{c.__module__}.{c.__qualname__}" for c in cls.__subclasses__()) + data = "|".join(names).encode() + return hashlib.sha256(data).hexdigest()[:16] + class XTCOperTensor(XTCOperator): def __init__(self) -> None: From 1eff19b012ca300ee9341547c1def3eccd41f33c Mon Sep 17 00:00:00 2001 From: Liam Semeria Date: Fri, 13 Mar 2026 12:12:19 +0100 Subject: [PATCH 10/10] serialization: version mismatch info, added pytest --- serialization_temp_example.py | 26 -------- src/xtc/graphs/xtc/builder.py | 9 ++- .../pytest/graphs/test_graph_serialization.py | 66 +++++++++++++++++++ 3 files changed, 74 insertions(+), 27 deletions(-) delete mode 100644 serialization_temp_example.py create mode 100644 tests/pytest/graphs/test_graph_serialization.py diff --git a/serialization_temp_example.py b/serialization_temp_example.py deleted file mode 100644 index 317d28761..000000000 --- a/serialization_temp_example.py +++ /dev/null @@ -1,26 +0,0 @@ -from pathlib import Path -import xtc.graphs.xtc.op as O -from xtc.graphs.xtc.context import XTCGraphContext as C - -if not Path("output.yaml").exists(): - I, J, K, dtype = 4, 32, 512, "float32" - a = O.tensor((I, K), dtype, name="A") - b = O.tensor((K, J), dtype, name="B") - c = O.tensor((J, I), dtype, name="C") - - with O.graph(name="matmul_relu") as gb: - m = O.matmul(a, b, name="M") - r = O.relu(m, threshold=0.1) - k = O.matmul(c, r, name="K") - C.outputs(m) - - graph = gb.graph - print(graph) - graph.dump("output.yaml") - -print("loading from yaml....") - -with O.graph(name="matmul_relu") as gb2: - gb2.load("output.yaml") - -print(gb2.graph) diff --git a/src/xtc/graphs/xtc/builder.py b/src/xtc/graphs/xtc/builder.py index cc39cf581..b66952dd7 100644 --- a/src/xtc/graphs/xtc/builder.py +++ b/src/xtc/graphs/xtc/builder.py @@ -7,6 +7,7 @@ from .graph import XTCGraph from .context import XTCGraphContext from .expr import XTCTensorExpr +from .operators import XTCOperator from . import op_factory from yaml import safe_load @@ -52,8 +53,14 @@ def tuplify(obj: Any) -> Any: if "name" in node: args.append(node["name"]) if not hasattr(op_factory, expr["op"]["name"]): + version_mismatch = ( + "version mismatch detected, " + if XTCOperator.version_string() != graph_dict["ops_version"] + else "" + ) raise ValueError( - f"serialized op {expr['op']['name']} is not implemented!" + version_mismatch + + f"serialized op {expr['op']['name']} is not implemented." ) op_func = getattr(op_factory, expr["op"]["name"]) expr_uid_map[node["uid"]] = op_func(*args, **tuplify(expr["op"]["attrs"])) diff --git a/tests/pytest/graphs/test_graph_serialization.py b/tests/pytest/graphs/test_graph_serialization.py new file mode 100644 index 000000000..3aecd687a --- /dev/null +++ b/tests/pytest/graphs/test_graph_serialization.py @@ -0,0 +1,66 @@ +import tempfile +import xtc.graphs.xtc.op as O + +def test_matmul_relu_to_from_dict(): + I, J, K, dtype = 4, 32, 512, "float32" + a = O.tensor((I, K), dtype, name="A") + b = O.tensor((K, J), dtype, name="B") + + with O.graph(name="matmul_relu") as gb: + m = O.matmul(a, b, name="matmul") + O.relu(m, name="relu") + + graph_dict = gb.graph.to_dict() + with O.graph() as gb2: + gb2.from_dict(graph_dict) + assert graph_dict != {} + assert graph_dict == gb2.graph.to_dict() + + +def test_conv2d_pad_sdump_sload(): + N, H, W, F, R, S, C, SH, SW, dtype = 1, 8, 8, 16, 5, 5, 3, 2, 2, "float32" + a = O.tensor((N, H, W, C), dtype, name="I") + b = O.tensor((R, S, C, F), dtype, name="W") + + with O.graph(name="pad_conv2d_nhwc_mini") as gb: + p = O.pad(a, padding={1: (2), 2: (2, 2)}, name="pad") + O.conv2d(p, b, stride=(SH, SW), name="conv") + + graph_str = gb.graph.dumps() + with O.graph(name="matmul_relu") as gb2: + gb2.loads(graph_str) + assert graph_str != "" + assert graph_str == gb2.graph.dumps() + + with tempfile.NamedTemporaryFile(mode="w+", delete=True) as f: + gb.graph.dump(f.name) + with O.graph() as gb3: + gb3.load(f.name) + assert gb.graph.to_dict() == gb3.graph.to_dict() + +def test_mlp_fc_custom_output(): + img = O.tensor() + w1 = O.tensor() + w2 = O.tensor() + w3 = O.tensor() + w4 = O.tensor() + fc = lambda i, w, nout: O.matmul(O.reshape(i, shape=(1, -1)), O.reshape(w, shape=(-1, nout))) + # Multi Layer Perceptron with 3 relu(fc) + 1 fc + with O.graph(name="mlp4") as gb: + with O.graph(name="l1"): + l1 = O.relu(fc(img, w1, 512)) + with O.graph(name="l2"): + l2 = O.relu(fc(l1, w2, 256)) + with O.graph(name="l3"): + l3 = O.relu(fc(l2, w3, 128)) + with O.graph(name="l4"): + l4 = fc(l3, w4, 10) + O.reshape(l4, shape=(-1,)) + O.outputs(l1) + + graph_dict = gb.graph.to_dict() + with O.graph() as gb2: + gb2.from_dict(graph_dict) + assert graph_dict != {} + assert graph_dict == gb2.graph.to_dict() +