diff --git a/simuk/sbc.py b/simuk/sbc.py index 9a9633d..78d21b6 100644 --- a/simuk/sbc.py +++ b/simuk/sbc.py @@ -16,6 +16,7 @@ except ImportError: pass +import inspect from collections.abc import Mapping import numpy as np @@ -102,7 +103,7 @@ def __init__( self.numpyro_model = model self.model = self.numpyro_model.model self.run_simulations = self._run_simulations_numpyro - self.data_dir = data_dir + self.data_dir = data_dir if data_dir is not None else {} else: raise ValueError( "model should be one of pymc.Model, bambi.Model, or numpyro.infer.mcmc.MCMCKernel" @@ -123,6 +124,7 @@ def __init__( self._extract_variable_names() self.simulations = {name: [] for name in self.var_names} self._simulations_complete = 0 + if simulator is not None and not callable(simulator): raise ValueError("simulator should be a function or None") if simulator is not None and self.observed_vars: @@ -138,11 +140,26 @@ def __init__( "predictive samples. Either change the model or specify a simulator " "with the `simulator` argument." ) + + if simulator is None and self.engine == "numpyro": + if not self.observed_model_vars: + raise ValueError( + "There are no observed variables we can condition on, and NumPyro " + "will not generate prior predictive samples. Either change the model " + "or specify a simulator with the `simulator` argument." + ) + missing = [name for name in self.observed_model_vars if name not in self.data_dir] + if missing: + raise ValueError( + "The following model parameters are missing from data_dir: " + + ", ".join(sorted(missing)) + ) self.simulator = simulator def _extract_variable_names(self): """Extract observed and free variables from the model.""" if self.engine == "numpyro": + self.model_params = set(inspect.signature(self.model).parameters.keys()) with trace() as tr: with seed(rng_seed=int(self._seeds[0])): self.numpyro_model.model(**self.data_dir) @@ -156,6 +173,14 @@ def _extract_variable_names(self): for name, site in tr.items() if site["type"] == "sample" and site.get("is_observed", False) ] + # Observed model variables are those that are marked as observed + # and are also model function parameters in order to be able to condition on them. + # For instance, this is used to filter out factor variables that are marked as observed + # but cannot be conditioned on. + self.observed_model_vars = [ + name for name in self.observed_vars if name in self.model_params + ] + else: self.observed_vars = [obs.name for obs in self.model.observed_RVs] self.var_names = [v.name for v in self.model.free_RVs] @@ -200,7 +225,11 @@ def _get_prior_predictive_samples(self): def _get_prior_predictive_samples_numpyro(self): """Generate samples to use for the simulations using numpyro.""" predictive = Predictive(self.model, num_samples=self.num_simulations) - free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars} + free_vars_data = { + k: v + for k, v in self.data_dir.items() + if k not in self.observed_vars and k in self.model_params + } samples = predictive(jax.random.PRNGKey(self._seeds[0]), **free_vars_data) prior = {k: v for k, v in samples.items() if k not in self.observed_vars} if self.simulator: @@ -211,7 +240,7 @@ def _get_prior_predictive_samples_numpyro(self): results.append(self.simulator(**params)) prior_pred = {key: [result[key] for result in results] for key in results[0]} else: - prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars} + prior_pred = {k: v for k, v in samples.items() if k in self.observed_model_vars} return prior, prior_pred def _get_posterior_samples(self, prior_predictive_draw): @@ -219,7 +248,8 @@ def _get_posterior_samples(self, prior_predictive_draw): new_model = pm.observe(self.model, prior_predictive_draw) with new_model: check = pm.sample( - **self.sample_kwargs, random_seed=self._seeds[self._simulations_complete] + **self.sample_kwargs, + random_seed=self._seeds[self._simulations_complete], ) posterior = extract(check, group="posterior", keep_dataset=True) @@ -229,13 +259,16 @@ def _get_posterior_samples_numpyro(self, prior_predictive_draw): """Generate posterior samples using numpyro conditioned to a prior predictive sample.""" mcmc = MCMC(self.numpyro_model, **self.sample_kwargs) rng_seed = jax.random.PRNGKey(self._seeds[self._simulations_complete]) - # If using a custom simulator, some variables present in `prior_predictive_draw` - # might be missing from self.observed_vars. - # TODO: Not sure if the union is redundant here and perhaps prior_predictive_draw.keys() - # could be sufficient. - extended_observed_vars = set(prior_predictive_draw.keys()).union(self.observed_vars) - free_vars_data = {k: v for k, v in self.data_dir.items() if k not in extended_observed_vars} - mcmc.run(rng_seed, **free_vars_data, **prior_predictive_draw) + + free_vars_data = { + k: v + for k, v in self.data_dir.items() + if k not in self.observed_model_vars and k in self.model_params + } + prior_predictive_args = { + k: v for k, v in prior_predictive_draw.items() if k in self.observed_model_vars + } + mcmc.run(rng_seed, **free_vars_data, **prior_predictive_args) return from_numpyro(mcmc)["posterior"] def _convert_to_datatree(self): @@ -271,7 +304,10 @@ def run_simulations(self): if self.simulator is not None: self.observed_vars = list(prior_pred.data_vars) self.var_names = list( - filter(lambda var_name: var_name not in self.observed_vars, list(prior.data_vars)) + filter( + lambda var_name: var_name not in self.observed_vars, + list(prior.data_vars), + ) ) self.simulations = {var_name: [] for var_name in self.var_names} @@ -309,8 +345,17 @@ def _run_simulations_numpyro(self): # if simulator is used, ignore observed_vars if self.simulator is not None: self.observed_vars = list(prior_pred.keys()) + self.observed_model_vars = [ + name for name in self.observed_vars if name in self.model_params + ] + if not self.observed_model_vars: + raise ValueError("No observed variables to condition on") + self.var_names = list( - filter(lambda var_name: var_name not in self.observed_vars, list(prior.keys())) + filter( + lambda var_name: var_name not in self.observed_vars, + list(prior.keys()), + ) ) self.simulations = {var_name: [] for var_name in self.var_names} try: diff --git a/simuk/tests/test_sbc.py b/simuk/tests/test_sbc.py index 0204e88..6e7ab8d 100644 --- a/simuk/tests/test_sbc.py +++ b/simuk/tests/test_sbc.py @@ -56,6 +56,11 @@ def eight_schools_cauchy_prior_no_observed(J, sigma, y=None): numpyro.factor("custom_likelihood", log_likelihood) +def numpyro_model_double_observed(y1=jnp.array([0.0]), y2=jnp.array([0.0])): + numpyro.sample("y1", dist.Normal(0, 1), obs=y1) + numpyro.sample("y2", dist.Normal(0, 1), obs=y2) + + # Custom simulator functions def centered_eight_simulator(theta, seed, **kwargs): rng = np.random.default_rng(seed) @@ -175,3 +180,38 @@ def test_sbc_numpyro_fail_no_observed_variable(): sample_kwargs={"num_warmup": 50, "num_samples": 25}, ) sbc.run_simulations() + + +def test_sbc_numpyro_missing_observed_data(): + with pytest.raises(ValueError, match="missing from data_dir"): + simuk.SBC( + NUTS(numpyro_model_double_observed), + data_dir={"y2": [0.0]}, + num_simulations=10, + sample_kwargs={"num_warmup": 10, "num_samples": 5}, + ) + + +def test_sbc_numpyro_empty_observed_data(): + with pytest.raises(ValueError, match="no observed variables"): + simuk.SBC( + NUTS(eight_schools_cauchy_prior), + data_dir={"J": 8, "sigma": sigma}, + num_simulations=10, + sample_kwargs={"num_warmup": 10, "num_samples": 5}, + ) + + +def test_sbc_numpyro_simulator_no_conditionable_observed(): + def bad_simulator(**kwargs): + return {"not_a_param": np.array([0.0])} + + sbc = simuk.SBC( + NUTS(eight_schools_cauchy_prior), + data_dir={"J": 8, "sigma": sigma, "y": data}, + num_simulations=5, + sample_kwargs={"num_warmup": 10, "num_samples": 5}, + simulator=bad_simulator, + ) + with pytest.raises(ValueError, match="No observed variables to condition on"): + sbc.run_simulations()