diff --git a/tunix/generate/utils.py b/tunix/generate/utils.py index 4f5b86cd8..37d236eac 100644 --- a/tunix/generate/utils.py +++ b/tunix/generate/utils.py @@ -706,8 +706,10 @@ def _apply_dtype_cast( val: jnp.ndarray, tgt_dtype: jnp.dtype, src_key: str ) -> jnp.ndarray: if val.dtype != tgt_dtype: - logging.warning( + logging.log_first_n( + logging.WARNING, 'Type mismatch on %s: %s -> %s', + 1, src_key, val.dtype, tgt_dtype,