diff --git a/simuk/sbc.py b/simuk/sbc.py index 575e7e0..9a9633d 100644 --- a/simuk/sbc.py +++ b/simuk/sbc.py @@ -169,7 +169,7 @@ def _get_prior_predictive_samples(self): """Generate samples to use for the simulations.""" with self.model: idata = pm.sample_prior_predictive( - samples=self.num_simulations, random_seed=self._seeds[0] + draws=self.num_simulations, random_seed=self._seeds[0] ) prior = extract(idata, group="prior", keep_dataset=True) if self.simulator is None: @@ -270,10 +270,11 @@ def run_simulations(self): # if simulator is used, ignore observed_vars if self.simulator is not None: self.observed_vars = list(prior_pred.data_vars) - self.var_names = list(filter(lambda var_name: var_name not in self.observed_vars, - list(prior.data_vars))) + self.var_names = list( + filter(lambda var_name: var_name not in self.observed_vars, list(prior.data_vars)) + ) self.simulations = {var_name: [] for var_name in self.var_names} - + try: while self._simulations_complete < self.num_simulations: idx = self._simulations_complete @@ -308,8 +309,9 @@ def _run_simulations_numpyro(self): # if simulator is used, ignore observed_vars if self.simulator is not None: self.observed_vars = list(prior_pred.keys()) - self.var_names = list(filter(lambda var_name: var_name not in self.observed_vars, - list(prior.keys()))) + self.var_names = list( + filter(lambda var_name: var_name not in self.observed_vars, list(prior.keys())) + ) self.simulations = {var_name: [] for var_name in self.var_names} try: while self._simulations_complete < self.num_simulations: