From 80dc5175d0760bb4772ecbd2a289aae69b15b1df Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 18 May 2026 18:37:49 -0400 Subject: [PATCH 1/4] bug fix and unit tests --- pyrenew/model/multisignal_model.py | 33 ++ pyrenew/observation/count_observations.py | 126 +++++-- test/test_observation_counts.py | 404 ++++++++++++++++++++++ 3 files changed, 531 insertions(+), 32 deletions(-) diff --git a/pyrenew/model/multisignal_model.py b/pyrenew/model/multisignal_model.py index 1a0f4b88..12f16d9e 100644 --- a/pyrenew/model/multisignal_model.py +++ b/pyrenew/model/multisignal_model.py @@ -186,6 +186,35 @@ def _resolve_first_day_dow( n_init = self.latent.n_initialization_points return (convert_date(obs_start_date).weekday() - n_init) % 7 + def _resolve_first_observed_dow( + self, + obs_start_date: dt.date | dt.datetime | np.datetime64 | None, + ) -> int | None: + """ + Derive the observation-axis day-of-week from ``obs_start_date``. + + Returns the day-of-week of the first observed day directly, + without the initialization-period offset that + ``_resolve_first_day_dow`` applies. Used to anchor reporting + effects (day-of-week multipliers) that should apply only to + observation days, not to the padded initialization period. + + Parameters + ---------- + obs_start_date + Date of the first observation day, or ``None``. + + Returns + ------- + int or None + Day-of-week index in ``{0, ..., 6}`` (0=Monday, ISO + convention) of the first observed day. ``None`` when + ``obs_start_date`` is ``None``. + """ + if obs_start_date is None: + return None + return convert_date(obs_start_date).weekday() + def _check_obs_start_date( self, obs_start_date: dt.date | dt.datetime | np.datetime64 | None, @@ -366,6 +395,8 @@ def sample( """ self._check_obs_start_date(obs_start_date) first_day_dow = self._resolve_first_day_dow(obs_start_date) + first_observed_dow = self._resolve_first_observed_dow(obs_start_date) + first_observed_idx = self.latent.n_initialization_points # Generate latent infections (proportions) latent_sample = self.latent.sample( @@ -413,6 +444,8 @@ def sample( obs_process.sample( infections=latent_infections, first_day_dow=first_day_dow, + first_observed_idx=first_observed_idx, + first_observed_dow=first_observed_dow, **obs_data, ) diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 76423e56..e9bd44b5 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -275,23 +275,29 @@ def _apply_right_truncation( def _apply_day_of_week( self, predicted: ArrayLike, - first_day_dow: int, + first_observed_idx: int, + first_observed_dow: int, ) -> ArrayLike: """ Apply day-of-week multiplicative adjustment to predicted counts. - Tiles a 7-element effect vector across the full time axis, - aligned to the calendar via ``first_day_dow``. NaN values - in the initialization period propagate unchanged (NaN * effect = NaN), - which is correct since masked days are excluded from the likelihood. + Multiplies ``predicted[first_observed_idx:]`` by the weekday + cycle anchored at ``first_observed_dow``. Positions before + ``first_observed_idx`` (initialization / pre-observation period) + are left unchanged so that ``NaN`` entries in the delay-tail do + not feed the multiplier. Parameters ---------- predicted : ArrayLike Predicted counts. Shape: (n_timepoints,) or (n_timepoints, n_subpops). - first_day_dow : int - Day of the week for element 0 of the time axis + first_observed_idx : int + Index on the shared time axis where the first observation + period begins. Positions before this index receive no + day-of-week adjustment. + first_observed_dow : int + Day-of-week of ``predicted[first_observed_idx]`` (0=Monday, 6=Sunday, ISO convention). Returns @@ -301,13 +307,13 @@ def _apply_day_of_week( """ dow_effect = self.day_of_week_rv() self._deterministic("day_of_week_effect", dow_effect) - n_timepoints = predicted.shape[0] - daily_effect = dow_effect[ - get_sequential_day_of_week_indices(first_day_dow, n_timepoints) + n_obs_days = predicted.shape[0] - first_observed_idx + obs_effect = dow_effect[ + get_sequential_day_of_week_indices(first_observed_dow, n_obs_days) ] if predicted.ndim == 2: - daily_effect = daily_effect[:, None] - return predicted * daily_effect + obs_effect = obs_effect[:, None] + return predicted.at[first_observed_idx:].multiply(obs_effect) def _aggregate( self, @@ -362,6 +368,8 @@ def _compute_predicted( infections: ArrayLike, first_day_dow: int | None, right_truncation_offset: int | None, + first_observed_idx: int = 0, + first_observed_dow: int | None = None, ) -> ArrayLike: """ Build the predicted counts on the reporting-period grid. @@ -380,12 +388,20 @@ def _compute_predicted( first_day_dow Day-of-week index of element 0 of the shared time axis (0=Monday, 6=Sunday, ISO convention). Required when - ``day_of_week_rv`` was set at construction or when ``aggregation == "weekly"``. right_truncation_offset If set (together with ``right_truncation_rv``), the number of additional reporting days that have occurred since the last observation. + first_observed_idx + Index on the shared time axis where the first observation + period begins. The day-of-week effect is applied only to + positions ``[first_observed_idx:]``. Defaults to ``0`` for + unpadded direct callers. + first_observed_dow + Day-of-week of ``predicted[first_observed_idx]`` + (0=Monday, 6=Sunday, ISO convention). Required when + ``day_of_week_rv`` was set at construction. Returns ------- @@ -397,16 +413,18 @@ def _compute_predicted( Raises ------ ValueError - If ``day_of_week_rv`` was set but ``first_day_dow`` is - ``None``. + If ``day_of_week_rv`` was set but ``first_observed_dow`` + is ``None``. """ predicted_daily = self._predicted_obs(infections) if self.day_of_week_rv is not None: - if first_day_dow is None: + if first_observed_dow is None: raise ValueError( "first_day_dow is required when day_of_week_rv is set." ) - predicted_daily = self._apply_day_of_week(predicted_daily, first_day_dow) + predicted_daily = self._apply_day_of_week( + predicted_daily, first_observed_idx, first_observed_dow + ) if self.right_truncation_rv is not None and right_truncation_offset is not None: predicted_daily = self._apply_right_truncation( predicted_daily, right_truncation_offset @@ -462,7 +480,7 @@ def _score_masked( safe_predicted = jnp.where(jnp.isnan(predicted), 1.0, predicted) safe_obs = None if obs is not None: - safe_obs = jnp.where(jnp.isnan(obs), safe_predicted, obs) + safe_obs = jnp.where(jnp.isnan(obs), 0.0, obs) return self.noise.sample( name=self._sample_site_name("obs"), predicted=safe_predicted, @@ -655,6 +673,8 @@ def sample( right_truncation_offset: int | None = None, first_day_dow: int | None = None, period_end_times: ArrayLike | None = None, + first_observed_idx: int | None = None, + first_observed_dow: int | None = None, ) -> ObservationSample: """ Sample aggregated counts. @@ -689,15 +709,28 @@ def sample( construction), apply right-truncation adjustment to the daily predictions. first_day_dow - Day-of-week index of the first timepoint on the shared - time axis (0=Monday, 6=Sunday, ISO convention). Required - when ``day_of_week_rv`` was set at construction or when - ``aggregation == "weekly"``. This aligns observation-level - day-of-week effects or weekly aggregation to the shared daily - model axis. + Day-of-week index of element 0 of the shared time axis + (0=Monday, 6=Sunday, ISO convention). Required when + ``aggregation == "weekly"``; also serves as the + backwards-compatible day-of-week anchor for direct + callers without padding (used as ``first_observed_dow`` + when the latter is not supplied). period_end_times Daily-axis indices of each observed period's final day. Required when ``reporting_schedule == "irregular"``. + first_observed_idx + Index on the shared time axis where the first observation + period begins. Day-of-week effects are applied only to + positions ``[first_observed_idx:]``. Supplied by + ``MultiSignalModel`` as ``n_initialization_points``. + Defaults to ``0`` for direct callers without padding. + first_observed_dow + Day-of-week of the first observed period + (0=Monday, 6=Sunday, ISO convention). Required when + ``day_of_week_rv`` was set at construction. Supplied by + ``MultiSignalModel`` as ``obs_start_date.weekday()``. + When not supplied, falls back to ``first_day_dow`` for + backwards compatibility with direct unpadded callers. Returns ------- @@ -707,8 +740,16 @@ def sample( grid; equal to daily predictions when ``aggregation == "daily"``). """ + if first_observed_idx is None: + first_observed_idx = 0 + if first_observed_dow is None: + first_observed_dow = first_day_dow predicted = self._compute_predicted( - infections, first_day_dow, right_truncation_offset + infections, + first_day_dow, + right_truncation_offset, + first_observed_idx=first_observed_idx, + first_observed_dow=first_observed_dow, ) if self.reporting_schedule == "regular": @@ -892,6 +933,8 @@ def sample( first_day_dow: int | None = None, period_end_times: ArrayLike | None = None, subpop_indices: ArrayLike | None = None, + first_observed_idx: int | None = None, + first_observed_dow: int | None = None, ) -> ObservationSample: """ Sample subpopulation-level counts. @@ -927,12 +970,11 @@ def sample( construction), apply right-truncation adjustment to the daily predictions. first_day_dow - Day-of-week index of the first timepoint on the shared - time axis (0=Monday, 6=Sunday, ISO convention). Required - when ``day_of_week_rv`` was set at construction or when - ``aggregation == "weekly"``. This aligns observation-level - day-of-week effects or weekly aggregation to the shared daily - model axis. + Day-of-week index of element 0 of the shared time axis + (0=Monday, 6=Sunday, ISO convention). Required when + ``aggregation == "weekly"``; also serves as the + backwards-compatible day-of-week anchor for direct + callers without padding. period_end_times Daily-axis indices of each observed period's final day. Required when ``reporting_schedule == "irregular"``. @@ -943,6 +985,18 @@ def sample( columns of the aggregated array enter the likelihood. For ``reporting_schedule="irregular"``: shape ``(n_obs,)`` with one subpopulation per observation. + first_observed_idx + Index on the shared time axis where the first observation + period begins. Day-of-week effects are applied only to + positions ``[first_observed_idx:]``. Supplied by + ``MultiSignalModel`` as ``n_initialization_points``. + Defaults to ``0`` for direct callers without padding. + first_observed_dow + Day-of-week of the first observed period + (0=Monday, 6=Sunday, ISO convention). Required when + ``day_of_week_rv`` was set at construction. Supplied by + ``MultiSignalModel`` as ``obs_start_date.weekday()``. + When not supplied, falls back to ``first_day_dow``. Returns ------- @@ -955,8 +1009,16 @@ def sample( if subpop_indices is None: raise ValueError(f"Observation '{self.name}': subpop_indices is required.") + if first_observed_idx is None: + first_observed_idx = 0 + if first_observed_dow is None: + first_observed_dow = first_day_dow predicted = self._compute_predicted( - infections, first_day_dow, right_truncation_offset + infections, + first_day_dow, + right_truncation_offset, + first_observed_idx=first_observed_idx, + first_observed_dow=first_observed_dow, ) if self.reporting_schedule == "regular": diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index 6c17d3a2..0972a063 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -2,10 +2,12 @@ Unit tests for PopulationCounts and SubpopulationCounts classes. """ +import jax import jax.numpy as jnp import numpyro import numpyro.distributions as dist import pytest +from numpyro.infer.util import log_density from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.observation import ( @@ -1486,5 +1488,407 @@ def test_weekly_regular_with_obs_conditions( assert result.observed.shape == (4, 2) +class TestScoreMaskedSafeObs: + """ + Tests for the safe-placeholder behavior in ``_score_masked``. + + The masked likelihood path replaces NaN entries of ``obs`` with a + placeholder so that the noise distribution's ``log_prob`` is finite + at every position. NumPyro's mask handler zeroes the *contribution* + of those positions in the forward sum, but ``jax.grad`` still + differentiates the unselected branch; a non-finite ``log_prob`` + there produces ``0 * NaN = NaN`` cotangents that escape the mask + and corrupt parameter gradients. For count noise, the placeholder + must be a value in the integer support of the distribution. + """ + + @staticmethod + def _multi_day_delay_pmf() -> jnp.ndarray: + """ + Return a 3-day delay PMF so that ``predicted`` has 2 leading NaN. + + Returns + ------- + jnp.ndarray + A length-3 delay PMF. + """ + return jnp.array([0.5, 0.3, 0.2]) + + @staticmethod + def _padded_obs(n_total: int, n_init: int, value: float) -> jnp.ndarray: + """ + Return a length-``n_total`` array with ``n_init`` leading NaN. + + Parameters + ---------- + n_total + Length of the returned array. + n_init + Number of leading positions to set to ``NaN``. + value + Constant value to fill the remaining positions. + + Returns + ------- + jnp.ndarray + Padded observation array. + """ + obs = jnp.full(n_total, value, dtype=jnp.float32) + return obs.at[:n_init].set(jnp.nan) + + def test_safe_obs_zero_at_masked_positions(self): + """ + Masked obs positions enter the noise distribution as ``0.0``. + + ``NegativeBinomial2.log_prob`` is finite at integer counts + only; the masked-position placeholder must be in support so + that the forward log_prob is finite and the backward gradient + does not leak NaN through the mask. + """ + delay_pmf = self._multi_day_delay_pmf() + n_total = 14 + n_init = 5 + process = PopulationCounts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + infections = jnp.ones(n_total) * 100.0 + obs = self._padded_obs(n_total, n_init, value=5.0) + + with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as tr: + process.sample(infections=infections, obs=obs) + + site_value = tr["test_obs"]["value"] + assert jnp.all(jnp.isfinite(site_value)) + assert jnp.all(site_value[:n_init] == 0.0) + assert jnp.allclose(site_value[n_init:], obs[n_init:]) + + def test_log_prob_finite_at_every_position(self): + """ + ``noise.log_prob`` evaluates to a finite value at every slot. + + Without an in-support placeholder, masked slots would receive + the non-integer ``safe_predicted`` value and + ``NegativeBinomial2.log_prob`` would return ``-inf`` (or NaN) + there, which is the failure mode that breaks gradients. + """ + delay_pmf = self._multi_day_delay_pmf() + n_total = 14 + n_init = 5 + process = PopulationCounts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + infections = jnp.ones(n_total) * 100.0 + obs = self._padded_obs(n_total, n_init, value=5.0) + + with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as tr: + process.sample(infections=infections, obs=obs) + + site = tr["test_obs"] + log_p = site["fn"].log_prob(site["value"]) + assert jnp.all(jnp.isfinite(log_p)) + + def test_dow_gradient_finite_through_masked_obs(self): + """ + Gradients w.r.t. a DOW effect are finite under masked obs. + + Isolates the obs-side NaN-cotangent leak repaired by the + in-support placeholder. With a length-1 delay PMF + ``predicted`` has no NaN tail, so any NaN gradient at the DOW + effect can only arise from the masked-obs branch of + ``_score_masked``. Before the fix, ``safe_obs = safe_predicted`` + sends non-integer obs into ``NegativeBinomial2.log_prob`` at + masked slots; the ``0 * NaN`` cotangent in the mask handler + leaks NaN back through the DOW multiplier. With the in-support + placeholder, all gradient entries are finite. + """ + delay_pmf = jnp.array([1.0]) + n_total = 21 + n_init = 5 + first_day_dow = 2 + infections = jnp.ones(n_total) * 1000.0 + obs = self._padded_obs(n_total, n_init, value=5.0) + + def model(dow_value: jnp.ndarray) -> None: + """Run a PopulationCounts sample with the given DOW effect.""" + process = PopulationCounts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + day_of_week_rv=DeterministicVariable("dow", dow_value), + ) + process.sample( + infections=infections, + obs=obs, + first_day_dow=first_day_dow, + ) + + def log_p(dow_value: jnp.ndarray) -> jnp.ndarray: + """ + Return the joint log-density of the model at ``dow_value``. + + Parameters + ---------- + dow_value + Day-of-week effect vector at which to evaluate. + + Returns + ------- + jnp.ndarray + Scalar joint log-density. + """ + value, _ = log_density(model, (dow_value,), {}, params={}) + return value + + dow_value = jnp.array([2.0, 0.5, 0.5, 0.5, 0.5, 1.5, 1.5]) + grad = jax.grad(log_p)(dow_value) + assert jnp.all(jnp.isfinite(grad)) + + +class TestDayOfWeekObservationAxis: + """ + Tests for applying the day-of-week effect on the observation axis. + + Issue #824: the day-of-week multiplier must apply only to + positions at and after the first observed day, not across the + padded initialization period. Multiplying through the delay-tail + NaN region of ``predicted`` produces a second ``0 * NaN = NaN`` + cotangent path back to the day-of-week effect under autodiff, + independent of the masked-obs leak that ``TestScoreMaskedSafeObs`` + covers. + """ + + @staticmethod + def _multi_day_delay_pmf() -> jnp.ndarray: + """ + Return a 3-day delay PMF so ``predicted[:2]`` is NaN. + + Returns + ------- + jnp.ndarray + A length-3 delay PMF. + """ + return jnp.array([0.5, 0.3, 0.2]) + + @staticmethod + def _padded_obs(n_total: int, n_init: int, value: float) -> jnp.ndarray: + """ + Return a length-``n_total`` array with ``n_init`` leading NaN. + + Parameters + ---------- + n_total + Length of the returned array. + n_init + Number of leading positions to set to ``NaN``. + value + Constant value to fill the remaining positions. + + Returns + ------- + jnp.ndarray + Padded observation array. + """ + obs = jnp.full(n_total, value, dtype=jnp.float32) + return obs.at[:n_init].set(jnp.nan) + + def test_prefix_unchanged_by_dow_multiplier(self): + """ + Positions before ``first_observed_idx`` are not multiplied. + + Built two processes with identical setup except for the DOW + effect (one uniform ``ones(7)``, the other non-uniform). On + the observation suffix the predictions differ by the DOW + ratio; on the initialization prefix the predictions are + identical because the prefix is excluded from the multiplier. + """ + delay_pmf = jnp.array([1.0]) + n_total = 21 + first_observed_idx = 7 + first_observed_dow = 0 + + infections = jnp.ones(n_total) * 1000.0 + + process_uniform = PopulationCounts( + name="uniform", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=PoissonNoise(), + day_of_week_rv=DeterministicVariable("dow", jnp.ones(7)), + ) + process_nonuniform = PopulationCounts( + name="nonuniform", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=PoissonNoise(), + day_of_week_rv=DeterministicVariable( + "dow", jnp.array([2.0, 1.5, 1.0, 1.0, 0.5, 0.5, 0.5]) + ), + ) + + with numpyro.handlers.seed(rng_seed=0): + uniform = process_uniform.sample( + infections=infections, + obs=None, + first_observed_idx=first_observed_idx, + first_observed_dow=first_observed_dow, + ) + with numpyro.handlers.seed(rng_seed=0): + nonuniform = process_nonuniform.sample( + infections=infections, + obs=None, + first_observed_idx=first_observed_idx, + first_observed_dow=first_observed_dow, + ) + + assert jnp.allclose( + uniform.predicted[:first_observed_idx], + nonuniform.predicted[:first_observed_idx], + ) + assert not jnp.allclose( + uniform.predicted[first_observed_idx:], + nonuniform.predicted[first_observed_idx:], + ) + + def test_suffix_starts_at_first_observed_dow(self): + """ + The suffix multiplier starts at the supplied observed DOW. + + With ``first_observed_dow=6`` (Sunday), the multiplier at + ``predicted[first_observed_idx]`` is ``dow_effect[6]``, the + next position is ``dow_effect[0]`` (Monday), and so on. + """ + delay_pmf = jnp.array([1.0]) + n_total = 14 + first_observed_idx = 4 + first_observed_dow = 6 + dow_effect = jnp.array([2.0, 1.5, 1.0, 1.0, 0.5, 0.5, 3.0]) + + process = PopulationCounts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 1.0), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=PoissonNoise(), + day_of_week_rv=DeterministicVariable("dow", dow_effect), + ) + infections = jnp.ones(n_total) * 100.0 + + with numpyro.handlers.seed(rng_seed=0): + result = process.sample( + infections=infections, + obs=None, + first_observed_idx=first_observed_idx, + first_observed_dow=first_observed_dow, + ) + + assert jnp.isclose(result.predicted[first_observed_idx], 100.0 * dow_effect[6]) + assert jnp.isclose( + result.predicted[first_observed_idx + 1], 100.0 * dow_effect[0] + ) + assert jnp.isclose( + result.predicted[first_observed_idx + 7], 100.0 * dow_effect[6] + ) + + def test_zero_offset_matches_legacy_behavior(self): + """ + ``first_observed_idx=0`` reproduces the legacy "DOW across whole axis" behavior. + + Direct callers that pass only ``first_day_dow`` (e.g., the + existing ``TestDayOfWeek`` tests) must continue to see DOW + applied starting at index 0 with the supplied phase. + """ + delay_pmf = jnp.array([1.0]) + n_total = 14 + first_day_dow = 2 + dow_effect = jnp.array([2.0, 1.5, 1.0, 0.8, 0.7, 0.5, 0.5]) + + process = PopulationCounts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 1.0), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=PoissonNoise(), + day_of_week_rv=DeterministicVariable("dow", dow_effect), + ) + infections = jnp.ones(n_total) * 100.0 + + with numpyro.handlers.seed(rng_seed=0): + legacy = process.sample( + infections=infections, + obs=None, + first_day_dow=first_day_dow, + ) + with numpyro.handlers.seed(rng_seed=0): + explicit = process.sample( + infections=infections, + obs=None, + first_observed_idx=0, + first_observed_dow=first_day_dow, + ) + + assert jnp.allclose(legacy.predicted, explicit.predicted) + + def test_dow_gradient_finite_with_delay_tail_nan(self): + """ + Gradients are finite when ``predicted`` has a NaN delay-tail. + + Reproduces the issue-#824 gradient blow-up in its full form: + a multi-day delay PMF makes ``predicted[:len(delay)-1]`` NaN, + and a non-trivial ``first_observed_idx`` places the DOW + multiplier outside that region. Before Fix B the DOW + multiplier was tiled across the NaN tail and ``0 * NaN`` + leaked NaN back to the DOW effect; after Fix B the gradient + is finite at every slot. + """ + delay_pmf = self._multi_day_delay_pmf() + n_total = 21 + first_observed_idx = 5 + first_observed_dow = 0 + infections = jnp.ones(n_total) * 1000.0 + obs = self._padded_obs(n_total, first_observed_idx, value=5.0) + + def model(dow_value: jnp.ndarray) -> None: + """Sample with the given DOW effect on the observation suffix.""" + process = PopulationCounts( + name="test", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + day_of_week_rv=DeterministicVariable("dow", dow_value), + ) + process.sample( + infections=infections, + obs=obs, + first_observed_idx=first_observed_idx, + first_observed_dow=first_observed_dow, + ) + + def log_p(dow_value: jnp.ndarray) -> jnp.ndarray: + """ + Return the joint log-density of the model at ``dow_value``. + + Parameters + ---------- + dow_value + Day-of-week effect vector at which to evaluate. + + Returns + ------- + jnp.ndarray + Scalar joint log-density. + """ + value, _ = log_density(model, (dow_value,), {}, params={}) + return value + + dow_value = jnp.array([2.0, 0.5, 0.5, 0.5, 0.5, 1.5, 1.5]) + grad = jax.grad(log_p)(dow_value) + assert jnp.all(jnp.isfinite(grad)) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 3aa2eaa5ff792c88f2b05b146ef0dfe060600e5f Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 18 May 2026 18:56:53 -0400 Subject: [PATCH 2/4] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- pyrenew/observation/count_observations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index e9bd44b5..306835fb 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -420,7 +420,7 @@ def _compute_predicted( if self.day_of_week_rv is not None: if first_observed_dow is None: raise ValueError( - "first_day_dow is required when day_of_week_rv is set." + "first_observed_dow is required when day_of_week_rv is set." ) predicted_daily = self._apply_day_of_week( predicted_daily, first_observed_idx, first_observed_dow From deb5f242d086e5ad29a67ae6413145a56f1debe0 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 19 May 2026 07:58:50 -0400 Subject: [PATCH 3/4] update unit test to match code --- test/test_observation_counts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index 0972a063..d1482d22 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -610,7 +610,7 @@ def test_dow_rv_without_offset_raises(self, simple_delay_pmf): infections = jnp.ones(20) * 1000 with numpyro.handlers.seed(rng_seed=42): - with pytest.raises(ValueError, match="first_day_dow is required"): + with pytest.raises(ValueError, match="first_observed_dow is required"): process.sample(infections=infections, obs=None, first_day_dow=None) def test_uniform_dow_effect_unchanged(self, simple_delay_pmf): @@ -856,7 +856,7 @@ def test_counts_by_subpop_dow_without_offset_raises(self): subpop_indices = jnp.array([0, 1]) with numpyro.handlers.seed(rng_seed=42): - with pytest.raises(ValueError, match="first_day_dow is required"): + with pytest.raises(ValueError, match="first_observed_dow is required"): process.sample( infections=infections, period_end_times=period_end_times, From 7f8b4988e8b6398d63e3c46693fa6981d4abbae5 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 19 May 2026 13:45:00 -0400 Subject: [PATCH 4/4] revert changes, apply simpler fix --- pyrenew/model/multisignal_model.py | 33 ---- pyrenew/observation/count_observations.py | 137 ++++++---------- test/test_observation_counts.py | 180 +++++----------------- 3 files changed, 82 insertions(+), 268 deletions(-) diff --git a/pyrenew/model/multisignal_model.py b/pyrenew/model/multisignal_model.py index 12f16d9e..1a0f4b88 100644 --- a/pyrenew/model/multisignal_model.py +++ b/pyrenew/model/multisignal_model.py @@ -186,35 +186,6 @@ def _resolve_first_day_dow( n_init = self.latent.n_initialization_points return (convert_date(obs_start_date).weekday() - n_init) % 7 - def _resolve_first_observed_dow( - self, - obs_start_date: dt.date | dt.datetime | np.datetime64 | None, - ) -> int | None: - """ - Derive the observation-axis day-of-week from ``obs_start_date``. - - Returns the day-of-week of the first observed day directly, - without the initialization-period offset that - ``_resolve_first_day_dow`` applies. Used to anchor reporting - effects (day-of-week multipliers) that should apply only to - observation days, not to the padded initialization period. - - Parameters - ---------- - obs_start_date - Date of the first observation day, or ``None``. - - Returns - ------- - int or None - Day-of-week index in ``{0, ..., 6}`` (0=Monday, ISO - convention) of the first observed day. ``None`` when - ``obs_start_date`` is ``None``. - """ - if obs_start_date is None: - return None - return convert_date(obs_start_date).weekday() - def _check_obs_start_date( self, obs_start_date: dt.date | dt.datetime | np.datetime64 | None, @@ -395,8 +366,6 @@ def sample( """ self._check_obs_start_date(obs_start_date) first_day_dow = self._resolve_first_day_dow(obs_start_date) - first_observed_dow = self._resolve_first_observed_dow(obs_start_date) - first_observed_idx = self.latent.n_initialization_points # Generate latent infections (proportions) latent_sample = self.latent.sample( @@ -444,8 +413,6 @@ def sample( obs_process.sample( infections=latent_infections, first_day_dow=first_day_dow, - first_observed_idx=first_observed_idx, - first_observed_dow=first_observed_dow, **obs_data, ) diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 306835fb..3aaa7a34 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -275,45 +275,50 @@ def _apply_right_truncation( def _apply_day_of_week( self, predicted: ArrayLike, - first_observed_idx: int, - first_observed_dow: int, + first_day_dow: int, ) -> ArrayLike: """ Apply day-of-week multiplicative adjustment to predicted counts. - Multiplies ``predicted[first_observed_idx:]`` by the weekday - cycle anchored at ``first_observed_dow``. Positions before - ``first_observed_idx`` (initialization / pre-observation period) - are left unchanged so that ``NaN`` entries in the delay-tail do - not feed the multiplier. + Multiplies the finite entries of ``predicted`` by the weekday + cycle anchored at ``first_day_dow``. ``NaN`` entries (the + delay-tail at the start of the shared time axis) are preserved + through the JAX "double-where" idiom: the inner product is + evaluated against a NaN-free surrogate so its backward + cotangent is finite at every position, then the outer + ``jnp.where`` restores ``NaN`` to its original positions in + the output. Parameters ---------- predicted : ArrayLike Predicted counts. Shape: (n_timepoints,) or (n_timepoints, n_subpops). - first_observed_idx : int - Index on the shared time axis where the first observation - period begins. Positions before this index receive no - day-of-week adjustment. - first_observed_dow : int - Day-of-week of ``predicted[first_observed_idx]`` + first_day_dow : int + Day-of-week of ``predicted[0]`` on the shared time axis (0=Monday, 6=Sunday, ISO convention). Returns ------- ArrayLike Adjusted predicted counts, same shape as input. + + Notes + ----- + See https://docs.jax.dev/en/latest/faq.html#gradients-contain-nan-where-using-where + for the double-where pattern. """ dow_effect = self.day_of_week_rv() self._deterministic("day_of_week_effect", dow_effect) - n_obs_days = predicted.shape[0] - first_observed_idx - obs_effect = dow_effect[ - get_sequential_day_of_week_indices(first_observed_dow, n_obs_days) + n_timepoints = predicted.shape[0] + daily_effect = dow_effect[ + get_sequential_day_of_week_indices(first_day_dow, n_timepoints) ] if predicted.ndim == 2: - obs_effect = obs_effect[:, None] - return predicted.at[first_observed_idx:].multiply(obs_effect) + daily_effect = daily_effect[:, None] + finite_pred = ~jnp.isnan(predicted) + safe_predicted = jnp.where(finite_pred, predicted, 0.0) + return jnp.where(finite_pred, safe_predicted * daily_effect, predicted) def _aggregate( self, @@ -368,8 +373,6 @@ def _compute_predicted( infections: ArrayLike, first_day_dow: int | None, right_truncation_offset: int | None, - first_observed_idx: int = 0, - first_observed_dow: int | None = None, ) -> ArrayLike: """ Build the predicted counts on the reporting-period grid. @@ -388,20 +391,12 @@ def _compute_predicted( first_day_dow Day-of-week index of element 0 of the shared time axis (0=Monday, 6=Sunday, ISO convention). Required when + ``day_of_week_rv`` was set at construction or when ``aggregation == "weekly"``. right_truncation_offset If set (together with ``right_truncation_rv``), the number of additional reporting days that have occurred since the last observation. - first_observed_idx - Index on the shared time axis where the first observation - period begins. The day-of-week effect is applied only to - positions ``[first_observed_idx:]``. Defaults to ``0`` for - unpadded direct callers. - first_observed_dow - Day-of-week of ``predicted[first_observed_idx]`` - (0=Monday, 6=Sunday, ISO convention). Required when - ``day_of_week_rv`` was set at construction. Returns ------- @@ -413,18 +408,16 @@ def _compute_predicted( Raises ------ ValueError - If ``day_of_week_rv`` was set but ``first_observed_dow`` - is ``None``. + If ``day_of_week_rv`` was set but ``first_day_dow`` is + ``None``. """ predicted_daily = self._predicted_obs(infections) if self.day_of_week_rv is not None: - if first_observed_dow is None: + if first_day_dow is None: raise ValueError( - "first_observed_dow is required when day_of_week_rv is set." + "first_day_dow is required when day_of_week_rv is set." ) - predicted_daily = self._apply_day_of_week( - predicted_daily, first_observed_idx, first_observed_dow - ) + predicted_daily = self._apply_day_of_week(predicted_daily, first_day_dow) if self.right_truncation_rv is not None and right_truncation_offset is not None: predicted_daily = self._apply_right_truncation( predicted_daily, right_truncation_offset @@ -673,8 +666,6 @@ def sample( right_truncation_offset: int | None = None, first_day_dow: int | None = None, period_end_times: ArrayLike | None = None, - first_observed_idx: int | None = None, - first_observed_dow: int | None = None, ) -> ObservationSample: """ Sample aggregated counts. @@ -709,28 +700,15 @@ def sample( construction), apply right-truncation adjustment to the daily predictions. first_day_dow - Day-of-week index of element 0 of the shared time axis - (0=Monday, 6=Sunday, ISO convention). Required when - ``aggregation == "weekly"``; also serves as the - backwards-compatible day-of-week anchor for direct - callers without padding (used as ``first_observed_dow`` - when the latter is not supplied). + Day-of-week index of the first timepoint on the shared + time axis (0=Monday, 6=Sunday, ISO convention). Required + when ``day_of_week_rv`` was set at construction or when + ``aggregation == "weekly"``. This aligns observation-level + day-of-week effects or weekly aggregation to the shared daily + model axis. period_end_times Daily-axis indices of each observed period's final day. Required when ``reporting_schedule == "irregular"``. - first_observed_idx - Index on the shared time axis where the first observation - period begins. Day-of-week effects are applied only to - positions ``[first_observed_idx:]``. Supplied by - ``MultiSignalModel`` as ``n_initialization_points``. - Defaults to ``0`` for direct callers without padding. - first_observed_dow - Day-of-week of the first observed period - (0=Monday, 6=Sunday, ISO convention). Required when - ``day_of_week_rv`` was set at construction. Supplied by - ``MultiSignalModel`` as ``obs_start_date.weekday()``. - When not supplied, falls back to ``first_day_dow`` for - backwards compatibility with direct unpadded callers. Returns ------- @@ -740,16 +718,8 @@ def sample( grid; equal to daily predictions when ``aggregation == "daily"``). """ - if first_observed_idx is None: - first_observed_idx = 0 - if first_observed_dow is None: - first_observed_dow = first_day_dow predicted = self._compute_predicted( - infections, - first_day_dow, - right_truncation_offset, - first_observed_idx=first_observed_idx, - first_observed_dow=first_observed_dow, + infections, first_day_dow, right_truncation_offset ) if self.reporting_schedule == "regular": @@ -933,8 +903,6 @@ def sample( first_day_dow: int | None = None, period_end_times: ArrayLike | None = None, subpop_indices: ArrayLike | None = None, - first_observed_idx: int | None = None, - first_observed_dow: int | None = None, ) -> ObservationSample: """ Sample subpopulation-level counts. @@ -970,11 +938,12 @@ def sample( construction), apply right-truncation adjustment to the daily predictions. first_day_dow - Day-of-week index of element 0 of the shared time axis - (0=Monday, 6=Sunday, ISO convention). Required when - ``aggregation == "weekly"``; also serves as the - backwards-compatible day-of-week anchor for direct - callers without padding. + Day-of-week index of the first timepoint on the shared + time axis (0=Monday, 6=Sunday, ISO convention). Required + when ``day_of_week_rv`` was set at construction or when + ``aggregation == "weekly"``. This aligns observation-level + day-of-week effects or weekly aggregation to the shared daily + model axis. period_end_times Daily-axis indices of each observed period's final day. Required when ``reporting_schedule == "irregular"``. @@ -985,18 +954,6 @@ def sample( columns of the aggregated array enter the likelihood. For ``reporting_schedule="irregular"``: shape ``(n_obs,)`` with one subpopulation per observation. - first_observed_idx - Index on the shared time axis where the first observation - period begins. Day-of-week effects are applied only to - positions ``[first_observed_idx:]``. Supplied by - ``MultiSignalModel`` as ``n_initialization_points``. - Defaults to ``0`` for direct callers without padding. - first_observed_dow - Day-of-week of the first observed period - (0=Monday, 6=Sunday, ISO convention). Required when - ``day_of_week_rv`` was set at construction. Supplied by - ``MultiSignalModel`` as ``obs_start_date.weekday()``. - When not supplied, falls back to ``first_day_dow``. Returns ------- @@ -1009,16 +966,8 @@ def sample( if subpop_indices is None: raise ValueError(f"Observation '{self.name}': subpop_indices is required.") - if first_observed_idx is None: - first_observed_idx = 0 - if first_observed_dow is None: - first_observed_dow = first_day_dow predicted = self._compute_predicted( - infections, - first_day_dow, - right_truncation_offset, - first_observed_idx=first_observed_idx, - first_observed_dow=first_observed_dow, + infections, first_day_dow, right_truncation_offset ) if self.reporting_schedule == "regular": diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index d1482d22..5fd5b6ab 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -610,7 +610,7 @@ def test_dow_rv_without_offset_raises(self, simple_delay_pmf): infections = jnp.ones(20) * 1000 with numpyro.handlers.seed(rng_seed=42): - with pytest.raises(ValueError, match="first_observed_dow is required"): + with pytest.raises(ValueError, match="first_day_dow is required"): process.sample(infections=infections, obs=None, first_day_dow=None) def test_uniform_dow_effect_unchanged(self, simple_delay_pmf): @@ -856,7 +856,7 @@ def test_counts_by_subpop_dow_without_offset_raises(self): subpop_indices = jnp.array([0, 1]) with numpyro.handlers.seed(rng_seed=42): - with pytest.raises(ValueError, match="first_observed_dow is required"): + with pytest.raises(ValueError, match="first_day_dow is required"): process.sample( infections=infections, period_end_times=period_end_times, @@ -1651,17 +1651,20 @@ def log_p(dow_value: jnp.ndarray) -> jnp.ndarray: assert jnp.all(jnp.isfinite(grad)) -class TestDayOfWeekObservationAxis: +class TestDayOfWeekNanGradientSafety: """ - Tests for applying the day-of-week effect on the observation axis. - - Issue #824: the day-of-week multiplier must apply only to - positions at and after the first observed day, not across the - padded initialization period. Multiplying through the delay-tail - NaN region of ``predicted`` produces a second ``0 * NaN = NaN`` - cotangent path back to the day-of-week effect under autodiff, - independent of the masked-obs leak that ``TestScoreMaskedSafeObs`` - covers. + Tests for gradient-safe handling of the delay-tail NaN region. + + Issue #824: a multi-day delay PMF leaves + ``predicted[:len(delay_pmf)-1]`` as NaN before the day-of-week + multiplier is applied. The previous implementation tiled the + multiplier across the entire array; multiplying NaN by the + day-of-week vector produced a NaN cotangent through ``jnp.where`` + that leaked back to the day-of-week parameters under autodiff, + causing stochastic-DOW priors to diverge under NUTS. The + double-where pattern in ``_apply_day_of_week`` keeps the + multiplication gradient-safe while preserving the original NaN + positions in the output. """ @staticmethod @@ -1698,32 +1701,18 @@ def _padded_obs(n_total: int, n_init: int, value: float) -> jnp.ndarray: obs = jnp.full(n_total, value, dtype=jnp.float32) return obs.at[:n_init].set(jnp.nan) - def test_prefix_unchanged_by_dow_multiplier(self): + def test_delay_tail_nan_preserved_through_dow(self): """ - Positions before ``first_observed_idx`` are not multiplied. + ``predicted`` NaN entries remain NaN after the multiplier runs. - Built two processes with identical setup except for the DOW - effect (one uniform ``ones(7)``, the other non-uniform). On - the observation suffix the predictions differ by the DOW - ratio; on the initialization prefix the predictions are - identical because the prefix is excluded from the multiplier. + The double-where idiom restores the original NaN values at the + delay-tail positions, regardless of the day-of-week vector. """ - delay_pmf = jnp.array([1.0]) + delay_pmf = self._multi_day_delay_pmf() + n_tail = delay_pmf.shape[0] - 1 n_total = 21 - first_observed_idx = 7 - first_observed_dow = 0 - - infections = jnp.ones(n_total) * 1000.0 - - process_uniform = PopulationCounts( - name="uniform", - ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), - delay_distribution_rv=DeterministicPMF("delay", delay_pmf), - noise=PoissonNoise(), - day_of_week_rv=DeterministicVariable("dow", jnp.ones(7)), - ) - process_nonuniform = PopulationCounts( - name="nonuniform", + process = PopulationCounts( + name="test", ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), delay_distribution_rv=DeterministicPMF("delay", delay_pmf), noise=PoissonNoise(), @@ -1731,129 +1720,39 @@ def test_prefix_unchanged_by_dow_multiplier(self): "dow", jnp.array([2.0, 1.5, 1.0, 1.0, 0.5, 0.5, 0.5]) ), ) - - with numpyro.handlers.seed(rng_seed=0): - uniform = process_uniform.sample( - infections=infections, - obs=None, - first_observed_idx=first_observed_idx, - first_observed_dow=first_observed_dow, - ) - with numpyro.handlers.seed(rng_seed=0): - nonuniform = process_nonuniform.sample( - infections=infections, - obs=None, - first_observed_idx=first_observed_idx, - first_observed_dow=first_observed_dow, - ) - - assert jnp.allclose( - uniform.predicted[:first_observed_idx], - nonuniform.predicted[:first_observed_idx], - ) - assert not jnp.allclose( - uniform.predicted[first_observed_idx:], - nonuniform.predicted[first_observed_idx:], - ) - - def test_suffix_starts_at_first_observed_dow(self): - """ - The suffix multiplier starts at the supplied observed DOW. - - With ``first_observed_dow=6`` (Sunday), the multiplier at - ``predicted[first_observed_idx]`` is ``dow_effect[6]``, the - next position is ``dow_effect[0]`` (Monday), and so on. - """ - delay_pmf = jnp.array([1.0]) - n_total = 14 - first_observed_idx = 4 - first_observed_dow = 6 - dow_effect = jnp.array([2.0, 1.5, 1.0, 1.0, 0.5, 0.5, 3.0]) - - process = PopulationCounts( - name="test", - ascertainment_rate_rv=DeterministicVariable("ihr", 1.0), - delay_distribution_rv=DeterministicPMF("delay", delay_pmf), - noise=PoissonNoise(), - day_of_week_rv=DeterministicVariable("dow", dow_effect), - ) - infections = jnp.ones(n_total) * 100.0 + infections = jnp.ones(n_total) * 1000.0 with numpyro.handlers.seed(rng_seed=0): result = process.sample( infections=infections, obs=None, - first_observed_idx=first_observed_idx, - first_observed_dow=first_observed_dow, - ) - - assert jnp.isclose(result.predicted[first_observed_idx], 100.0 * dow_effect[6]) - assert jnp.isclose( - result.predicted[first_observed_idx + 1], 100.0 * dow_effect[0] - ) - assert jnp.isclose( - result.predicted[first_observed_idx + 7], 100.0 * dow_effect[6] - ) - - def test_zero_offset_matches_legacy_behavior(self): - """ - ``first_observed_idx=0`` reproduces the legacy "DOW across whole axis" behavior. - - Direct callers that pass only ``first_day_dow`` (e.g., the - existing ``TestDayOfWeek`` tests) must continue to see DOW - applied starting at index 0 with the supplied phase. - """ - delay_pmf = jnp.array([1.0]) - n_total = 14 - first_day_dow = 2 - dow_effect = jnp.array([2.0, 1.5, 1.0, 0.8, 0.7, 0.5, 0.5]) - - process = PopulationCounts( - name="test", - ascertainment_rate_rv=DeterministicVariable("ihr", 1.0), - delay_distribution_rv=DeterministicPMF("delay", delay_pmf), - noise=PoissonNoise(), - day_of_week_rv=DeterministicVariable("dow", dow_effect), - ) - infections = jnp.ones(n_total) * 100.0 - - with numpyro.handlers.seed(rng_seed=0): - legacy = process.sample( - infections=infections, - obs=None, - first_day_dow=first_day_dow, - ) - with numpyro.handlers.seed(rng_seed=0): - explicit = process.sample( - infections=infections, - obs=None, - first_observed_idx=0, - first_observed_dow=first_day_dow, + first_day_dow=0, ) - assert jnp.allclose(legacy.predicted, explicit.predicted) + assert jnp.all(jnp.isnan(result.predicted[:n_tail])) + assert jnp.all(jnp.isfinite(result.predicted[n_tail:])) def test_dow_gradient_finite_with_delay_tail_nan(self): """ Gradients are finite when ``predicted`` has a NaN delay-tail. - Reproduces the issue-#824 gradient blow-up in its full form: - a multi-day delay PMF makes ``predicted[:len(delay)-1]`` NaN, - and a non-trivial ``first_observed_idx`` places the DOW - multiplier outside that region. Before Fix B the DOW - multiplier was tiled across the NaN tail and ``0 * NaN`` - leaked NaN back to the DOW effect; after Fix B the gradient - is finite at every slot. + Reproduces the issue-#824 gradient blow-up: a multi-day delay + PMF makes ``predicted[:len(delay)-1]`` NaN, and the + day-of-week multiplier is tiled across the whole array. Before + the fix, ``NaN * dow_effect[i]`` at delay-tail positions + leaked a NaN cotangent back to ``dow_effect[i]`` through + ``jnp.where``. With the double-where pattern the inner + multiplication operates on a NaN-free surrogate, so the + gradient is finite at every slot. """ delay_pmf = self._multi_day_delay_pmf() n_total = 21 - first_observed_idx = 5 - first_observed_dow = 0 + n_init = 5 infections = jnp.ones(n_total) * 1000.0 - obs = self._padded_obs(n_total, first_observed_idx, value=5.0) + obs = self._padded_obs(n_total, n_init, value=5.0) def model(dow_value: jnp.ndarray) -> None: - """Sample with the given DOW effect on the observation suffix.""" + """Sample with the given DOW effect over the full time axis.""" process = PopulationCounts( name="test", ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), @@ -1864,8 +1763,7 @@ def model(dow_value: jnp.ndarray) -> None: process.sample( infections=infections, obs=obs, - first_observed_idx=first_observed_idx, - first_observed_dow=first_observed_dow, + first_day_dow=2, ) def log_p(dow_value: jnp.ndarray) -> jnp.ndarray: