Skip to content

Include PFJax Support#121

Draft
DanWaxman wants to merge 8 commits intomainfrom
dw-include-pfjax
Draft

Include PFJax Support#121
DanWaxman wants to merge 8 commits intomainfrom
dw-include-pfjax

Conversation

@DanWaxman
Copy link
Collaborator

No description provided.

Needs to be this way until cd-dynamax allows for higher jax versions
This adds a particle filter and marginal particle filter backend for PFJax. It includes the filter configs, glue code, and unittesting.

Lots of this code is AI-generated; I've looked over & used to first approximation, but it requires a closer inspection (and likely some pairing down) before seriously merging.
Copy link
Collaborator

@mattlevine22 mattlevine22 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me, would be happy with including with a commented notebook (or dropping the notebook as a private thing)

description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.12,<3.14"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oooo that's a little scary, let's try to remember that PFJax is what pushed us to do this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's actually more of a CD-Dynamax problem.

PFJax requires jax>=0.7.2 for Python 3.14: https://github.com/mlysy/pfjax/blob/754f11c3937e25f8ed1045424a2af5fbad553cc3/pyproject.toml#L26C1-L33C2

This is (directly after) when jax started providing 3.14 wheels: https://docs.jax.dev/en/latest/changelog.html#jax-0-7-1-august-20-2025

But cd_dynamax has a pinned jax which is lower (0.6.2, I think? Something like that).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof got it, let's remember THAT then, and maybe deal with it in CD Dynamax refactor

assert log_weights.shape == (obs_times.shape[0], 64)


def test_discrete_pfjax_marginal_particle_filter_records_outputs() -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this test and the previous can be combined with a param that chooses PFConfig or MarginalPFConfig (with everything else looking the same).

@@ -0,0 +1,596 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obviously you'll probably clean-up/comment the notebook, but I would start by asking "what do we want a reader to gain from this notebook example?"

  • how to use blackjax with our stuff?
  • PF vs MarginalPF vs EKF for parameter inference in this setting?
  • other things?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm writing Blackjax wrappers in a separate PR, that should make this one nicer. I was thinking of it as a "paper reproduction" PR, of "hey look at this recent complicated method; it fits naturally in our paradigm in 5 lines of code, and changing a single line gets you one of X related inference methods." A nice story could be showing PF works okay, marginal PF works better, EKF is fast but biased, EnKF is fast and probably less biased if in a second problem we have Gaussian observations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great. Do we want it to be a "deep dive" in that case?

We may want a place for non-doc scripts/notebooks for when we do something quickly and don't want to "polish" it. Alternatively we can banish this practice and push ourselves to have nice notebooks (we've been pretty good about this so far).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I have it as a deep dive now (unless I misread and you were suggesting it be elsewhere). I would like to polish this -- it would be nice, I think, to have a "series" of these notebooks reproducing several relevant papers (e.g., ensemble MCMC).

description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.12,<3.14"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was it PFJax or blackjax that created the need for <3.14? Let's try to remember this somehow so we aren't struggling when we need to update.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As commented above, it's actually CD-Dynamax

mattlevine22 added a commit that referenced this pull request Mar 10, 2026
In experimenting for #121, I found that Blackjax HMC has significantly
better performance in some use cases for whatever reason (I suspect to
do with CRN?). This PR adds support for a number of MCMC wrappers,
including HMC, NUTS, MALA, and SGLD.

Introduces `MCMCInference` API and a notebook to show these comparisons.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants