Day-of-week effects applied on observation time axis (i.e., not before time 0)#827
Day-of-week effects applied on observation time axis (i.e., not before time 0)#827cdc-mitzimorris wants to merge 50 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
Fixes two NaN-cotangent leaks that caused NUTS divergences in MultiSignalModel when day-of-week effects were used alongside padded initialization periods. The first fix replaces the masked-likelihood placeholder for count observations with 0.0 (an in-support integer for count distributions) so that backward gradients through the mask remain finite. The second fix restricts the day-of-week multiplier to the observation suffix (predicted[first_observed_idx:]), so the NaN delay-tail in the initialization prefix never feeds into the DOW parameter gradient. MultiSignalModel now derives and threads first_observed_idx/first_observed_dow from obs_start_date.
Changes:
_score_maskedplaceholder switched fromsafe_predictedto0.0for masked obs entries._apply_day_of_weeknow multiplies onlypredicted[first_observed_idx:], anchored atfirst_observed_dow;PopulationCounts.sampleandSubpopulationCounts.sampleaccept new optionalfirst_observed_idx/first_observed_dowkwargs with backwards-compatible fallbacks tofirst_day_dow.MultiSignalModel.sampleresolvesfirst_observed_dowfromobs_start_dateand passes both new kwargs (plusfirst_observed_idx = n_initialization_points) to each observation process.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| pyrenew/observation/count_observations.py | Replaces fractional masked-obs placeholder with 0.0; restricts DOW multiplier to observation suffix; adds first_observed_idx/first_observed_dow kwargs to _apply_day_of_week, _compute_predicted, and both sample methods. |
| pyrenew/model/multisignal_model.py | Adds _resolve_first_observed_dow helper and threads first_observed_idx/first_observed_dow into each observation process call. |
| test/test_observation_counts.py | New TestScoreMaskedSafeObs and TestDayOfWeekObservationAxis test classes covering masked-obs placeholder behavior, suffix-only DOW application, the zero-offset legacy fallback, and finite-gradient regression tests through jax.grad. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
|
Thank you for your contribution @cdc-mitzimorris 🚀! Your github-pages is ready for download 👉 here 👈! |
The renewal process samples infections on a shared daily time axis that begins n_init days before the first observed report. Observations occupy the suffix of this axis; the prefix is initialization padding used by the convolution but never reported. The bug arises because the day-of-week reporting effect, and the masked-likelihood machinery that handles unobserved days, both leak information from the unobserved prefix into the gradient computation that drives MCMC. Two distinct paths are involved, and both must be closed.
Change 1 —
safe_obsplaceholder for masked count observationsIn
_score_masked, replacesafe_obs = jnp.where(jnp.isnan(obs), safe_predicted, obs)withsafe_obs = jnp.where(jnp.isnan(obs), 0.0, obs).On a padded MultiSignalModel, the observation array is dense across the full time axis, with NaN entries marking unobserved days (the entire initialization prefix, plus any gaps inside the observation window). The likelihood masks those positions so they don't contribute to the model's log-density. But every position still has a log_prob evaluated against the noise distribution; the masked ones are then zeroed out before summing.
The old placeholder used the model's own predicted mean as the dummy observation at masked positions. Predicted means are fractional (e.g., 7.42 reported visits), while count distributions like NegativeBinomial2 are only defined on whole-number counts. Feeding fractional values into a count log_prob returns -inf, which is then zeroed by the mask in the forward direction — but the gradient computation that NUTS needs in order to take its next step propagates that pathology backward, and the parameters feeding the predicted mean (notably any stochastic day-of-week prior) receive undefined gradients. NUTS interprets the resulting trajectory as divergent on every iteration.
0.0is a valid whole-number count for every distribution we support, so the masked-position log_prob is now a finite, smoothly-differentiable number. It still doesn't contribute to the likelihood (the mask still zeroes it), but the gradient flowing backward through the mask is now a well-defined zero rather than something undefined.Change 2 — Apply day-of-week effect to the observation suffix only
_apply_day_of_weeknow multiplies the day-of-week vector againstpredicted[first_observed_idx:], leaving the initialization prefix untouched. MultiSignalModel computesfirst_observed_idx = n_initandfirst_observed_dow = obs_start_date.weekday()from the user-supplied calendar anchor and threads them through every observation process.Conceptually, the day-of-week effect is a reporting effect: it scales counts on actual reporting days. The initialization prefix has no reports, so applying the effect there has no physical meaning. As implemented, the effect was applied across the entire model time axis, aligned to the day-of-week of element 0 of the padded axis — which works numerically when the effect is a fixed deterministic vector, but causes failures when doing inference.
The mechanism: ascertainment × delay convolution leaves the first few prefix positions of predicted undefined (the convolution has no infection history to draw from yet). Those undefined positions then got multiplied by day-of-week values, contaminating the multiplier itself with the same kind of undefined-gradient pathology described above. Once the multiplier carries a hidden defect, the gradient with respect to the day-of-week prior is undefined, and again NUTS diverges. Restricting the multiplication to the suffix means undefined prefix values never meet the day-of-week parameters, so the gradient stays well-defined.