From 3e80178779c7902b24f4456da07060c438546420 Mon Sep 17 00:00:00 2001 From: cab14bacc <86755693+Cab14bacc@users.noreply.github.com> Date: Fri, 22 May 2026 18:26:10 +0300 Subject: [PATCH 1/4] fix: numpyro wrongly conditioning on non-sample sites --- simuk/sbc.py | 100 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 79 insertions(+), 21 deletions(-) diff --git a/simuk/sbc.py b/simuk/sbc.py index 9a9633d..dd79fb0 100644 --- a/simuk/sbc.py +++ b/simuk/sbc.py @@ -17,6 +17,7 @@ pass from collections.abc import Mapping +import inspect import numpy as np from arviz_base import dict_to_dataset, extract, from_dict, from_numpyro @@ -102,7 +103,7 @@ def __init__( self.numpyro_model = model self.model = self.numpyro_model.model self.run_simulations = self._run_simulations_numpyro - self.data_dir = data_dir + self.data_dir = data_dir if data_dir is not None else {} else: raise ValueError( "model should be one of pymc.Model, bambi.Model, or numpyro.infer.mcmc.MCMCKernel" @@ -123,6 +124,7 @@ def __init__( self._extract_variable_names() self.simulations = {name: [] for name in self.var_names} self._simulations_complete = 0 + if simulator is not None and not callable(simulator): raise ValueError("simulator should be a function or None") if simulator is not None and self.observed_vars: @@ -138,11 +140,28 @@ def __init__( "predictive samples. Either change the model or specify a simulator " "with the `simulator` argument." ) + + if simulator is None and self.engine == "numpyro": + if not self.observed_model_vars: + raise ValueError( + "There are no observed variables we can condition on, and NumPyro " + "will not generate prior predictive samples. Either change the model " + "or specify a simulator with the `simulator` argument." + ) + missing = [ + name for name in self.observed_model_vars if name not in self.data_dir + ] + if missing: + raise ValueError( + "The following model parameters are missing from data_dir: " + ", ".join(sorted(missing)) + ) self.simulator = simulator def _extract_variable_names(self): """Extract observed and free variables from the model.""" if self.engine == "numpyro": + self.model_params = set(inspect.signature(self.model).parameters.keys()) with trace() as tr: with seed(rng_seed=int(self._seeds[0])): self.numpyro_model.model(**self.data_dir) @@ -156,6 +175,14 @@ def _extract_variable_names(self): for name, site in tr.items() if site["type"] == "sample" and site.get("is_observed", False) ] + # Observed model variables are those that are marked as observed + # and are also model function parameters in order to be able to condition on them. + # For instance, this is used to filter out factor variables that are marked as observed + # but cannot be conditioned on. + self.observed_model_vars = [ + name for name in self.observed_vars if name in self.model_params + ] + else: self.observed_vars = [obs.name for obs in self.model.observed_RVs] self.var_names = [v.name for v in self.model.free_RVs] @@ -178,20 +205,25 @@ def _get_prior_predictive_samples(self): # Deal with custom simulator prior_pred = [] for i in range(prior.sizes["sample"]): - params = {var: prior[var].isel(sample=i).values for var in prior.data_vars} + params = { + var: prior[var].isel(sample=i).values for var in prior.data_vars + } params["seed"] = self._seeds[i] try: res = self.simulator(**params) - assert isinstance(res, Mapping), ( - f"Simulator must return a dictionary, got {type(res)}" - ) + assert isinstance( + res, Mapping + ), f"Simulator must return a dictionary, got {type(res)}" prior_pred.append(res) except Exception as e: raise ValueError( f"Error generating prior predictive sample with parameters {params}: {e}." ) prior_pred = dict_to_dataset( - {key: np.stack([pp[key] for pp in prior_pred]) for key in prior_pred[0]}, + { + key: np.stack([pp[key] for pp in prior_pred]) + for key in prior_pred[0] + }, sample_dims=["sample"], coords={**prior.coords}, ) @@ -200,7 +232,11 @@ def _get_prior_predictive_samples(self): def _get_prior_predictive_samples_numpyro(self): """Generate samples to use for the simulations using numpyro.""" predictive = Predictive(self.model, num_samples=self.num_simulations) - free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars} + free_vars_data = { + k: v for k, v in self.data_dir.items() + if k not in self.observed_vars + and k in self.model_params + } samples = predictive(jax.random.PRNGKey(self._seeds[0]), **free_vars_data) prior = {k: v for k, v in samples.items() if k not in self.observed_vars} if self.simulator: @@ -209,9 +245,11 @@ def _get_prior_predictive_samples_numpyro(self): params = dict(zip(prior.keys(), vals)) params["seed"] = self._seeds[i] results.append(self.simulator(**params)) - prior_pred = {key: [result[key] for result in results] for key in results[0]} + prior_pred = { + key: [result[key] for result in results] for key in results[0] + } else: - prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars} + prior_pred = {k: v for k, v in samples.items() if k in self.observed_model_vars} return prior, prior_pred def _get_posterior_samples(self, prior_predictive_draw): @@ -219,7 +257,8 @@ def _get_posterior_samples(self, prior_predictive_draw): new_model = pm.observe(self.model, prior_predictive_draw) with new_model: check = pm.sample( - **self.sample_kwargs, random_seed=self._seeds[self._simulations_complete] + **self.sample_kwargs, + random_seed=self._seeds[self._simulations_complete], ) posterior = extract(check, group="posterior", keep_dataset=True) @@ -229,13 +268,16 @@ def _get_posterior_samples_numpyro(self, prior_predictive_draw): """Generate posterior samples using numpyro conditioned to a prior predictive sample.""" mcmc = MCMC(self.numpyro_model, **self.sample_kwargs) rng_seed = jax.random.PRNGKey(self._seeds[self._simulations_complete]) - # If using a custom simulator, some variables present in `prior_predictive_draw` - # might be missing from self.observed_vars. - # TODO: Not sure if the union is redundant here and perhaps prior_predictive_draw.keys() - # could be sufficient. - extended_observed_vars = set(prior_predictive_draw.keys()).union(self.observed_vars) - free_vars_data = {k: v for k, v in self.data_dir.items() if k not in extended_observed_vars} - mcmc.run(rng_seed, **free_vars_data, **prior_predictive_draw) + + free_vars_data = { + k: v + for k, v in self.data_dir.items() + if k not in self.observed_model_vars and k in self.model_params + } + prior_predictive_args = { + k: v for k, v in prior_predictive_draw.items() if k in self.observed_model_vars + } + mcmc.run(rng_seed, **free_vars_data, **prior_predictive_args) return from_numpyro(mcmc)["posterior"] def _convert_to_datatree(self): @@ -271,7 +313,10 @@ def run_simulations(self): if self.simulator is not None: self.observed_vars = list(prior_pred.data_vars) self.var_names = list( - filter(lambda var_name: var_name not in self.observed_vars, list(prior.data_vars)) + filter( + lambda var_name: var_name not in self.observed_vars, + list(prior.data_vars), + ) ) self.simulations = {var_name: [] for var_name in self.var_names} @@ -286,7 +331,9 @@ def run_simulations(self): posterior = self._get_posterior_samples(prior_predictive_draw) for name in self.var_names: self.simulations[name].append( - (posterior[name] < prior[name].sel(chain=0, draw=idx)).sum("sample").values + (posterior[name] < prior[name].sel(chain=0, draw=idx)) + .sum("sample") + .values ) self._simulations_complete += 1 progress.update() @@ -309,8 +356,17 @@ def _run_simulations_numpyro(self): # if simulator is used, ignore observed_vars if self.simulator is not None: self.observed_vars = list(prior_pred.keys()) + self.observed_model_vars = [ + name for name in self.observed_vars if name in self.model_params + ] + if not self.observed_model_vars: + raise ValueError("No observed variables to condition on") + self.var_names = list( - filter(lambda var_name: var_name not in self.observed_vars, list(prior.keys())) + filter( + lambda var_name: var_name not in self.observed_vars, + list(prior.keys()), + ) ) self.simulations = {var_name: [] for var_name in self.var_names} try: @@ -320,7 +376,9 @@ def _run_simulations_numpyro(self): posterior = self._get_posterior_samples_numpyro(prior_predictive_draw) for name in self.var_names: self.simulations[name].append( - (posterior[name].sel(chain=0) < prior[name][idx]).sum(axis=0).values + (posterior[name].sel(chain=0) < prior[name][idx]) + .sum(axis=0) + .values ) self._simulations_complete += 1 progress.update() From b01add5a841ab5c8004ef880e3afe538afa40d83 Mon Sep 17 00:00:00 2001 From: cab14bacc <86755693+Cab14bacc@users.noreply.github.com> Date: Fri, 22 May 2026 18:31:15 +0300 Subject: [PATCH 2/4] chore: fix linting errors --- simuk/sbc.py | 48 ++++++++++++++++++------------------------------ 1 file changed, 18 insertions(+), 30 deletions(-) diff --git a/simuk/sbc.py b/simuk/sbc.py index dd79fb0..c484a46 100644 --- a/simuk/sbc.py +++ b/simuk/sbc.py @@ -140,7 +140,7 @@ def __init__( "predictive samples. Either change the model or specify a simulator " "with the `simulator` argument." ) - + if simulator is None and self.engine == "numpyro": if not self.observed_model_vars: raise ValueError( @@ -148,13 +148,12 @@ def __init__( "will not generate prior predictive samples. Either change the model " "or specify a simulator with the `simulator` argument." ) - missing = [ - name for name in self.observed_model_vars if name not in self.data_dir - ] + missing = [name for name in self.observed_model_vars if name not in self.data_dir] if missing: raise ValueError( - "The following model parameters are missing from data_dir: " - ", ".join(sorted(missing)) + "The following model parameters are missing from data_dir: , ".join( + sorted(missing) + ) ) self.simulator = simulator @@ -177,7 +176,7 @@ def _extract_variable_names(self): ] # Observed model variables are those that are marked as observed # and are also model function parameters in order to be able to condition on them. - # For instance, this is used to filter out factor variables that are marked as observed + # For instance, this is used to filter out factor variables that are marked as observed # but cannot be conditioned on. self.observed_model_vars = [ name for name in self.observed_vars if name in self.model_params @@ -205,25 +204,20 @@ def _get_prior_predictive_samples(self): # Deal with custom simulator prior_pred = [] for i in range(prior.sizes["sample"]): - params = { - var: prior[var].isel(sample=i).values for var in prior.data_vars - } + params = {var: prior[var].isel(sample=i).values for var in prior.data_vars} params["seed"] = self._seeds[i] try: res = self.simulator(**params) - assert isinstance( - res, Mapping - ), f"Simulator must return a dictionary, got {type(res)}" + assert isinstance(res, Mapping), ( + f"Simulator must return a dictionary, got {type(res)}" + ) prior_pred.append(res) except Exception as e: raise ValueError( f"Error generating prior predictive sample with parameters {params}: {e}." ) prior_pred = dict_to_dataset( - { - key: np.stack([pp[key] for pp in prior_pred]) - for key in prior_pred[0] - }, + {key: np.stack([pp[key] for pp in prior_pred]) for key in prior_pred[0]}, sample_dims=["sample"], coords={**prior.coords}, ) @@ -233,9 +227,9 @@ def _get_prior_predictive_samples_numpyro(self): """Generate samples to use for the simulations using numpyro.""" predictive = Predictive(self.model, num_samples=self.num_simulations) free_vars_data = { - k: v for k, v in self.data_dir.items() - if k not in self.observed_vars - and k in self.model_params + k: v + for k, v in self.data_dir.items() + if k not in self.observed_vars and k in self.model_params } samples = predictive(jax.random.PRNGKey(self._seeds[0]), **free_vars_data) prior = {k: v for k, v in samples.items() if k not in self.observed_vars} @@ -245,9 +239,7 @@ def _get_prior_predictive_samples_numpyro(self): params = dict(zip(prior.keys(), vals)) params["seed"] = self._seeds[i] results.append(self.simulator(**params)) - prior_pred = { - key: [result[key] for result in results] for key in results[0] - } + prior_pred = {key: [result[key] for result in results] for key in results[0]} else: prior_pred = {k: v for k, v in samples.items() if k in self.observed_model_vars} return prior, prior_pred @@ -331,9 +323,7 @@ def run_simulations(self): posterior = self._get_posterior_samples(prior_predictive_draw) for name in self.var_names: self.simulations[name].append( - (posterior[name] < prior[name].sel(chain=0, draw=idx)) - .sum("sample") - .values + (posterior[name] < prior[name].sel(chain=0, draw=idx)).sum("sample").values ) self._simulations_complete += 1 progress.update() @@ -361,7 +351,7 @@ def _run_simulations_numpyro(self): ] if not self.observed_model_vars: raise ValueError("No observed variables to condition on") - + self.var_names = list( filter( lambda var_name: var_name not in self.observed_vars, @@ -376,9 +366,7 @@ def _run_simulations_numpyro(self): posterior = self._get_posterior_samples_numpyro(prior_predictive_draw) for name in self.var_names: self.simulations[name].append( - (posterior[name].sel(chain=0) < prior[name][idx]) - .sum(axis=0) - .values + (posterior[name].sel(chain=0) < prior[name][idx]).sum(axis=0).values ) self._simulations_complete += 1 progress.update() From 52ca16703b8981d42feeb41b99d715cbd75c8fc8 Mon Sep 17 00:00:00 2001 From: cab14bacc <86755693+Cab14bacc@users.noreply.github.com> Date: Fri, 22 May 2026 18:46:01 +0300 Subject: [PATCH 3/4] chore: fix linting errors --- simuk/sbc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simuk/sbc.py b/simuk/sbc.py index c484a46..aa63d9a 100644 --- a/simuk/sbc.py +++ b/simuk/sbc.py @@ -16,8 +16,8 @@ except ImportError: pass -from collections.abc import Mapping import inspect +from collections.abc import Mapping import numpy as np from arviz_base import dict_to_dataset, extract, from_dict, from_numpyro From b0241be2771395beb90bd4503b1e78b575cf9eea Mon Sep 17 00:00:00 2001 From: cab14bacc <86755693+Cab14bacc@users.noreply.github.com> Date: Fri, 22 May 2026 20:11:58 +0300 Subject: [PATCH 4/4] feat(test): add tests for numpyro fix --- simuk/sbc.py | 5 ++--- simuk/tests/test_sbc.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/simuk/sbc.py b/simuk/sbc.py index aa63d9a..78d21b6 100644 --- a/simuk/sbc.py +++ b/simuk/sbc.py @@ -151,9 +151,8 @@ def __init__( missing = [name for name in self.observed_model_vars if name not in self.data_dir] if missing: raise ValueError( - "The following model parameters are missing from data_dir: , ".join( - sorted(missing) - ) + "The following model parameters are missing from data_dir: " + + ", ".join(sorted(missing)) ) self.simulator = simulator diff --git a/simuk/tests/test_sbc.py b/simuk/tests/test_sbc.py index 0204e88..6e7ab8d 100644 --- a/simuk/tests/test_sbc.py +++ b/simuk/tests/test_sbc.py @@ -56,6 +56,11 @@ def eight_schools_cauchy_prior_no_observed(J, sigma, y=None): numpyro.factor("custom_likelihood", log_likelihood) +def numpyro_model_double_observed(y1=jnp.array([0.0]), y2=jnp.array([0.0])): + numpyro.sample("y1", dist.Normal(0, 1), obs=y1) + numpyro.sample("y2", dist.Normal(0, 1), obs=y2) + + # Custom simulator functions def centered_eight_simulator(theta, seed, **kwargs): rng = np.random.default_rng(seed) @@ -175,3 +180,38 @@ def test_sbc_numpyro_fail_no_observed_variable(): sample_kwargs={"num_warmup": 50, "num_samples": 25}, ) sbc.run_simulations() + + +def test_sbc_numpyro_missing_observed_data(): + with pytest.raises(ValueError, match="missing from data_dir"): + simuk.SBC( + NUTS(numpyro_model_double_observed), + data_dir={"y2": [0.0]}, + num_simulations=10, + sample_kwargs={"num_warmup": 10, "num_samples": 5}, + ) + + +def test_sbc_numpyro_empty_observed_data(): + with pytest.raises(ValueError, match="no observed variables"): + simuk.SBC( + NUTS(eight_schools_cauchy_prior), + data_dir={"J": 8, "sigma": sigma}, + num_simulations=10, + sample_kwargs={"num_warmup": 10, "num_samples": 5}, + ) + + +def test_sbc_numpyro_simulator_no_conditionable_observed(): + def bad_simulator(**kwargs): + return {"not_a_param": np.array([0.0])} + + sbc = simuk.SBC( + NUTS(eight_schools_cauchy_prior), + data_dir={"J": 8, "sigma": sigma, "y": data}, + num_simulations=5, + sample_kwargs={"num_warmup": 10, "num_samples": 5}, + simulator=bad_simulator, + ) + with pytest.raises(ValueError, match="No observed variables to condition on"): + sbc.run_simulations()