diff --git a/sotodlib/preprocess/pcore.py b/sotodlib/preprocess/pcore.py index 8443e170f..0149ffbc2 100644 --- a/sotodlib/preprocess/pcore.py +++ b/sotodlib/preprocess/pcore.py @@ -37,8 +37,9 @@ def __init__(self, step_cfgs): self.save_cfgs = step_cfgs.get("save") self.select_cfgs = step_cfgs.get("select") self.plot_cfgs = step_cfgs.get("plot") - self.skip_on_sim = step_cfgs.get("skip_on_sim", False) - def process(self, aman, proc_aman, sim=False): + self.skip_on_sim = step_cfgs.get("skip_on_sim") + self.use_data_aman = step_cfgs.get("use_data_aman", False) + def process(self, aman, proc_aman, sim=False, data_aman=None): """ This function makes changes to the time ordered data AxisManager. Ex: calibrating or detrending the timestreams. This function will use any configuration information under the ``process`` key of the @@ -56,6 +57,9 @@ def process(self, aman, proc_aman, sim=False): sim: Bool False by default when analyzing data. Should be True when doing Transfer Function simulations and determining which steps should be run. + data_aman: AxisManager (Optional) + An AxisManager containing the preprocessed data to be used by + this process. """ if self.process_cfgs is None: return aman, proc_aman @@ -435,8 +439,14 @@ def extend(self, index, other): super().extend( [self._check_item(item) for item in other]) def __setitem__(self, index, item): super().__setitem__(index, self._check_item(item)) + def __getitem__(self, index): + result = super().__getitem__(index) + if isinstance(index, slice): + return Pipeline(result) + else: + return result - def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False): + def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False, data_amans=None): """ The main workhorse function for the pipeline class. This function takes an AxisManager TOD and successively runs the pipeline of preprocessing @@ -472,6 +482,11 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False): given ``proc_aman`` is ``aman.preprocess``. This assumes ``process.calc_and_save()`` has been run on this aman before and has injested flags and other information into ``proc_aman``. + data_amans: dict (Optional) + A dictionary of AxisManagers with keys (step, process.name) + filled with AxisManager processed up to step-1. This is used + to pre-load all data AxisManager which could be required when + processing simulations (e.g. to provide a T2P template) Returns ------- @@ -520,10 +535,21 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False): success = 'end' for step, process in enumerate(self): + if sim and (process.skip_on_sim is None): + raise ValueError(f"Process {process.name} missing required field `skip_on_sim`") if sim and process.skip_on_sim: continue self.logger.debug(f"Running {process.name}") - aman, proc_aman = process.process(aman, proc_aman, sim) + if (data_amans is not None) and process.use_data_aman: + try: + data_aman = data_amans[step, process.name] + except KeyError: + raise KeyError(f"Requested to use data AxisManager for process {process.name} but not found in data_amans") + else: + if process.use_data_aman and sim: + raise ValueError(f"Process {process.name} requested to use data_aman but none was provided to Pipeline.run()") + data_aman = None + process.process(aman, proc_aman, sim, data_aman) if run_calc: aman, proc_aman = process.calc_and_save(aman, proc_aman) process.plot(aman, proc_aman, filename=os.path.join(self.plot_dir, '{ctime}/{obsid}', f'{step+1}_{{name}}.png')) diff --git a/sotodlib/preprocess/preprocess_util.py b/sotodlib/preprocess/preprocess_util.py index 86bea130e..a5c40ceb6 100644 --- a/sotodlib/preprocess/preprocess_util.py +++ b/sotodlib/preprocess/preprocess_util.py @@ -549,7 +549,8 @@ def load_and_preprocess(obs_id, configs, context=None, dets=None, meta=None, def multilayer_load_and_preprocess(obs_id, configs_init, configs_proc, dets=None, meta=None, no_signal=None, logger=None, init_only=False, - ignore_cfg_check=False): + ignore_cfg_check=False, + stop_for_sims=False): """Loads the saved information from the preprocessing pipeline from a reference and a dependent database and runs the processing section of the pipeline for each. @@ -583,6 +584,11 @@ def multilayer_load_and_preprocess(obs_id, configs_init, configs_proc, ignore_cfg_check : bool If True, do not attempt to validate that configs_init is the same as the config used to create the existing init db. + stop_for_sims: bool + Optinal. If True, will stop before each step of the pipeline + with the flag `use_data_aman` set to True. The intended use is + to prepare all necessary data products that cannot be stored in + the preprocessing database, to process simulations. Returns ------- @@ -600,6 +606,21 @@ def multilayer_load_and_preprocess(obs_id, configs_init, configs_proc, configs_proc, context_proc = get_preprocess_context(configs_proc) meta_proc = context_proc.get_meta(obs_id, dets=dets, meta=meta) + # Count number of stops + if stop_for_sims: + num_stops = 0 + for process in configs_init["process_pipe"]: + if process.get("use_data_aman", False): + num_stops += 1 + for process in configs_proc["process_pipe"]: + if process.get("use_data_aman", False): + num_stops += 1 + logger.warning( + "Currently running with `stop_for_sims=True`. " + f"It will generate {num_stops} additional copies " + "of the data AxisManager with a higher memory usage." + ) + group_by_init = np.atleast_1d(configs_init['subobs'].get('use', 'detset')) group_by_proc = np.atleast_1d(configs_proc['subobs'].get('use', 'detset')) @@ -611,7 +632,9 @@ def multilayer_load_and_preprocess(obs_id, configs_init, configs_proc, return None else: pipe_init = Pipeline(configs_init["process_pipe"], logger=logger) - aman_cfgs_ref = get_pcfg_check_aman(pipe_init) + + if not ignore_cfg_check: + aman_cfgs_ref = get_pcfg_check_aman(pipe_init) if ( ignore_cfg_check or @@ -637,27 +660,97 @@ def multilayer_load_and_preprocess(obs_id, configs_init, configs_proc, aman = context_init.get_obs(meta_init, no_signal=no_signal) logger.info("Running initial pipeline") - pipe_init.run(aman, aman.preprocess, select=False) - if init_only: - return aman + if stop_for_sims: + out_amans_init = run_pipeline_stepgroups( + pipe_init, + aman, + run_last_step=not(init_only) + ) + if init_only: + return out_amans_init + else: + pipe_init.run(aman, aman.preprocess, select=False) + if init_only: + return aman logger.info("Running dependent pipeline") + if stop_for_sims: + aman = out_amans_init[(len(pipe_init), 'last')] proc_aman = context_proc.get_meta(obs_id, meta=aman) if 'valid_data' in aman.preprocess: aman.preprocess.move('valid_data', None) aman.preprocess.merge(proc_aman.preprocess) - pipe_proc.run(aman, aman.preprocess, select=False) + if stop_for_sims: + out_amans = run_pipeline_stepgroups( + pipe_proc, + out_amans_init[(len(pipe_init), 'last')], + ) + out_amans.update({ + (step, name): out_amans_init[(step, name)] + for (step, name) in out_amans_init + if name != 'last' + }) + return out_amans - return aman + else: + pipe_proc.run(aman, aman.preprocess, select=False) + return aman else: raise ValueError('Dependency check between configs failed.') +def run_pipeline_stepgroups(pipe, aman, run_last_step=False): + """ + Run a Pipeline object, grouping steps based on + the flag `use_data_aman` in the configuration + file. + Arguments + ---------- + pipe : Pipeline + Pipeline object to run. + aman : AxisManager + AxisManager to process. + run_last_step : bool + If True, will create a dict item containing the + AxisManager after run the full pipeline. + """ + batch_idx = [ + (step, process.name) + for step, process in enumerate(pipe) + if process.use_data_aman + ] + if batch_idx or run_last_step: + batch_idx = [(0, pipe[0].name)] + batch_idx + if run_last_step: + batch_idx += [(len(pipe), 'last')] + pipes = {} + for idx in range(len(batch_idx)-1): + start, start_name = batch_idx[idx] + end, end_name = batch_idx[idx+1] + # If asked to stop at the first process + # one needs to save the current state of + # the AxisManager + if end == 0: + pipes[end, end_name] = None + else: + pipes[end, end_name] = pipe[start:end] + out_amans = {} + loc_aman = aman.copy() + for (step, name), pipe in pipes.items(): + if pipe is not None: + pipe.run(loc_aman, aman.preprocess, select=False) + out_amans[step, name] = loc_aman.copy() + return out_amans + else: + return {} + + def multilayer_load_and_preprocess_sim(obs_id, configs_init, configs_proc, sim_map, meta=None, logger=None, init_only=False, - t2ptemplate_aman=None): + ignore_cfg_check=False, + data_amans=None): """Loads the saved information from the preprocessing pipeline from a reference and a dependent database, loads the signal from a (simulated) map into the AxisManager and runs the processing section of the pipeline @@ -689,9 +782,14 @@ def multilayer_load_and_preprocess_sim(obs_id, configs_init, configs_proc, Optional. Logger object or None will generate a new one. init_only : bool Optional. Whether or not to run the dependent pipeline. - t2ptemplate_aman : AxisManager - Optional. AxisManager to use as a template for t2p leakage - deprojection. + ignore_cfg_check : bool + If True, do not attempt to validate that configs_init is the same as + the config used to create the existing init db. + data_amans: dict (Optional) + A dictionary of AxisManagers with keys (step, process.name) + filled with AxisManager processed up to step-1. This is used + to pre-load all data AxisManager which could be required when + processing simulations (e.g. to provide a T2P template) Returns ------- @@ -719,10 +817,14 @@ def multilayer_load_and_preprocess_sim(obs_id, configs_init, configs_proc, return None else: pipe_init = Pipeline(configs_init["process_pipe"], logger=logger) - aman_cfgs_ref = get_pcfg_check_aman(pipe_init) - if check_cfg_match(aman_cfgs_ref, meta_proc.preprocess['pcfg_ref'], - logger=logger): + if not ignore_cfg_check: + aman_cfgs_ref = get_pcfg_check_aman(pipe_init) + + if ignore_cfg_check or check_cfg_match( + aman_cfgs_ref, + meta_proc.preprocess['pcfg_ref'], + logger=logger): pipe_proc = Pipeline(configs_proc["process_pipe"], logger=logger) logger.info("Restricting detectors on all proc pipeline processes") @@ -755,27 +857,12 @@ def multilayer_load_and_preprocess_sim(obs_id, configs_init, configs_proc, if init_only: return aman - if t2ptemplate_aman is not None: - # Replace Q,U with simulated timestreams - t2ptemplate_aman.wrap("demodQ", aman.demodQ, [(0, 'dets'), (1, 'samps')], overwrite=True) - t2ptemplate_aman.wrap("demodU", aman.demodU, [(0, 'dets'), (1, 'samps')], overwrite=True) - - t2p_aman = t2pleakage.get_t2p_coeffs( - t2ptemplate_aman, - merge_stats=False - ) - t2pleakage.subtract_t2p( - aman, - t2p_aman, - T_signal=t2ptemplate_aman.dsT - ) - logger.info("Running dependent pipeline") proc_aman = context_proc.get_meta(obs_id, meta=aman) if 'valid_data' in aman.preprocess: aman.preprocess.move('valid_data', None) aman.preprocess.merge(proc_aman.preprocess) - pipe_proc.run(aman, aman.preprocess, sim=True) + pipe_proc.run(aman, aman.preprocess, sim=True, data_amans=data_amans) return aman else: diff --git a/sotodlib/preprocess/processes.py b/sotodlib/preprocess/processes.py index 8b951256d..af423ec48 100644 --- a/sotodlib/preprocess/processes.py +++ b/sotodlib/preprocess/processes.py @@ -37,7 +37,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") start_stop = tod_ops.fft_trim(aman, **self.process_cfgs) proc_aman.restrict(self.process_cfgs.get('axis', 'samps'), (start_stop)) @@ -55,8 +57,10 @@ def __init__(self, step_cfgs): self.save_name = None super().__init__(step_cfgs) - - def process(self, aman, proc_aman, sim=False): + + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") tod_ops.detrend_tod(aman, signal_name=self.signal, **self.process_cfgs) return aman, proc_aman @@ -291,7 +295,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") field = self.process_cfgs['jumps_aman'] aman[self.signal] = tod_ops.jumps.jumpfix_subtract_heights( aman[self.signal], proc_aman[field].jump_flag.mask(), @@ -421,7 +427,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") full_output = self.process_cfgs.get('full_output') if full_output: freqs, Pxx, nseg = tod_ops.fft_ops.calc_psd(aman, signal=aman[self.signal], @@ -903,7 +911,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") if self.process_cfgs["kind"] == "single_value": if self.process_cfgs.get("divide", False): aman[self.signal] /= self.process_cfgs["val"] @@ -1096,7 +1106,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") if not(proc_aman[self.hwpss_stats] is None): modes = [int(m[1:]) for m in proc_aman[self.hwpss_stats].modes.vals[::2]] if sim: @@ -1182,7 +1194,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") tod_ops.apodize.apodize_cosine(aman, **self.process_cfgs) return aman, proc_aman @@ -1221,7 +1235,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") hwp.demod_tod(aman, **self.process_cfgs["demod_cfgs"]) if self.process_cfgs.get("trim_samps"): trim = self.process_cfgs["trim_samps"] @@ -1324,7 +1340,9 @@ def save(self, proc_aman, azss_stats): if self.save_cfgs: proc_aman.wrap(self.save_name, azss_stats) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") if 'subtract_in_place' in self.calc_cfgs: raise ValueError('calc_cfgs.subtract_in_place is not allowed use process_cfgs.subtract') if self.process_cfgs is None: @@ -1392,7 +1410,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") process_cfgs = copy.deepcopy(self.process_cfgs) if sim: process_cfgs["azss"] = proc_aman.get(process_cfgs["azss"]) @@ -1429,7 +1449,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") if self.flags is not None: glitch_flags=proc_aman.get(self.flags) tod_ops.gapfill.fill_glitches( @@ -1487,7 +1509,9 @@ def save(self, proc_aman, turn_aman): if self.save_cfgs: proc_aman.wrap(self.save_name, turn_aman) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") tod_ops.flags.get_turnaround_flags(aman, **self.process_cfgs) return aman, proc_aman @@ -1502,9 +1526,11 @@ class SubPolyf(_Preprocess): def __init__(self, step_cfgs): self.save_name = None - super().__init__(step_cfgs) - - def process(self, aman, proc_aman, sim=False): + super().__init__(step_cfgs) + + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") tod_ops.sub_polyf.subscan_polyfilter(aman, **self.process_cfgs) return aman, proc_aman @@ -1944,7 +1970,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") if (not 'hwp_angle' in aman._fields) and ('hwp_angle' in proc_aman._fields): aman.wrap('hwp_angle', proc_aman['hwp_angle']['hwp_angle'], [(0, 'samps')]) @@ -2029,7 +2057,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") field = self.process_cfgs.get("noise_fit_array", None) if field: noise_fit = proc_aman[field] @@ -2268,7 +2298,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") n_modes = self.process_cfgs.get('n_modes') signal = aman.get(self.signal) if self.model_signal: @@ -2356,7 +2388,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") n_modes = self.process_cfgs.get('n_modes') signal = aman.get(self.signal) flags = aman.flags.get(self.process_cfgs.get('source_flags')) @@ -2547,8 +2581,32 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): - tod_ops.t2pleakage.subtract_t2p(aman, proc_aman['t2p']) + def process(self, aman, proc_aman, sim=False, data_aman=None): + if sim and data_aman is not None: + + data_aman.restrict("dets", aman.dets.vals) + data_aman.wrap("demodQ", aman.demodQ, [(0, 'dets'), (1, 'samps')], overwrite=True) + data_aman.wrap("demodU", aman.demodU, [(0, 'dets'), (1, 'samps')], overwrite=True) + + if self.process_cfgs.get("fit_in_freq"): + t2p_aman = tod_ops.t2pleakage.get_t2p_coeffs_in_freq( + data_aman, + merge_stats=False, + ML_fit=True + ) + else: + t2p_aman = tod_ops.t2pleakage.get_t2p_coeffs( + data_aman, + merge_stats=False + ) + tod_ops.t2pleakage.subtract_t2p( + aman, + t2p_aman, + T_signal=data_aman.dsT + ) + else: + tod_ops.t2pleakage.subtract_t2p(aman, proc_aman['t2p'], + **self.process_cfgs) return aman, proc_aman class SplitFlags(_Preprocess): @@ -2615,7 +2673,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") warnings.warn("UnionFlags function is deprecated and only kept to allow loading of old process archives. Use generalized method CombineFlags") from so3g.proj import RangesMatrix @@ -2656,7 +2716,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") from so3g.proj import RangesMatrix if isinstance(self.process_cfgs['method'], list): if len(self.process_cfgs['flag_labels']) != len(self.process_cfgs['method']): @@ -2710,7 +2772,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") from sotodlib.coords import demod demod.rotate_focal_plane(aman, **self.process_cfgs) return aman, proc_aman @@ -2735,7 +2799,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") from sotodlib.coords import demod demod.rotate_demodQU(aman, **self.process_cfgs) return aman, proc_aman @@ -2775,7 +2841,9 @@ def save(self, proc_aman, aman): if self.save_cfgs: proc_aman.wrap(self.save_name, aman['qu_common_mode_coeffs']) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") if 'qu_common_mode_coeffs' in proc_aman: tod_ops.deproject.subtract_qu_common_mode(aman, self.signal_name_Q, self.signal_name_U, coeff_aman=proc_aman['qu_common_mode_coeffs'], @@ -2804,7 +2872,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") scan_freq = tod_ops.utils.get_scan_freq(aman) hpf_cfg = {'type': 'sine2', 'cutoff': scan_freq, 'trans_width': scan_freq/10} filt = tod_ops.get_hpf(hpf_cfg) @@ -2886,7 +2956,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") from sotodlib.coords import pointing_model if self.process_cfgs: pointing_model.apply_pointing_model(aman) @@ -2972,7 +3044,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") from sotodlib.obs_ops import correct_iir_params correct_iir_params(aman) @@ -3010,7 +3084,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") flags = aman.flags.get(self.process_cfgs.get('flags')) trimst, trimen = core.flagman.find_common_edge_idx(flags) aman.restrict('samps', (aman.samps.offset + trimst, @@ -3162,7 +3238,9 @@ def __init__(self, step_cfgs): super().__init__(step_cfgs) - def process(self, aman, proc_aman, sim=False): + def process(self, aman, proc_aman, sim=False, data_aman=None): + if data_aman is not None: + raise NotImplementedError("No support for using data AxisManager in process") aman.move(**self.process_cfgs) return aman, proc_aman diff --git a/sotodlib/tod_ops/t2pleakage.py b/sotodlib/tod_ops/t2pleakage.py index 6a3804aa8..972105b29 100644 --- a/sotodlib/tod_ops/t2pleakage.py +++ b/sotodlib/tod_ops/t2pleakage.py @@ -304,6 +304,7 @@ def get_t2p_coeffs(aman, T_sig_name='dsT', Q_sig_name='demodQ', U_sig_name='demo def get_t2p_coeffs_in_freq(aman, T_sig_name='dsT', Q_sig_name='demodQ', U_sig_name='demodU', fs=None, fit_freq_range=(0.01, 0.1), wn_freq_range=(0.2, 1.9), subtract_sig=False, merge_stats=True, t2p_stats_name='t2p_stats', + ML_fit=False ): """ Compute the leakage coefficients from temperature (T) to polarization (Q and U) in Fourier @@ -333,6 +334,13 @@ def get_t2p_coeffs_in_freq(aman, T_sig_name='dsT', Q_sig_name='demodQ', U_sig_na Whether to merge the calculated statistics back into `aman`. Default is True. t2p_stats_name : str Name under which to wrap the output AxisManager containing statistics. Default is 't2p_stats'. + ML_fit : bool + Whether to use the ML solution (analytic) to fit for coefficients. Default is False, + which uses ODR to perform a linear fit. + We are assuming a linear model: y_i = b * x_i + eps_i, where eps_i is white noise. + The log likelihood for this model is: log L(b) = - (1 / (2*sigma^2)) * sum_i (y_i - b*x_i)^2 + const + For which the maximum likelihood estimator for b is: b_hat = sum_i (x_i * y_i) / sum_i (x_i^2) + This is equivalent to Ordinary Least Square (OLS) regression with no intercept. Returns ------- @@ -347,64 +355,82 @@ def get_t2p_coeffs_in_freq(aman, T_sig_name='dsT', Q_sig_name='demodQ', U_sig_na I_fs = rfft(aman[T_sig_name], axis=1) Q_fs = rfft(aman[Q_sig_name], axis=1) U_fs = rfft(aman[U_sig_name], axis=1) - - coeffsQ = np.zeros(aman.dets.count) - errorsQ = np.zeros(aman.dets.count) - redchi2sQ = np.zeros(aman.dets.count) - coeffsU = np.zeros(aman.dets.count) - errorsU = np.zeros(aman.dets.count) - redchi2sU = np.zeros(aman.dets.count) fit_mask = (fit_freq_range[0] < freqs) & (freqs < fit_freq_range[1]) - wn_mask = (wn_freq_range[0] < freqs) & (freqs < wn_freq_range[1]) - def leakage_model(B, x): - return B[0] * x + if ML_fit: + x = np.real(I_fs[:, fit_mask]) + yQ = np.real(Q_fs[:, fit_mask]) + yU = np.real(U_fs[:, fit_mask]) + coeffsQ = np.sum(x * yQ, axis=1) / np.sum(x**2, axis=1) + coeffsU = np.sum(x * yU, axis=1) / np.sum(x**2, axis=1) - model = Model(leakage_model) + stdQ = np.sqrt(np.sum((yQ - coeffsQ[:, np.newaxis] * x)**2, axis=1) / (x.shape[1] - 1)) + stdU = np.sqrt(np.sum((yU - coeffsU[:, np.newaxis] * x)**2, axis=1) / (x.shape[1] - 1)) - for i in range(aman.dets.count): - # fit Q - Q_wnl = np.nanmean(np.abs(Q_fs[i][wn_mask])) - x = np.real(I_fs[i])[fit_mask] - y = np.real(Q_fs[i])[fit_mask] - sx = Q_wnl / np.sqrt(2) * np.ones_like(x) - sy = Q_wnl * np.ones_like(y) - try: - data = RealData(x=x, - y=y, - sx=sx, - sy=sy) - odr = ODR(data, model, beta0=[1e-3]) - output = odr.run() - coeffsQ[i] = output.beta[0] - errorsQ[i] = output.sd_beta[0] - redchi2sQ[i] = output.sum_square / (len(x) - 2) - except: - coeffsQ[i] = np.nan - errorsQ[i] = np.nan - redchi2sQ[i] = np.nan - - #fit U - U_wnl = np.nanmean(np.abs(U_fs[i][wn_mask])) - x = np.real(I_fs[i])[fit_mask] - y = np.real(U_fs[i])[fit_mask] - sx = U_wnl / np.sqrt(2) * np.ones_like(x) - sy = U_wnl * np.ones_like(y) - try: - data = RealData(x=x, - y=y, - sx=sx, - sy=sy) - odr = ODR(data, model, beta0=[1e-3]) - output = odr.run() - coeffsU[i] = output.beta[0] - errorsU[i] = output.sd_beta[0] - redchi2sU[i] = output.sum_square / (len(x) - 2) - except: - coeffsU[i] = np.nan - errorsU[i] = np.nan - redchi2sU[i] = np.nan + errorsQ = stdQ / np.sqrt(np.sum(x**2, axis=1)) + errorsU = stdU / np.sqrt(np.sum(x**2, axis=1)) + + redchi2sQ = np.sum((yQ - coeffsQ[:, np.newaxis] * x) ** 2 / stdQ[:, np.newaxis] ** 2, axis=1) / (x.shape[1] - 1) + redchi2sU = np.sum((yU - coeffsU[:, np.newaxis] * x) ** 2 / stdU[:, np.newaxis] ** 2, axis=1) / (x.shape[1] - 1) + else: + coeffsQ = np.zeros(aman.dets.count) + errorsQ = np.zeros(aman.dets.count) + redchi2sQ = np.zeros(aman.dets.count) + coeffsU = np.zeros(aman.dets.count) + errorsU = np.zeros(aman.dets.count) + redchi2sU = np.zeros(aman.dets.count) + + + wn_mask = (wn_freq_range[0] < freqs) & (freqs < wn_freq_range[1]) + + def leakage_model(B, x): + return B[0] * x + + model = Model(leakage_model) + + for i in range(aman.dets.count): + # fit Q + Q_wnl = np.nanmean(np.abs(Q_fs[i][wn_mask])) + x = np.real(I_fs[i])[fit_mask] + y = np.real(Q_fs[i])[fit_mask] + sx = Q_wnl / np.sqrt(2) * np.ones_like(x) + sy = Q_wnl * np.ones_like(y) + try: + data = RealData(x=x, + y=y, + sx=sx, + sy=sy) + odr = ODR(data, model, beta0=[1e-3]) + output = odr.run() + coeffsQ[i] = output.beta[0] + errorsQ[i] = output.sd_beta[0] + redchi2sQ[i] = output.sum_square / (len(x) - 2) + except: + coeffsQ[i] = np.nan + errorsQ[i] = np.nan + redchi2sQ[i] = np.nan + + #fit U + U_wnl = np.nanmean(np.abs(U_fs[i][wn_mask])) + x = np.real(I_fs[i])[fit_mask] + y = np.real(U_fs[i])[fit_mask] + sx = U_wnl / np.sqrt(2) * np.ones_like(x) + sy = U_wnl * np.ones_like(y) + try: + data = RealData(x=x, + y=y, + sx=sx, + sy=sy) + odr = ODR(data, model, beta0=[1e-3]) + output = odr.run() + coeffsU[i] = output.beta[0] + errorsU[i] = output.sd_beta[0] + redchi2sU[i] = output.sum_square / (len(x) - 2) + except: + coeffsU[i] = np.nan + errorsU[i] = np.nan + redchi2sU[i] = np.nan out_aman = core.AxisManager(aman.dets, aman.samps) out_aman.wrap('coeffsQ', coeffsQ, [(0, 'dets')])