It's unclear to me why the following code does not work as MultivariateNormalDiag supports batch dimensions for loc and scale:
import distrax as dx
import jax
import jax.numpy as jnp
from jax import vmap
@jax.jit
def build():
def single(i):
return dx.MultivariateNormalDiag(jnp.zeros(10), jnp.ones(10))
x = vmap(single)(jnp.arange(10))
return x
dist = build()
dist.loc
produces the following error:
Traceback (most recent call last):
File ".../test.py", line 17, in <module>
dist.loc
File ".../python3.12/site-packages/distrax/_src/distributions/mvn_from_bijector.py", line 103, in loc
return jnp.broadcast_to(self._loc, shape=shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 2087, in broadcast_to
return util._broadcast_to(array, shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../python3.12/site-packages/jax/_src/numpy/util.py", line 422, in _broadcast_to
raise ValueError(f"Cannot broadcast to shape with fewer dimensions: {arr_shape=} {shape=}")
ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(10, 10) shape=(10,)
This seems similar to #239
Ah, I see in the README that this distribution is specifically called out for being problematic with vmap.
It's unclear to me why the following code does not work as
MultivariateNormalDiagsupports batch dimensions forlocandscale:produces the following error:
This seems similar to #239
Ah, I see in the README that this distribution is specifically called out for being problematic with vmap.