diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..b3b9f44 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,50 @@ +name: CI + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install the project + run: uv sync --dev + + - name: Run tests + run: uv run pytest + + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + python-version: "3.12" + + - name: Install the project + run: uv sync --dev + + - name: MyPy type checking + run: uv run mypy src/ + + - name: Lint with ruff + run: | + uv run ruff check . + uv run ruff format --check . diff --git a/pyproject.toml b/pyproject.toml index 8aa3110..a021ef4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,5 +49,5 @@ select = ["E", "F", "I", "B", "UP"] testpaths = ["tests"] [tool.mypy] -python_version = "3.9" +python_version = "3.10" ignore_missing_imports = true diff --git a/src/structree/_config.py b/src/structree/_config.py index 39252c7..ba73287 100644 --- a/src/structree/_config.py +++ b/src/structree/_config.py @@ -12,7 +12,7 @@ from __future__ import annotations -from typing import Annotated, Literal, Optional, Type, TypeVar, Union +from typing import Annotated, Literal, Type, TypeVar, Union from pydantic import BaseModel, ConfigDict, Field @@ -83,7 +83,7 @@ class StructConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - def __init_subclass__(cls, type: Optional[str] = None, **kwargs): + def __init_subclass__(cls, type: str | None = None, **kwargs): super().__init_subclass__(**kwargs) if type is not None: cls.__annotations__ = {"type": Literal[type], **cls.__annotations__} diff --git a/src/structree/_flatten_util.py b/src/structree/_flatten_util.py index b556207..4870a5f 100644 --- a/src/structree/_flatten_util.py +++ b/src/structree/_flatten_util.py @@ -229,5 +229,5 @@ def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr): chunks = np.split(arr, indices[:-1]) return [ chunk.reshape(shape).astype(dtype) - for chunk, shape, dtype in zip(chunks, shapes, from_dtypes, strict=False) + for chunk, shape, dtype in zip(chunks, shapes, from_dtypes) ] diff --git a/src/structree/_registry.py b/src/structree/_registry.py index 43b5ef9..5cf6844 100644 --- a/src/structree/_registry.py +++ b/src/structree/_registry.py @@ -15,7 +15,7 @@ import dataclasses from collections import OrderedDict from collections.abc import Callable, Sequence -from typing import Any, NamedTuple, Optional, TypeVar +from typing import Any, NamedTuple, TypeVar Typ = TypeVar("Typ", bound=type[Any]) @@ -102,7 +102,7 @@ def _dict_to_iter(d: dict): def _dict_from_iter(keys, vals): - return dict(zip(keys, vals, strict=False)) + return dict(zip(keys, vals)) register_struct(dict, _dict_to_iter, _dict_from_iter) @@ -110,7 +110,7 @@ def _dict_from_iter(keys, vals): # OrderedDict def _od_from_iter(keys, vals): - return OrderedDict(zip(keys, vals, strict=False)) + return OrderedDict(zip(keys, vals)) register_struct(OrderedDict, _dict_to_iter, _od_from_iter) @@ -118,8 +118,8 @@ def _od_from_iter(keys, vals): def register_dataclass( nodetype: Typ, - data_fields: Optional[Sequence[str]] = None, - meta_fields: Optional[Sequence[str]] = None, + data_fields: Sequence[str] | None = None, + meta_fields: Sequence[str] | None = None, drop_fields: Sequence[str] = (), ) -> Typ: """ @@ -227,8 +227,8 @@ def register_dataclass( raise ValueError(msg) def unflatten_func(meta, data): - meta_args = tuple(zip(meta_fields, meta, strict=False)) - data_args = tuple(zip(data_fields, data, strict=False)) + meta_args = tuple(zip(meta_fields, meta)) + data_args = tuple(zip(data_fields, data)) kwargs = dict(meta_args + data_args) return nodetype(**kwargs) diff --git a/src/structree/_struct.py b/src/structree/_struct.py index 7d82471..da475af 100644 --- a/src/structree/_struct.py +++ b/src/structree/_struct.py @@ -33,7 +33,7 @@ import functools from collections.abc import Callable from dataclasses import InitVar, fields, replace -from typing import Any, Optional, TypeVar, Union +from typing import Any, TypeVar from typing_extensions import dataclass_transform @@ -55,7 +55,7 @@ def field( static: bool = False, *, - metadata: Optional[dict[str, Any]] = None, + metadata: dict[str, Any] | None = None, **kwargs, ) -> dataclasses.Field: """ @@ -107,7 +107,7 @@ def field( @dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required] -def struct(cls: Optional[T] = None, **kwargs) -> Union[T, Callable]: +def struct(cls: T | None = None, **kwargs) -> T | Callable: """ Decorator to convert a class into a tree-compatible frozen dataclass. diff --git a/src/structree/_tree_util.py b/src/structree/_tree_util.py index f15e126..884f968 100644 --- a/src/structree/_tree_util.py +++ b/src/structree/_tree_util.py @@ -15,7 +15,6 @@ TYPE_CHECKING, Any, NamedTuple, - Optional, TypeVar, cast, ) @@ -32,7 +31,7 @@ class TreeDef(NamedTuple): - node_data: Optional[tuple[type, Hashable]] + node_data: None | tuple[type, Hashable] children: tuple[TreeDef, ...] num_leaves: int @@ -64,7 +63,7 @@ def __eq__(self, other: object) -> bool: # Flatten/unflatten functions # def tree_flatten( - x: Tree, is_leaf: Optional[Callable[[Any], bool]] = None + x: Tree, is_leaf: Callable[[Any], bool] | None = None ) -> tuple[list[ArrayLike], TreeDef]: """ Flatten a tree into a list of leaves and a treedef. @@ -113,7 +112,7 @@ def tree_flatten( def _tree_flatten( - x: Tree, is_leaf: Optional[Callable[[Any], bool]] + x: Tree, is_leaf: Callable[[Any], bool] | None ) -> tuple[Iterable, TreeDef]: if x is None: return [], NONE_DEF @@ -197,7 +196,7 @@ def _tree_unflatten(treedef: TreeDef, xs: Iterator) -> Tree: # -def tree_structure(tree: Tree, is_leaf: Optional[Callable[[Any], bool]] = None) -> TreeDef: +def tree_structure(tree: Tree, is_leaf: Callable[[Any], bool] | None = None) -> TreeDef: """ Extract the structure of a tree without the leaf values. @@ -227,7 +226,7 @@ def tree_structure(tree: Tree, is_leaf: Optional[Callable[[Any], bool]] = None) def tree_leaves( - tree: Tree, is_leaf: Optional[Callable[[Any], bool]] = None + tree: Tree, is_leaf: Callable[[Any], bool] | None = None ) -> list[ArrayLike]: """ Extract all leaf values from a tree. @@ -254,7 +253,7 @@ def tree_leaves( return flat -def tree_all(tree: Tree, is_leaf: Optional[Callable[[Any], bool]] = None) -> bool: +def tree_all(tree: Tree, is_leaf: Callable[[Any], bool] | None = None) -> bool: """ Check if all leaves in the tree evaluate to True. @@ -284,7 +283,7 @@ def tree_map( f: Callable, tree: T, *rest: tuple[T, ...], - is_leaf: Optional[Callable[[Any], bool]] = None, + is_leaf: Callable[[Any], bool] | None = None, ) -> T: """ Apply a function to each leaf in a tree. @@ -343,7 +342,7 @@ def tree_map( ) flat.append(r_flat) - flat = [f(*args) for args in zip(*flat, strict=False)] + flat = [f(*args) for args in zip(*flat)] tree_out: T = tree_unflatten(treedef, flat) # type: ignore return tree_out @@ -352,7 +351,7 @@ def tree_reduce( function: Callable[[V, ArrayLike], V], tree: Tree, initializer: V, - is_leaf: Optional[Callable[[Any], bool]] = None, + is_leaf: Callable[[Any], bool] | None = None, ) -> V: """ Reduce a tree to a single value using a function and initializer. diff --git a/tests/test_config.py b/tests/test_config.py index dc1e866..f4f97f6 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,7 +13,7 @@ class _VariantBase: pass -@struct(kw_only=True) +@struct class _VariantA(_VariantBase): param: float @@ -21,7 +21,7 @@ def __call__(self): return self.param -@struct(kw_only=True) +@struct class _VariantB(_VariantBase): param: float diff --git a/tests/test_ravel.py b/tests/test_ravel.py index ec0d952..fd85ff0 100644 --- a/tests/test_ravel.py +++ b/tests/test_ravel.py @@ -7,11 +7,12 @@ import numpy as np import pytest -# Symbolic leaves come from Archimedes (dev dependency). +# Symbolic leaves come from Archimedes (dev dependency, requires Python 3.11+). +archimedes = pytest.importorskip("archimedes", reason="archimedes not available") from archimedes._core import SymbolicArray, compile, sym # noqa: E402 -import structree as tree -from structree import field, struct +import structree as tree # noqa: E402 +from structree import field, struct # noqa: E402 class TestRavel: diff --git a/tests/test_struct.py b/tests/test_struct.py index 76cfff5..39a9736 100644 --- a/tests/test_struct.py +++ b/tests/test_struct.py @@ -66,7 +66,7 @@ def assert_flat_equal(f1, f2): assert len(f1) == len(f2) - for a, b in zip(f1, f2, strict=False): + for a, b in zip(f1, f2): if isinstance(a, np.ndarray): assert np.allclose(a, b) else: @@ -284,6 +284,10 @@ class Point5: y: float +@pytest.mark.skipif( + __import__("sys").version_info < (3, 10), + reason="kw_only requires Python 3.10+", +) def test_register_struct_decorator(): @struct(kw_only=True) class Point: