diff --git a/etils/enp/numpy_utils.py b/etils/enp/numpy_utils.py index a0a260c..dd420a5 100644 --- a/etils/enp/numpy_utils.py +++ b/etils/enp/numpy_utils.py @@ -174,13 +174,15 @@ def is_dtype(self, dtype) -> bool: ) def as_np_dtype(self, dtype): - if self.is_tf_dtype(dtype): + if self.is_np_dtype(dtype) or self.is_jax_dtype(dtype): + pass + elif self.is_tf_dtype(dtype): dtype = dtype.as_numpy_dtype elif self.is_torch_dtype(dtype): from etils.enp import compat # pylint: disable=g-import-not-at-top dtype = compat.dtype_torch_to_np(dtype) - elif not self.is_jax_dtype(dtype) and not self.is_np_dtype(dtype): + else: raise TypeError(f'Invalid dtype: {dtype!r}') return np.dtype(dtype)