diff --git a/etils/enp/numpy_utils.py b/etils/enp/numpy_utils.py index 1041d3ea..ffd69334 100644 --- a/etils/enp/numpy_utils.py +++ b/etils/enp/numpy_utils.py @@ -26,6 +26,7 @@ from etils import epy import numpy as np +import pandas as pd if typing.TYPE_CHECKING: from etils.enp.typing import Array @@ -42,7 +43,7 @@ # When `strict=False` (in `get_xnp`, `is_array`,...), those types are also # accepted: -_ARRAY_LIKE_TYPES = (int, bool, float, list, tuple) +_ARRAY_LIKE_TYPES = (int, bool, float, list, tuple, pd.Series) # During the class construction, pytype fails because of name conflict between # the `np` `@property` and the module. diff --git a/etils/enp/numpy_utils_test.py b/etils/enp/numpy_utils_test.py index dab3aebb..a046c073 100644 --- a/etils/enp/numpy_utils_test.py +++ b/etils/enp/numpy_utils_test.py @@ -18,6 +18,7 @@ import jax import jax.numpy as jnp import numpy as np +import pandas as pd import pytest import tensorflow as tf import tensorflow.experimental.numpy as tnp @@ -78,6 +79,7 @@ def test_lazy(): assert lazy.get_xnp(np.array([123])) is np assert lazy.get_xnp(torch.Tensor([123])) is torch assert lazy.get_xnp([123], strict=False) is np + assert lazy.get_xnp(pd.Series([123]), strict=False) with pytest.raises(TypeError, match='Cannot infer the numpy'): lazy.get_xnp([123])