Skip to content

Day-of-week effects applied on observation time axis (i.e., not before time 0)#827

Open
cdc-mitzimorris wants to merge 50 commits into
mainfrom
mem_824_day_of_week_bug
Open

Day-of-week effects applied on observation time axis (i.e., not before time 0)#827
cdc-mitzimorris wants to merge 50 commits into
mainfrom
mem_824_day_of_week_bug

Conversation

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator

The renewal process samples infections on a shared daily time axis that begins n_init days before the first observed report. Observations occupy the suffix of this axis; the prefix is initialization padding used by the convolution but never reported. The bug arises because the day-of-week reporting effect, and the masked-likelihood machinery that handles unobserved days, both leak information from the unobserved prefix into the gradient computation that drives MCMC. Two distinct paths are involved, and both must be closed.

Change 1 — safe_obs placeholder for masked count observations

In _score_masked, replace safe_obs = jnp.where(jnp.isnan(obs), safe_predicted, obs) with safe_obs = jnp.where(jnp.isnan(obs), 0.0, obs).

On a padded MultiSignalModel, the observation array is dense across the full time axis, with NaN entries marking unobserved days (the entire initialization prefix, plus any gaps inside the observation window). The likelihood masks those positions so they don't contribute to the model's log-density. But every position still has a log_prob evaluated against the noise distribution; the masked ones are then zeroed out before summing.

The old placeholder used the model's own predicted mean as the dummy observation at masked positions. Predicted means are fractional (e.g., 7.42 reported visits), while count distributions like NegativeBinomial2 are only defined on whole-number counts. Feeding fractional values into a count log_prob returns -inf, which is then zeroed by the mask in the forward direction — but the gradient computation that NUTS needs in order to take its next step propagates that pathology backward, and the parameters feeding the predicted mean (notably any stochastic day-of-week prior) receive undefined gradients. NUTS interprets the resulting trajectory as divergent on every iteration.

0.0 is a valid whole-number count for every distribution we support, so the masked-position log_prob is now a finite, smoothly-differentiable number. It still doesn't contribute to the likelihood (the mask still zeroes it), but the gradient flowing backward through the mask is now a well-defined zero rather than something undefined.

Change 2 — Apply day-of-week effect to the observation suffix only

_apply_day_of_week now multiplies the day-of-week vector against predicted[first_observed_idx:], leaving the initialization prefix untouched. MultiSignalModel computes first_observed_idx = n_init and first_observed_dow = obs_start_date.weekday() from the user-supplied calendar anchor and threads them through every observation process.

Conceptually, the day-of-week effect is a reporting effect: it scales counts on actual reporting days. The initialization prefix has no reports, so applying the effect there has no physical meaning. As implemented, the effect was applied across the entire model time axis, aligned to the day-of-week of element 0 of the padded axis — which works numerically when the effect is a fixed deterministic vector, but causes failures when doing inference.

The mechanism: ascertainment × delay convolution leaves the first few prefix positions of predicted undefined (the convolution has no infection history to draw from yet). Those undefined positions then got multiplied by day-of-week values, contaminating the multiplier itself with the same kind of undefined-gradient pathology described above. Once the multiplier carries a hidden defect, the gradient with respect to the day-of-week prior is undefined, and again NUTS diverges. Restricting the multiplication to the suffix means undefined prefix values never meet the day-of-week parameters, so the gradient stays well-defined.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes two NaN-cotangent leaks that caused NUTS divergences in MultiSignalModel when day-of-week effects were used alongside padded initialization periods. The first fix replaces the masked-likelihood placeholder for count observations with 0.0 (an in-support integer for count distributions) so that backward gradients through the mask remain finite. The second fix restricts the day-of-week multiplier to the observation suffix (predicted[first_observed_idx:]), so the NaN delay-tail in the initialization prefix never feeds into the DOW parameter gradient. MultiSignalModel now derives and threads first_observed_idx/first_observed_dow from obs_start_date.

Changes:

  • _score_masked placeholder switched from safe_predicted to 0.0 for masked obs entries.
  • _apply_day_of_week now multiplies only predicted[first_observed_idx:], anchored at first_observed_dow; PopulationCounts.sample and SubpopulationCounts.sample accept new optional first_observed_idx/first_observed_dow kwargs with backwards-compatible fallbacks to first_day_dow.
  • MultiSignalModel.sample resolves first_observed_dow from obs_start_date and passes both new kwargs (plus first_observed_idx = n_initialization_points) to each observation process.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
pyrenew/observation/count_observations.py Replaces fractional masked-obs placeholder with 0.0; restricts DOW multiplier to observation suffix; adds first_observed_idx/first_observed_dow kwargs to _apply_day_of_week, _compute_predicted, and both sample methods.
pyrenew/model/multisignal_model.py Adds _resolve_first_observed_dow helper and threads first_observed_idx/first_observed_dow into each observation process call.
test/test_observation_counts.py New TestScoreMaskedSafeObs and TestDayOfWeekObservationAxis test classes covering masked-obs placeholder behavior, suffix-only DOW application, the zero-offset legacy fallback, and finite-gradient regression tests through jax.grad.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread pyrenew/observation/count_observations.py Outdated
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 18, 2026

Thank you for your contribution @cdc-mitzimorris 🚀! Your github-pages is ready for download 👉 here 👈!
(The artifact expires on 2026-05-25T23:01:11Z. You can re-generate it by re-running the workflow here.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants