BNN forward pass unification across predict and update#131
BNN forward pass unification across predict and update#131shaharbar1 wants to merge 1 commit intodevelopfrom
Conversation
📝 WalkthroughWalkthroughAdded a backend-agnostic Changes
Estimated Code Review Effort🎯 3 (Moderate) | ⏱️ ~22 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
| self.__dict__.update(state) | ||
| self._init_private_attrs() | ||
|
|
||
| def _forward_layers( |
There was a problem hiding this comment.
should you add a test that verifies the two paths are equivalent?
There was a problem hiding this comment.
Done. Take a look...
4cc6e3a to
e1decb8
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/test_model.py (1)
1672-1706: Test may be fragile for MCMC due to defaultnum_chains=2.The test correctly verifies that both
sample_probaandupdatecall_forward_layers. However, for MCMC, the test doesn't specifynum_chains, so the defaultnum_chains=2from_default_mcmc_kwargsis applied. Depending on NumPyro's internal behavior, this could cause the model to be traced more than once (e.g., once per chain), making the strictcall_count == refassertion fail intermittently.Consider either:
- Explicitly setting
num_chains=1for MCMC to ensure deterministic behavior, or- Changing the assertion to
call_count >= refto verify_forward_layersis called at least once♻️ Suggested fix to ensure deterministic MCMC call count
`@pytest.mark.parametrize`( "update_method, update_kwargs", - [("VI", {"num_steps": 2}), ("MCMC", {"num_warmup": 2, "num_samples": 2})], + [("VI", {"num_steps": 2}), ("MCMC", {"num_warmup": 2, "num_samples": 2, "num_chains": 1})], )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_model.py` around lines 1672 - 1706, The test is fragile for MCMC because the default num_chains=2 can cause multiple model traces; update the test_bnn_sample_proba_and_update_both_use_forward_layers to ensure deterministic behavior by setting num_chains=1 when update_method == "MCMC" (e.g., include "num_chains": 1 in update_kwargs passed to BayesianNeuralNetwork.cold_start), keeping the assertions that call_count == ref; alternatively, if you prefer looser verification, change the two assertions to assert call_count >= ref to allow multiple traces; locate BaseBayesianNeuralNetwork._forward_layers and BayesianNeuralNetwork.cold_start in the test to apply the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/test_model.py`:
- Around line 1672-1706: The test is fragile for MCMC because the default
num_chains=2 can cause multiple model traces; update the
test_bnn_sample_proba_and_update_both_use_forward_layers to ensure deterministic
behavior by setting num_chains=1 when update_method == "MCMC" (e.g., include
"num_chains": 1 in update_kwargs passed to BayesianNeuralNetwork.cold_start),
keeping the assertions that call_count == ref; alternatively, if you prefer
looser verification, change the two assertions to assert call_count >= ref to
allow multiple traces; locate BaseBayesianNeuralNetwork._forward_layers and
BayesianNeuralNetwork.cold_start in the test to apply the change.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ae78c1b0-e38d-4692-91f1-96767ab4666b
📒 Files selected for processing (3)
pybandits/model.pypyproject.tomltests/test_model.py
✅ Files skipped from review due to trivial changes (1)
- pyproject.toml
### Changes: * Add `_forward_layers` to `BaseBayesianNeuralNetwork` with injectable `linear_fn` and `xp` params for backend-agnostic layer loop * Refactor `model()` in `create_update_model` to call `_forward_layers` with JAX backend (`jnp.dot`, `jnp`) * Refactor `forward_pass()` to call `_forward_layers` with NumPy backend (`np.einsum`, `np`) * Add `_Array = Union[np.ndarray, jax.Array]` type alias and `ModuleType` import for precise type hints on `_forward_layers` * Remove redundant `n_layers` local in `create_update_model` and `w_shape` from `weights_biases` tuples
e1decb8 to
8349128
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@pybandits/model.py`:
- Around line 1750-1759: The final logit used by forward_pass diverges from
create_update_model because create_update_model clips the last-layer output to
[-15, 15] but forward_pass feeds the raw weighted_sum into _numpy_sigmoid; fix
this by applying the same clipping to the last-layer logit in forward_pass
(i.e., after computing weighted_sum from linear_transform.squeeze(-1) and before
calling _numpy_sigmoid) so both paths use identical logits (alternatively,
remove clipping from create_update_model if you intend no clipping, but ensure
both forward_pass and create_update_model use the same behavior).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7738dccb-d748-4fb5-94e7-4c09c043871a
📒 Files selected for processing (3)
pybandits/model.pypyproject.tomltests/test_model.py
✅ Files skipped from review due to trivial changes (2)
- pyproject.toml
- tests/test_model.py
| linear_transform = self._forward_layers( | ||
| next_layer_input, | ||
| sampled_weights, | ||
| self._numpy_activation_fn, | ||
| lambda x, w, b: np.einsum("...i,...ij->...j", x, w) + b, | ||
| np, | ||
| ) | ||
|
|
||
| weighted_sum = linear_transform.squeeze(-1) | ||
| prob = _numpy_sigmoid(weighted_sum) |
There was a problem hiding this comment.
Predict and update still diverge on the final logit.
create_update_model() clips the last-layer output to [-15, 15] before using it, but forward_pass() now feeds the raw value into _numpy_sigmoid(). So the two paths can still return different probabilities for the same weights/context when logits get large, even though the hidden-layer loop is shared. If exact parity is the goal, clip here too or remove the clip from the update path.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@pybandits/model.py` around lines 1750 - 1759, The final logit used by
forward_pass diverges from create_update_model because create_update_model clips
the last-layer output to [-15, 15] but forward_pass feeds the raw weighted_sum
into _numpy_sigmoid; fix this by applying the same clipping to the last-layer
logit in forward_pass (i.e., after computing weighted_sum from
linear_transform.squeeze(-1) and before calling _numpy_sigmoid) so both paths
use identical logits (alternatively, remove clipping from create_update_model if
you intend no clipping, but ensure both forward_pass and create_update_model use
the same behavior).
Changes:
_forward_layerstoBaseBayesianNeuralNetworkwith injectablelinear_fnandxpparams for backend-agnostic layer loopmodel()increate_update_modelto call_forward_layerswith JAX backend (jnp.dot,jnp)forward_pass()to call_forward_layerswith NumPy backend (np.einsum,np)_Array = Union[np.ndarray, jax.Array]type alias andModuleTypeimport for precise type hints on_forward_layersn_layerslocal increate_update_modelandw_shapefromweights_biasestuplesSummary by CodeRabbit
Chores
Refactor
Tests