Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ thermostat (SGNHT) algorithm from [Ding et al, 2014](https://proceedings.neurips
(SGHMC with adaptive friction coefficient).
- [`sgmcmc.baoa`](sgmcmc/baoa.md) implements the BAOA integrator for SGHMC
from [Leimkuhler and Matthews, 2015 - p271](https://link.springer.com/book/10.1007/978-3-319-16375-8).
- [`sgmcmc.sglrw`](sgmcmc/sglrw.md) implements the stochastic gradient lattice random
walk (SGLRW) algorithm from [Mensch et al, 2026](https://arxiv.org/abs/2602.15925).

For an overview and unifying framework for SGMCMC methods, see [Ma et al, 2015](https://arxiv.org/abs/1506.04696).

Expand Down
3 changes: 3 additions & 0 deletions docs/api/sgmcmc/sglrw.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SGLRW

::: posteriors.sgmcmc.sglrw
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ nav:
- api/sgmcmc/sghmc.md
- api/sgmcmc/sgnht.md
- api/sgmcmc/baoa.md
- api/sgmcmc/sglrw.md
- VI:
- Dense: api/vi/dense.md
- Diag: api/vi/diag.md
Expand Down
1 change: 1 addition & 0 deletions posteriors/sgmcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from posteriors.sgmcmc import sghmc
from posteriors.sgmcmc import sgnht
from posteriors.sgmcmc import baoa
from posteriors.sgmcmc import sglrw
158 changes: 158 additions & 0 deletions posteriors/sgmcmc/sglrw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# type: ignore posteriors is not typed
from typing import Any
from functools import partial
import torch
from torch import Tensor
from torch.func import grad_and_value
from tensordict import TensorClass

from posteriors.types import TensorTree, Transform, LogProbFn, Schedule
from posteriors.tree_utils import flexi_tree_map, tree_insert_
from posteriors.utils import CatchAuxError


def build(
log_posterior: LogProbFn,
lr: float | Schedule,
temperature: float | Schedule = 1.0,
) -> Transform:
"""Builds SGLRW transform - Stochastic Gradient Lattice Random Walk.

Algorithm from [Mensch et al, 2026](https://arxiv.org/abs/2602.15925)
adapted from [Duffield et al, 2025](https://arxiv.org/abs/2508.20883):
$$
θ_{t+1} = θ_t + δx Δ(θₜ, t)
$$
where $δx = √(lr * 2 * T)$ is a spatial stepsize and $Δ(θₜ, t)$ is a random
binary valued vector defined in the paper.

Targets $p_T(θ) \\propto \\exp( \\log p(θ) / T)$ with temperature $T$,
as it discretizes the overdamped Langevin SDE:
$$
dθ = ∇ log p_T(θ) dt + √(2 T) dW
$$

The log posterior and temperature are recommended to be [constructed in tandem](../../log_posteriors.md)
to ensure robust scaling for a large amount of data and variable batch size.

Args:
log_posterior: Function that takes parameters and input batch and
returns the log posterior value (which can be unnormalised)
as well as auxiliary information, e.g. from the model call.
lr: Learning rate,
scalar or schedule (callable taking step index, returning scalar).
temperature: Temperature of the sampling distribution.
Scalar or schedule (callable taking step index, returning scalar).

Returns:
SGLRW transform (posteriors.types.Transform instance).
"""
update_fn = partial(
update,
log_posterior=log_posterior,
lr=lr,
temperature=temperature,
)
return Transform(init, update_fn)


class SGLRWState(TensorClass["frozen"]):
"""State encoding params for SG-LRW (binary).

Attributes:
params: Parameters.
log_posterior: Last log posterior evaluation.
step: Current step count.
"""

params: TensorTree
log_posterior: Tensor = torch.tensor(torch.nan)
step: Tensor = torch.tensor(0)


def init(params: TensorTree) -> SGLRWState:
"""Initialise SG-LRW."""
return SGLRWState(params)


def update(
state: SGLRWState,
batch: Any,
log_posterior: LogProbFn,
lr: float | Schedule,
temperature: float | Schedule = 1.0,
inplace: bool = False,
) -> tuple[SGLRWState, TensorTree]:
with torch.no_grad(), CatchAuxError():
grads, (log_post, aux) = grad_and_value(log_posterior, has_aux=True)(
state.params, batch
)

# Resolve schedules
lr_val = lr(state.step) if callable(lr) else lr
T_val = temperature(state.step) if callable(temperature) else temperature
lr_val = torch.as_tensor(
lr_val, dtype=state.params.dtype, device=state.params.device
)
T_val = torch.as_tensor(T_val, dtype=state.params.dtype, device=state.params.device)

# Spatial stepsize to make update binary
diffusion_val = torch.sqrt(2.0 * T_val)
delta_x = torch.sqrt(lr_val) * diffusion_val

# Per-parameter binary LRW transform
def transform_params(p, g):
p_plus = ternary_probs(g, diffusion_val, lr_val, delta_x)[:, 2]

u = torch.rand_like(p_plus)
step_sign = torch.where(
u < p_plus, torch.ones_like(p_plus), -torch.ones_like(p_plus)
)
step = delta_x * step_sign
return p + step

params = flexi_tree_map(transform_params, state.params, grads, inplace=inplace)

if inplace:
tree_insert_(state.log_posterior, log_post.detach())
tree_insert_(state.step, state.step + 1)
return state, aux
return SGLRWState(params, log_post.detach(), state.step + 1), aux


def ternary_probs(
drift_val: Tensor,
diffusion_val: Tensor,
stepsize: Tensor,
delta_x: Tensor,
) -> Tensor:
"""
Generate the probabilities for the ternary update
from the discretization parameters.

Args:
drift_val: Evaluation of the Drift function.
diffusion_val: Evaluation of the Diffusion function.
stepsize: Temporal stepsize value.
delta_x: Spatial stepsize value.

Returns:
Update probabilities as a tensor, with last axis being [p_minus, p_zero, p_plus].
"""
desired_mean = stepsize * drift_val
desired_var = stepsize * diffusion_val**2
scaled_mean = desired_mean / delta_x
scaled_var = desired_var / delta_x**2

# Ensure p_minus + p_plus <= 1
scaled_var = torch.clamp(scaled_var, 0.0, 1.0)

# Ensure positive probs
scaled_mean = torch.clamp(scaled_mean, -scaled_var, scaled_var)

# Clip probs for numerical stability
p_plus = torch.clamp(0.5 * (scaled_var + scaled_mean), 0.0, 1.0)
p_minus = torch.clamp(0.5 * (scaled_var - scaled_mean), 0.0, 1.0)
p_zero = torch.clamp(1 - p_plus - p_minus, 0.0, 1.0)

return torch.stack([p_minus, p_zero, p_plus], dim=-1)
39 changes: 39 additions & 0 deletions tests/sgmcmc/test_sglrw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from functools import partial
import torch
from posteriors.sgmcmc import sglrw
from tests.scenarios import get_multivariate_normal_log_prob
from tests.utils import verify_inplace_update
from tests.sgmcmc.utils import run_test_sgmcmc_gaussian


def test_sglrw():
torch.manual_seed(42)

# Set inference parameters
lr = 1e-2

# Run MCMC test on Gaussian
run_test_sgmcmc_gaussian(
partial(sglrw.build, lr=lr),
)


def test_sglrw_inplace_step():
torch.manual_seed(42)

# Load log posterior
dim = 5
log_prob, _ = get_multivariate_normal_log_prob(dim)

# Set inference parameters
def lr(step):
return 1e-2 * (step + 1) ** -0.33

# Build transform
transform = sglrw.build(log_prob, lr)

# Initialise
params = torch.randn(dim)

# Verify inplace update
verify_inplace_update(transform, params, None)