Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: CI

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install the project
run: uv sync --dev

- name: Run tests
run: uv run pytest

lint:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v5
with:
python-version: "3.12"

- name: Install the project
run: uv sync --dev

- name: MyPy type checking
run: uv run mypy src/

- name: Lint with ruff
run: |
uv run ruff check .
uv run ruff format --check .
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ select = ["E", "F", "I", "B", "UP"]
testpaths = ["tests"]

[tool.mypy]
python_version = "3.9"
python_version = "3.10"
ignore_missing_imports = true
4 changes: 2 additions & 2 deletions src/structree/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from __future__ import annotations

from typing import Annotated, Literal, Optional, Type, TypeVar, Union
from typing import Annotated, Literal, Type, TypeVar, Union

from pydantic import BaseModel, ConfigDict, Field

Expand Down Expand Up @@ -83,7 +83,7 @@ class StructConfig(BaseModel):

model_config = ConfigDict(arbitrary_types_allowed=True)

def __init_subclass__(cls, type: Optional[str] = None, **kwargs):
def __init_subclass__(cls, type: str | None = None, **kwargs):
super().__init_subclass__(**kwargs)
if type is not None:
cls.__annotations__ = {"type": Literal[type], **cls.__annotations__}
Expand Down
2 changes: 1 addition & 1 deletion src/structree/_flatten_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,5 +229,5 @@ def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr):
chunks = np.split(arr, indices[:-1])
return [
chunk.reshape(shape).astype(dtype)
for chunk, shape, dtype in zip(chunks, shapes, from_dtypes, strict=False)
for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)
]
14 changes: 7 additions & 7 deletions src/structree/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import dataclasses
from collections import OrderedDict
from collections.abc import Callable, Sequence
from typing import Any, NamedTuple, Optional, TypeVar
from typing import Any, NamedTuple, TypeVar

Typ = TypeVar("Typ", bound=type[Any])

Expand Down Expand Up @@ -102,24 +102,24 @@ def _dict_to_iter(d: dict):


def _dict_from_iter(keys, vals):
return dict(zip(keys, vals, strict=False))
return dict(zip(keys, vals))


register_struct(dict, _dict_to_iter, _dict_from_iter)


# OrderedDict
def _od_from_iter(keys, vals):
return OrderedDict(zip(keys, vals, strict=False))
return OrderedDict(zip(keys, vals))


register_struct(OrderedDict, _dict_to_iter, _od_from_iter)


def register_dataclass(
nodetype: Typ,
data_fields: Optional[Sequence[str]] = None,
meta_fields: Optional[Sequence[str]] = None,
data_fields: Sequence[str] | None = None,
meta_fields: Sequence[str] | None = None,
drop_fields: Sequence[str] = (),
) -> Typ:
"""
Expand Down Expand Up @@ -227,8 +227,8 @@ def register_dataclass(
raise ValueError(msg)

def unflatten_func(meta, data):
meta_args = tuple(zip(meta_fields, meta, strict=False))
data_args = tuple(zip(data_fields, data, strict=False))
meta_args = tuple(zip(meta_fields, meta))
data_args = tuple(zip(data_fields, data))
kwargs = dict(meta_args + data_args)
return nodetype(**kwargs)

Expand Down
6 changes: 3 additions & 3 deletions src/structree/_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import functools
from collections.abc import Callable
from dataclasses import InitVar, fields, replace
from typing import Any, Optional, TypeVar, Union
from typing import Any, TypeVar

from typing_extensions import dataclass_transform

Expand All @@ -55,7 +55,7 @@
def field(
static: bool = False,
*,
metadata: Optional[dict[str, Any]] = None,
metadata: dict[str, Any] | None = None,
**kwargs,
) -> dataclasses.Field:
"""
Expand Down Expand Up @@ -107,7 +107,7 @@ def field(


@dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
def struct(cls: Optional[T] = None, **kwargs) -> Union[T, Callable]:
def struct(cls: T | None = None, **kwargs) -> T | Callable:
"""
Decorator to convert a class into a tree-compatible frozen dataclass.

Expand Down
19 changes: 9 additions & 10 deletions src/structree/_tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
TYPE_CHECKING,
Any,
NamedTuple,
Optional,
TypeVar,
cast,
)
Expand All @@ -32,7 +31,7 @@


class TreeDef(NamedTuple):
node_data: Optional[tuple[type, Hashable]]
node_data: None | tuple[type, Hashable]
children: tuple[TreeDef, ...]
num_leaves: int

Expand Down Expand Up @@ -64,7 +63,7 @@ def __eq__(self, other: object) -> bool:
# Flatten/unflatten functions
#
def tree_flatten(
x: Tree, is_leaf: Optional[Callable[[Any], bool]] = None
x: Tree, is_leaf: Callable[[Any], bool] | None = None
) -> tuple[list[ArrayLike], TreeDef]:
"""
Flatten a tree into a list of leaves and a treedef.
Expand Down Expand Up @@ -113,7 +112,7 @@ def tree_flatten(


def _tree_flatten(
x: Tree, is_leaf: Optional[Callable[[Any], bool]]
x: Tree, is_leaf: Callable[[Any], bool] | None
) -> tuple[Iterable, TreeDef]:
if x is None:
return [], NONE_DEF
Expand Down Expand Up @@ -197,7 +196,7 @@ def _tree_unflatten(treedef: TreeDef, xs: Iterator) -> Tree:
#


def tree_structure(tree: Tree, is_leaf: Optional[Callable[[Any], bool]] = None) -> TreeDef:
def tree_structure(tree: Tree, is_leaf: Callable[[Any], bool] | None = None) -> TreeDef:
"""
Extract the structure of a tree without the leaf values.

Expand Down Expand Up @@ -227,7 +226,7 @@ def tree_structure(tree: Tree, is_leaf: Optional[Callable[[Any], bool]] = None)


def tree_leaves(
tree: Tree, is_leaf: Optional[Callable[[Any], bool]] = None
tree: Tree, is_leaf: Callable[[Any], bool] | None = None
) -> list[ArrayLike]:
"""
Extract all leaf values from a tree.
Expand All @@ -254,7 +253,7 @@ def tree_leaves(
return flat


def tree_all(tree: Tree, is_leaf: Optional[Callable[[Any], bool]] = None) -> bool:
def tree_all(tree: Tree, is_leaf: Callable[[Any], bool] | None = None) -> bool:
"""
Check if all leaves in the tree evaluate to True.

Expand Down Expand Up @@ -284,7 +283,7 @@ def tree_map(
f: Callable,
tree: T,
*rest: tuple[T, ...],
is_leaf: Optional[Callable[[Any], bool]] = None,
is_leaf: Callable[[Any], bool] | None = None,
) -> T:
"""
Apply a function to each leaf in a tree.
Expand Down Expand Up @@ -343,7 +342,7 @@ def tree_map(
)
flat.append(r_flat)

flat = [f(*args) for args in zip(*flat, strict=False)]
flat = [f(*args) for args in zip(*flat)]
tree_out: T = tree_unflatten(treedef, flat) # type: ignore
return tree_out

Expand All @@ -352,7 +351,7 @@ def tree_reduce(
function: Callable[[V, ArrayLike], V],
tree: Tree,
initializer: V,
is_leaf: Optional[Callable[[Any], bool]] = None,
is_leaf: Callable[[Any], bool] | None = None,
) -> V:
"""
Reduce a tree to a single value using a function and initializer.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ class _VariantBase:
pass


@struct(kw_only=True)
@struct
class _VariantA(_VariantBase):
param: float

def __call__(self):
return self.param


@struct(kw_only=True)
@struct
class _VariantB(_VariantBase):
param: float

Expand Down
7 changes: 4 additions & 3 deletions tests/test_ravel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import numpy as np
import pytest

# Symbolic leaves come from Archimedes (dev dependency).
# Symbolic leaves come from Archimedes (dev dependency, requires Python 3.11+).
archimedes = pytest.importorskip("archimedes", reason="archimedes not available")
from archimedes._core import SymbolicArray, compile, sym # noqa: E402

import structree as tree
from structree import field, struct
import structree as tree # noqa: E402
from structree import field, struct # noqa: E402


class TestRavel:
Expand Down
6 changes: 5 additions & 1 deletion tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

def assert_flat_equal(f1, f2):
assert len(f1) == len(f2)
for a, b in zip(f1, f2, strict=False):
for a, b in zip(f1, f2):
if isinstance(a, np.ndarray):
assert np.allclose(a, b)
else:
Expand Down Expand Up @@ -284,6 +284,10 @@ class Point5:
y: float


@pytest.mark.skipif(
__import__("sys").version_info < (3, 10),
reason="kw_only requires Python 3.10+",
)
def test_register_struct_decorator():
@struct(kw_only=True)
class Point:
Expand Down
Loading