diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 76423e56..3aaa7a34 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -280,10 +280,14 @@ def _apply_day_of_week( """ Apply day-of-week multiplicative adjustment to predicted counts. - Tiles a 7-element effect vector across the full time axis, - aligned to the calendar via ``first_day_dow``. NaN values - in the initialization period propagate unchanged (NaN * effect = NaN), - which is correct since masked days are excluded from the likelihood. + Multiplies the finite entries of ``predicted`` by the weekday + cycle anchored at ``first_day_dow``. ``NaN`` entries (the + delay-tail at the start of the shared time axis) are preserved + through the JAX "double-where" idiom: the inner product is + evaluated against a NaN-free surrogate so its backward + cotangent is finite at every position, then the outer + ``jnp.where`` restores ``NaN`` to its original positions in + the output. Parameters ---------- @@ -291,13 +295,18 @@ def _apply_day_of_week( Predicted counts. Shape: (n_timepoints,) or (n_timepoints, n_subpops). first_day_dow : int - Day of the week for element 0 of the time axis + Day-of-week of ``predicted[0]`` on the shared time axis (0=Monday, 6=Sunday, ISO convention). Returns ------- ArrayLike Adjusted predicted counts, same shape as input. + + Notes + ----- + See https://docs.jax.dev/en/latest/faq.html#gradients-contain-nan-where-using-where + for the double-where pattern. """ dow_effect = self.day_of_week_rv() self._deterministic("day_of_week_effect", dow_effect) @@ -307,7 +316,9 @@ def _apply_day_of_week( ] if predicted.ndim == 2: daily_effect = daily_effect[:, None] - return predicted * daily_effect + finite_pred = ~jnp.isnan(predicted) + safe_predicted = jnp.where(finite_pred, predicted, 0.0) + return jnp.where(finite_pred, safe_predicted * daily_effect, predicted) def _aggregate( self, @@ -462,7 +473,7 @@ def _score_masked( safe_predicted = jnp.where(jnp.isnan(predicted), 1.0, predicted) safe_obs = None if obs is not None: - safe_obs = jnp.where(jnp.isnan(obs), safe_predicted, obs) + safe_obs = jnp.where(jnp.isnan(obs), 0.0, obs) return self.noise.sample( name=self._sample_site_name("obs"), predicted=safe_predicted, diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index 6c17d3a2..5fd5b6ab 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -2,10 +2,12 @@ Unit tests for PopulationCounts and SubpopulationCounts classes. """ +import jax import jax.numpy as jnp import numpyro import numpyro.distributions as dist import pytest +from numpyro.infer.util import log_density from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.observation import ( @@ -1486,5 +1488,305 @@ def test_weekly_regular_with_obs_conditions( assert result.observed.shape == (4, 2) +class TestScoreMaskedSafeObs: + """ + Tests for the safe-placeholder behavior in ``_score_masked``. + + The masked likelihood path replaces NaN entries of ``obs`` with a + placeholder so that the noise distribution's ``log_prob`` is finite + at every position. NumPyro's mask handler zeroes the *contribution* + of those positions in the forward sum, but ``jax.grad`` still + differentiates the unselected branch; a non-finite ``log_prob`` + there produces ``0 * NaN = NaN`` cotangents that escape the mask + and corrupt parameter gradients. For count noise, the placeholder + must be a value in the integer support of the distribution. + """ + + @staticmethod + def _multi_day_delay_pmf() -> jnp.ndarray: + """ + Return a 3-day delay PMF so that ``predicted`` has 2 leading NaN. + + Returns + ------- + jnp.ndarray + A length-3 delay PMF. + """ + return jnp.array([0.5, 0.3, 0.2]) + + @staticmethod + def _padded_obs(n_total: int, n_init: int, value: float) -> jnp.ndarray: + """ + Return a length-``n_total`` array with ``n_init`` leading NaN. + + Parameters + ---------- + n_total + Length of the returned array. + n_init + Number of leading positions to set to ``NaN``. + value + Constant value to fill the remaining positions. + + Returns + ------- + jnp.ndarray + Padded observation array. + """ + obs = jnp.full(n_total, value, dtype=jnp.float32) + return obs.at[:n_init].set(jnp.nan) + + def test_safe_obs_zero_at_masked_positions(self): + """ + Masked obs positions enter the noise distribution as ``0.0``. + + ``NegativeBinomial2.log_prob`` is finite at integer counts + only; the masked-position placeholder must be in support so + that the forward log_prob is finite and the backward gradient + does not leak NaN through the mask. + """ + delay_pmf = self._multi_day_delay_pmf() + n_total = 14 + n_init = 5 + process = PopulationCounts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + infections = jnp.ones(n_total) * 100.0 + obs = self._padded_obs(n_total, n_init, value=5.0) + + with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as tr: + process.sample(infections=infections, obs=obs) + + site_value = tr["test_obs"]["value"] + assert jnp.all(jnp.isfinite(site_value)) + assert jnp.all(site_value[:n_init] == 0.0) + assert jnp.allclose(site_value[n_init:], obs[n_init:]) + + def test_log_prob_finite_at_every_position(self): + """ + ``noise.log_prob`` evaluates to a finite value at every slot. + + Without an in-support placeholder, masked slots would receive + the non-integer ``safe_predicted`` value and + ``NegativeBinomial2.log_prob`` would return ``-inf`` (or NaN) + there, which is the failure mode that breaks gradients. + """ + delay_pmf = self._multi_day_delay_pmf() + n_total = 14 + n_init = 5 + process = PopulationCounts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + infections = jnp.ones(n_total) * 100.0 + obs = self._padded_obs(n_total, n_init, value=5.0) + + with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as tr: + process.sample(infections=infections, obs=obs) + + site = tr["test_obs"] + log_p = site["fn"].log_prob(site["value"]) + assert jnp.all(jnp.isfinite(log_p)) + + def test_dow_gradient_finite_through_masked_obs(self): + """ + Gradients w.r.t. a DOW effect are finite under masked obs. + + Isolates the obs-side NaN-cotangent leak repaired by the + in-support placeholder. With a length-1 delay PMF + ``predicted`` has no NaN tail, so any NaN gradient at the DOW + effect can only arise from the masked-obs branch of + ``_score_masked``. Before the fix, ``safe_obs = safe_predicted`` + sends non-integer obs into ``NegativeBinomial2.log_prob`` at + masked slots; the ``0 * NaN`` cotangent in the mask handler + leaks NaN back through the DOW multiplier. With the in-support + placeholder, all gradient entries are finite. + """ + delay_pmf = jnp.array([1.0]) + n_total = 21 + n_init = 5 + first_day_dow = 2 + infections = jnp.ones(n_total) * 1000.0 + obs = self._padded_obs(n_total, n_init, value=5.0) + + def model(dow_value: jnp.ndarray) -> None: + """Run a PopulationCounts sample with the given DOW effect.""" + process = PopulationCounts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + day_of_week_rv=DeterministicVariable("dow", dow_value), + ) + process.sample( + infections=infections, + obs=obs, + first_day_dow=first_day_dow, + ) + + def log_p(dow_value: jnp.ndarray) -> jnp.ndarray: + """ + Return the joint log-density of the model at ``dow_value``. + + Parameters + ---------- + dow_value + Day-of-week effect vector at which to evaluate. + + Returns + ------- + jnp.ndarray + Scalar joint log-density. + """ + value, _ = log_density(model, (dow_value,), {}, params={}) + return value + + dow_value = jnp.array([2.0, 0.5, 0.5, 0.5, 0.5, 1.5, 1.5]) + grad = jax.grad(log_p)(dow_value) + assert jnp.all(jnp.isfinite(grad)) + + +class TestDayOfWeekNanGradientSafety: + """ + Tests for gradient-safe handling of the delay-tail NaN region. + + Issue #824: a multi-day delay PMF leaves + ``predicted[:len(delay_pmf)-1]`` as NaN before the day-of-week + multiplier is applied. The previous implementation tiled the + multiplier across the entire array; multiplying NaN by the + day-of-week vector produced a NaN cotangent through ``jnp.where`` + that leaked back to the day-of-week parameters under autodiff, + causing stochastic-DOW priors to diverge under NUTS. The + double-where pattern in ``_apply_day_of_week`` keeps the + multiplication gradient-safe while preserving the original NaN + positions in the output. + """ + + @staticmethod + def _multi_day_delay_pmf() -> jnp.ndarray: + """ + Return a 3-day delay PMF so ``predicted[:2]`` is NaN. + + Returns + ------- + jnp.ndarray + A length-3 delay PMF. + """ + return jnp.array([0.5, 0.3, 0.2]) + + @staticmethod + def _padded_obs(n_total: int, n_init: int, value: float) -> jnp.ndarray: + """ + Return a length-``n_total`` array with ``n_init`` leading NaN. + + Parameters + ---------- + n_total + Length of the returned array. + n_init + Number of leading positions to set to ``NaN``. + value + Constant value to fill the remaining positions. + + Returns + ------- + jnp.ndarray + Padded observation array. + """ + obs = jnp.full(n_total, value, dtype=jnp.float32) + return obs.at[:n_init].set(jnp.nan) + + def test_delay_tail_nan_preserved_through_dow(self): + """ + ``predicted`` NaN entries remain NaN after the multiplier runs. + + The double-where idiom restores the original NaN values at the + delay-tail positions, regardless of the day-of-week vector. + """ + delay_pmf = self._multi_day_delay_pmf() + n_tail = delay_pmf.shape[0] - 1 + n_total = 21 + process = PopulationCounts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=PoissonNoise(), + day_of_week_rv=DeterministicVariable( + "dow", jnp.array([2.0, 1.5, 1.0, 1.0, 0.5, 0.5, 0.5]) + ), + ) + infections = jnp.ones(n_total) * 1000.0 + + with numpyro.handlers.seed(rng_seed=0): + result = process.sample( + infections=infections, + obs=None, + first_day_dow=0, + ) + + assert jnp.all(jnp.isnan(result.predicted[:n_tail])) + assert jnp.all(jnp.isfinite(result.predicted[n_tail:])) + + def test_dow_gradient_finite_with_delay_tail_nan(self): + """ + Gradients are finite when ``predicted`` has a NaN delay-tail. + + Reproduces the issue-#824 gradient blow-up: a multi-day delay + PMF makes ``predicted[:len(delay)-1]`` NaN, and the + day-of-week multiplier is tiled across the whole array. Before + the fix, ``NaN * dow_effect[i]`` at delay-tail positions + leaked a NaN cotangent back to ``dow_effect[i]`` through + ``jnp.where``. With the double-where pattern the inner + multiplication operates on a NaN-free surrogate, so the + gradient is finite at every slot. + """ + delay_pmf = self._multi_day_delay_pmf() + n_total = 21 + n_init = 5 + infections = jnp.ones(n_total) * 1000.0 + obs = self._padded_obs(n_total, n_init, value=5.0) + + def model(dow_value: jnp.ndarray) -> None: + """Sample with the given DOW effect over the full time axis.""" + process = PopulationCounts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + day_of_week_rv=DeterministicVariable("dow", dow_value), + ) + process.sample( + infections=infections, + obs=obs, + first_day_dow=2, + ) + + def log_p(dow_value: jnp.ndarray) -> jnp.ndarray: + """ + Return the joint log-density of the model at ``dow_value``. + + Parameters + ---------- + dow_value + Day-of-week effect vector at which to evaluate. + + Returns + ------- + jnp.ndarray + Scalar joint log-density. + """ + value, _ = log_density(model, (dow_value,), {}, params={}) + return value + + dow_value = jnp.array([2.0, 0.5, 0.5, 0.5, 0.5, 1.5, 1.5]) + grad = jax.grad(log_p)(dow_value) + assert jnp.all(jnp.isfinite(grad)) + + if __name__ == "__main__": pytest.main([__file__, "-v"])