Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c185f50
added a multi-gaussian model and failing tests
dimkab Mar 16, 2026
751440f
Dirac observation path passing
dimkab Mar 17, 2026
44e1d96
simplified
dimkab Mar 17, 2026
0ba2fb4
Merge branch 'main' into db-missingness
dimkab Mar 17, 2026
2ee9a36
lint
dimkab Mar 17, 2026
ca089c9
Merge branch 'main' into db-missingness
dimkab Mar 18, 2026
e6e4d6f
Merge branch 'main' into db-missingness
dimkab Mar 18, 2026
52427d1
partial missingness - initial implementation
dimkab Mar 18, 2026
62bffb7
arviz
dimkab Mar 18, 2026
af8f60b
Merge branch 'main' into db-missingness-partial
dimkab Mar 18, 2026
5ae14ca
pinning arviz
dimkab Mar 18, 2026
14f95a8
switched back to plot_posterior; more plots to missingness tests
dimkab Mar 18, 2026
c3e3d16
Merge branch 'main' into db-missingness
DanWaxman Mar 18, 2026
9bb4d52
Merge branch 'db-missingness' into db-missingness-partial
dimkab Mar 18, 2026
1fbd93b
Merge branch 'main' into db-missingness-partial
dimkab Mar 18, 2026
b0ba128
refactored simulate and introduced diagonal observation models
dimkab Mar 18, 2026
6faba3f
Merge branch 'main' into db-missingness-partial
dimkab Mar 19, 2026
6a8c448
Merge branch 'db-missingness-partial' of https://github.com/BasisRese…
dimkab Mar 19, 2026
35d3ce4
lint
dimkab Mar 19, 2026
d799371
consolidated tests
dimkab Mar 19, 2026
18c8e34
masked_log_prob already implemented in base class
dimkab Mar 19, 2026
bc060a7
removed the birth-death explicit support, simplified the simulator ac…
dimkab Mar 19, 2026
3ac9278
nit
dimkab Mar 19, 2026
aa99758
further cleanup
dimkab Mar 19, 2026
054f5dc
Merge branch 'main' into db-missingness-partial
dimkab Mar 20, 2026
6cd748b
plot tweaking
dimkab Mar 20, 2026
45957ca
fixed title
dimkab Mar 20, 2026
cb03d55
unroll_missing=True is the default now
dimkab Mar 20, 2026
baa89e2
nit
dimkab Mar 20, 2026
8d7498c
refactor to use numpyro.mask
dimkab Mar 20, 2026
e109e74
testa pass
dimkab Mar 21, 2026
90999d0
tests run faster
dimkab Mar 23, 2026
f47ad01
better mcmc posterior
dimkab Mar 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,185 changes: 1,185 additions & 0 deletions docs/tutorials/gentle_intro/08_missing_observations.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions dynestyx/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)
from dynestyx.models.lti_dynamics import LTI_continuous, LTI_discrete
from dynestyx.models.observations import (
DiagonalGaussianObservation,
DiagonalLinearGaussianObservation,
DiracIdentityObservation,
GaussianObservation,
LinearGaussianObservation,
Expand All @@ -25,6 +27,8 @@
__all__ = [
"ContinuousTimeStateEvolution",
"AffineDrift",
"DiagonalGaussianObservation",
"DiagonalLinearGaussianObservation",
"DiracIdentityObservation",
"DiscreteTimeStateEvolution",
"DynamicalModel",
Expand Down
40 changes: 40 additions & 0 deletions dynestyx/models/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Core interfaces and base classes for dynamical models."""

import dataclasses
from abc import abstractmethod
from collections.abc import Callable
from typing import Any, Protocol

Expand Down Expand Up @@ -375,3 +376,42 @@ def sample(self, x, u, t, *args, **kwargs):
seed = kwargs.pop("seed")
kwargs["key"] = seed
return dist.sample(*args, **kwargs)

@abstractmethod
def __call__(self, x, u, t) -> DistributionT: ...

def masked_log_prob(
self,
y: jax.Array,
obs_mask: jax.Array,
x: Any,
u: Any = None,
t: Any = None,
) -> jax.Array:
"""Log p(y_obs | x) scoring only observed dimensions.

Args:
y: Observation with NaN replaced by safe values. Shape (obs_dim,).
obs_mask: Boolean array, True = observed. Shape (obs_dim,).
x: Latent state.
u: Control or None.
t: Time or None.

Returns:
Scalar log-probability summed over observed dims only.
"""
import numpyro.distributions as _dist_mod

d = self(x, u, t)
# Unwrap Independent(base, 1) to get per-element log_probs
base = d
if isinstance(d, _dist_mod.Independent) and d.reinterpreted_batch_ndims == 1:
base = d.base_dist
per_dim_lp = base.log_prob(y) # (obs_dim,) if base is element-wise
if jnp.ndim(per_dim_lp) == 0:
raise NotImplementedError(
f"{type(self).__name__}.masked_log_prob: distribution "
f"{type(d).__name__} does not decompose per-dimension. "
"Override masked_log_prob in the subclass."
)
return jnp.sum(jnp.where(obs_mask, per_dim_lp, 0.0))
96 changes: 96 additions & 0 deletions dynestyx/models/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,95 @@ def __call__(self, x, u, t):
return dist.MultivariateNormal(loc=loc, covariance_matrix=self.R)


class DiagonalLinearGaussianObservation(ObservationModel):
"""
Linear-Gaussian observation model with diagonal noise covariance.

Observations are modeled as

$$
y_t \\sim \\mathcal{N}(H x_t + D u_t + b, \\mathrm{diag}(R_{\\mathrm{diag}})).
$$

Because the noise dims are independent, `masked_log_prob` is exact for any
obs mask — no sub-block extraction is required.
"""

H: jax.Array
R_diag: jax.Array
D: jax.Array | None = None
bias: jax.Array | None = None

def __init__(
self,
H: jax.Array,
R_diag: jax.Array,
D: jax.Array | None = None,
bias: jax.Array | None = None,
):
"""
Args:
H (jax.Array): Observation matrix with shape $(d_y, d_x)$.
R_diag (jax.Array): Per-dimension noise variances with shape
$(d_y,)$.
D (jax.Array | None): Optional control matrix with shape
$(d_y, d_u)$.
bias (jax.Array | None): Optional additive bias with shape
$(d_y,)$.
"""
self.H = H
self.R_diag = R_diag
self.D = D
self.bias = bias

def __call__(self, x, u, t):
loc = jnp.dot(self.H, x)
if self.D is not None and u is not None:
loc += jnp.dot(self.D, u)
if self.bias is not None:
loc += self.bias
return dist.Independent(dist.Normal(loc, jnp.sqrt(self.R_diag)), 1)


class DiagonalGaussianObservation(ObservationModel):
"""
Nonlinear Gaussian observation model with diagonal noise covariance.

Observations are modeled as

$$
y_t \\sim \\mathcal{N}(h(x_t, u_t, t), \\mathrm{diag}(R_{\\mathrm{diag}})),
$$

where $h$ is a user-provided measurement function.

Because the noise dims are independent, `masked_log_prob` is exact for any
obs mask — no sub-block extraction is required.
"""

h: Callable[[State, Control, Time], Observation]
R_diag: jax.Array

def __init__(
self,
h: Callable[[State, Control, Time], jax.Array],
R_diag: jax.Array,
):
"""
Args:
h (Callable[[State, Control, Time], jax.Array]): Measurement
function mapping $(x, u, t)$ to the mean observation.
R_diag (jax.Array): Per-dimension noise variances with shape
$(d_y,)$.
"""
self.h = h
self.R_diag = R_diag

def __call__(self, x, u, t):
loc = self.h(x, u, t)
return dist.Independent(dist.Normal(loc, jnp.sqrt(self.R_diag)), 1)


class DiracIdentityObservation(ObservationModel):
"""
Noise-free identity observation model.
Expand All @@ -110,3 +199,10 @@ class DiracIdentityObservation(ObservationModel):

def __call__(self, x, u, t):
return dist.Delta(x)

def masked_log_prob(self, y, obs_mask, x, u=None, t=None):
raise NotImplementedError(
"DiracIdentityObservation does not support partial missingness. "
"Use DiagonalLinearGaussianObservation or DiagonalGaussianObservation "
"with small sigma_obs instead."
)
Loading