Skip to content

Missing data: partial missingness, unroll_missing, diagonal obs models#149

Open
dimkab wants to merge 33 commits intomainfrom
db-missingness-partial
Open

Missing data: partial missingness, unroll_missing, diagonal obs models#149
dimkab wants to merge 33 commits intomainfrom
db-missingness-partial

Conversation

@dimkab
Copy link

@dimkab dimkab commented Mar 18, 2026

Summary

  • Adds per-dimension (partial) missingness support to DiscreteTimeSimulator
  • unroll_missing=True (default) produces full-length output with gap imputation when entire rows are missing
  • New ObservationModel.masked_log_prob interface for per-dimension scoring
  • New DiagonalLinearGaussianObservation and DiagonalGaussianObservation observation models
  • Tutorial notebook (08) covering three missingness patterns + gap imputation + interacting particles

Core changes

dynestyx/simulators.py

unroll_missing: bool = True field on DiscreteTimeSimulator. The _simulate method dispatches to:

  • _simulate_missing_scan — when partial NaNs or unroll_missing=True with missing rows
  • _simulate_row_filter — when unroll_missing=False with missing rows
  • _simulate_plate / _simulate_scan — no missingness

_simulate_missing_scan has two sub-paths:

Sub-path A (DiracIdentityObservation) — factor + masked-sample:

  • Observed dims: numpyro.factor(obs_lp) scores the transition log-prob
  • Latent dims: numpyro.handlers.mask(mask=~obs_mask) { numpyro.sample(trans_dist) } — the mask zeros out both model and guide ELBO contributions at observed dims, so AutoNormal allocates parameters but they are inert (zero gradient)
  • Two variants: _step_dirac_partial (per-dim vector mask, requires Independent(Normal, 1) transitions, uses base_dist) and _step_dirac_wholerow (scalar mask, works with any transition including MultivariateNormal)

Sub-path B (non-Dirac obs models):

  • State is always latent (sampled unconditionally)
  • Observation scoring via obs_model.masked_log_prob(y, obs_mask, x, ...) for partial rows, or full log_prob zeroed at missing rows via numpyro.handlers.mask

dynestyx/models/

  • ObservationModel.masked_log_prob: scores only observed dims; default raises NotImplementedError
  • DiagonalLinearGaussianObservation: linear H with diagonal R, supports masked_log_prob
  • DiagonalGaussianObservation: nonlinear obs function with diagonal R, supports masked_log_prob

Tests

  • tests/test_with_missing_data.py — 39 tests covering all combinations of:
    • Models: LTI, diagonal obs (linear + nonlinear), particle SDE, interacting particles
    • Missingness: none, random, sequential, block, partial
    • unroll_missing: True/False
    • Inference: MCMC and SVI
  • Smoke/full mode via DYNESTYX_SMOKE_TEST env var (set in scripts/test.sh)
    • Smoke: small T, few samples, gradient-flow checks only
    • Full: convergence + posterior accuracy checks

Tutorial notebook (08)

Sections: whole-row block missingness (default + row-filter paths, side-by-side posterior comparison), per-dimension partial missingness, per-particle contiguous gaps with SVI.

Test plan

  • scripts/test.sh — smoke tests pass (~2 min)
  • scripts/test_full.sh — full convergence tests pass (~16 min)
  • Notebook runs end-to-end

🤖 Generated with Claude Code

@mattlevine22
Copy link
Collaborator

@dimkab just updated main to pin arviz #150

@dimkab dimkab self-assigned this Mar 18, 2026
@dimkab
Copy link
Author

dimkab commented Mar 18, 2026

@dimkab just updated main to pin arviz #150

👍

@dimkab dimkab marked this pull request as draft March 18, 2026 19:25
@dimkab dimkab changed the base branch from db-missingness to main March 18, 2026 19:26
@dimkab dimkab changed the title [DRAFT] partial missingness - initial implementation partial missingness - diagonal gaussian / dirac observations only Mar 19, 2026
@dimkab dimkab marked this pull request as ready for review March 19, 2026 23:28
@mattlevine22
Copy link
Collaborator

This is fantastic @dimkab , thank you!

So far I went through the notebook:

  • I really like the explanations / examples / plots
  • clearly it is working, yay!
  • I like the unroll_missing=False flag as being specified via a DiscreteTimeSimulator parameter...seems very intuitive.

Some thoughts / questions (will look more carefully and possibly answer my own questions):

  • I feel like unroll_missing=True can be the default. If you purposefully include a whole missing block, I'd assume the user wants us to do something with it. This way, they can just input data with NaN and it either "works" or "errors"
  • Is there an informative error for case w/ NaNs + non-diagonal observation?
  • For the particle example, I'd suggest building the SDE model and wrapping in Discretizer(). This will make it clear that we DO support continuous and discrete-time models. I like that you use DiracIdentity here because it shows/tests more key functionality (without confusing the reader, I think).
  • For funzies, I might suggest including the latent trajectory plot in the unroll_missing=False case (so they know what they get back)....although, again, I'm thinking this should be non-default, and hence could more go at the end of the notebook as an "aside" or fancier thing to do.
  • very small note: the last plot has solid lines that connect the circles, so it is not super clear that data was indeed missing there (contradicts the title of the plot).
  • zooming out...I wonder @DanWaxman 's thoughts on whether this is a Tutorial or Deep Dive. I think missingness is valuable as a tutorial, but perhaps the second example and the unroll_missing=False could get pushed to a deep dive notebook? I'm open to having it as-is as a tutorial though.

@mattlevine22
Copy link
Collaborator

mattlevine22 commented Mar 20, 2026

  • Can be a separate PR, but maybe we should add support for missingness in ODESimulator too. It is less important because we won't typically recommend people use ODESimulator that much (the SDE relaxation is much friendlier)....so let's worry about more important things first!

  • We may also want to throw a not-yet-supported error if users include NaNs in other contexts (e.g., filtering)

@dimkab
Copy link
Author

dimkab commented Mar 20, 2026

Thanks @mattlevine22 !

* I feel like `unroll_missing=True` can be the default. If you purposefully include a whole missing block, I'd assume the user wants us to do something with it. This way, they can just input data with NaN and it either "works" or "errors"

I agree, this makes more sense as the default. The question is, what if the users aren't aware they have missing data at all and so don't specify anything - should we at least emit some warning if there is missing data, and what we are doing (unrolling by default)?

* Is there an informative error for case w/ NaNs  + non-diagonal observation?

yes, see test_linear_gaussian_obs_masked_log_prob_raises and test_gaussian_obs_masked_log_prob_raises tests.

* For the particle example, I'd suggest building the SDE model and wrapping in `Discretizer()`. This will make it clear that we DO support continuous and discrete-time models. I like that you use DiracIdentity here because it shows/tests more key functionality (without confusing the reader, I think).

good idea, will do. This is actually how I'm using this in CIS anyway.

* For funzies, I might suggest including the latent trajectory plot in the `unroll_missing=False` case (so they know what they get back)....although, again, I'm thinking this should be non-default, and hence could more go at the end of the notebook as an "aside" or fancier thing to do.

In the notebook there are a couple of cells which print the actual returned trajectory length in this case (actually in both cases), the question is if a plot would add anything substantial here.

* very small note: the last plot has solid lines that connect the circles, so it is not super clear that data was indeed missing there (contradicts the title of the plot).

take a look now

@dimkab
Copy link
Author

dimkab commented Mar 20, 2026

@mattlevine22 actually regarding building the SDE model - this may not work because _EulerMaruyamaDiscreteEvolution returns MultivariateNormal which is not yet supported for partial missingness (the code expects Independent(Normal(...)) . This was in fact one of the pain points for CIS.

@mattlevine22
Copy link
Collaborator

@mattlevine22 actually regarding building the SDE model - this may not work because _EulerMaruyamaDiscreteEvolution returns MultivariateNormal which is not yet supported for partial missingness (the code expects Independent(Normal(...)) . This was in fact one of the pain points for CIS.

That's for state-evolution...in the code, it doesn't look like you require decoupled state transitions (although now I see that the notebook only treats this case).

The de-coupled state case is a little bit niche...it won't really work for the CIS application, right?

@dimkab
Copy link
Author

dimkab commented Mar 20, 2026

@mattlevine22 actually regarding building the SDE model - this may not work because _EulerMaruyamaDiscreteEvolution returns MultivariateNormal which is not yet supported for partial missingness (the code expects Independent(Normal(...)) . This was in fact one of the pain points for CIS.

That's for state-evolution...in the code, it doesn't look like you require decoupled state transitions (although now I see that the notebook only treats this case).

The de-coupled state case is a little bit niche...it won't really work for the CIS application, right?

not sure I understand what you mean by building the SDE model then... you suggested to go through a Discretizer() which as far as I see returns MultivariateNormal - unless I implement my own DiagonalEulerMauyama assuming a diagonal diffusion coefficient...

OR: do you mean I switch to a Diagonal observation model here?

@dimkab
Copy link
Author

dimkab commented Mar 20, 2026

made unroll_missing=True the default.

@DanWaxman
Copy link
Collaborator

not sure I understand what you mean by building the SDE model then... you suggested to go through a Discretizer() which as far as I see returns MultivariateNormal - unless I implement my own DiagonalEulerMauyama assuming a diagonal diffusion coefficient...

Euler-Maruyama is a time-discretization -- i.e., it translates the ContinuousTimeStateEvolution to a DiscreteTimeStateEvolution. The observation model remains unchanged:

if isinstance(dynamics.state_evolution, ContinuousTimeStateEvolution):
discrete_evolution = self.discretize(dynamics.state_evolution)
dynamics = DynamicalModel(
initial_condition=dynamics.initial_condition,
state_evolution=discrete_evolution,
observation_model=dynamics.observation_model,
control_model=dynamics.control_model,
control_dim=dynamics.control_dim,
)

So if you have an observation model with diagonal Gaussians, that should persist.

@dimkab
Copy link
Author

dimkab commented Mar 20, 2026

not sure I understand what you mean by building the SDE model then... you suggested to go through a Discretizer() which as far as I see returns MultivariateNormal - unless I implement my own DiagonalEulerMauyama assuming a diagonal diffusion coefficient...

Euler-Maruyama is a time-discretization -- i.e., it translates the ContinuousTimeStateEvolution to a DiscreteTimeStateEvolution. The observation model remains unchanged:

if isinstance(dynamics.state_evolution, ContinuousTimeStateEvolution):
discrete_evolution = self.discretize(dynamics.state_evolution)
dynamics = DynamicalModel(
initial_condition=dynamics.initial_condition,
state_evolution=discrete_evolution,
observation_model=dynamics.observation_model,
control_model=dynamics.control_model,
control_dim=dynamics.control_dim,
)

So if you have an observation model with diagonal Gaussians, that should persist.

my observation model is dirac here...

@DanWaxman
Copy link
Collaborator

my observation model is dirac here...

Sorry, I guess I only half-read... but the point is the same, that whatever special observation model structure you have is preserved by Discretizer and SDE+Discretizer should be okay and inherit whatever features you have for DiscreteTimeSimulator.

@mattlevine22
Copy link
Collaborator

I'm also concerned that this is addressed w/ the masked_log_prob strategy adding to ``log.factor`.

Instead, @DanWaxman and I were picturing that things in simulators.py would still boil down to a numpyro.sample(..., MVN(obs_mean_k, obs_cov_k, obs=obs_k) call;

  • the trick becomes computing obs_mean_k, obs_cov_k, obs_k --- obs_mean_k and obs_cov_k would have the (reduced) dimensionality corresponding to the number of non-NaN observations.
  • this is can be done by combining the following elements:
    • state distribution MVN(m_k, Sigma_k); m_k is full state dim
    • obs distribution MVN( H @ m_k, Gamma_k); H @ m_k has dimension of a full observation (no-NaNs)
    • obs_dims_k are the non-NaN observed dimension indices at time k
    • obs_k is just the non-NaN observations at time k
    • there should be some nice linear algebra (gaussian marginalizations) that maps (m_k, Sigma_k, H, Gamma_k, obs_dims_k) -> (obs_mean_k, obs_cov_k, obs_k)

Doing the above will:

  1. make the implementation cleaner + more extensible later on
  2. deal with coupled states (i.e., agents interacting)

@dimkab
Copy link
Author

dimkab commented Mar 20, 2026

The algebra is straightforward:
https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Affine_transformation

For the full MVN you would need to apply these $B$ projection matrices which are computed from obs_mask, and make sure nothing is broken in jax scan.

I am not sure what's the big plan here for this PR. Are you OK with solving the diagonal Gaussian / Dirac case first? The interactions are encoded in the drift so I don't think we lose anything

@mattlevine22
Copy link
Collaborator

I am not sure what's the big plan here for this PR. Are you OK with solving the diagonal Gaussian / Dirac case first? The interactions are encoded in the drift so I don't think we lose anything

Yes, I'm OK with solving the simpler problem first.

But, I'm worried about doing it in a way that:
a) unnecessarily deviates from numpyro style (.factor when we could just do sample(..., MVN, ...obs=)),
and
b) assumes a model with independent states, but doesn't enforce this explicitly (I can run the notebook with off-diagonal terms in A, but it sounds like you don't want a user to do this until it is supported)

@dimkab
Copy link
Author

dimkab commented Mar 20, 2026

but how can you do numpyro.sample in a scan with varying shapes for each sample statement?

@dimkab
Copy link
Author

dimkab commented Mar 20, 2026

b) assumes a model with independent states, but doesn't enforce this explicitly (I can run the notebook with off-diagonal terms in A, but it sounds like you don't want a user to do this until it is supported)

to be precise: for gaussian observations you need diagonal observation covariance (and this is enforced); the state transition covariance can be anything because the states are fully sampled anyway. The enforcement for the state noise and the initial condition covariance comes into play only for dirac observations.

@mattlevine22
Copy link
Collaborator

From the new notebook:

"The two dimensions are independent in both transitions and observations, which is precisely why DiagonalLinearGaussianObservation can score them separately — a requirement for per-dimension partial missingness."

  • I interpret this to mean that the dimension independence in both transitions + observations is a requirement for per-dimension partial missingness. Do you actually only need independent observations?

@dimkab
Copy link
Author

dimkab commented Mar 20, 2026

From the new notebook:

"The two dimensions are independent in both transitions and observations, which is precisely why DiagonalLinearGaussianObservation can score them separately — a requirement for per-dimension partial missingness."

* I interpret this to mean that the dimension independence in both transitions + observations is a requirement for per-dimension partial missingness. Do you actually only need independent observations?

yeah sorry this should only concern observations.

@dimkab dimkab changed the title partial missingness - diagonal gaussian / dirac observations only Missing data: partial missingness, unroll_missing, diagonal obs models Mar 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants