Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 59 additions & 3 deletions cosmopower/wrappers/cobaya/cosmopower.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def initialize(self) -> None:
for nw in self._networks:
self._params_required_by_networks.update({str(p) : None for p in self._networks[nw].parameters})

# z is a special case (not actually a sampled parameter)
del(self._params_required_by_networks['z'])

def must_provide(self, **requirements: dict) -> dict:
super().must_provide(**requirements)

Expand All @@ -51,6 +54,16 @@ def must_provide(self, **requirements: dict) -> dict:
if k == "Cl":
for a in v:
required_quantities.append(f"Cl/{a.lower()}")

elif "Pk_interpolator" in requirements.keys():
v = self._must_provide[('Pk_grid', True, 'delta_tot', 'delta_tot')]
self.pk_redshifts = v.pop("z")
self.pk_k_max = v.pop("k_max")
self.pk_nonlin = v.pop("nonlinear", True)
if self.pk_nonlin:
required_quantities.append("Pk/nonlin")
else:
required_quantities.append("Pk/lin")
else:
required_quantities.append(f"{k}" if v is not None else k)

Expand Down Expand Up @@ -114,6 +127,9 @@ def must_provide(self, **requirements: dict) -> dict:

self.log.debug(f"Will evaluate networks {self._networks_to_eval}")

# z is a special case (not actually a sampled parameter)
del(must_provide['z'])

return must_provide

def calculate(self, state: dict, want_derived: bool = True,
Expand Down Expand Up @@ -155,12 +171,44 @@ def calculate(self, state: dict, want_derived: bool = True,

network = self.networks[quantity]

used_params = {p: input_params[p] for p in network.parameters}
# if 'z' in network.parameters:
# input_params['z'] = np.asarray(self.pk_redshifts)

# used_params = {p: input_params[p] for p in network.parameters}

used_params = {}

for p in network.parameters:
if p == 'z':
continue # 'z' is treated as an emulated but not sampled parameter
else:
used_params[p] = input_params[p]

if self.parser.is_log(quantity):
data = network.ten_to_predictions_np(used_params)[0, :]
get_data = network.ten_to_predictions_np
else:
data = network.predictions_np(used_params)[0, :]
get_data = network.predictions_np

if 'z' in network.parameters:

data = np.empty([len(np.asarray(self.pk_redshifts)), len(network.modes)])

for iz, z_eval in enumerate(self.pk_redshifts):
used_params = {**used_params, **{'z': z_eval}}
data[iz] = get_data(used_params)[0, :]

state["z"] = self.pk_redshifts
state[('Pk_grid', True, 'delta_tot', 'delta_tot')] = data

else:
data = get_data(used_params)[0, :]

import pdb; pdb.set_trace()

# if self.parser.is_log(quantity):
# data = network.ten_to_predictions_np(used_params)[0, :]
# else:
# data = network.predictions_np(used_params)[0, :]

self.set_in_dict(state, quantity, data)
if self.parser.modes_label(quantity) == "l":
Expand Down Expand Up @@ -254,6 +302,14 @@ def get_fsigma8(self, z: Union[float, np.ndarray]) -> np.ndarray:
compute fsigma8, but neither are \
computed.")

def get_Pk_grid(self, var_pair=("delta_tot", "delta_tot"), nonlinear=True):

k = self.current_state["k"]
z = self.current_state["z"]
Pk = self.current_state[('Pk_grid', True, 'delta_tot', 'delta_tot')]

return k, z, Pk

def get_Cl(self, ell_factor: bool = False,
units: str = "FIRASmuK2") -> dict:
cls_old = self.current_state.copy()
Expand Down