Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
680bb1e
Merge branch 'main' of https://github.com/CDCgov/PyRenew
cdc-mitzimorris Sep 15, 2025
2cb876b
update
cdc-mitzimorris Sep 18, 2025
60db8df
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Sep 22, 2025
32a5314
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Oct 5, 2025
d6213f2
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Oct 8, 2025
96f27c9
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Nov 17, 2025
1cb6fa2
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Nov 24, 2025
f62e1e4
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Dec 4, 2025
0c6785d
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Dec 22, 2025
1ee62b9
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Jan 29, 2026
0629461
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 4, 2026
efeadee
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 5, 2026
371ba98
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 5, 2026
0304bed
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 6, 2026
ffeea65
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 9, 2026
50e7261
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 9, 2026
dae6af8
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 10, 2026
5cb3097
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 11, 2026
1d80ccc
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 11, 2026
e73b401
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 12, 2026
b1473b5
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 18, 2026
0b929b5
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 18, 2026
3ee00a7
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 24, 2026
307982a
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 24, 2026
b862bc6
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 26, 2026
2c665a5
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Mar 11, 2026
60d6458
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Mar 12, 2026
ec8c464
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Mar 19, 2026
c018bf7
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Mar 24, 2026
d0207dd
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 4, 2026
f3c706a
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 9, 2026
684c6c5
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 10, 2026
ca2454f
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 13, 2026
0f38afc
merge
cdc-mitzimorris Apr 14, 2026
d8e7a57
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 16, 2026
7e9b5fe
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 24, 2026
e1d8014
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 24, 2026
83ddbf0
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 24, 2026
69ea4ea
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 24, 2026
555e87b
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 28, 2026
fa5a7cb
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 4, 2026
69cdab0
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 6, 2026
c28a89f
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 6, 2026
fd091ca
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 7, 2026
8cee471
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 7, 2026
b2a1e1a
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 11, 2026
2006afd
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 12, 2026
a31ec85
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 13, 2026
80dc517
bug fix and unit tests
cdc-mitzimorris May 18, 2026
3aa2eaa
Potential fix for pull request finding
cdc-mitzimorris May 18, 2026
cbba2a2
Merge branch 'mem_824_day_of_week_bug' of github-bf06:CDCgov/PyRenew …
cdc-mitzimorris May 18, 2026
deb5f24
update unit test to match code
cdc-mitzimorris May 19, 2026
7f8b498
revert changes, apply simpler fix
cdc-mitzimorris May 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions pyrenew/observation/count_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,24 +280,33 @@ def _apply_day_of_week(
"""
Apply day-of-week multiplicative adjustment to predicted counts.

Tiles a 7-element effect vector across the full time axis,
aligned to the calendar via ``first_day_dow``. NaN values
in the initialization period propagate unchanged (NaN * effect = NaN),
which is correct since masked days are excluded from the likelihood.
Multiplies the finite entries of ``predicted`` by the weekday
cycle anchored at ``first_day_dow``. ``NaN`` entries (the
delay-tail at the start of the shared time axis) are preserved
through the JAX "double-where" idiom: the inner product is
evaluated against a NaN-free surrogate so its backward
cotangent is finite at every position, then the outer
``jnp.where`` restores ``NaN`` to its original positions in
the output.

Parameters
----------
predicted : ArrayLike
Predicted counts. Shape: (n_timepoints,) or
(n_timepoints, n_subpops).
first_day_dow : int
Day of the week for element 0 of the time axis
Day-of-week of ``predicted[0]`` on the shared time axis
(0=Monday, 6=Sunday, ISO convention).

Returns
-------
ArrayLike
Adjusted predicted counts, same shape as input.

Notes
-----
See https://docs.jax.dev/en/latest/faq.html#gradients-contain-nan-where-using-where
for the double-where pattern.
"""
dow_effect = self.day_of_week_rv()
self._deterministic("day_of_week_effect", dow_effect)
Expand All @@ -307,7 +316,9 @@ def _apply_day_of_week(
]
if predicted.ndim == 2:
daily_effect = daily_effect[:, None]
return predicted * daily_effect
finite_pred = ~jnp.isnan(predicted)
safe_predicted = jnp.where(finite_pred, predicted, 0.0)
return jnp.where(finite_pred, safe_predicted * daily_effect, predicted)

def _aggregate(
self,
Expand Down Expand Up @@ -462,7 +473,7 @@ def _score_masked(
safe_predicted = jnp.where(jnp.isnan(predicted), 1.0, predicted)
Comment thread
cdc-mitzimorris marked this conversation as resolved.
safe_obs = None
if obs is not None:
safe_obs = jnp.where(jnp.isnan(obs), safe_predicted, obs)
safe_obs = jnp.where(jnp.isnan(obs), 0.0, obs)
return self.noise.sample(
name=self._sample_site_name("obs"),
predicted=safe_predicted,
Expand Down
302 changes: 302 additions & 0 deletions test/test_observation_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
Unit tests for PopulationCounts and SubpopulationCounts classes.
"""

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import pytest
from numpyro.infer.util import log_density

from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.observation import (
Expand Down Expand Up @@ -1486,5 +1488,305 @@ def test_weekly_regular_with_obs_conditions(
assert result.observed.shape == (4, 2)


class TestScoreMaskedSafeObs:
"""
Tests for the safe-placeholder behavior in ``_score_masked``.

The masked likelihood path replaces NaN entries of ``obs`` with a
placeholder so that the noise distribution's ``log_prob`` is finite
at every position. NumPyro's mask handler zeroes the *contribution*
of those positions in the forward sum, but ``jax.grad`` still
differentiates the unselected branch; a non-finite ``log_prob``
there produces ``0 * NaN = NaN`` cotangents that escape the mask
and corrupt parameter gradients. For count noise, the placeholder
must be a value in the integer support of the distribution.
"""

@staticmethod
def _multi_day_delay_pmf() -> jnp.ndarray:
"""
Return a 3-day delay PMF so that ``predicted`` has 2 leading NaN.

Returns
-------
jnp.ndarray
A length-3 delay PMF.
"""
return jnp.array([0.5, 0.3, 0.2])

@staticmethod
def _padded_obs(n_total: int, n_init: int, value: float) -> jnp.ndarray:
"""
Return a length-``n_total`` array with ``n_init`` leading NaN.

Parameters
----------
n_total
Length of the returned array.
n_init
Number of leading positions to set to ``NaN``.
value
Constant value to fill the remaining positions.

Returns
-------
jnp.ndarray
Padded observation array.
"""
obs = jnp.full(n_total, value, dtype=jnp.float32)
return obs.at[:n_init].set(jnp.nan)

def test_safe_obs_zero_at_masked_positions(self):
"""
Masked obs positions enter the noise distribution as ``0.0``.

``NegativeBinomial2.log_prob`` is finite at integer counts
only; the masked-position placeholder must be in support so
that the forward log_prob is finite and the backward gradient
does not leak NaN through the mask.
"""
delay_pmf = self._multi_day_delay_pmf()
n_total = 14
n_init = 5
process = PopulationCounts(
name="test",
ascertainment_rate_rv=DeterministicVariable("ihr", 0.01),
delay_distribution_rv=DeterministicPMF("delay", delay_pmf),
noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)),
)
infections = jnp.ones(n_total) * 100.0
obs = self._padded_obs(n_total, n_init, value=5.0)

with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as tr:
process.sample(infections=infections, obs=obs)

site_value = tr["test_obs"]["value"]
assert jnp.all(jnp.isfinite(site_value))
assert jnp.all(site_value[:n_init] == 0.0)
assert jnp.allclose(site_value[n_init:], obs[n_init:])

def test_log_prob_finite_at_every_position(self):
"""
``noise.log_prob`` evaluates to a finite value at every slot.

Without an in-support placeholder, masked slots would receive
the non-integer ``safe_predicted`` value and
``NegativeBinomial2.log_prob`` would return ``-inf`` (or NaN)
there, which is the failure mode that breaks gradients.
"""
delay_pmf = self._multi_day_delay_pmf()
n_total = 14
n_init = 5
process = PopulationCounts(
name="test",
ascertainment_rate_rv=DeterministicVariable("ihr", 0.01),
delay_distribution_rv=DeterministicPMF("delay", delay_pmf),
noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)),
)
infections = jnp.ones(n_total) * 100.0
obs = self._padded_obs(n_total, n_init, value=5.0)

with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as tr:
process.sample(infections=infections, obs=obs)

site = tr["test_obs"]
log_p = site["fn"].log_prob(site["value"])
assert jnp.all(jnp.isfinite(log_p))

def test_dow_gradient_finite_through_masked_obs(self):
"""
Gradients w.r.t. a DOW effect are finite under masked obs.

Isolates the obs-side NaN-cotangent leak repaired by the
in-support placeholder. With a length-1 delay PMF
``predicted`` has no NaN tail, so any NaN gradient at the DOW
effect can only arise from the masked-obs branch of
``_score_masked``. Before the fix, ``safe_obs = safe_predicted``
sends non-integer obs into ``NegativeBinomial2.log_prob`` at
masked slots; the ``0 * NaN`` cotangent in the mask handler
leaks NaN back through the DOW multiplier. With the in-support
placeholder, all gradient entries are finite.
"""
delay_pmf = jnp.array([1.0])
n_total = 21
n_init = 5
first_day_dow = 2
infections = jnp.ones(n_total) * 1000.0
obs = self._padded_obs(n_total, n_init, value=5.0)

def model(dow_value: jnp.ndarray) -> None:
"""Run a PopulationCounts sample with the given DOW effect."""
process = PopulationCounts(
name="test",
ascertainment_rate_rv=DeterministicVariable("ihr", 0.01),
delay_distribution_rv=DeterministicPMF("delay", delay_pmf),
noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)),
day_of_week_rv=DeterministicVariable("dow", dow_value),
)
process.sample(
infections=infections,
obs=obs,
first_day_dow=first_day_dow,
)

def log_p(dow_value: jnp.ndarray) -> jnp.ndarray:
"""
Return the joint log-density of the model at ``dow_value``.

Parameters
----------
dow_value
Day-of-week effect vector at which to evaluate.

Returns
-------
jnp.ndarray
Scalar joint log-density.
"""
value, _ = log_density(model, (dow_value,), {}, params={})
return value

dow_value = jnp.array([2.0, 0.5, 0.5, 0.5, 0.5, 1.5, 1.5])
grad = jax.grad(log_p)(dow_value)
assert jnp.all(jnp.isfinite(grad))


class TestDayOfWeekNanGradientSafety:
"""
Tests for gradient-safe handling of the delay-tail NaN region.

Issue #824: a multi-day delay PMF leaves
``predicted[:len(delay_pmf)-1]`` as NaN before the day-of-week
multiplier is applied. The previous implementation tiled the
multiplier across the entire array; multiplying NaN by the
day-of-week vector produced a NaN cotangent through ``jnp.where``
that leaked back to the day-of-week parameters under autodiff,
causing stochastic-DOW priors to diverge under NUTS. The
double-where pattern in ``_apply_day_of_week`` keeps the
multiplication gradient-safe while preserving the original NaN
positions in the output.
"""

@staticmethod
def _multi_day_delay_pmf() -> jnp.ndarray:
"""
Return a 3-day delay PMF so ``predicted[:2]`` is NaN.

Returns
-------
jnp.ndarray
A length-3 delay PMF.
"""
return jnp.array([0.5, 0.3, 0.2])

@staticmethod
def _padded_obs(n_total: int, n_init: int, value: float) -> jnp.ndarray:
"""
Return a length-``n_total`` array with ``n_init`` leading NaN.

Parameters
----------
n_total
Length of the returned array.
n_init
Number of leading positions to set to ``NaN``.
value
Constant value to fill the remaining positions.

Returns
-------
jnp.ndarray
Padded observation array.
"""
obs = jnp.full(n_total, value, dtype=jnp.float32)
return obs.at[:n_init].set(jnp.nan)

def test_delay_tail_nan_preserved_through_dow(self):
"""
``predicted`` NaN entries remain NaN after the multiplier runs.

The double-where idiom restores the original NaN values at the
delay-tail positions, regardless of the day-of-week vector.
"""
delay_pmf = self._multi_day_delay_pmf()
n_tail = delay_pmf.shape[0] - 1
n_total = 21
process = PopulationCounts(
name="test",
ascertainment_rate_rv=DeterministicVariable("ihr", 0.01),
delay_distribution_rv=DeterministicPMF("delay", delay_pmf),
noise=PoissonNoise(),
day_of_week_rv=DeterministicVariable(
"dow", jnp.array([2.0, 1.5, 1.0, 1.0, 0.5, 0.5, 0.5])
),
)
infections = jnp.ones(n_total) * 1000.0

with numpyro.handlers.seed(rng_seed=0):
result = process.sample(
infections=infections,
obs=None,
first_day_dow=0,
)

assert jnp.all(jnp.isnan(result.predicted[:n_tail]))
assert jnp.all(jnp.isfinite(result.predicted[n_tail:]))

def test_dow_gradient_finite_with_delay_tail_nan(self):
"""
Gradients are finite when ``predicted`` has a NaN delay-tail.

Reproduces the issue-#824 gradient blow-up: a multi-day delay
PMF makes ``predicted[:len(delay)-1]`` NaN, and the
day-of-week multiplier is tiled across the whole array. Before
the fix, ``NaN * dow_effect[i]`` at delay-tail positions
leaked a NaN cotangent back to ``dow_effect[i]`` through
``jnp.where``. With the double-where pattern the inner
multiplication operates on a NaN-free surrogate, so the
gradient is finite at every slot.
"""
delay_pmf = self._multi_day_delay_pmf()
n_total = 21
n_init = 5
infections = jnp.ones(n_total) * 1000.0
obs = self._padded_obs(n_total, n_init, value=5.0)

def model(dow_value: jnp.ndarray) -> None:
"""Sample with the given DOW effect over the full time axis."""
process = PopulationCounts(
name="test",
ascertainment_rate_rv=DeterministicVariable("ihr", 0.01),
delay_distribution_rv=DeterministicPMF("delay", delay_pmf),
noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)),
day_of_week_rv=DeterministicVariable("dow", dow_value),
)
process.sample(
infections=infections,
obs=obs,
first_day_dow=2,
)

def log_p(dow_value: jnp.ndarray) -> jnp.ndarray:
"""
Return the joint log-density of the model at ``dow_value``.

Parameters
----------
dow_value
Day-of-week effect vector at which to evaluate.

Returns
-------
jnp.ndarray
Scalar joint log-density.
"""
value, _ = log_density(model, (dow_value,), {}, params={})
return value

dow_value = jnp.array([2.0, 0.5, 0.5, 0.5, 0.5, 1.5, 1.5])
grad = jax.grad(log_p)(dow_value)
assert jnp.all(jnp.isfinite(grad))


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading