From 0d4a2c3ef2303e632c9350d8c9d4dfaec86d1ed1 Mon Sep 17 00:00:00 2001 From: Sam Duffield Date: Tue, 10 Mar 2026 16:12:02 +0000 Subject: [PATCH] Add SGLRW --- docs/api/index.md | 2 + docs/api/sgmcmc/sglrw.md | 3 + mkdocs.yml | 1 + posteriors/sgmcmc/__init__.py | 1 + posteriors/sgmcmc/sglrw.py | 158 ++++++++++++++++++++++++++++++++++ tests/sgmcmc/test_sglrw.py | 39 +++++++++ 6 files changed, 204 insertions(+) create mode 100644 docs/api/sgmcmc/sglrw.md create mode 100644 posteriors/sgmcmc/sglrw.py create mode 100644 tests/sgmcmc/test_sglrw.py diff --git a/docs/api/index.md b/docs/api/index.md index 4cfeb7b..edfc939 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -36,6 +36,8 @@ thermostat (SGNHT) algorithm from [Ding et al, 2014](https://proceedings.neurips (SGHMC with adaptive friction coefficient). - [`sgmcmc.baoa`](sgmcmc/baoa.md) implements the BAOA integrator for SGHMC from [Leimkuhler and Matthews, 2015 - p271](https://link.springer.com/book/10.1007/978-3-319-16375-8). +- [`sgmcmc.sglrw`](sgmcmc/sglrw.md) implements the stochastic gradient lattice random +walk (SGLRW) algorithm from [Mensch et al, 2026](https://arxiv.org/abs/2602.15925). For an overview and unifying framework for SGMCMC methods, see [Ma et al, 2015](https://arxiv.org/abs/1506.04696). diff --git a/docs/api/sgmcmc/sglrw.md b/docs/api/sgmcmc/sglrw.md new file mode 100644 index 0000000..8e43b49 --- /dev/null +++ b/docs/api/sgmcmc/sglrw.md @@ -0,0 +1,3 @@ +# SGLRW + +::: posteriors.sgmcmc.sglrw \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index c600721..3b7e901 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -82,6 +82,7 @@ nav: - api/sgmcmc/sghmc.md - api/sgmcmc/sgnht.md - api/sgmcmc/baoa.md + - api/sgmcmc/sglrw.md - VI: - Dense: api/vi/dense.md - Diag: api/vi/diag.md diff --git a/posteriors/sgmcmc/__init__.py b/posteriors/sgmcmc/__init__.py index ffc9c6a..491c38a 100644 --- a/posteriors/sgmcmc/__init__.py +++ b/posteriors/sgmcmc/__init__.py @@ -2,3 +2,4 @@ from posteriors.sgmcmc import sghmc from posteriors.sgmcmc import sgnht from posteriors.sgmcmc import baoa +from posteriors.sgmcmc import sglrw diff --git a/posteriors/sgmcmc/sglrw.py b/posteriors/sgmcmc/sglrw.py new file mode 100644 index 0000000..7555bf5 --- /dev/null +++ b/posteriors/sgmcmc/sglrw.py @@ -0,0 +1,158 @@ +# type: ignore posteriors is not typed +from typing import Any +from functools import partial +import torch +from torch import Tensor +from torch.func import grad_and_value +from tensordict import TensorClass + +from posteriors.types import TensorTree, Transform, LogProbFn, Schedule +from posteriors.tree_utils import flexi_tree_map, tree_insert_ +from posteriors.utils import CatchAuxError + + +def build( + log_posterior: LogProbFn, + lr: float | Schedule, + temperature: float | Schedule = 1.0, +) -> Transform: + """Builds SGLRW transform - Stochastic Gradient Lattice Random Walk. + + Algorithm from [Mensch et al, 2026](https://arxiv.org/abs/2602.15925) + adapted from [Duffield et al, 2025](https://arxiv.org/abs/2508.20883): + $$ + θ_{t+1} = θ_t + δx Δ(θₜ, t) + $$ + where $δx = √(lr * 2 * T)$ is a spatial stepsize and $Δ(θₜ, t)$ is a random + binary valued vector defined in the paper. + + Targets $p_T(θ) \\propto \\exp( \\log p(θ) / T)$ with temperature $T$, + as it discretizes the overdamped Langevin SDE: + $$ + dθ = ∇ log p_T(θ) dt + √(2 T) dW + $$ + + The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md) + to ensure robust scaling for a large amount of data and variable batch size. + + Args: + log_posterior: Function that takes parameters and input batch and + returns the log posterior value (which can be unnormalised) + as well as auxiliary information, e.g. from the model call. + lr: Learning rate, + scalar or schedule (callable taking step index, returning scalar). + temperature: Temperature of the sampling distribution. + Scalar or schedule (callable taking step index, returning scalar). + + Returns: + SGLRW transform (posteriors.types.Transform instance). + """ + update_fn = partial( + update, + log_posterior=log_posterior, + lr=lr, + temperature=temperature, + ) + return Transform(init, update_fn) + + +class SGLRWState(TensorClass["frozen"]): + """State encoding params for SG-LRW (binary). + + Attributes: + params: Parameters. + log_posterior: Last log posterior evaluation. + step: Current step count. + """ + + params: TensorTree + log_posterior: Tensor = torch.tensor(torch.nan) + step: Tensor = torch.tensor(0) + + +def init(params: TensorTree) -> SGLRWState: + """Initialise SG-LRW.""" + return SGLRWState(params) + + +def update( + state: SGLRWState, + batch: Any, + log_posterior: LogProbFn, + lr: float | Schedule, + temperature: float | Schedule = 1.0, + inplace: bool = False, +) -> tuple[SGLRWState, TensorTree]: + with torch.no_grad(), CatchAuxError(): + grads, (log_post, aux) = grad_and_value(log_posterior, has_aux=True)( + state.params, batch + ) + + # Resolve schedules + lr_val = lr(state.step) if callable(lr) else lr + T_val = temperature(state.step) if callable(temperature) else temperature + lr_val = torch.as_tensor( + lr_val, dtype=state.params.dtype, device=state.params.device + ) + T_val = torch.as_tensor(T_val, dtype=state.params.dtype, device=state.params.device) + + # Spatial stepsize to make update binary + diffusion_val = torch.sqrt(2.0 * T_val) + delta_x = torch.sqrt(lr_val) * diffusion_val + + # Per-parameter binary LRW transform + def transform_params(p, g): + p_plus = ternary_probs(g, diffusion_val, lr_val, delta_x)[:, 2] + + u = torch.rand_like(p_plus) + step_sign = torch.where( + u < p_plus, torch.ones_like(p_plus), -torch.ones_like(p_plus) + ) + step = delta_x * step_sign + return p + step + + params = flexi_tree_map(transform_params, state.params, grads, inplace=inplace) + + if inplace: + tree_insert_(state.log_posterior, log_post.detach()) + tree_insert_(state.step, state.step + 1) + return state, aux + return SGLRWState(params, log_post.detach(), state.step + 1), aux + + +def ternary_probs( + drift_val: Tensor, + diffusion_val: Tensor, + stepsize: Tensor, + delta_x: Tensor, +) -> Tensor: + """ + Generate the probabilities for the ternary update + from the discretization parameters. + + Args: + drift_val: Evaluation of the Drift function. + diffusion_val: Evaluation of the Diffusion function. + stepsize: Temporal stepsize value. + delta_x: Spatial stepsize value. + + Returns: + Update probabilities as a tensor, with last axis being [p_minus, p_zero, p_plus]. + """ + desired_mean = stepsize * drift_val + desired_var = stepsize * diffusion_val**2 + scaled_mean = desired_mean / delta_x + scaled_var = desired_var / delta_x**2 + + # Ensure p_minus + p_plus <= 1 + scaled_var = torch.clamp(scaled_var, 0.0, 1.0) + + # Ensure positive probs + scaled_mean = torch.clamp(scaled_mean, -scaled_var, scaled_var) + + # Clip probs for numerical stability + p_plus = torch.clamp(0.5 * (scaled_var + scaled_mean), 0.0, 1.0) + p_minus = torch.clamp(0.5 * (scaled_var - scaled_mean), 0.0, 1.0) + p_zero = torch.clamp(1 - p_plus - p_minus, 0.0, 1.0) + + return torch.stack([p_minus, p_zero, p_plus], dim=-1) diff --git a/tests/sgmcmc/test_sglrw.py b/tests/sgmcmc/test_sglrw.py new file mode 100644 index 0000000..93c3a0c --- /dev/null +++ b/tests/sgmcmc/test_sglrw.py @@ -0,0 +1,39 @@ +from functools import partial +import torch +from posteriors.sgmcmc import sglrw +from tests.scenarios import get_multivariate_normal_log_prob +from tests.utils import verify_inplace_update +from tests.sgmcmc.utils import run_test_sgmcmc_gaussian + + +def test_sglrw(): + torch.manual_seed(42) + + # Set inference parameters + lr = 1e-2 + + # Run MCMC test on Gaussian + run_test_sgmcmc_gaussian( + partial(sglrw.build, lr=lr), + ) + + +def test_sglrw_inplace_step(): + torch.manual_seed(42) + + # Load log posterior + dim = 5 + log_prob, _ = get_multivariate_normal_log_prob(dim) + + # Set inference parameters + def lr(step): + return 1e-2 * (step + 1) ** -0.33 + + # Build transform + transform = sglrw.build(log_prob, lr) + + # Initialise + params = torch.randn(dim) + + # Verify inplace update + verify_inplace_update(transform, params, None)