From 66b5c25ccb545cb4723092c25de165102cfd8659 Mon Sep 17 00:00:00 2001 From: Ian Harrison Date: Fri, 27 Mar 2026 12:02:10 +0000 Subject: [PATCH] add calc and get_pk_grid --- cosmopower/wrappers/cobaya/cosmopower.py | 62 ++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 3 deletions(-) diff --git a/cosmopower/wrappers/cobaya/cosmopower.py b/cosmopower/wrappers/cobaya/cosmopower.py index 645af7a..b15082a 100644 --- a/cosmopower/wrappers/cobaya/cosmopower.py +++ b/cosmopower/wrappers/cobaya/cosmopower.py @@ -38,6 +38,9 @@ def initialize(self) -> None: for nw in self._networks: self._params_required_by_networks.update({str(p) : None for p in self._networks[nw].parameters}) + # z is a special case (not actually a sampled parameter) + del(self._params_required_by_networks['z']) + def must_provide(self, **requirements: dict) -> dict: super().must_provide(**requirements) @@ -51,6 +54,16 @@ def must_provide(self, **requirements: dict) -> dict: if k == "Cl": for a in v: required_quantities.append(f"Cl/{a.lower()}") + + elif "Pk_interpolator" in requirements.keys(): + v = self._must_provide[('Pk_grid', True, 'delta_tot', 'delta_tot')] + self.pk_redshifts = v.pop("z") + self.pk_k_max = v.pop("k_max") + self.pk_nonlin = v.pop("nonlinear", True) + if self.pk_nonlin: + required_quantities.append("Pk/nonlin") + else: + required_quantities.append("Pk/lin") else: required_quantities.append(f"{k}" if v is not None else k) @@ -114,6 +127,9 @@ def must_provide(self, **requirements: dict) -> dict: self.log.debug(f"Will evaluate networks {self._networks_to_eval}") + # z is a special case (not actually a sampled parameter) + del(must_provide['z']) + return must_provide def calculate(self, state: dict, want_derived: bool = True, @@ -155,12 +171,44 @@ def calculate(self, state: dict, want_derived: bool = True, network = self.networks[quantity] - used_params = {p: input_params[p] for p in network.parameters} + # if 'z' in network.parameters: + # input_params['z'] = np.asarray(self.pk_redshifts) + + # used_params = {p: input_params[p] for p in network.parameters} + + used_params = {} + + for p in network.parameters: + if p == 'z': + continue # 'z' is treated as an emulated but not sampled parameter + else: + used_params[p] = input_params[p] if self.parser.is_log(quantity): - data = network.ten_to_predictions_np(used_params)[0, :] + get_data = network.ten_to_predictions_np else: - data = network.predictions_np(used_params)[0, :] + get_data = network.predictions_np + + if 'z' in network.parameters: + + data = np.empty([len(np.asarray(self.pk_redshifts)), len(network.modes)]) + + for iz, z_eval in enumerate(self.pk_redshifts): + used_params = {**used_params, **{'z': z_eval}} + data[iz] = get_data(used_params)[0, :] + + state["z"] = self.pk_redshifts + state[('Pk_grid', True, 'delta_tot', 'delta_tot')] = data + + else: + data = get_data(used_params)[0, :] + + import pdb; pdb.set_trace() + + # if self.parser.is_log(quantity): + # data = network.ten_to_predictions_np(used_params)[0, :] + # else: + # data = network.predictions_np(used_params)[0, :] self.set_in_dict(state, quantity, data) if self.parser.modes_label(quantity) == "l": @@ -254,6 +302,14 @@ def get_fsigma8(self, z: Union[float, np.ndarray]) -> np.ndarray: compute fsigma8, but neither are \ computed.") + def get_Pk_grid(self, var_pair=("delta_tot", "delta_tot"), nonlinear=True): + + k = self.current_state["k"] + z = self.current_state["z"] + Pk = self.current_state[('Pk_grid', True, 'delta_tot', 'delta_tot')] + + return k, z, Pk + def get_Cl(self, ell_factor: bool = False, units: str = "FIRASmuK2") -> dict: cls_old = self.current_state.copy()