From 67ec2955f660397b78b674a044bb2adb8079154f Mon Sep 17 00:00:00 2001 From: Chaitanya Mishra Date: Sat, 24 Jan 2026 14:29:28 +0530 Subject: [PATCH] linalg: add eps to normalize --- etils/enp/linalg.py | 20 +++++++++++++++++--- etils/enp/linalg_test.py | 13 +++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/etils/enp/linalg.py b/etils/enp/linalg.py index 71557da8..fd6cd778 100644 --- a/etils/enp/linalg.py +++ b/etils/enp/linalg.py @@ -23,6 +23,20 @@ lazy = numpy_utils.lazy -def normalize(x: FloatArray['*d'], axis: int = -1) -> FloatArray['*d']: - """Normalize the vector to the unit norm.""" - return x / compat.norm(x, axis=axis, keepdims=True) +def normalize( + x: FloatArray['*d'], + axis: int = -1, + *, + eps: float = 0.0, +) -> FloatArray['*d']: + """Normalize the vector to the unit norm. + + Args: + x: Input array to normalize. + axis: Axis along which to compute the norm. + eps: Optional epsilon to avoid division by zero for zero-norm vectors. + """ + denom = compat.norm(x, axis=axis, keepdims=True) + if eps: + denom = denom + eps + return x / denom diff --git a/etils/enp/linalg_test.py b/etils/enp/linalg_test.py index b99bf46f..d6ad9ce7 100644 --- a/etils/enp/linalg_test.py +++ b/etils/enp/linalg_test.py @@ -29,6 +29,16 @@ def test_normalize(xnp: enp.NpModule): assert y.shape == x.shape np.testing.assert_allclose(y, [1.0, 0.0, 0.0]) + y = enp.linalg.normalize(x, eps=1e-6) + np.testing.assert_allclose(y, [1.0, 0.0, 0.0], rtol=1e-6, atol=1e-6) + + +@enp.testing.parametrize_xnp() +def test_normalize_eps_zero(xnp: enp.NpModule): + x = xnp.asarray([0.0, 0.0, 0.0]) + y = enp.linalg.normalize(x, eps=1e-6) + assert enp.compat.is_array_xnp(y, xnp) + np.testing.assert_allclose(y, [0.0, 0.0, 0.0]) @enp.testing.parametrize_xnp() def test_normalize_batched(xnp: enp.NpModule): @@ -52,6 +62,9 @@ def test_normalize_batched(xnp: enp.NpModule): ], ) + y_eps = enp.linalg.normalize(x, eps=1e-6) + np.testing.assert_allclose(y_eps, y, rtol=1e-6, atol=1e-6) + @enp.testing.parametrize_xnp() def test_norm(xnp: enp.NpModule):