You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I found this type of page really helpful when extended GPJax in the past.
Explaining
Filter and Smoother interfaces, which are common throughout
build_filter is where the complexity comes in, as each method requires its own parts of a state space model
Why there is no explicit update/predict steps (textbook for Kalman) and filter_prepare is used instead (might be for Add "Conventions" page to docs #189)
And an explanation of how the parallelisation works
Below is a bit of a dump of some notes I made while going through the package which cover these three points. They might be helpful for us to copy and paste parts, so I'm putting them here. If not, you can stop reading.
Filter and Smoother
The two central inference objects are:
Filter -- a NamedTuple with three callables + a flag:
• init_prepare(model_inputs, key) -- creates the initial state (time 0)
• filter_prepare(model_inputs, key) -- converts model inputs at time t into a prepared state
• filter_combine(state_1, state_2) -- combines a previous state with a prepared state to produce the filtered state
• associative: bool -- if True, filter_combine is associative, enabling parallel filtering via jax.lax.associative_scan
Smoother -- same pattern but runs backwards in time:
• convert_filter_to_smoother_state -- converts the final filter state to seed the backward pass
• smoother_prepare(filter_state, model_inputs, key) -- prepares a state for the smoother
• smoother_combine(state_1, state_2) -- combines two smoother states
• associative: bool -- enables parallel smoothing
Three families of inference methods plug into this interface:
Discrete methods work with probability vectors and transition matrices
parallel (associative) Kalman filter
(Notation below is from Kevin Murphy Advanced Topics in ProbML. I added them to a general Kalman filtering note.)
The standard Kalman filter is sequential – each step depends on the previous filtered mean and covariance, giving $O(T)$ serial complexity. The parallel Kalman filter reformulates the predict + update step as an affine map on the filtered mean, enabling a parallel scan in $O(\log T)$.
Reformulation as an affine map
Substituting standard the time update (predict) into the measurement update, the filtered mean at time $t$ is:
The covariance update $\boldsymbol{\Sigma}_{t|t}$ does not depend on the mean – it only depends on the model parameters $(\mathbf{F}_t, \mathbf{Q}_t, \mathbf{H}_t, \mathbf{R}_t)$ and can be computed independently.
associative scan
The composition of two affine maps is itself an affine map:
where $e_{i:j}$ denotes the composed affine map from step $i$ to $j$. At each level, independent compositions run in parallel. Associativity guarantees that any grouping produces the correct result.
cuthbert has three functions that map to the maths above:
init_prepare: creates the initial element with $\mathbf{A}_0 = \mathbf{0}$, $\mathbf{b}_0 = \boldsymbol{\mu}_0$
filter_prepare: encodes one time step's parameters $(\mathbf{F}_t, \mathbf{Q}_t, \mathbf{H}_t, \mathbf{R}_t, \mathbf{y}_t)$ into the affine scan element $(\mathbf{A}_t, \mathbf{b}_t, \ldots)$ via associative_params_single
filter_combine: composes two scan elements via the associative filtering_operator
In parallel mode, all elements are prepared independently via jax.vmap(filter_prepare), then composed via jax.lax.associative_scan(filter_combine, ...). In sequential mode, the same filter_prepare and filter_combine are called inside a jax.lax.scan loop.
I found this type of page really helpful when extended GPJax in the past.
Explaining
FilterandSmootherinterfaces, which are common throughoutbuild_filteris where the complexity comes in, as each method requires its own parts of a state space modelfilter_prepareis used instead (might be for Add "Conventions" page to docs #189)Below is a bit of a dump of some notes I made while going through the package which cover these three points. They might be helpful for us to copy and paste parts, so I'm putting them here. If not, you can stop reading.
Filter and Smoother
The two central inference objects are:
Filter-- aNamedTuplewith three callables + a flag:•
init_prepare(model_inputs, key)-- creates the initial state (time 0)•
filter_prepare(model_inputs, key)-- converts model inputs at time t into a prepared state•
filter_combine(state_1, state_2)-- combines a previous state with a prepared state to produce the filtered state•
associative: bool -- if True, filter_combine is associative, enabling parallel filtering viajax.lax.associative_scanSmoother-- same pattern but runs backwards in time:•
convert_filter_to_smoother_state-- converts the final filter state to seed the backward pass•
smoother_prepare(filter_state, model_inputs, key)-- prepares a state for the smoother•
smoother_combine(state_1, state_2)-- combines two smoother states•
associative: bool -- enables parallel smoothingThree families of inference methods plug into this interface:
cuthbert.discrete- Hidden Markov Models (forward-backward / Baum-Welch)cuthbert.gaussian- Kalman filters/smoothers (standard, extended via Taylor linearisation, unscented/ensemble via moment transforms)cuthbert.smc- Sequential Monte Carlo / particle filtersbuild_filter
build_filteris different for each method (Kalman, EKF, particle filter, etc.) because each inference method requires different thingsgaussian.kalman.build_filter:get_init_params, get_dynamics_params, get_observation_params(returning matrices F, H, etc.)gaussian.taylor.build_filter:get_init_log_density, get_dynamics_log_density, get_observation_func(returning log density callables + linearisation points)gaussian.moments.build_filter: similar but returning mean/chol_cov callablessmc.particle_filter.build_filter: init_sample, propagate_sample, log_potential, plus tuning parameters (n_particles, resampling scheme, ESS threshold)parallel (associative) Kalman filter
(Notation below is from Kevin Murphy Advanced Topics in ProbML. I added them to a general Kalman filtering note.)
The standard Kalman filter is sequential – each step depends on the previous filtered mean and covariance, giving$O(T)$ serial complexity. The parallel Kalman filter reformulates the predict + update step as an affine map on the filtered mean, enabling a parallel scan in $O(\log T)$ .
Reformulation as an affine map
Substituting standard the time update (predict) into the measurement update, the filtered mean at time$t$ is:
This has the form of an affine map on the previous filtered mean:
where:
The covariance update$\boldsymbol{\Sigma}_{t|t}$ does not depend on the mean – it only depends on the model parameters $(\mathbf{F}_t, \mathbf{Q}_t, \mathbf{H}_t, \mathbf{R}_t)$ and can be computed independently.
associative scan
The composition of two affine maps is itself an affine map:
This composition is associative (it is function composition). To compute all$T$ filtered states, we need the prefix compositions:
The associative scan (parallel prefix scan) computes all of these in$O(\log T)$ parallel steps by composing pairs in a binary tree:
where$e_{i:j}$ denotes the composed affine map from step $i$ to $j$ . At each level, independent compositions run in parallel. Associativity guarantees that any grouping produces the correct result.
cuthbert implements this using
jax.lax.associative_sca(https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.associative_scan.html) withparallel=Trueflag:cuthberthas three functions that map to the maths above:init_prepare: creates the initial element withfilter_prepare: encodes one time step's parametersassociative_params_singlefilter_combine: composes two scan elements via the associativefiltering_operatorIn parallel mode, all elements are prepared independently via
jax.vmap(filter_prepare), then composed viajax.lax.associative_scan(filter_combine, ...). In sequential mode, the samefilter_prepareandfilter_combineare called inside ajax.lax.scanloop.