From e2b5ea32100cfe857c63fe074aa8d074e7c7d228 Mon Sep 17 00:00:00 2001 From: Klaus Greff Date: Sun, 15 Mar 2026 12:35:17 -0700 Subject: [PATCH] Fix `as_np_dtype` to check numpy/jax dtypes before torch dtypes PiperOrigin-RevId: 884072260 --- etils/enp/numpy_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/etils/enp/numpy_utils.py b/etils/enp/numpy_utils.py index a0a260c8..dd420a57 100644 --- a/etils/enp/numpy_utils.py +++ b/etils/enp/numpy_utils.py @@ -174,13 +174,15 @@ def is_dtype(self, dtype) -> bool: ) def as_np_dtype(self, dtype): - if self.is_tf_dtype(dtype): + if self.is_np_dtype(dtype) or self.is_jax_dtype(dtype): + pass + elif self.is_tf_dtype(dtype): dtype = dtype.as_numpy_dtype elif self.is_torch_dtype(dtype): from etils.enp import compat # pylint: disable=g-import-not-at-top dtype = compat.dtype_torch_to_np(dtype) - elif not self.is_jax_dtype(dtype) and not self.is_np_dtype(dtype): + else: raise TypeError(f'Invalid dtype: {dtype!r}') return np.dtype(dtype)