Skip to content

Implement MCMC/blackjax Wrappers#125

Merged
mattlevine22 merged 18 commits intomainfrom
dw-blackjax-convenience
Mar 10, 2026
Merged

Implement MCMC/blackjax Wrappers#125
mattlevine22 merged 18 commits intomainfrom
dw-blackjax-convenience

Conversation

@DanWaxman
Copy link
Collaborator

@DanWaxman DanWaxman commented Mar 2, 2026

In experimenting for #121, I found that Blackjax HMC has significantly better performance in some use cases for whatever reason (I suspect to do with CRN?). This PR adds support for a number of MCMC wrappers, including HMC, NUTS, and SGLD.

@DanWaxman DanWaxman marked this pull request as ready for review March 5, 2026 02:36
@DanWaxman DanWaxman requested a review from mattlevine22 March 5, 2026 02:37
Copy link
Collaborator

@mattlevine22 mattlevine22 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty solid.

  1. For PRs that create new API, we should probably include example usage in the docs:
  • Add explicit usage example in the docstring so it appears in the docs
  • It would also be "friendly" to show a chunk of numpyro demo-code that it can replace. For example, a code block where we call NUTS/MCMC inside with the numpyro way and get back a posterior...and then the same with FilterBasedMCMC.
  • Can we always expect the returned posterior to have the same attributes in both cases? That is, is there ever a time I can't use this
  • How will it interact with upcoming predict_times (i.e., supporting the Filter + Simulator pattern)?
  1. Since the purpose is "improved performance", I'd also suggest a deep dive that compares blackjax/numpyro (not exhaustively, but illustratively at least).

  2. I'm a bit un-easy about redundant user-facing APIs (with Filter -> NUTS vs FilterMCMC(NUTS) ).

  • Seems like we should have a SimulatorBasedMCMC or even LatentDynamicsMCMC(type= Filter | Simulator | Smoother)
  • At that point, should we also have LatentDynamicsSVI(type=Filter | Simulator | Smoother)?
  1. Why did the linter modify cd_dynamax and Cuthbert integrations?
  • specifically it created/deleted some spaces in those files
  • if this was a mistake, let's remove. otherwise, hopefully the linter stays happy going forward.

@DanWaxman
Copy link
Collaborator Author

  1. For PRs that create new API, we should probably include example usage in the docs:

Agreed. Part of my inhibition here was how much of the existing API in the docs we'd want to replace (similar to (3) below).

  1. Since the purpose is "improved performance", I'd also suggest a deep dive that compares blackjax/numpyro (not exhaustively, but illustratively at least).

Sure.

  1. I'm a bit un-easy about redundant user-facing APIs (with Filter -> NUTS vs FilterMCMC(NUTS) ).

I'm not sure I follow on this specific point. This is mostly a simplification of using the interpretations directly (that I think should suffice for most users).

Seems like we should have a SimulatorBasedMCMC or even LatentDynamicsMCMC(type= Filter | Simulator | Smoother)

Yeah, I thought about that a bit... Seems reasonable, though I'd like to keep the corresponding aliases (FilterBasedMCMC, SimulatorBasedMCMC) as well; they're a bit more explicit, even if they just call LatentDynamicsMCMC.

At that point, should we also have LatentDynamicsSVI(type=Filter | Simulator | Smoother)?

Maybe...

Why did the linter modify cd_dynamax and Cuthbert integrations?

I think because the __init__.py changed (I was having troubles when building documentation).

@mattlevine22
Copy link
Collaborator

To clarify my concern about redundant API....I think we should have 1 recommended way of doing this and keep this consistent in the docs.

  • So either a) we switch all the tutorials/notebooks to use FilterBasedMCMC or b) this becomes an "advanced feature" we can use for better inference (via direct blackjax integration).
  • I'm okay with switching everything to FilterBasedMCMC, but we should first verify for ourselves that it fits clearly into the broader API (predict times, simulators, etc.).

@DanWaxman
Copy link
Collaborator Author

Will get around to this next week most likely, but to recall before I forget: the decision was:

  1. We should more uniformly not use the data_conditioned_model pattern. It suffices to do with Filter(): mcmc.run(). This is simpler.
  2. We get rid of the FilterBasedMCMC. This should instead be replaced by a simpler helper class for blackjax MCMC that mirrors the NumPyro interface.
  • Potentially, this should mirror what is there now, where we allow for dispatching to mcmc_source, and have HMC, NUTS, SGMCMC, etc. under a single uniform hood.

@mattlevine22
Copy link
Collaborator

Agreed. AND we should have a tutorial page specifically about the different legal ways to use 'with Filter' (you can build a conditioned model this way OR you can just wrap things as you go).

@DanWaxman
Copy link
Collaborator Author

I've updated the notebooks to move away from the data_conditioned_model pattern. The FilterBasedMCMC is now MCMCInference, with a similar interface as before. Also included MALA.

There are probably a few documentation-related issues that remain from refactoring, but should be ready to review otherwise. I'd prefer to defer a more detailed notebook comparing different methods until we get a better empirical feel for things/figure out why NumPyro HMC had issues in #121.

@DanWaxman DanWaxman requested a review from mattlevine22 March 10, 2026 05:32
@mattlevine22
Copy link
Collaborator

FYI I had to do uv sync --reinstall-package cuthbertlib --reinstall-package cuthbert to run things.

Copy link
Collaborator

@mattlevine22 mattlevine22 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good!

@mattlevine22 mattlevine22 merged commit f7feadc into main Mar 10, 2026
2 checks passed
@DanWaxman DanWaxman deleted the dw-blackjax-convenience branch March 23, 2026 23:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants