Fix BatchLM vmap+jacfwd failures caused by in-place tensor mutations#301
Conversation
Agent-Logs-Url: https://github.com/Autostronomy/AstroPhot/sessions/d13bb976-bdfb-4fba-8140-58687f11e916 Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
Agent-Logs-Url: https://github.com/Autostronomy/AstroPhot/sessions/d13bb976-bdfb-4fba-8140-58687f11e916 Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
Agent-Logs-Url: https://github.com/Autostronomy/AstroPhot/sessions/d13bb976-bdfb-4fba-8140-58687f11e916 Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## dev #301 +/- ##
==========================================
+ Coverage 91.23% 91.54% +0.30%
==========================================
Files 113 113
Lines 6198 6194 -4
==========================================
+ Hits 5655 5670 +15
+ Misses 543 524 -19
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
Fixes BatchLM failures under vmap(jacfwd(...)) by removing in-place tensor mutations (and some bool-mask dynamic-shape patterns) from the Torch execution path, and adds regression tests to cover the affected integration/sampling configurations.
Changes:
- Updated Torch backend indexed write/add helpers to use functional
index_putfor integer tensor indices (forward-mode AD compatible). - Reworked bool-mask accumulation paths in brightness mixins and spline extrapolation to use
where-based masking (static shapes undervmap). - Added
BatchLMtest coverage acrossintegrate_mode,sampling_mode, rotated WCS/CD, and Poisson likelihood.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
astrophot/backend_obj.py |
Switches Torch indexed updates to index_put for LongTensor indices; removes unused and_at_indices. |
astrophot/models/mixins/brightness.py |
Replaces bool-mask indexed updates with where masking to keep shapes static under vmap+jacfwd. |
astrophot/models/func/spline.py |
Replaces extrapolation-time masked fills with where to avoid bool-index mutation/shape issues. |
tests/test_fit.py |
Adds targeted regression tests ensuring BatchLM runs across integration/sampling modes and likelihood variants. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
BatchLMusesvmap(jacfwd(...))to compute batched Jacobians. PyTorch's forward-mode AD cannot trace through in-place tensor mutations, causing failures when anyintegrate_modeother than"none"was used.Core fix: backend index operations
_fill_at_indices_torchand_add_at_indices_torchnow usetorch.index_put(functional, out-of-place) for LongTensor indices — the path taken bytopk-based adaptive integration. Bool/tuple/slice indices fall back toclone()+ in-place, which preserves compatibility for non-differentiated callers.Sweep for bool-mask in-place operations in brightness functions
vmapadditionally rejects operations that produce dynamic shapes (bool-indexed selection). Fixed in three places:spline.py— extrapolation boundary: replacedfill_at_indices(I, R > profR[-1], 0)withbackend.where(R > profR[-1], zeros, I).WedgeMixin.polar_model— replacedadd_at_indices(model, bool_mask, iradial_model(s, R[mask]))withmodel + where(mask, iradial_model(s, R), zeros). The radial model is now evaluated over the full pixel array; the mask is applied viawhereto maintain static shapes.RayMixin.polar_model— same pattern for cosine-weighted segment accumulation.New tests
Added
BatchLM-specific tests intest_fit.pycovering:integrate_modevalues:"none","bright","curvature"sampling_modevalues:"midpoint","simpsons","quad:3"