Skip to content
52 changes: 52 additions & 0 deletions src/xtc/graphs/xtc/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

from .graph import XTCGraph
from .context import XTCGraphContext
from .expr import XTCTensorExpr
from .operators import XTCOperator
from . import op_factory
from yaml import safe_load


class graph_builder:
Expand All @@ -25,3 +29,51 @@ def __exit__(self, *_: Any) -> None:
def graph(self) -> XTCGraph:
assert self._graph is not None, "can't get graph inside builder context"
return self._graph

@classmethod
def from_dict(cls, graph_dict: dict[str, Any]) -> Any:
def tuplify(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: tuplify(v) for k, v in obj.items()}
elif isinstance(obj, list):
return tuple(tuplify(v) for v in obj)
else:
return obj

expr_uid_map = {}
if "name" in graph_dict:
XTCGraphContext.name(graph_dict["name"])

for inp in graph_dict["inputs"]:
expr_uid_map[inp["uid"]] = XTCTensorExpr.from_dict(inp["expr"])

for node in graph_dict["nodes"]:
expr = node["expr"]
args = [expr_uid_map.get(arg) for arg in expr["args"]]
if "name" in node:
args.append(node["name"])
if not hasattr(op_factory, expr["op"]["name"]):
version_mismatch = (
"version mismatch detected, "
if XTCOperator.version_string() != graph_dict["ops_version"]
else ""
)
raise ValueError(
version_mismatch
+ f"serialized op {expr['op']['name']} is not implemented."
)
op_func = getattr(op_factory, expr["op"]["name"])
expr_uid_map[node["uid"]] = op_func(*args, **tuplify(expr["op"]["attrs"]))

outputs = [expr_uid_map[out["uid"]] for out in graph_dict["outputs"]]
XTCGraphContext.outputs(*outputs)
return graph_dict

@classmethod
def loads(cls, yaml_str: str) -> None:
cls.from_dict(safe_load(yaml_str))

@classmethod
def load(cls, file_name: str) -> None:
with open(file_name, "r") as f:
cls.from_dict(safe_load(f))
3 changes: 3 additions & 0 deletions src/xtc/graphs/xtc/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,8 @@ def outputs(self, *outs: XTCExpr) -> None:
def inputs(self, *inps: XTCExpr) -> None:
return self.current.add_inputs(*inps)

def name(self, name_str: str) -> None:
self.current._graph_kwargs["name"] = name_str


XTCGraphContext = XTCGraphScopes()
21 changes: 19 additions & 2 deletions src/xtc/graphs/xtc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024-2026 The XTC Project Authors
#
from typing_extensions import override
from typing import cast, Any
from typing_extensions import override, Self
from typing import cast, Any, Iterable
import functools
import operator
import numpy as np
Expand Down Expand Up @@ -108,6 +108,23 @@ def __eq__(self, other: object) -> bool:
return NotImplemented
return self.dtype == other.dtype and self.shape == other.shape

@override
def to_dict(self) -> dict[str, Any]:
tensor_dict: dict[str, Any] = {}
if self.shape:
tensor_dict["shape"] = list(cast(Iterable[int], self.shape))
if self.dtype:
tensor_dict["dtype"] = self.dtype
return tensor_dict

@override
@classmethod
def from_dict(cls, tensor_dict: dict[str, Any]) -> Self:
return cls(
shape=tuple(tensor_dict["shape"]) if "shape" in tensor_dict else None,
dtype=tensor_dict["dtype"] if "dtype" in tensor_dict else None,
)


class XTCConstantTensorType(XTCTensorType, ConstantTensorType):
def __init__(self, shape: ConstantShapeType, dtype: ConstantDataType):
Expand Down
19 changes: 18 additions & 1 deletion src/xtc/graphs/xtc/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing_extensions import override
from typing_extensions import override, Self
from typing import Any, TypeAlias
import threading

Expand Down Expand Up @@ -94,6 +94,10 @@ def uid(self) -> str:
def __str__(self) -> str:
return f"{self.uid} = ?"

@abstractmethod
def to_dict(self) -> dict[str, Any]:
return {"uid": self.uid}


class XTCValueExpr(XTCExpr):
@property
Expand Down Expand Up @@ -180,6 +184,15 @@ def __str__(self) -> str:
args = ", ".join([arg.uid for arg in self.args])
return f"{self.uid} = {self.op_name}({args})"

@override
def to_dict(self) -> dict[str, Any]:
return {"type": self.type.to_dict()}

@classmethod
def from_dict(cls, op_dict: dict[str, Any]) -> Self:
type = XTCTensorType.from_dict(op_dict["type"])
return cls(tensor=type)


class XTCOpExpr(XTCExpr):
def __init__(self, op: XTCOperator, args: ArgumentsType) -> None:
Expand Down Expand Up @@ -220,6 +233,10 @@ def __str__(self) -> str:
args = ", ".join(params)
return f"{self.uid} = {self.op_name}({args})"

@override
def to_dict(self) -> dict[str, Any]:
return {"op": self._op.to_dict(), "args": [a.uid for a in self.args]}


class XTCMatmulExpr(XTCOpExpr):
def __init__(self, x: XTCExpr, y: XTCExpr, **attrs: Any) -> None:
Expand Down
32 changes: 31 additions & 1 deletion src/xtc/graphs/xtc/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
#
from typing_extensions import override
from collections.abc import Sequence, Mapping
from typing import TypeAlias, cast
from typing import TypeAlias, cast, Any

from xtc.itf.graph import Graph
from xtc.itf.data import TensorType, Tensor

from .node import XTCNode
from .utils import XTCGraphUtils
from .data import XTCTensor, XTCTensorType
from .operators import XTCOperator
from yaml import dump as yaml_dump, safe_dump

__all__ = [
"XTCGraph",
Expand Down Expand Up @@ -152,3 +154,31 @@ def __str__(self) -> str:
else:
graph_str += " nodes: {}\n"
return graph_str

def to_dict(self) -> dict[str, Any]:
graph_dict: dict[str, Any] = {}
if self._name:
graph_dict["name"] = self._name
graph_dict["inputs"] = [i.to_dict() for i in self._inputs]
graph_dict["outputs"] = [{"uid": o.to_dict()["uid"]} for o in self._outputs]
graph_dict["nodes"] = [n.to_dict() for n in self._nodes]
graph_dict["ops_version"] = XTCOperator.version_string()
lowest_uid = int(graph_dict["inputs"][0]["uid"][1:])

def compact_uids(obj: Any) -> Any:
if isinstance(obj, dict):
obj = {k: compact_uids(v) for k, v in obj.items()}
elif isinstance(obj, list):
obj = [compact_uids(v) for v in obj]
elif isinstance(obj, str) and obj[0] == "%":
obj = f"%{int(obj[1:]) - lowest_uid}"
return obj

return compact_uids(graph_dict)

def dumps(self) -> str:
return safe_dump(self.to_dict())

def dump(self, file_name: str) -> None:
with open(file_name, "w") as f:
yaml_dump(self.to_dict(), f, sort_keys=False)
11 changes: 10 additions & 1 deletion src/xtc/graphs/xtc/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
from typing_extensions import override
from collections.abc import Sequence
from typing import cast
from typing import cast, Any

from xtc.itf.graph import Node
from xtc.itf.operator import Operator
Expand Down Expand Up @@ -156,3 +156,12 @@ def __str__(self) -> str:
if self.inputs_types is not None and self.outputs_types is not None:
type_str = f" : {self.inputs_types} -> {self.outputs_types}"
return str(self._expr).split("=", 1)[1].strip() + attrs_str + type_str

@override
def to_dict(self) -> dict[str, Any]:
node_dict: dict[str, Any] = {}
if self.uid != self.name:
node_dict["name"] = self.name
node_dict["uid"] = self.uid
node_dict["expr"] = self._expr.to_dict()
return node_dict
21 changes: 21 additions & 0 deletions src/xtc/graphs/xtc/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import functools
import operator
import numpy as np
import hashlib

from xtc.itf.operator import Operator
from xtc.itf.data import Tensor, TensorType
Expand Down Expand Up @@ -86,6 +87,26 @@ def _get_operation(
outs_maps=outs_maps,
)

@override
def to_dict(self) -> dict[str, Any]:
def listify(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: listify(v) for k, v in obj.items()}
elif isinstance(obj, tuple):
return [listify(v) for v in obj]
else:
return obj

op_dict: dict[str, Any] = {"name": self._name}
op_dict["attrs"] = {k: listify(v) for k, v in self._attrs.__dict__.items()}
return op_dict

@classmethod
def version_string(cls) -> str:
names = sorted(f"{c.__module__}.{c.__qualname__}" for c in cls.__subclasses__())
data = "|".join(names).encode()
return hashlib.sha256(data).hexdigest()[:16]


class XTCOperTensor(XTCOperator):
def __init__(self) -> None:
Expand Down
9 changes: 8 additions & 1 deletion src/xtc/itf/data/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
from abc import ABC, abstractmethod
from typing import Any, TypeAlias
from typing_extensions import override
from typing_extensions import override, Self
import numpy.typing


Expand Down Expand Up @@ -53,6 +53,13 @@ def ndim(self) -> int:
"""
...

@abstractmethod
def to_dict(self) -> dict[str, Any]: ...

@classmethod
@abstractmethod
def from_dict(cls, tensor_dict: dict[str, Any]) -> Self: ...


class ConstantTensorType(TensorType):
@property
Expand Down
3 changes: 3 additions & 0 deletions src/xtc/itf/graph/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,6 @@ def forward(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]:
List of output tensors
"""
...

@abstractmethod
def to_dict(self) -> dict[str, str | Sequence[TensorType]]: ...
4 changes: 4 additions & 0 deletions src/xtc/itf/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (c) 2024-2026 The XTC Project Authors
#
from abc import ABC, abstractmethod
from typing import Any
from collections.abc import Sequence
from ..data.tensor import Tensor, TensorType

Expand Down Expand Up @@ -48,3 +49,6 @@ def forward(self, inputs: Sequence[Tensor]) -> Sequence[Tensor]:
List of output tensors
"""
...

@abstractmethod
def to_dict(self) -> dict[str, Any]: ...
66 changes: 66 additions & 0 deletions tests/pytest/graphs/test_graph_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import tempfile
import xtc.graphs.xtc.op as O

def test_matmul_relu_to_from_dict():
I, J, K, dtype = 4, 32, 512, "float32"
a = O.tensor((I, K), dtype, name="A")
b = O.tensor((K, J), dtype, name="B")

with O.graph(name="matmul_relu") as gb:
m = O.matmul(a, b, name="matmul")
O.relu(m, name="relu")

graph_dict = gb.graph.to_dict()
with O.graph() as gb2:
gb2.from_dict(graph_dict)
assert graph_dict != {}
assert graph_dict == gb2.graph.to_dict()


def test_conv2d_pad_sdump_sload():
N, H, W, F, R, S, C, SH, SW, dtype = 1, 8, 8, 16, 5, 5, 3, 2, 2, "float32"
a = O.tensor((N, H, W, C), dtype, name="I")
b = O.tensor((R, S, C, F), dtype, name="W")

with O.graph(name="pad_conv2d_nhwc_mini") as gb:
p = O.pad(a, padding={1: (2), 2: (2, 2)}, name="pad")
O.conv2d(p, b, stride=(SH, SW), name="conv")

graph_str = gb.graph.dumps()
with O.graph(name="matmul_relu") as gb2:
gb2.loads(graph_str)
assert graph_str != ""
assert graph_str == gb2.graph.dumps()

with tempfile.NamedTemporaryFile(mode="w+", delete=True) as f:
gb.graph.dump(f.name)
with O.graph() as gb3:
gb3.load(f.name)
assert gb.graph.to_dict() == gb3.graph.to_dict()

def test_mlp_fc_custom_output():
img = O.tensor()
w1 = O.tensor()
w2 = O.tensor()
w3 = O.tensor()
w4 = O.tensor()
fc = lambda i, w, nout: O.matmul(O.reshape(i, shape=(1, -1)), O.reshape(w, shape=(-1, nout)))
# Multi Layer Perceptron with 3 relu(fc) + 1 fc
with O.graph(name="mlp4") as gb:
with O.graph(name="l1"):
l1 = O.relu(fc(img, w1, 512))
with O.graph(name="l2"):
l2 = O.relu(fc(l1, w2, 256))
with O.graph(name="l3"):
l3 = O.relu(fc(l2, w3, 128))
with O.graph(name="l4"):
l4 = fc(l3, w4, 10)
O.reshape(l4, shape=(-1,))
O.outputs(l1)

graph_dict = gb.graph.to_dict()
with O.graph() as gb2:
gb2.from_dict(graph_dict)
assert graph_dict != {}
assert graph_dict == gb2.graph.to_dict()

Loading