Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
596 changes: 596 additions & 0 deletions docs/deep_dives/particle_hmc.ipynb

Large diffs are not rendered by default.

24 changes: 22 additions & 2 deletions dynestyx/inference/filter_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

ResamplingBaseMethod = Literal["systematic", "multinomial", "stratified"]
ResamplingDifferentiableMethod = Literal["stop_gradient", "straight_through", "soft"]
FilterSource = Literal["cuthbert", "cd_dynamax", "dynestyx"]
FilterSource = Literal["cuthbert", "cd_dynamax", "dynestyx", "pfjax"]
FilterEmissionOrder = Literal["zeroth", "first", "second"]
FilterStateOrder = Literal["zeroth", "first", "second"]

Expand Down Expand Up @@ -226,7 +226,7 @@ class PFConfig(BaseFilterConfig):
ess_threshold_ratio (float): Resampling fires when the effective
sample size drops below `ess_threshold_ratio * n_particles`.
`1.0` → always resample; `0.0` → never. Defaults to `0.7`.
filter_source (FilterSource): Backend. Defaults to `"cuthbert"`, which is currently the only available implementation.
filter_source (FilterSource): Backend. Defaults to `"cuthbert"`.

??? note "Algorithm Reference"
At each step, particles are propagated through the transition and
Expand Down Expand Up @@ -264,6 +264,25 @@ class PFConfig(BaseFilterConfig):
filter_source: FilterSource = "cuthbert"


@dataclasses.dataclass
class MarginalPFConfig(PFConfig):
r"""Marginal Particle Filter (MPF / Rao-Blackwellized PF) for discrete-time models.

This configuration targets the PFJax marginal particle filter (`particle_filter_rb`),
which computes particle weights via mixture-based corrections and can reduce variance
relative to standard bootstrap PF in some settings.

Attributes:
stop_proposal_gradient (bool): Whether to stop gradients through proposal
sampling and proposal log-density terms (PFJax `stop_proposal_gradient`).
Defaults to `True`.
filter_source (FilterSource): Backend. Defaults to `"pfjax"`.
"""

stop_proposal_gradient: bool = True
filter_source: FilterSource = "pfjax"


@dataclasses.dataclass
class EKFConfig(BaseFilterConfig):
"""Extended Kalman Filter (EKF) for discrete-time models.
Expand Down Expand Up @@ -602,6 +621,7 @@ class ContinuousTimeUKFConfig(UKFConfig, ContinuousTimeConfig):
DiscreteTimeConfigs: tuple[type, ...] = (
EnKFConfig,
PFConfig,
MarginalPFConfig,
EKFConfig,
KFConfig,
UKFConfig,
Expand Down
21 changes: 19 additions & 2 deletions dynestyx/inference/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
HMMConfig,
HMMConfigs,
KFConfig,
MarginalPFConfig,
PFConfig,
PFResamplingConfig,
UKFConfig,
Expand All @@ -33,6 +34,9 @@
from dynestyx.inference.integrations.cuthbert.discrete import (
run_discrete_filter as run_cuthbert_discrete,
)
from dynestyx.inference.integrations.pfjax.discrete import (
run_discrete_filter as run_pfjax_discrete,
)
from dynestyx.models import DynamicalModel
from dynestyx.types import FunctionOfTime

Expand Down Expand Up @@ -250,10 +254,10 @@ def _filter_discrete_time(
ctrl_values=None,
**kwargs,
) -> None:
"""Discrete-time marginal likelihood via cuthbert or cd-dynamax.
"""Discrete-time marginal likelihood via cuthbert, pfjax, or cd-dynamax.

Filter type inferred from config class: KFConfig, EKFConfig, UKFConfig (cd-dynamax)
or EKFConfig (cuthbert), PFConfig (cuthbert).
or EKFConfig (cuthbert), PFConfig (cuthbert or pfjax).

Args:
name: Name of the factor.
Expand Down Expand Up @@ -288,6 +292,18 @@ def _filter_discrete_time(
ctrl_values=ctrl_values,
**kwargs,
)
elif filter_config.filter_source == "pfjax":
run_pfjax_discrete(
name,
dynamics,
filter_config,
key=key,
obs_times=obs_times,
obs_values=obs_values,
ctrl_times=ctrl_times,
ctrl_values=ctrl_values,
**kwargs,
)
else:
raise ValueError(f"Unknown filter source: {filter_config.filter_source}")

Expand Down Expand Up @@ -342,6 +358,7 @@ def _filter_continuous_time(
"HMMConfig",
"HMMConfigs",
"KFConfig",
"MarginalPFConfig",
"PFConfig",
"PFResamplingConfig",
"UKFConfig",
Expand Down
3 changes: 3 additions & 0 deletions dynestyx/inference/integrations/pfjax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from dynestyx.inference.integrations.pfjax.discrete import run_discrete_filter

__all__ = ["run_discrete_filter"]
Loading
Loading