Conversation
Needs to be this way until cd-dynamax allows for higher jax versions
This adds a particle filter and marginal particle filter backend for PFJax. It includes the filter configs, glue code, and unittesting. Lots of this code is AI-generated; I've looked over & used to first approximation, but it requires a closer inspection (and likely some pairing down) before seriously merging.
mattlevine22
left a comment
There was a problem hiding this comment.
Overall looks good to me, would be happy with including with a commented notebook (or dropping the notebook as a private thing)
| description = "Add your description here" | ||
| readme = "README.md" | ||
| requires-python = ">=3.12" | ||
| requires-python = ">=3.12,<3.14" |
There was a problem hiding this comment.
oooo that's a little scary, let's try to remember that PFJax is what pushed us to do this?
There was a problem hiding this comment.
I think it's actually more of a CD-Dynamax problem.
PFJax requires jax>=0.7.2 for Python 3.14: https://github.com/mlysy/pfjax/blob/754f11c3937e25f8ed1045424a2af5fbad553cc3/pyproject.toml#L26C1-L33C2
This is (directly after) when jax started providing 3.14 wheels: https://docs.jax.dev/en/latest/changelog.html#jax-0-7-1-august-20-2025
But cd_dynamax has a pinned jax which is lower (0.6.2, I think? Something like that).
There was a problem hiding this comment.
Oof got it, let's remember THAT then, and maybe deal with it in CD Dynamax refactor
| assert log_weights.shape == (obs_times.shape[0], 64) | ||
|
|
||
|
|
||
| def test_discrete_pfjax_marginal_particle_filter_records_outputs() -> None: |
There was a problem hiding this comment.
Looks like this test and the previous can be combined with a param that chooses PFConfig or MarginalPFConfig (with everything else looking the same).
| @@ -0,0 +1,596 @@ | |||
| { | |||
There was a problem hiding this comment.
Obviously you'll probably clean-up/comment the notebook, but I would start by asking "what do we want a reader to gain from this notebook example?"
- how to use blackjax with our stuff?
- PF vs MarginalPF vs EKF for parameter inference in this setting?
- other things?
There was a problem hiding this comment.
I'm writing Blackjax wrappers in a separate PR, that should make this one nicer. I was thinking of it as a "paper reproduction" PR, of "hey look at this recent complicated method; it fits naturally in our paradigm in 5 lines of code, and changing a single line gets you one of X related inference methods." A nice story could be showing PF works okay, marginal PF works better, EKF is fast but biased, EnKF is fast and probably less biased if in a second problem we have Gaussian observations.
There was a problem hiding this comment.
Sounds great. Do we want it to be a "deep dive" in that case?
We may want a place for non-doc scripts/notebooks for when we do something quickly and don't want to "polish" it. Alternatively we can banish this practice and push ourselves to have nice notebooks (we've been pretty good about this so far).
There was a problem hiding this comment.
Right, I have it as a deep dive now (unless I misread and you were suggesting it be elsewhere). I would like to polish this -- it would be nice, I think, to have a "series" of these notebooks reproducing several relevant papers (e.g., ensemble MCMC).
| description = "Add your description here" | ||
| readme = "README.md" | ||
| requires-python = ">=3.12" | ||
| requires-python = ">=3.12,<3.14" |
There was a problem hiding this comment.
Was it PFJax or blackjax that created the need for <3.14? Let's try to remember this somehow so we aren't struggling when we need to update.
There was a problem hiding this comment.
As commented above, it's actually CD-Dynamax
In experimenting for #121, I found that Blackjax HMC has significantly better performance in some use cases for whatever reason (I suspect to do with CRN?). This PR adds support for a number of MCMC wrappers, including HMC, NUTS, MALA, and SGLD. Introduces `MCMCInference` API and a notebook to show these comparisons.
No description provided.