From 91dd84d881cdc3e0a76d5b655f9e5e8e751b86ba Mon Sep 17 00:00:00 2001 From: Peter Zhizhin Date: Tue, 11 Jun 2024 02:29:56 -0700 Subject: [PATCH] Add dill support for enp typings Currently, running some functions on Apache Beam leads to errors. This can be boiled down to this piece of code not working: ``` import dill from etils.enp.typing import ui8 dill.dumps(ui8['N']) >> PicklingError: Can't pickle ui8[N]: it's not found as etils.enp.array_types.typing.ui8 ``` Additionally, if `dill` is not available, the behavior doesn't change. PiperOrigin-RevId: 642193800 --- etils/enp/array_types/typing.py | 11 +++++++++++ etils/enp/array_types/typing_test.py | 19 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/etils/enp/array_types/typing.py b/etils/enp/array_types/typing.py index 5c177cb4..2e9453f2 100644 --- a/etils/enp/array_types/typing.py +++ b/etils/enp/array_types/typing.py @@ -19,6 +19,7 @@ from etils.enp.array_types import dtypes import numpy as np + _T = TypeVar('_T') # Match both `np.dtype('int32')` and np.int32 @@ -98,6 +99,16 @@ def __instancecheck__(cls, instance: np.ndarray) -> bool: """`isinstance(array, f32['h w c'])`.""" raise NotImplementedError +try: + import dill # pylint: disable=g-import-not-at-top # pytype: disable=import-error + + @dill.register(ArrayAliasMeta) + def _save_array_alias_meta(pickler, obj: ArrayAliasMeta) -> None: + args = (obj.shape, obj.dtype) + pickler.save_reduce(ArrayAliasMeta, args, obj=obj) + +except ImportError: + pass def _normalize_shape_item(item: _ShapeItem) -> ShapeSpec: """Returns the `str` representation associated with the shape element.""" diff --git a/etils/enp/array_types/typing_test.py b/etils/enp/array_types/typing_test.py index d49a56bf..ddfdfb8d 100644 --- a/etils/enp/array_types/typing_test.py +++ b/etils/enp/array_types/typing_test.py @@ -20,6 +20,14 @@ import numpy as np import pytest +try: + import dill + + _DILL_AVAILABLE = True +except ImportError: + dill = None + _DILL_AVAILABLE = False + # TODO(epot): Add `bfloat16` to array_types. Not this might require some # LazyDType to lazy-load jax. bf16 = enp.typing.ArrayAliasMeta(shape=None, dtype=np.dtype(jnp.bfloat16)) @@ -89,3 +97,14 @@ def test_array_eq(): assert f32['h w'] != ui8['h w'] assert {f32['h w'], f32['h w'], f32['h', 'w']} == {f32['h w']} + + +@pytest.mark.skipif(not _DILL_AVAILABLE, reason='dill not available') +def test_f32_can_be_pickled_unpickled_with_dill(): + assert dill is not None, ( + 'dill library is not available. We should have skipped this test, but for' + ' some reason we did not.' + ) + my_type = f32['N'] + my_type_pickled_unpickled = dill.loads(dill.dumps(my_type)) + assert my_type_pickled_unpickled == my_type