diff --git a/etils/enp/array_types/typing.py b/etils/enp/array_types/typing.py index 5c177cb4..2e9453f2 100644 --- a/etils/enp/array_types/typing.py +++ b/etils/enp/array_types/typing.py @@ -19,6 +19,7 @@ from etils.enp.array_types import dtypes import numpy as np + _T = TypeVar('_T') # Match both `np.dtype('int32')` and np.int32 @@ -98,6 +99,16 @@ def __instancecheck__(cls, instance: np.ndarray) -> bool: """`isinstance(array, f32['h w c'])`.""" raise NotImplementedError +try: + import dill # pylint: disable=g-import-not-at-top # pytype: disable=import-error + + @dill.register(ArrayAliasMeta) + def _save_array_alias_meta(pickler, obj: ArrayAliasMeta) -> None: + args = (obj.shape, obj.dtype) + pickler.save_reduce(ArrayAliasMeta, args, obj=obj) + +except ImportError: + pass def _normalize_shape_item(item: _ShapeItem) -> ShapeSpec: """Returns the `str` representation associated with the shape element.""" diff --git a/etils/enp/array_types/typing_test.py b/etils/enp/array_types/typing_test.py index d49a56bf..ddfdfb8d 100644 --- a/etils/enp/array_types/typing_test.py +++ b/etils/enp/array_types/typing_test.py @@ -20,6 +20,14 @@ import numpy as np import pytest +try: + import dill + + _DILL_AVAILABLE = True +except ImportError: + dill = None + _DILL_AVAILABLE = False + # TODO(epot): Add `bfloat16` to array_types. Not this might require some # LazyDType to lazy-load jax. bf16 = enp.typing.ArrayAliasMeta(shape=None, dtype=np.dtype(jnp.bfloat16)) @@ -89,3 +97,14 @@ def test_array_eq(): assert f32['h w'] != ui8['h w'] assert {f32['h w'], f32['h w'], f32['h', 'w']} == {f32['h w']} + + +@pytest.mark.skipif(not _DILL_AVAILABLE, reason='dill not available') +def test_f32_can_be_pickled_unpickled_with_dill(): + assert dill is not None, ( + 'dill library is not available. We should have skipped this test, but for' + ' some reason we did not.' + ) + my_type = f32['N'] + my_type_pickled_unpickled = dill.loads(dill.dumps(my_type)) + assert my_type_pickled_unpickled == my_type