Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion cfr/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,14 @@ def get_clim(self, fields, tag=None, verbose=False, search_dist=5, load=True, **
if tag is not None:
name = f'{tag}.{name}'

nda = field.da.sel(lat=self.lat, lon=self.lon, **_kwargs)
da = field.da
if da.indexes.get('lat') is not None and not da.indexes['lat'].is_unique:
_, lat_idx = np.unique(da['lat'].values, return_index=True)
da = da.isel(lat=lat_idx)
if da.indexes.get('lon') is not None and not da.indexes['lon'].is_unique:
_, lon_idx = np.unique(da['lon'].values, return_index=True)
da = da.isel(lon=lon_idx)
nda = da.sel(lat=self.lat, lon=self.lon, **_kwargs)
if np.all(np.isnan(nda.values)) and search_dist is not None:
for i in range(1, search_dist+1):
p_header(f'{self.pid} >>> Nearest climate is NaN. Searching around within distance of {i} deg ...')
Expand Down
103 changes: 90 additions & 13 deletions cfr/reconres.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,34 +168,108 @@ def load_proxylabels(self, verbose=False):
if verbose:
p_success(f">>> ReconRes.proxy_labels created")

def indpdt_verif(self, job_path, verbose=False, calib_period=(1850, 2000),min_verif_len=10):
def indpdt_verif(self, job_path, verbose=False, calib_period=(1850, 2000), min_verif_len=10, debug=False):
"""
Perform independent verification.
job_path (str): the path to the job.
verbose (bool, optional): print verbose information. Defaults to False.
debug (bool, optional): print diagnostic info on the first iteration. Defaults to False.
"""
# load the reconstructions for the "prior"
job = ReconJob()
job.load(job_path)
try:
import psutil, os as _os
def _rss_mb():
return psutil.Process(_os.getpid()).memory_info().rss / 1e6
except ImportError:
def _rss_mb():
return float('nan')

indpdt_info = []
for path_index ,path in enumerate(self.paths):
for path_index, path in enumerate(self.paths):
# Recreate job each iteration so proxy clim caches and prior fields
# never accumulate across ensemble members.
job = ReconJob()
job.load(job_path)

proxy_labels = self.proxy_labels[path_index]

# Identify which prior variables are present in the reconstruction file.
# Variables absent from the file (e.g. pr in a tas-only reconstruction)
# are kept from the freshly loaded original prior.
with xr.open_dataset(path) as ds_check:
recon_vars_in_file = [k for k in job.prior if k in ds_check]
non_recon_prior = {k: v for k, v in job.prior.items() if k not in recon_vars_in_file}

job.load_clim(
tag="prior",
path_dict={
"tas": path,
},
path_dict={vn: path for vn in recon_vars_in_file},
anom_period=(1951, 1980),
)
# Restore variables not in the reconstruction file from the original prior.
if non_recon_prior and path_index == 0:
p_warning(f">>> Variables {list(non_recon_prior.keys())} not found in reconstruction file — using original prior (e.g. CCSM4) for these variables.")
job.prior.update(non_recon_prior)
del non_recon_prior

# Mark reconstructed fields as already annualized (integer year coords)
# and collapse any ensemble dimension by mean before get_clim extracts
# a point time series, to avoid OOM at full proxy/ensemble scale.
for vn in recon_vars_in_file:
if 'ens' in job.prior[vn].da.dims:
job.prior[vn].da = job.prior[vn].da.mean(dim='ens')
job.prior[vn].da.attrs['annualized'] = 1

# Clear only the model.* keys each proxy's PSM actually needs,
# so forward_psms() re-fetches them from the updated prior.
for pobj in job.proxydb.records.values():
if 'clim' not in pobj.__dict__:
continue
for vn in pobj.psm.climate_required:
key = f'model.{vn}'
if key in pobj.clim:
del pobj.clim[key]

if debug and path_index == 0:
trw_proxies = [(pid, p) for pid, p in job.proxydb.records.items()
if getattr(p, 'ptype', '') == 'tree.TRW']
print(f"[debug] total calibrated proxies: {job.proxydb.filter(by='tag', keys=['calibrated']).nrec}")
print(f"[debug] TRW proxies: {len(trw_proxies)}")
print(f"[debug] prior keys loaded: {list(job.prior.keys())}")
print(f"[debug] recon vars from file: {recon_vars_in_file}")
print(f"[debug] RSS before forward_psms: {_rss_mb():.0f} MB")

job.forward_psms(verbose=verbose)

if debug and path_index == 0:
trw_proxies = [(pid, p) for pid, p in job.proxydb.records.items()
if getattr(p, 'ptype', '') == 'tree.TRW']
for pid, p in trw_proxies[:2]:
print(f"\n[debug] {pid} ({p.ptype})")
for key in ['model.tas', 'model.pr']:
if hasattr(p, 'clim') and key in p.clim and p.clim[key] is not None:
da = p.clim[key].da
print(f" {key}: {da.time.values[0]} → {da.time.values[-1]}, len={len(da.time)}")
else:
print(f" {key}: NOT FOUND")

# Drop heavy model.* clim fields now that pseudo values are computed.
# obs.* fields are left intact as they are small and may be needed.
for pobj in job.proxydb.records.values():
if hasattr(pobj, 'clim'):
pobj.clim = {k: v for k, v in pobj.clim.items() if not k.startswith('model.')}

if debug and path_index == 0:
print(f"[debug] RSS after clim cleanup: {_rss_mb():.0f} MB")

if verbose:
p_success(f">>> Prior loaded from {path}")
# compare the pesudo-proxy records with the real records

# Compare pseudo-proxy records with real proxy observations.
calib_PDB = job.proxydb.filter(by="tag", keys=["calibrated"])
for i, (pname, proxy) in enumerate(calib_PDB.records.items()):
detail = proxy.psm.calib_details
attr_dict = {}
attr_dict['name'] = pname
attr_dict['ptype'] = proxy.ptype
attr_dict['seasonality'] = detail['seasonality']
if pname in proxy_labels['pids_assim']:
attr_dict['assim'] = True
Expand All @@ -221,8 +295,8 @@ def indpdt_verif(self, job_path, verbose=False, calib_period=(1850, 2000),min_ve
Df.astype(float)
masks = {
"all": None,
"in": (Df.index >= calib_period[0]) & (Df.index <= calib_period[1]), # in the calibration period
"before": (Df.index < calib_period[0]), # before the calibration period
"in": (Df.index >= calib_period[0]) & (Df.index <= calib_period[1]),
"before": (Df.index < calib_period[0]),
}
for mask_name, mask in masks.items():
if mask is not None:
Expand All @@ -240,16 +314,19 @@ def indpdt_verif(self, job_path, verbose=False, calib_period=(1850, 2000),min_ve
attr_dict[mask_name + '_corr'] = corr
attr_dict[mask_name + '_ce'] = ce
indpdt_info.append(attr_dict)

indpdt_info = pd.DataFrame(indpdt_info)
self.indpdt_info = indpdt_info
self.indpdt_calib_period = calib_period
if verbose:
p_success(f">>> indpdt verification completed, results stored in ReconRes.indpdt_info")
p_success(f">>> Records Number: {len(indpdt_info)}")
return indpdt_info

def plot_indpdt_verif(self):
"""
Plot the indpdt verification results.
"""
fig, axs = visual.plot_indpdt_dist(self.indpdt_info)
calib_period = getattr(self, 'indpdt_calib_period', [1850, 2000])
fig, axs = visual.plot_indpdt_dist(self.indpdt_info, calib_period=calib_period)
return fig, axs