Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 58 additions & 13 deletions simuk/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
except ImportError:
pass

import inspect
from collections.abc import Mapping

import numpy as np
Expand Down Expand Up @@ -102,7 +103,7 @@ def __init__(
self.numpyro_model = model
self.model = self.numpyro_model.model
self.run_simulations = self._run_simulations_numpyro
self.data_dir = data_dir
self.data_dir = data_dir if data_dir is not None else {}
else:
raise ValueError(
"model should be one of pymc.Model, bambi.Model, or numpyro.infer.mcmc.MCMCKernel"
Expand All @@ -123,6 +124,7 @@ def __init__(
self._extract_variable_names()
self.simulations = {name: [] for name in self.var_names}
self._simulations_complete = 0

if simulator is not None and not callable(simulator):
raise ValueError("simulator should be a function or None")
if simulator is not None and self.observed_vars:
Expand All @@ -138,11 +140,26 @@ def __init__(
"predictive samples. Either change the model or specify a simulator "
"with the `simulator` argument."
)

if simulator is None and self.engine == "numpyro":
if not self.observed_model_vars:
raise ValueError(
"There are no observed variables we can condition on, and NumPyro "
"will not generate prior predictive samples. Either change the model "
"or specify a simulator with the `simulator` argument."
)
missing = [name for name in self.observed_model_vars if name not in self.data_dir]
if missing:
raise ValueError(
"The following model parameters are missing from data_dir: "
+ ", ".join(sorted(missing))
)
self.simulator = simulator

def _extract_variable_names(self):
"""Extract observed and free variables from the model."""
if self.engine == "numpyro":
self.model_params = set(inspect.signature(self.model).parameters.keys())
with trace() as tr:
with seed(rng_seed=int(self._seeds[0])):
self.numpyro_model.model(**self.data_dir)
Expand All @@ -156,6 +173,14 @@ def _extract_variable_names(self):
for name, site in tr.items()
if site["type"] == "sample" and site.get("is_observed", False)
]
# Observed model variables are those that are marked as observed
# and are also model function parameters in order to be able to condition on them.
# For instance, this is used to filter out factor variables that are marked as observed
# but cannot be conditioned on.
self.observed_model_vars = [
name for name in self.observed_vars if name in self.model_params
]

else:
self.observed_vars = [obs.name for obs in self.model.observed_RVs]
self.var_names = [v.name for v in self.model.free_RVs]
Expand Down Expand Up @@ -200,7 +225,11 @@ def _get_prior_predictive_samples(self):
def _get_prior_predictive_samples_numpyro(self):
"""Generate samples to use for the simulations using numpyro."""
predictive = Predictive(self.model, num_samples=self.num_simulations)
free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars}
free_vars_data = {
k: v
for k, v in self.data_dir.items()
if k not in self.observed_vars and k in self.model_params
}
samples = predictive(jax.random.PRNGKey(self._seeds[0]), **free_vars_data)
prior = {k: v for k, v in samples.items() if k not in self.observed_vars}
if self.simulator:
Expand All @@ -211,15 +240,16 @@ def _get_prior_predictive_samples_numpyro(self):
results.append(self.simulator(**params))
prior_pred = {key: [result[key] for result in results] for key in results[0]}
else:
prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars}
prior_pred = {k: v for k, v in samples.items() if k in self.observed_model_vars}
return prior, prior_pred

def _get_posterior_samples(self, prior_predictive_draw):
"""Generate posterior samples conditioned to a prior predictive sample."""
new_model = pm.observe(self.model, prior_predictive_draw)
with new_model:
check = pm.sample(
**self.sample_kwargs, random_seed=self._seeds[self._simulations_complete]
**self.sample_kwargs,
random_seed=self._seeds[self._simulations_complete],
)

posterior = extract(check, group="posterior", keep_dataset=True)
Expand All @@ -229,13 +259,16 @@ def _get_posterior_samples_numpyro(self, prior_predictive_draw):
"""Generate posterior samples using numpyro conditioned to a prior predictive sample."""
mcmc = MCMC(self.numpyro_model, **self.sample_kwargs)
rng_seed = jax.random.PRNGKey(self._seeds[self._simulations_complete])
# If using a custom simulator, some variables present in `prior_predictive_draw`
# might be missing from self.observed_vars.
# TODO: Not sure if the union is redundant here and perhaps prior_predictive_draw.keys()
# could be sufficient.
extended_observed_vars = set(prior_predictive_draw.keys()).union(self.observed_vars)
free_vars_data = {k: v for k, v in self.data_dir.items() if k not in extended_observed_vars}
mcmc.run(rng_seed, **free_vars_data, **prior_predictive_draw)

free_vars_data = {
k: v
for k, v in self.data_dir.items()
if k not in self.observed_model_vars and k in self.model_params
}
prior_predictive_args = {
k: v for k, v in prior_predictive_draw.items() if k in self.observed_model_vars
}
mcmc.run(rng_seed, **free_vars_data, **prior_predictive_args)
return from_numpyro(mcmc)["posterior"]

def _convert_to_datatree(self):
Expand Down Expand Up @@ -271,7 +304,10 @@ def run_simulations(self):
if self.simulator is not None:
self.observed_vars = list(prior_pred.data_vars)
self.var_names = list(
filter(lambda var_name: var_name not in self.observed_vars, list(prior.data_vars))
filter(
lambda var_name: var_name not in self.observed_vars,
list(prior.data_vars),
)
)
self.simulations = {var_name: [] for var_name in self.var_names}

Expand Down Expand Up @@ -309,8 +345,17 @@ def _run_simulations_numpyro(self):
# if simulator is used, ignore observed_vars
if self.simulator is not None:
self.observed_vars = list(prior_pred.keys())
self.observed_model_vars = [
name for name in self.observed_vars if name in self.model_params
]
if not self.observed_model_vars:
raise ValueError("No observed variables to condition on")

self.var_names = list(
filter(lambda var_name: var_name not in self.observed_vars, list(prior.keys()))
filter(
lambda var_name: var_name not in self.observed_vars,
list(prior.keys()),
)
)
self.simulations = {var_name: [] for var_name in self.var_names}
try:
Expand Down
40 changes: 40 additions & 0 deletions simuk/tests/test_sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def eight_schools_cauchy_prior_no_observed(J, sigma, y=None):
numpyro.factor("custom_likelihood", log_likelihood)


def numpyro_model_double_observed(y1=jnp.array([0.0]), y2=jnp.array([0.0])):
numpyro.sample("y1", dist.Normal(0, 1), obs=y1)
numpyro.sample("y2", dist.Normal(0, 1), obs=y2)


# Custom simulator functions
def centered_eight_simulator(theta, seed, **kwargs):
rng = np.random.default_rng(seed)
Expand Down Expand Up @@ -175,3 +180,38 @@ def test_sbc_numpyro_fail_no_observed_variable():
sample_kwargs={"num_warmup": 50, "num_samples": 25},
)
sbc.run_simulations()


def test_sbc_numpyro_missing_observed_data():
with pytest.raises(ValueError, match="missing from data_dir"):
simuk.SBC(
NUTS(numpyro_model_double_observed),
data_dir={"y2": [0.0]},
num_simulations=10,
sample_kwargs={"num_warmup": 10, "num_samples": 5},
)


def test_sbc_numpyro_empty_observed_data():
with pytest.raises(ValueError, match="no observed variables"):
simuk.SBC(
NUTS(eight_schools_cauchy_prior),
data_dir={"J": 8, "sigma": sigma},
num_simulations=10,
sample_kwargs={"num_warmup": 10, "num_samples": 5},
)


def test_sbc_numpyro_simulator_no_conditionable_observed():
def bad_simulator(**kwargs):
return {"not_a_param": np.array([0.0])}

sbc = simuk.SBC(
NUTS(eight_schools_cauchy_prior),
data_dir={"J": 8, "sigma": sigma, "y": data},
num_simulations=5,
sample_kwargs={"num_warmup": 10, "num_samples": 5},
simulator=bad_simulator,
)
with pytest.raises(ValueError, match="No observed variables to condition on"):
sbc.run_simulations()
Loading