Skip to content

BNN forward pass unification across predict and update#131

Open
shaharbar1 wants to merge 1 commit intodevelopfrom
feature/forward_pass_unification
Open

BNN forward pass unification across predict and update#131
shaharbar1 wants to merge 1 commit intodevelopfrom
feature/forward_pass_unification

Conversation

@shaharbar1
Copy link
Copy Markdown
Collaborator

@shaharbar1 shaharbar1 commented Mar 31, 2026

Changes:

  • Add _forward_layers to BaseBayesianNeuralNetwork with injectable linear_fn and xp params for backend-agnostic layer loop
  • Refactor model() in create_update_model to call _forward_layers with JAX backend (jnp.dot, jnp)
  • Refactor forward_pass() to call _forward_layers with NumPy backend (np.einsum, np)
  • Add _Array = Union[np.ndarray, jax.Array] type alias and ModuleType import for precise type hints on _forward_layers
  • Remove redundant n_layers local in create_update_model and w_shape from weights_biases tuples

Summary by CodeRabbit

  • Chores

    • Released version 6.0.3
  • Refactor

    • Consolidated per-layer model computation into a backend-agnostic implementation to improve maintainability and ensure consistent behavior across runtimes
  • Tests

    • Added tests verifying both sampling and updating use the same forward computation and behave consistently

@shaharbar1 shaharbar1 added the enhancement New feature or request label Mar 31, 2026
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 31, 2026

📝 Walkthrough

Walkthrough

Added a backend-agnostic _forward_layers(...) helper and refactored both NumPyro/JAX and NumPy forward paths in pybandits/model.py to use it, changed per-layer storage to (w, b), added typing/imports, bumped package version, and added a test asserting _forward_layers is used.

Changes

Cohort / File(s) Summary
Model refactor
pybandits/model.py
Introduced _forward_layers(...) to centralize per-layer linear transform, activation, and optional residual logic. Replaced inline per-layer loops in NumPyro/JAX create_update_model and NumPy forward_pass with calls to this helper. Changed weights_biases payload to per-layer (w, b) and added _Array typing and ModuleType import.
Version bump
pyproject.toml
Bumped package version from 6.0.1 to 6.0.3.
Tests
tests/test_model.py
Added test_bnn_sample_proba_and_update_both_use_forward_layers that patches BaseBayesianNeuralNetwork._forward_layers to count invocations and asserts sample_proba and update call the helper the expected number of times for VI and MCMC flows.

Estimated Code Review Effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

Poem

🐇 I hopped through weights and bias so neat,
I joined NumPy, JAX in a single beat,
One helper now guides each layered run,
Tests counted my hops — the job was done,
A rabbit's patch, compact and sweet.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 62.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: unification of the forward pass logic used across both prediction (forward_pass) and model update (create_update_model) paths via a new _forward_layers helper.
Description check ✅ Passed The description follows the template with all required sections (Changes, Tests added, Documentation) completed, providing specific implementation details that map to actual file changes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feature/forward_pass_unification

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Comment thread pybandits/model.py
self.__dict__.update(state)
self._init_private_attrs()

def _forward_layers(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should you add a test that verifies the two paths are equivalent?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Take a look...

@shaharbar1 shaharbar1 force-pushed the feature/forward_pass_unification branch from 4cc6e3a to e1decb8 Compare April 5, 2026 17:48
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/test_model.py (1)

1672-1706: Test may be fragile for MCMC due to default num_chains=2.

The test correctly verifies that both sample_proba and update call _forward_layers. However, for MCMC, the test doesn't specify num_chains, so the default num_chains=2 from _default_mcmc_kwargs is applied. Depending on NumPyro's internal behavior, this could cause the model to be traced more than once (e.g., once per chain), making the strict call_count == ref assertion fail intermittently.

Consider either:

  1. Explicitly setting num_chains=1 for MCMC to ensure deterministic behavior, or
  2. Changing the assertion to call_count >= ref to verify _forward_layers is called at least once
♻️ Suggested fix to ensure deterministic MCMC call count
 `@pytest.mark.parametrize`(
     "update_method, update_kwargs",
-    [("VI", {"num_steps": 2}), ("MCMC", {"num_warmup": 2, "num_samples": 2})],
+    [("VI", {"num_steps": 2}), ("MCMC", {"num_warmup": 2, "num_samples": 2, "num_chains": 1})],
 )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_model.py` around lines 1672 - 1706, The test is fragile for MCMC
because the default num_chains=2 can cause multiple model traces; update the
test_bnn_sample_proba_and_update_both_use_forward_layers to ensure deterministic
behavior by setting num_chains=1 when update_method == "MCMC" (e.g., include
"num_chains": 1 in update_kwargs passed to BayesianNeuralNetwork.cold_start),
keeping the assertions that call_count == ref; alternatively, if you prefer
looser verification, change the two assertions to assert call_count >= ref to
allow multiple traces; locate BaseBayesianNeuralNetwork._forward_layers and
BayesianNeuralNetwork.cold_start in the test to apply the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/test_model.py`:
- Around line 1672-1706: The test is fragile for MCMC because the default
num_chains=2 can cause multiple model traces; update the
test_bnn_sample_proba_and_update_both_use_forward_layers to ensure deterministic
behavior by setting num_chains=1 when update_method == "MCMC" (e.g., include
"num_chains": 1 in update_kwargs passed to BayesianNeuralNetwork.cold_start),
keeping the assertions that call_count == ref; alternatively, if you prefer
looser verification, change the two assertions to assert call_count >= ref to
allow multiple traces; locate BaseBayesianNeuralNetwork._forward_layers and
BayesianNeuralNetwork.cold_start in the test to apply the change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ae78c1b0-e38d-4692-91f1-96767ab4666b

📥 Commits

Reviewing files that changed from the base of the PR and between 4cc6e3a and e1decb8.

📒 Files selected for processing (3)
  • pybandits/model.py
  • pyproject.toml
  • tests/test_model.py
✅ Files skipped from review due to trivial changes (1)
  • pyproject.toml

### Changes:
* Add `_forward_layers` to `BaseBayesianNeuralNetwork` with injectable `linear_fn` and `xp` params for backend-agnostic layer loop
* Refactor `model()` in `create_update_model` to call `_forward_layers` with JAX backend (`jnp.dot`, `jnp`)
* Refactor `forward_pass()` to call `_forward_layers` with NumPy backend (`np.einsum`, `np`)
* Add `_Array = Union[np.ndarray, jax.Array]` type alias and `ModuleType` import for precise type hints on `_forward_layers`
* Remove redundant `n_layers` local in `create_update_model` and `w_shape` from `weights_biases` tuples
@shaharbar1 shaharbar1 force-pushed the feature/forward_pass_unification branch from e1decb8 to 8349128 Compare April 6, 2026 11:34
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pybandits/model.py`:
- Around line 1750-1759: The final logit used by forward_pass diverges from
create_update_model because create_update_model clips the last-layer output to
[-15, 15] but forward_pass feeds the raw weighted_sum into _numpy_sigmoid; fix
this by applying the same clipping to the last-layer logit in forward_pass
(i.e., after computing weighted_sum from linear_transform.squeeze(-1) and before
calling _numpy_sigmoid) so both paths use identical logits (alternatively,
remove clipping from create_update_model if you intend no clipping, but ensure
both forward_pass and create_update_model use the same behavior).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7738dccb-d748-4fb5-94e7-4c09c043871a

📥 Commits

Reviewing files that changed from the base of the PR and between e1decb8 and 8349128.

📒 Files selected for processing (3)
  • pybandits/model.py
  • pyproject.toml
  • tests/test_model.py
✅ Files skipped from review due to trivial changes (2)
  • pyproject.toml
  • tests/test_model.py

Comment thread pybandits/model.py
Comment on lines +1750 to +1759
linear_transform = self._forward_layers(
next_layer_input,
sampled_weights,
self._numpy_activation_fn,
lambda x, w, b: np.einsum("...i,...ij->...j", x, w) + b,
np,
)

weighted_sum = linear_transform.squeeze(-1)
prob = _numpy_sigmoid(weighted_sum)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Predict and update still diverge on the final logit.

create_update_model() clips the last-layer output to [-15, 15] before using it, but forward_pass() now feeds the raw value into _numpy_sigmoid(). So the two paths can still return different probabilities for the same weights/context when logits get large, even though the hidden-layer loop is shared. If exact parity is the goal, clip here too or remove the clip from the update path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pybandits/model.py` around lines 1750 - 1759, The final logit used by
forward_pass diverges from create_update_model because create_update_model clips
the last-layer output to [-15, 15] but forward_pass feeds the raw weighted_sum
into _numpy_sigmoid; fix this by applying the same clipping to the last-layer
logit in forward_pass (i.e., after computing weighted_sum from
linear_transform.squeeze(-1) and before calling _numpy_sigmoid) so both paths
use identical logits (alternatively, remove clipping from create_update_model if
you intend no clipping, but ensure both forward_pass and create_update_model use
the same behavior).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants