diff --git a/docs/site_pipeline.rst b/docs/site_pipeline.rst index e0fd5b511..0556a1650 100644 --- a/docs/site_pipeline.rst +++ b/docs/site_pipeline.rst @@ -377,8 +377,251 @@ The output database ``wafer_info.sqlite`` and HDF5 file ``wafer_info.h5`` are written to the ``output_dir``, which is created if it does not exist. -make_read_det_match + +get_brightsrc_pointing_part1 +---------------------------- + +The two-part ``get_brightsrc_pointing_part{}`` script set will solve for the xieta +coordinates of detectors that observe a bright source during an observation. + +It is a two part process that requires a map step and then a TOD step. +The scripts require the settings and preprocessing config files described below. + +For job submission and parallelization, see example NERSC slurm submission config at the end of this section. + +The code will process all wafers unless otherwise specified. +It is recommended to run with ``parallel_job: True`` in the config files if analyzing +multiple wafers at once. +Otherwise, specify a wafer slot or restrict detectors in command line args to debug. + +Command Line arguments: +.. argparse:: + :module: sotodlib.site_pipeline.get_brightsrc_pointing_part1 + :func: get_parser + + +There options to include min_ctime and max_ctime arguments, which will process all observations +in the time frame, is not recommended unless severely restricting the detectors for debugging. + + +Generated results ``````````````````` +The Step 1 map-based analysis scripts will generate the following outputs in the specified directory: + + 1. Single detector maps in ``/results/single_det_maps/_.hdf``. + + * All single maps are packaged in a single hdf file, with detector readout_id as the keys in the h5py file. + + 2. Fitted xi-eta focal plane position results saved as ResultSet in ``/path/to/results/map_based_results`` + as specified in the Step 1 config file. Script will append 'force_zero_roll' onto the specified results_dir + if True in config file. Load ResultSet with keyword 'focal_plane' + + * Contents: ``ResultSet<[dets:readout_id, xi, eta, gamma, R2], N rows>`` + + +The Step 2 TOD-based analysis scripts will use the map-based results as a starting point + and then generate the finalized outputs in the specified directory: + + 1. Fitted xi-eta focal plane position results saved as ResultSet in ``/path/to/results/tod_based_results`` + as specified in config file for Step-2. Script will append 'force_zero_roll' onto the specified results_dir + if True in config file. Load ResultSet with keyword 'focal_plane'. + The median boresight values from small time range the source was visible to each detector is included. + + * Contents: ``ResultSet<[dets:readout_id, xi, eta, gamma, xi_err, eta_err, R2, redchi2, az, el, roll], N rows>`` + +Configuration Files +``````````````````` +The configuration files to be input as ``configs`` in the command line should +have the following arguments as well as any preprocessing steps wished to be taken. +Only processing steps that are agnostic of det-match can be used to do +initial analyses without formalized metadata. + +The parameters in these examples could be used for SAT mid-freq moon observations. + +Step 1 Config: + +.. code-block:: yaml + + context_file: /path/to/context.yaml + query_tags: ['moon=1'] #(alternatively specify --sso_name in kwargs) + + optics_config_fn: '/global/cfs/cdirs/sobs/users/elleshaw/process_brightsrc/ufm_to_fp.yaml' + single_det_maps_dir: /path/to/results/single_det_maps + results_dir: /path/to/results/map_based_results + + parallel_job: True #For job submission. Parallel across wafers. + wafer_mask_det: 8. # (degrees) mask around detector to cut TOD when source too far away. + res_deg: 0.3 + xieta_bs_offset: [0., 0.] #Good to input xieta offset in radians. (!!! for satp2) + save_normal_roll: False #false for SAT, true for LAT + save_force_zero_roll: True #true for SAT, false for LAT + + hit_circle_r_deg: 7. # radial mask to decide which UFMs are hit by source and should be analyzed. + hit_time_threshold: 600 #seconds, if hit_time not met then UFM does not get analyzed. + + process_pipe: + - name: 'detrend' + process: + count: 2000 + method: 'linear' + - name: 'apodize' + process: + apodize_samps: 2000 + - name: 'fourier_filter' + process: + signal_name: "signal" + wrap_name: null + filt_function: "low_pass_sine2" + trim_samps: null + filter_params: + cutoff: 1.9 + width: 0.2 + - name: 'fourier_filter' + process: + signal_name: "signal" + wrap_name: null + filt_function: "high_pass_sine2" + trim_samps: 2000 + filter_params: + cutoff: 0.05 + width: 0.1 + +Part 2 is the TOD-based step. Its config file should look like the following. +The parameters in these examples are used for SAT mid-freq moon observations. + +.. code-block:: yaml + + context_file: /path/to/context.yaml + query_tags: ['moon=1'] #(alternatively specify --sso_name in kwargs) + + optics_config_fn: '/global/cfs/cdirs/sobs/users/elleshaw/process_brightsrc/ufm_to_fp.yaml' + fp_hdf_dir: /path/to/results/map_based_results from step 1 config file. + # If force_zero_roll is was True, then append _force_zero_roll to the end. + # Just make sure it matches where the results from Step 1 are. + result_dir: /path/to/resuls/tod_based_results #Where you want the final Step2 results to show up. + + parallel_job: True + force_zero_roll: True #Results will show up roatated in the xi-eta results as they are on the sky. + ds_factor: 40 + mask_deg: 2.5 # (degrees) size for circular mask around SSO (helps exclude focal plane reflections too) + fit_func_name: 'gaussian2d_nonlin' + max_non_linear_order: 3 #Suggested to use 1 for jupiter or sso's + #that do not saturate. + fwhm_init_deg: 0.5 # (degrees) Lower for SATp2 + error_estimation_method: 'force_one_redchi2' + flag_name_rms_calc: 'around_source' + flag_rms_calc_exclusive: False + + process_pipe: + - name: 'detrend' + process: + count: 2000 + method: 'linear' + - name: 'fourier_filter' + process: + signal_name: 'signal' + filt_function: 'iir_filter' + trim_samps: null + filter_params: + invert: True + - name: 'apodize' + process: + apodize_samps: 2000 + - name: 'fourier_filter' + process: + signal_name: "signal" + wrap_name: null + filt_function: "low_pass_sine2" + trim_samps: null + filter_params: + cutoff: 1.9 + width: 0.2 + - name: 'source_flags' + source_flags_name: 'source_wide' + save: True + calc: + mask: + shape: circle + xyr: [0., 0., 5.0] + merge: True + max_pix: 10000000000 + - name: 'source_flags' + source_flags_name: 'source_narrow' + save: True + calc: + mask: + shape: circle + xyr: [0., 0., 3.0] + merge: True + max_pix: 10000000000 + - name: 'combine_flags' + process: + flag_labels: ['source_wide.moon', 'source_narrow.moon'] + method: 'except' + total_flags_label: 'around_source' + - name: 'flag_turnarounds' + process: + truncate: True + - name: 'sub_polyf' + process: + method: 'legendre' + degree: 2 + mask: 'around_source' + exclusive: False + +Example NERSC slurm job submission config file +`````````````````````````````````````````````` + +.. code-block:: yaml + #!/bin/bash -l + + #SBATCH --qos=shared + #SBATCH --constraint=cpu + #SBATCH --nodes=1 + #SBATCH --ntasks=1 + + #SBATCH --cpus-per-task=14 + #SBATCH --time=00:30:00 + #SBATCH --mem=220G`` #(may need regular queue & up to 400 Gb for long obs) + + export OMP_NUM_THREADS=1 + set -e + + tele=$1 + obs=$2 + map=$3 + basis=$4 + source="moon_from_moon" + + ymldir="/path/to/processing_settings_config_folder" + yfile="${ymldir}/preprocess_config_moon_${basis}_based_${tele}.yaml" + + if (($map)); then + echo submitted map job; + srun -n 1 -N 1 -c 14 python3 + /path/to/sotodlib/site_pipeline/get_brightsrc_pointing_step1.py $yfile + --obs_id=${2} --sso_name="moon"; + else + echo submitted tod job; + srun -n 1 -N 1 -c 14 python3 + /path/to/sotodlib/site_pipeline/get_brightsrc_pointing_step2.py $yfile + --obs_id=${2} --sso_name="moon"; + fi + + +Submit the job submission file with the following commands: + +1. For Step 1 map-based + + * ``sbatch submit_moon_job_script.sh 1 map`` + +2. For Step 2 TOD based + + * ``sbatch submit_moon_job_script.sh 0 tod`` + + +make_read_det_match +------------------- This script generates the readout ID to detector ID mapping required to translate between the detector hardware information (ex: pixel position) and the readout IDs of the resonators used to index the SMuRF data. The script uses the @@ -447,6 +690,16 @@ entries mater. det_info: true multi: true + +get_brightsrc_pointing_part2 +---------------------------- +See Part 1 for description + +.. argparse:: + :module: sotodlib.site_pipeline.get_brightsrc_pointing_part2 + :func: get_parser + + update_det_match ------------------ @@ -692,6 +945,10 @@ The ``focal_plane_full`` dataset contains nine columns: - ``eta_m``: The measured eta in radians - ``gamma_m``: The measured gamma in radians. - ``weights``: The average weights of the measurements for this det. +- ``r2``: The fit weight passed in from the get_brightsrc_pointing dataset +- ``az``: The median Az value in radians from source-detector crossing +- ``el``: The median El value in radians from source-detector crossing +- ``roll``: The median Roll value in radians from source-detector crossing - ``n_point``: The number of pointing fits used for the det. - ``n_gamma``: The number of gamma fits used for this det. @@ -732,7 +989,7 @@ always be ``(1, 1, 1)`` and ``shear`` will be ``0``. ``finalize_focal_plane`` will also output a ``ManifestDb`` as a file called ``db.sqlite`` in the output directory. By default this will be indexed by ``stream_id`` and ``obs:timestamp`` and will point to the ``focal_plane`` dataset. -If you are running in ``per_obs`` mode then it wirbe indexed by ``obs_id`` and will point +If you are running in ``per_obs`` mode then it will be indexed by ``obs_id`` and will point to results associated with data observation. Be warned that in this case there will only be entries for observations with pointing fits, so design your context accordingly. diff --git a/sotodlib/coords/brightsrc_pointing.py b/sotodlib/coords/brightsrc_pointing.py new file mode 100644 index 000000000..0ed141e21 --- /dev/null +++ b/sotodlib/coords/brightsrc_pointing.py @@ -0,0 +1,518 @@ +# These functions are used for fitting detector positions from bright point sources +# called by site_pipeline.get_brightsrc_pointing_step1 and site_pipeline.get_brightsrc_pointing_step2 +import os +import re +from tqdm import tqdm +import numpy as np +from scipy import interpolate +from scipy.optimize import curve_fit +#from joblib import Parallel, delayed + +from sotodlib import core +from sotodlib import coords +from sotodlib.coords import optics +from sotodlib.core import metadata +from sotodlib.io.metadata import write_dataset, read_dataset + +from so3g.proj import quat +from pixell import enmap +import h5py +from scipy.ndimage import maximum_filter + + +def get_planet_trajectory(tod, planet, _split=20, return_model=False): + """ + Generate the trajectory in horizon coordinates of a given planet over a specified time range. + + Parameters: + tod : An axis manager containing a timestamps field, which is used to + determine the time range and generate the trajectory. + planet (str): The name of the object for which to generate the trajectory. e.g. "moon" or "saturn" + _split (int, optional): Number of points to interpolate the trajectory. Defaults to 20. + return_model (bool, optional): If True, returns interpolation functions of az and el. Defaults to False. + + Returns: + If return_model is True: + tuple: Tuple containing interpolation functions for azimuth and elevation. + If return_model is False: + array: Array of quaternions representing trajectory of the planet at each timestamp. + """ + timestamps_sparse = np.linspace(tod.timestamps[0], tod.timestamps[-1], _split) + + planet_az_sparse = np.zeros_like(timestamps_sparse) + planet_el_sparse = np.zeros_like(timestamps_sparse) + for i, timestamp in enumerate(timestamps_sparse): + az, el, _ = coords.planets.get_source_azel(planet, timestamp) + planet_az_sparse[i] = az + planet_el_sparse[i] = el + planet_az_func = interpolate.interp1d(timestamps_sparse, planet_az_sparse, kind="quadratic", fill_value='extrapolate') + planet_el_func = interpolate.interp1d(timestamps_sparse, planet_el_sparse, kind="quadratic", fill_value='extrapolate') + if return_model: + return planet_az_func, planet_el_func + else: + planet_az = planet_az_func(tod.timestamps) + planet_el = planet_el_func(tod.timestamps) + q_planet = quat.rotation_lonlat(-1 * planet_az, planet_el) + return q_planet + +def get_wafer_centered_sight(tod=None, planet=None, q_planet=None, q_bs=None, q_wafer=None): + """ + Calculate the sightline vector from the focal plane, centered on the wafer, to a planet. + + Parameters: + tod : An axis manager + planet (str): The name of the planet to calculate the sightline vector. + q_planet (optional): Quaternion representing the trajectory of the planet. + If None, it will be computed using get_planet_trajectory. Defaults to None. + q_bs (optional): Quaternion representing the trajectory of the boresight. + If None, it will be computed using the current boresight angles from tod. Defaults to None. + q_wafer (optional): Quaternion representing the center of wafer to the center of boresight. + If None, it will be computed using the median of the focal plane xi and eta from tod.focal_plane. + Defaults to None. + + Returns: + Sightline vector for the planet trajectory centered on the center of the wafer. + """ + #breakpoint() + if q_planet is None: + q_planet = get_planet_trajectory(tod, planet) + if q_bs is None: + q_bs = quat.rotation_lonlat(-1 * tod.boresight.az, tod.boresight.el) + if q_wafer is None: + q_wafer = quat.rotation_xieta(np.nanmedian(tod.focal_plane.xi), + np.nanmedian(tod.focal_plane.eta)) + + xi_wafer, eta_wafer, _ = quat.decompose_xieta(q_wafer) + z_to_x = quat.rotation_lonlat(0, 0) + sight = z_to_x * ~(q_bs * q_wafer) * q_planet + return sight + +def get_wafer_xieta(wafer_slot, optics_config_fn, xieta_bs_offset=(0., 0.), + roll_bs_offset=0., tod=None, wrap_to_tod=True,): + """ + Calculate the xi and eta coordinates for a given wafer slot on the focal plane. + + Parameters: + wafer_slot (str): The slot identifier of the wafer. + optics_config_fn (str): File name containing the optics configuration. + xieta_bs_offset (tuple): Offset in xieta coordinates for the focal plane, default is (0., 0.). + roll_bs_offset (float): Boresight roll offset. Default is 0 + tod (TimeOrderedData): TOD object to which focal plane infomation that all detectors have uniform pointing at center of the wafer is wrapped. + wrap_to_tod (bool): If True, wrap the calculated xi and eta coordinates to the Time-Ordered Data (TOD), default is True. + + Returns: + tuple: A tuple containing the calculated xi and eta coordinates for the specified wafer slot. + """ + + optics_config = optics.load_ufm_to_fp_config(optics_config_fn)['SAT'] + wafer_x, wafer_y = optics_config[wafer_slot]['dx'], optics_config[wafer_slot]['dy'] + wafer_r = np.sqrt(wafer_x**2 + wafer_y**2) + wafer_theta = np.arctan2(wafer_y, wafer_x) + + fp_to_sky = optics.sat_to_sky(optics.SAT_R_FP, optics.SAT_R_SKY) + lon = fp_to_sky(wafer_r) + + q1 = quat.rotation_iso(lon, 0) + q2 = quat.rotation_iso(0, 0, np.pi/2 - wafer_theta + roll_bs_offset) + q3 = quat.rotation_xieta(xieta_bs_offset[0], xieta_bs_offset[1]) + q = q3 * q2 * q1 + + xi_wafer, eta_wafer, _ = quat.decompose_xieta(q) + if wrap_to_tod: + if tod is None: + raise ValueError('tod is not provided.') + if 'focal_plane' in tod._fields.keys(): + tod.move('focal_plane', None) + focal_plane = core.AxisManager(tod.dets) + focal_plane.wrap('xi', np.ones(tod.dets.count, dtype='float32') * xi_wafer, [(0, 'dets')]) + focal_plane.wrap('eta', np.ones(tod.dets.count, dtype='float32') * eta_wafer, [(0, 'dets')]) + focal_plane.wrap('gamma', np.zeros(tod.dets.count, dtype='float32'), [(0, 'dets')]) + tod.wrap('focal_plane', focal_plane) + + # set boresight roll to zero + tod.boresight.wrap('roll_original', tod.boresight.roll, [(0, 'samps')]) + tod.boresight.roll *= 0. + + return xi_wafer, eta_wafer + + +def get_rough_hit_time(tod, wafer_slot, sso_name, circle_r_deg=7.,optics_config_fn=None): + """ + Estimate the rough hit time, which is the amount of time for which the source + is within some distance from the center of a wafer slot. + + Parameters: + tod : An AxisManager object + wafer_slot (str): Identifier for the wafer slot. + sso_name (str): Name of the Solar System Object (e.g., 'moon', 'jupiter'). + circle_r_deg (float, optional): Radius in degrees defining the circular region around the wafer center. + Defaults to 7 degrees. + + Returns: + float: Estimated rough hit time within the circular region around the wafer center. + """ + q_bs = quat.rotation_lonlat(-1 * tod.boresight.az, tod.boresight.el) + q_planet = get_planet_trajectory(tod, sso_name) + xi_wafer, eta_wafer = get_wafer_xieta(wafer_slot, optics_config_fn=optics_config_fn, + roll_bs_offset=np.median(tod.boresight.roll), wrap_to_tod=False) + q_wafer = quat.rotation_xieta(xi_wafer, eta_wafer) + + q_wafer_centered = get_wafer_centered_sight(q_planet=q_planet, q_bs=q_bs, q_wafer=q_wafer) + x_to_z = ~quat.rotation_lonlat(0, 0) + xi_wafer_centered, eta_wafer_centered, _ = quat.decompose_xieta(x_to_z * q_wafer_centered) + r_wafer_centered = np.sqrt(xi_wafer_centered**2 + eta_wafer_centered**2) + hit_time = (tod.timestamps[-1] - tod.timestamps[0]) * np.mean(np.rad2deg(r_wafer_centered) < circle_r_deg) + return hit_time + +def make_wafer_centered_maps(tod, sso_name, optics_config_fn, map_hdf, + xieta_bs_offset=(0., 0.), roll_bs_offset=None, + signal='signal', wafer_mask_deg=8., res_deg=0.3, cuts=None,): + """ + Generate boresight-centered maps from Time-Ordered Data (TOD) for each individual detector. + This script modifies tod.focal_plane and tod.boresight + + Parameters: + tod : an axismanager object + sso_name (str): Name of the planet for which the trajectory is calculated. + optics_config_fn (str): File name containing the optics configuration. + map_hdf (str): Path to the HDF5 file where the maps will be saved. + xieta_bs_offset (tuple): Offset in xieta coordinates for the boresight, default is (0., 0.). + roll_bs_offset (float): Offset in roll angle for the boresight, default is None. + signal (str): Name of the signal to be used, default is 'signal'. + wcs_kernel (ndarray): WCS kernel for mapping, default is None. + res_deg (float): Resolution of the map in degrees, default is 0.3 degrees. + cuts (tuple): Cuts to be applied to the map, default is None. + + Returns: + None + """ + #breakpoint() + q_planet = get_planet_trajectory(tod, sso_name) + q_bs = quat.rotation_lonlat(-1 * tod.boresight.az, tod.boresight.el) + if roll_bs_offset is None: + roll_bs_offset = np.mean(tod.boresight.roll) + + # wafer + if np.unique(tod.det_info.wafer_slot).shape[0] > 1: + raise ValueError('tod include detectors from more than one wafer') + wafer_slot = tod.det_info.wafer_slot[0] + xi_wafer, eta_wafer = get_wafer_xieta(wafer_slot=wafer_slot, + xieta_bs_offset=xieta_bs_offset, + roll_bs_offset=roll_bs_offset, + tod=tod, + optics_config_fn=optics_config_fn, + wrap_to_tod=True) + + coords.planets.compute_source_flags(tod, center_on=sso_name, max_pix=100000000, + wrap='source', mask={'shape':'circle', 'xyr':[0., 0., wafer_mask_deg]}) + + + + q_wafer = quat.rotation_xieta(xi_wafer, eta_wafer) + sight = get_wafer_centered_sight(tod, sso_name, q_planet, q_bs, q_wafer) + xi0 = tod.focal_plane.xi[0] + eta0 = tod.focal_plane.eta[0] + xi_bs_offset, eta_bs_offset = xieta_bs_offset + tod.focal_plane.xi *= 0. + tod.focal_plane.eta *= 0. + tod.boresight.roll *= 0. + + + box = np.deg2rad([[-wafer_mask_deg, -wafer_mask_deg], [wafer_mask_deg, wafer_mask_deg]]) + geom = enmap.geometry(pos=box, res=res_deg*coords.DEG) + if cuts is None: + cuts = ~tod.flags['source'] + P = coords.P.for_tod(tod=tod, geom=geom, comps='T', cuts=cuts, sight=sight, threads=False) + + wT = None + for di, det in enumerate(tqdm(tod.dets.vals)): + det_weights = np.zeros(tod.dets.count, dtype='float32') + det_weights[di] = 1. + mT_weighted = P.to_map(tod=tod, signal=signal, comps='T', det_weights=det_weights) + if wT is None: + wT = P.to_weights(tod, signal=signal, comps='T', det_weights=det_weights) + mT = P.remove_weights(signal_map=mT_weighted, weights_map=wT, comps='T')[0] + + enmap.write_hdf(map_hdf, mT, address=det, + extra={'xi0': xi0, + 'eta0': eta0, + 'xi_bs_offset': xi_bs_offset, + 'eta_bs_offset': eta_bs_offset, + 'roll_bs_offset': roll_bs_offset}) + return + +def detect_peak_xieta(mT, filter_size=None): + """ + Detects the peak in a given pixcell map and converts it to ξ and η coordinates. + + Parameters: + - mT (enmap.ndmap): a map object + - filter_size (int, optional): Size of the filter window for peak detection. + If not provided, it's calculated as a fraction + of the minimum dimension of mT. + + Returns: + - xi_peak (float): xi coordinate of the peak. + - eta_peak (float): eta coordinate of the peak. + - ra_peak (float): ra coordinate of the peak. + - dec_peak (float): dec coordinate of the peak. + - peak_i (int): Row index of the peak. + - peak_j (int): Column index of the peak. + """ + if filter_size is None: + filter_size = int(np.min(mT.shape)//10) + local_max = maximum_filter(mT, footprint=np.ones((filter_size, filter_size)), + mode='constant', cval=np.nan) + peak_i, peak_j = np.where(mT == np.nanmax(local_max)) + peak_i = int(np.median(peak_i)) + peak_j = int(np.median(peak_j)) + dec_grid, ra_grid = mT.posmap() + + ra_peak = ra_grid[peak_i][peak_j] + dec_peak = dec_grid[peak_i][peak_j] + xi_peak, eta_peak = _radec2xieta(ra_peak, dec_peak) + return xi_peak, eta_peak, ra_peak, dec_peak, peak_i, peak_j + +def get_center_of_mass(x, y, z, + circle_mask={'x0':0, 'y0':0, 'r_circle':3.0*coords.DEG}, + percentile_mask = {'q': 50}): + """ + Calculates the center of mass of a dataset within specified masks. + + Parameters: + - x (ndarray): Array of x-coordinates. + - y (ndarray): Array of y-coordinates. + - z (ndarray): Array of data values corresponding to the coordinates. + - circle_mask (dict, optional): Parameters defining circular mask. + Should contain keys 'x0', 'y0', and 'r_circle'. + Defaults to a circle centered at (0, 0) with radius 3.0 degrees. + - percentile_mask (dict, optional): Parameters defining percentile mask. + Should contain key 'q' representing the percentile threshold. + Defaults to the 50th percentile. + + Returns: + - x_center (float): x-coordinate of the center of mass. + - y_center (float): y-coordinate of the center of mass. + """ + mask = ~np.isnan(z) + if circle_mask is not None: + x0, y0 = circle_mask['x0'], circle_mask['y0'] + r_circle = circle_mask['r_circle'] + r = np.sqrt((x-x0)**2 + (y-y0)**2) + mask = np.logical_and(mask, rnp.nanpercentile(z[mask], q)) + + _x = x[mask] + _y = y[mask] + _z = z[mask] + + total_mass = np.nansum(_z) + x_center = np.nansum(_x * _z) / total_mass + y_center = np.nansum(_y * _z) / total_mass + return x_center, y_center + +def get_edgemap(mT, edge_avoidance=1*coords.DEG, edge_check='nan'): + """ + Generates an edge map for a given map, marking regions near the edges where data is potentially unreliable. + + Parameters: + - mT (enmap.ndmap): a map object + - edge_avoidance (float, optional): Size of the edge avoidance region, defaults to 1 degree. + - edge_check (str, optional): Method for checking edges. Should be one of {'nan', 'zero'}. + 'nan': Checks for NaN values at edges. + 'zero': Checks for zero values at edges. + Defaults to 'nan'. + + Returns: + - edge_map (enmap.ndmap): 2D boolean array representing the edge map, where True indicates regions near the edges. + """ + if edge_check not in ('nan', 'zero'): + raise ValueError('only `nan` or `zero` is supported') + + edge_map = enmap.zeros(mT.shape, mT.wcs) + edge_margin_size = int(edge_avoidance/np.mean(mT.pixshape())) + + for i, row in enumerate(mT): + if edge_check == 'nan': + nonzero_idxes = np.where(~np.isnan(row))[0] + elif edge_check == 'zero': + nonzero_idxes = np.where(row != 0)[0] + if len(nonzero_idxes>0): + edge_map[i, :nonzero_idxes[0] + edge_margin_size] = True + edge_map[i, nonzero_idxes[-1] - edge_margin_size:] = True + else: + edge_map[i, :] = True + + for j, col in enumerate(mT.T): + if edge_check == 'nan': + nonzero_idxes = np.where(~np.isnan(col))[0] + elif edge_check == 'zero': + nonzero_idxes = np.where(col != 0)[0] + if len(nonzero_idxes>0): + edge_map[:nonzero_idxes[0] + edge_margin_size, j] = True + edge_map[nonzero_idxes[-1] - edge_margin_size:, j] = True + else: + edge_map[:, j] = True + return edge_map + + + +def map_to_xieta(mT, edge_avoidance=1.0*coords.DEG, edge_check='nan', + r_tune_circle=1.0*coords.DEG, q_tune=50, + r_fit_circle=3.0*coords.DEG, beam_sigma_init=0.5*coords.DEG, ): + """ + Derive (xi,eta) coordinate of a peak from a given map and calculates the coefficient of determination (R^2) + as a measure of how well the data fits a Gaussian model around the peak. + + Parameters: + - mT (enmap.ndmap): a map object. + - edge_avoidance (float, optional): Size of the edge avoidance region, defaults to 1 degree. + - edge_check (str, optional): Method for checking edges. Should be one of {'nan', 'zero'}. Defaults to 'nan'. + - r_tune_circle (float, optional): Radius of the circle used for tuning the peak position, specified in radians. Defaults to 1 degree. + - q_tune (int, optional): Percentile threshold used for tuning the peak position. Defaults to 50. + - r_fit_circle (float, optional): Radius of the circle used for fitting the Gaussian model, specified in radians. Defaults to 3 degrees. + - beam_sigma_init (float, optional): Initial guess for the sigma parameter of the Gaussian beam, specified in radians. Defaults to 0.5 degree. + + Returns: + - xi_det (float): xi coordinate of the detected peak. + - eta_det (float): eta coordinate of the detected peak. + - R2_det (float): Coefficient of determination (R^2) indicating the goodness of fit of the data around the peak. + If no valid peak is detected or if fitting fails, returns NaN. + """ + if np.all(np.isnan(mT)): + xi_det, eta_det, R2_det = np.nan, np.nan, np.nan + + else: + xi_peak, eta_peak, ra_peak, dec_peak, peak_i, peak_j = detect_peak_xieta(mT) + if edge_avoidance > 0.: + edge_map = get_edgemap(mT, edge_avoidance=edge_avoidance, edge_check=edge_check) + edge_valid = not edge_map[peak_i, peak_j] + else: + edge_valid = True + + if edge_valid: + dec_flat, ra_flat = mT.posmap() + dec_flat, ra_flat = dec_flat.flatten(), ra_flat.flatten() + xi_flat, eta_flat = _radec2xieta(ra_flat, dec_flat) + + circle_mask = {'x0':xi_peak, 'y0':eta_peak, 'r_circle':r_tune_circle} + percentile_mask = {'q': q_tune} + xi_peak, eta_peak = get_center_of_mass(xi_flat, eta_flat, mT.flatten(), + circle_mask=circle_mask, percentile_mask=percentile_mask) + + # check R2(=coefficient of determination) + r = np.sqrt((xi_flat - xi_peak)**2 + (eta_flat - eta_peak)**2) + z = mT.flatten() + mask_fit = np.logical_and(~np.isnan(z), r= 2: + logger.info(f'sso_names of {sso_names} are found from observation tags.' + + f'Processing only {sso_name}') + + # Load data + logger.info(f'loading meta data: {wafer_slot}') + meta = ctx.get_meta(obs_id, dets={'wafer_slot': wafer_slot}) + logger.info(f'finished loading meta data: {wafer_slot}') + try: + meta.restrict('dets', meta.detcal.bg > -1) + except: + pass + if restrict_dets_for_debug is not False: + try: + restrict_dets_for_debug = int(restrict_dets_for_debug) + meta.restrict('dets', meta.dets.vals[:restrict_dets_for_debug]) + except ValueError: + _testdets = restrict_dets_for_debug.split(',') + restrict_list = [det.split('\'')[1].strip() for det in _testdets] + meta.restrict('dets', restrict_list) + logger.info(f'loading tod data: {wafer_slot}') + tod = ctx.get_obs(meta) + logger.info(f'finished loading tod data: {wafer_slot}') + # tod processing + logger.info(f'tod processing {wafer_slot}') + pipe = Pipeline(configs["process_pipe"], logger=logger) + proc_aman, success = pipe.run(tod) + logger.info(f'done with tod processing {wafer_slot}') + # make single detecctor maps + logger.info(f'Making single detector maps') + os.makedirs(single_det_maps_dir, exist_ok=True) + map_hdf = os.path.join(single_det_maps_dir, f'{obs_id}_{wafer_slot}.hdf') + bsp.make_wafer_centered_maps(tod, sso_name, optics_config_fn, map_hdf=map_hdf, + xieta_bs_offset=xieta_bs_offset, + wafer_mask_deg=wafer_mask_deg, res_deg=res_deg) + + #next step + result_filename = f'focal_plane_{obs_id}_{wafer_slot}.hdf' + # reconstruct pointing from single detector maps + if save_normal_roll: + logger.info(f'Saving map-based pointing results') + + fp_rset_map_based = bsp.get_xieta_from_maps(map_hdf, save=True, + output_dir=result_dir, + filename=result_filename, + force_zero_roll=False, + edge_avoidance = edge_avoidance_deg*coords.DEG) + + if save_force_zero_roll: + logger.info(f'Saving map-based pointing results (force-zero-roll)') + result_dir_force_zero_roll = result_dir + '_force_zero_roll' + fp_rset_map_based_force_zero_roll = bsp.get_xieta_from_maps(map_hdf, save=True, + output_dir=result_dir_force_zero_roll, + filename=result_filename, + force_zero_roll=True, + edge_avoidance = edge_avoidance_deg*coords.DEG) + return f"Finished processing {obs_id}, {wafer_slot}" + +def main_one_wafer_dummy(configs, obs_id, wafer_slot, restrict_dets_for_debug=False): + if type(configs) == str: + configs = yaml.safe_load(open(configs, "r")) + ctx = core.Context(configs.get('context_file')) + single_det_maps_dir = configs.get('single_det_maps_dir') + result_dir = configs.get('result_dir') + save_normal_roll = configs.get('save_normal_roll', True) + save_force_zero_roll = configs.get('save_force_zero_roll', True) + + meta = ctx.get_meta(obs_id, dets={'wafer_slot': wafer_slot}) + if restrict_dets_for_debug is not False: + try: + restrict_dets_for_debug = int(restrict_dets_for_debug) + meta.restrict('dets', meta.dets.vals[:restrict_dets_for_debug]) + except ValueError: + _testdets = restrict_dets_for_debug.split(',') + restrict_list = [det.split('\'')[1].strip() for det in _testdets] + meta.restrict('dets', restrict_list) + result_filename = f'focal_plane_{obs_id}_{wafer_slot}.hdf' + + fp_rset_dummy_map_based = metadata.ResultSet(keys=['dets:readout_id', 'xi', 'eta', 'gamma', 'R2']) + for det in meta.dets.vals: + fp_rset_dummy_map_based.rows.append((det, np.nan, np.nan, np.nan, np.nan)) + + if save_normal_roll: + os.makedirs(result_dir, exist_ok=True) + write_dataset(fp_rset_dummy_map_based, + filename=os.path.join(result_dir, result_filename), + address='focal_plane', + overwrite=True) + + if save_force_zero_roll: + result_dir_force_zero_roll = result_dir + '_force_zero_roll' + os.makedirs(result_dir_force_zero_roll, exist_ok=True) + write_dataset(fp_rset_dummy_map_based, + filename=os.path.join(result_dir_force_zero_roll, result_filename), + address='focal_plane', + overwrite=True) + return + +def combine_pointings(pointing_result_files): + combined_dict = {} + for file in pointing_result_files: + rset = read_dataset(file, 'focal_plane') + for row in rset[:]: + if row['dets:readout_id'] not in combined_dict.keys(): + combined_dict[row['dets:readout_id']] = {} + combined_dict[row['dets:readout_id']]['xi'] = row['xi'] + combined_dict[row['dets:readout_id']]['eta'] = row['eta'] + combined_dict[row['dets:readout_id']]['gamma'] = row['gamma'] + combined_dict[row['dets:readout_id']]['R2'] = row['R2'] + + focal_plane = metadata.ResultSet(keys=['dets:readout_id', 'xi', 'eta', 'gamma', 'R2']) + + for det, val in combined_dict.items(): + focal_plane.rows.append((det, val['xi'], val['eta'], val['gamma'], val['R2'])) + return focal_plane + +def main_one_obs(configs, obs_id, sso_name=None, + restrict_dets_for_debug=False): + if type(configs) == str: + configs = yaml.safe_load(open(configs, "r")) + ctx = core.Context(configs.get('context_file')) + optics_config_fn = configs.get('optics_config_fn') + + result_dir = configs.get('result_dir') + save_normal_roll = configs.get('save_normal_roll') + save_force_zero_roll = configs.get('save_force_zero_roll') + + hit_time_threshold = configs.get('hit_time_threshold', 600) + hit_circle_r_deg = configs.get('hit_circle_r_deg', 7.0) + + if sso_name is None: + logger.info('deriving sso_name from observation tag') + obs_tags = ctx.obsdb.get(obs_id, tags=True)['tags'] + sso_names = _get_sso_names_from_tags(ctx, obs_id) + sso_name = sso_names[0] + if len(sso_names) >= 2: + logger.info(f'sso_names of {sso_names} are found from observation tags.' + + f'Processing only {sso_name}') + + tod = ctx.get_obs(obs_id, no_signal=True) + streamed_wafer_slots = ['ws{}'.format(index) for index, bit in enumerate(obs_id.split('_')[-1]) if bit == '1'] + processed_wafer_slots = [] + finished_wafer_slots = [] + skipped_wafer_slots = [] + check_dir = result_dir + '_force_zero_roll' if save_force_zero_roll else result_dir + + for ws in streamed_wafer_slots: + hit_time = bsp.get_rough_hit_time(tod, + wafer_slot=ws, + sso_name=sso_name, + circle_r_deg=hit_circle_r_deg, + optics_config_fn=optics_config_fn) + logger.info(f'hit_time for {ws} is {hit_time:.1f} [sec]') + if hit_time >= hit_time_threshold: + if os.path.exists(os.path.join(check_dir, f'focal_plane_{obs_id}_{ws}.hdf')): + finished_wafer_slots.append(ws) + else: + processed_wafer_slots.append(ws) + else: + skipped_wafer_slots.append(ws) + + logger.info(f'Found saved data for these wafer_slots: {finished_wafer_slots}') + logger.info(f'Will continue for these wafer_slots: {processed_wafer_slots}') + + if configs.get('parallel_job'): + logger.info('Continuing with parallel job') + try: + n_jobs = int(os.environ.get('SLURM_CPUS_PER_TASK', 1)) + except: + n_jobs = -1 + + logger.info('Entering wafer pool') + rank, executor, as_completed_callable = get_exec_env(nprocs=n_jobs) + futures = [executor.submit( + main_one_wafer, + configs=configs, + obs_id=obs_id, + wafer_slot=wafer_slot, + sso_name=sso_name, + restrict_dets_for_debug=restrict_dets_for_debug, + ) + for wafer_slot in processed_wafer_slots] + for future in as_completed_callable(futures): + logger.info(future.result()) + + else: + logger.info('Continuing with serial processing of wafers.') + for wafer_slot in processed_wafer_slots: + logger.info(f'Processing {obs_id}, {wafer_slot}') + main_one_wafer(configs=configs, + obs_id=obs_id, + wafer_slot=wafer_slot, + sso_name=sso_name, + restrict_dets_for_debug=restrict_dets_for_debug) + + logger.info(f'create dummy hdf for non-hitting wafer: {skipped_wafer_slots}') + for wafer_slot in skipped_wafer_slots: + main_one_wafer_dummy(configs=configs, + obs_id=obs_id, + wafer_slot=wafer_slot, + restrict_dets_for_debug=restrict_dets_for_debug) + + logger.info('making combined result') + if save_normal_roll: + pointing_result_files = glob.glob(os.path.join(result_dir, f'focal_plane_{obs_id}_ws[0-6].hdf')) + fp_rset_full = combine_pointings(pointing_result_files) + fp_rset_full_file = os.path.join(os.path.join(result_dir, f'focal_plane_{obs_id}_all.hdf')) + write_dataset(fp_rset_full, filename=fp_rset_full_file, + address='focal_plane', overwrite=True) + + + if save_force_zero_roll: + result_dir_force_zero_roll = result_dir + '_force_zero_roll' + pointing_result_files = glob.glob(os.path.join(result_dir_force_zero_roll, f'focal_plane_{obs_id}_ws[0-6].hdf')) + fp_rset_full = combine_pointings(pointing_result_files) + fp_rset_full_file = os.path.join(os.path.join(result_dir_force_zero_roll, f'focal_plane_{obs_id}_all.hdf')) + write_dataset(fp_rset_full, filename=fp_rset_full_file, + address='focal_plane', overwrite=True) + + logger.info(f'ta da! Finished with {obs_id}') + return + +def main(configs, min_ctime=None, max_ctime=None, update_delay=None, + obs_id=None, wafer_slot=None, sso_name=None, restrict_dets_for_debug=False): + if (min_ctime is None) and (update_delay is not None): + # If min_ctime is provided it will use that.. + # Otherwise it will use update_delay to set min_ctime. + min_ctime = int(time.time()) - update_delay*86400 + + if type(configs) == str: + configs = yaml.safe_load(open(configs, "r")) + ctx = core.Context(configs.get('context_file')) + + if obs_id is None: + query_text = configs.get('query_text', None) + query_tags = configs.get('query_tags', None) + tot_query = "and " + if query_text is not None: + tot_query += f"{query_text} and " + if min_ctime is not None: + tot_query += f"timestamp>={min_ctime} and " + if max_ctime is not None: + tot_query += f"timestamp<={max_ctime} and " + tot_query = tot_query[4:-4] + if tot_query == "": + tot_query = "1" + + logger.info(f'tot_query: {tot_query}') + obs_list= ctx.obsdb.query(tot_query, query_tags) + + for obs in obs_list: + obs_id = obs['obs_id'] + logger.info(f'Processing {obs_id}') + main_one_obs(configs=configs, obs_id=obs_id, + restrict_dets_for_debug=restrict_dets_for_debug) + + elif obs_id is not None: + logger.info(f'Processing {obs_id}') + if wafer_slot is None: + main_one_obs(configs=configs, obs_id=obs_id, sso_name=sso_name, + restrict_dets_for_debug=restrict_dets_for_debug) + else: + main_one_wafer(configs=configs, obs_id=obs_id, wafer_slot=wafer_slot, sso_name=sso_name, + restrict_dets_for_debug=restrict_dets_for_debug) + +def get_parser(): + parser = argparse.ArgumentParser(description="Process TOD data and update pointing") + parser.add_argument("configs", type=str, help="Path to the configuration file") + parser.add_argument('--min-ctime', type=int, help="Minimum timestamp for the beginning of an observation list") + parser.add_argument('--max-ctime', type=int, help="Maximum timestamp for the beginning of an observation list") + parser.add_argument('--update-delay', type=int, help="Number of days (unit is days) in the past to start observation list.") + parser.add_argument("--obs-id", type=str, + help="Specific observation obs_id to process. If provided, overrides other filtering parameters.") + + parser.add_argument("--wafer-slot", type=str, default=None, + help="Wafer slot to be processed (e.g., 'ws0', 'ws3'). Valid only when obs_id is specified.") + + parser.add_argument("--sso-name", type=str, default=None, + help="Name of solar system object (e.g., 'moon', 'jupiter'). If not specified, get sso_name from observation tags. "\ + + "Valid only when obs_id is specified") + parser.add_argument("--restrict-dets-for-debug", type=str, default=False) + return parser + +if __name__ == '__main__': + main_launcher(main, get_parser) diff --git a/sotodlib/site_pipeline/get_brightsrc_pointing_step2.py b/sotodlib/site_pipeline/get_brightsrc_pointing_step2.py new file mode 100644 index 000000000..ef1a77853 --- /dev/null +++ b/sotodlib/site_pipeline/get_brightsrc_pointing_step2.py @@ -0,0 +1,595 @@ +import os +import numpy as np +import yaml +import h5py +import argparse +import time +import glob +from tqdm import tqdm +import logging + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +from scipy.optimize import curve_fit +from sotodlib.core import metadata +from sotodlib.io.metadata import read_dataset, write_dataset +from sotodlib.coords import brightsrc_pointing as bsp +from sotodlib import core +from sotodlib import coords +from sotodlib import tod_ops +import so3g +from so3g.proj import quat +import sotodlib.coords.planets as planets +from sotodlib.site_pipeline.utils.pipeline import main_launcher +from sotodlib.preprocess import Pipeline +from sotodlib.utils.procs_pool import get_exec_env +from sotodlib.site_pipeline.utils.logging import init_logger as sp_init_logger + +logger = logging.getLogger("get_brightsrc_pointing_step2") +if not logger.hasHandlers(): + sp_init_logger("get_brightsrc_pointing_step2") + +def _get_sso_names_from_tags(ctx, obs_id, candidate_names=['moon', 'jupiter', 'mars', 'saturn']): + obs_tags = ctx.obsdb.get(obs_id, tags=True)['tags'] + sso_names = [] + for _name in candidate_names: + if _name in obs_tags: + sso_names.append(_name) + if len(sso_names) == 0: + raise NameError('Could not find sso_name from observation tags') + else: + return sso_names + +def gaussian2d_nonlin(xieta, xi0, eta0, fwhm_xi, fwhm_eta, phi, a, nonlin_coeffs): + """ An Gaussian beam model with non-linear response + Args + ------ + xi, eta: cordinates in the detector's system + xi0, eta0: float, float + center position of the Gaussian beam model + fwhm_xi, fwhm_eta, phi: float, float, float + fwhm along the xi, eta axis (rotated) + and the rotation angle (in radians) + a: float + amplitude of the Gaussian beam model + nonlin_coeffs: float + Coefficient of non-linear term normalized by linear term (from 2nd term). + The order is ascending. + Ouput: + ------ + Model at xieta + """ + xi, eta = xieta + xi_rot = xi * np.cos(phi) - eta * np.sin(phi) + eta_rot = xi * np.sin(phi) + eta * np.cos(phi) + factor = 2 * np.sqrt(2 * np.log(2)) + xi_coef = -0.5 * (xi_rot - xi0) ** 2 / (fwhm_xi / factor) ** 2 + eta_coef = -0.5 * (eta_rot - eta0) ** 2 / (fwhm_eta / factor) ** 2 + lin_gauss = np.exp(xi_coef + eta_coef) + polycoeffs_discending = np.hstack([nonlin_coeffs[::-1], [1, 0]]) + return a * np.poly1d(polycoeffs_discending)(lin_gauss) + +def wrapper_gaussian2d_nonlin(xieta, xi0, eta0, fwhm_xi, fwhm_eta, phi, a, *args): + """ + A wrapper for `gaussian2d_nonlin` + """ + nonlin_coeffs = np.array(args) + return gaussian2d_nonlin(xieta, xi0, eta0, fwhm_xi, fwhm_eta, phi, a, nonlin_coeffs) + +def wrap_fp_rset(tod, fp_rset): + tod.restrict('dets', tod.dets.vals[np.isin(tod.dets.vals, fp_rset['dets:readout_id'])]) + _, ind_tod, ind_rset = core.util.get_coindices(tod.dets.vals, fp_rset['dets:readout_id']) + focal_plane = core.AxisManager(tod.dets) + focal_plane.wrap_new('xi', shape=('dets', )) + focal_plane.wrap_new('eta', shape=('dets', )) + focal_plane.wrap_new('gamma', shape=('dets', )) + focal_plane.xi[ind_tod] = fp_rset['xi'][ind_rset] + focal_plane.eta[ind_tod] = fp_rset['eta'][ind_rset] + focal_plane.gamma[ind_tod] = fp_rset['gamma'][ind_rset] + + if 'focal_plane' in tod._fields.keys(): + tod.move('focal_plane', None) + tod.wrap('focal_plane', focal_plane) + return + +def wrap_fp_from_hdf(tod, fp_hdf_file, data_set='focal_plane'): + fp_rset = read_dataset(fp_hdf_file, data_set) + wrap_fp_rset(tod, fp_rset) + return + + +def update_xieta(tod, + sso_name=None, + fp_hdf_file=None, + force_zero_roll=False, + pipe=None, + ds_factor=10, + mask_deg=3, + fit_func_name = 'gaussian2d_nonlin', + max_non_linear_order = 1, + fwhm_init_deg = 0.5, + error_estimation_method='force_one_redchi2', + flag_name_rms_calc = 'source', + flag_rms_calc_exclusive = True, + ): + """ + Update xieta parameters for each detector by TOD fitting of a point source observation. + + Parameters: + - tod : + an Axismanager object + - sso_name (str): + Name of the Solar System Object (SSO). + - fp_hdf_file (str or None): + Path to the HDF file containing focal plane information. Default is None. + If None, tod.focal_plane is used for focal plane information. + - force_zero_roll (bool): + Flag indicating whether to force the roll to be zero. Default is False. + If True, input and output focal plane information assumes force_zero_roll condition. + - pipe (Pipeline or None): + Preprocessing pipeline to be applied to the TOD. Default is None, which + do not apply any processing. + - ds_factor (int): + Downsampling factor for fitting. Default is 10. + - mask_deg (float): + Mask radius in degrees for source flagging. Default is 3. + - fit_func_name (str): + Name of the fitting function. Default is 'gaussian2d_nonlin'. 'gaussian2d_nonlin' is only supported. + - max_non_linear_order (int): + Maximum non-linear order for fitting function. Default is 1. If you want to use simple gaussian set it to be 1. + Higher order is for the case that detector response is distorted by non-point-like source or too-strogng source, such as the Moon. + - fwhm_init_deg (float): + Initial guess for full width at half maximum in degrees. Default is 0.5. + - error_estimation_method (str): + Method for error estimation. Default is 'force_one_redchi2'. 'force_one_redchi2' and 'rms_from_data' are supported. + If 'rms_from_data', errorbar of each data point is set by root-mean-square of the data points flaged by 'flag_name_rms_calc', + and errorbar of xi,eta is set from the fit covariance matrix. If 'force_one_redchi2', the errorbar of (xi,eta) is equivalent the case + if the error bar of each data point is set as the reduced chi-square is equal to unity. + - flag_name_rms_calc (str): + Name of the flag used for RMS calculation. Default is 'source'. + - flag_rms_calc_exclusive (bool): + Flag indicating whether the RMS calculation is exclusive to the flag. Default is True. + + Returns: + - focal_plane (ResultSet): ResultSet containing updated xieta parameters for each detector. + """ + # if focal_plane result is specified, use the information as a prior + if fp_hdf_file is not None: + wrap_fp_from_hdf(tod, fp_hdf_file) + + # set dets without focal_plane info to have (xi, eta, gamma) = (0, 0, 0), just to avoid error + xieta_isnan = (np.isnan(tod.focal_plane.xi)) | (np.isnan(tod.focal_plane.eta)) + gamma_isnan = np.isnan(tod.focal_plane.gamma) + tod.focal_plane.xi[xieta_isnan] = 0. + tod.focal_plane.eta[xieta_isnan] = 0. + tod.focal_plane.gamma[gamma_isnan] = 0. + + # If input focal_plane is a result with `force_zero_roll`, set the roll to be zero + # Original value is stored to `roll_original` + if force_zero_roll: + if 'roll_original' in tod.boresight._fields.keys(): + pass + else: + tod.boresight.wrap('roll_original', tod.boresight.roll, [(0, 'samps')]) + tod.boresight.roll *= 0. + + # compute source flags + if 'source' in tod.flags._fields.keys(): + tod.flags.move('source', None) + coords.planets.compute_source_flags(tod, + center_on=sso_name, + max_pix=1e10, + wrap='source', + mask={'shape':'circle', 'xyr':[0.,0.,mask_deg]}) + + # restrict data to duration when at least one detector hit the source + summed_flag = np.sum(tod.flags['source'].mask()[~xieta_isnan], axis=0).astype('bool') + idx_hit = np.where(summed_flag)[0] + idx_first, idx_last = idx_hit[0], idx_hit[-1] + tod.restrict('samps', (tod.samps.offset+idx_first, tod.samps.offset+idx_last)) + + # run preprocess pipeline if provided + if pipe is not None: + proc_aman, success = pipe.run(tod) + + # get rms of flagged region for later error estimation + if flag_rms_calc_exclusive: + mask_for_rms_calc = tod.flags[flag_name_rms_calc].mask() + else: + mask_for_rms_calc = ~tod.flags[flag_name_rms_calc].mask() + rms = np.ma.std(np.ma.masked_array(tod.signal, mask_for_rms_calc), axis=1).data + if 'rms' in tod._fields.keys(): + tod.move('rms', None) + tod.wrap('rms', rms, [(0, 'dets')]) + + # use downsampled data for faster fitting + mask_ds = slice(None, None, ds_factor) + ts_ds = tod.timestamps[mask_ds] + q_bore = so3g.proj.CelestialSightLine.az_el(ts_ds, tod.boresight.az[mask_ds], + tod.boresight.el[mask_ds], weather="typical").Q + q_bore_roll = quat.rotation_iso(0, 0, np.median(tod.boresight.roll)) + sig_ds = tod.signal[:, mask_ds] + source_flags_ds = tod.flags['source'].mask()[:, mask_ds] + + # fit each detector data + xieta_dict = {} + for di, det in enumerate(tqdm(tod.dets.vals)): + mask_di = source_flags_ds[di] + bs_az = np.nanmedian(tod.boresight.az[mask_ds][mask_di]) + bs_el = np.nanmedian(tod.boresight.el[mask_ds][mask_di]) + bs_roll = np.nanmedian(tod.boresight.roll[mask_ds][mask_di]) + + if np.any([xieta_isnan[di], np.all(mask_di==False), tod.rms[di]==0.]): + xieta_dict[det] = {'xi': np.nan, 'eta': np.nan, 'xi_err': np.nan, 'eta_err': np.nan, + 'R2': np.nan, 'redchi2': np.nan, 'az': np.nan, 'el': np.nan, 'roll': np.nan} + else: + ts = ts_ds[mask_di] + d1_unix = np.median(ts) + + xieta_det = np.array([tod.focal_plane.xi[di], tod.focal_plane.eta[di]]) + q_det = so3g.proj.quat.rotation_xieta(xieta_det[0], xieta_det[1]) + planet = planets.SlowSource.for_named_source(sso_name, d1_unix * 1.) + ra0, dec0 = planet.pos(d1_unix) + q_obj = so3g.proj.quat.rotation_lonlat(ra0, dec0) + q_total = ~q_det * ~q_bore_roll * ~q_bore * q_obj + + xi_src, eta_src, _ = quat.decompose_xieta(q_total) + xieta_src = np.array([xi_src, eta_src]) + xieta_src = xieta_src[:, mask_di] + sig = sig_ds[di][mask_di] + ptp_val = np.ptp(np.percentile(sig, [0.1, 99.9])) + + if fit_func_name == 'gaussian2d_nonlin': + p0 = np.array([0., 0., fwhm_init_deg*coords.DEG, fwhm_init_deg*coords.DEG, 0., ptp_val]) + bounds = np.array( + [[-np.inf, -np.inf, fwhm_init_deg*coords.DEG/5., fwhm_init_deg*coords.DEG/5., -np.pi, 0.1*ptp_val], + [np.inf, np.inf, fwhm_init_deg*coords.DEG*5, fwhm_init_deg*coords.DEG*5, np.pi, 10*ptp_val]] + ) + if max_non_linear_order >= 2: + p0 = np.append(p0, np.zeros(max_non_linear_order-1)) + bounds = np.hstack([bounds, + np.vstack([[-np.inf * np.ones(max_non_linear_order-1), + np.inf * np.ones(max_non_linear_order-1)]]) + ]) + fit_func = wrapper_gaussian2d_nonlin + else: + raise NameError("Unsupported name for 'fit_func_name'") + + try: + popt, pcov = curve_fit(fit_func, xdata=xieta_src, ydata=sig, sigma=tod.rms[di]*np.ones_like(sig), + p0=p0, bounds=bounds, absolute_sigma=True) + + chi2 = np.sum(((fit_func(xieta_src, *popt) - sig)/tod.rms[di])**2) + redchi2 = chi2 / (np.prod(xieta_src.shape) - popt.shape[0]) + R2 = 1. - np.sum((fit_func(xieta_src, *popt) - sig)**2) / np.sum((sig - sig.mean())**2) + xi_opt, eta_opt = popt[0], popt[1] + + if error_estimation_method == 'rms_from_data': + xi_err, eta_err = np.sqrt(pcov[0,0]), np.sqrt(pcov[1,1]) + elif error_estimation_method == 'force_one_redchi2': + # The error of (xi, eta) is equivalent the case if the error bar of each data point is set + # as the reduced chi-square is equal to unity. + xi_err, eta_err = np.sqrt(pcov[0,0] * redchi2), np.sqrt(pcov[1,1] * redchi2) + redchi2 = 1. + else: + raise NameError("Unsupported name for 'error_estimation_method'") + + xieta_det += np.array([xi_opt, eta_opt]) + xieta_dict[det] = {'xi': xieta_det[0], 'eta': xieta_det[1], 'xi_err': xi_err, 'eta_err': eta_err, + 'R2': R2, 'redchi2': redchi2, 'az' : bs_az, 'el': bs_el, 'roll': bs_roll} + except RuntimeError: + xieta_dict[det] = {'xi': np.nan, 'eta': np.nan, 'xi_err': np.nan, 'eta_err': np.nan, + 'R2': np.nan, 'redchi2': np.nan, 'az': np.nan, 'el': np.nan, 'roll': np.nan} + + focal_plane = metadata.ResultSet(keys=['dets:readout_id', 'xi', 'eta', 'gamma', 'xi_err', 'eta_err', 'R2', 'redchi2', 'az', 'el', 'roll']) + for det in tod.dets.vals: + focal_plane.rows.append((det, xieta_dict[det]['xi'], xieta_dict[det]['eta'], 0., + xieta_dict[det]['xi_err'], xieta_dict[det]['eta_err'], + xieta_dict[det]['R2'], xieta_dict[det]['redchi2'], + xieta_dict[det]['az'], xieta_dict[det]['el'], xieta_dict[det]['roll'], + )) + + return focal_plane + +def main_one_wafer(configs, obs_id, wafer_slot, sso_name=None, + restrict_dets_for_debug=False): + if type(configs) == str: + configs = yaml.safe_load(open(configs, "r")) + + # Derive parameters from config file + ctx = core.Context(configs.get('context_file')) + + # get prior + fp_hdf_file = configs.get('fp_hdf_file', None) + fp_hdf_dir = configs.get('fp_hdf_dir', None) + if fp_hdf_file is None: + if fp_hdf_dir is not None: + fp_hdf_file = os.path.join(fp_hdf_dir, f'focal_plane_{obs_id}_{wafer_slot}.hdf') + if not os.path.exists(fp_hdf_file): + fp_hdf_file = None + + result_dir = configs.get('result_dir') + force_zero_roll = configs.get('force_zero_roll', True) + if force_zero_roll: + result_dir = result_dir + '_force_zero_roll' + + # get sso_name if it is not specified + if sso_name is None: + logger.info('deriving sso_name from observation tag') + obs_tags = ctx.obsdb.get(obs_id, tags=True)['tags'] + sso_names = _get_sso_names_from_tags(ctx, obs_id) + sso_name = sso_names[0] + if len(sso_names) >= 2: + logger.info(f'sso_names of {sso_names} are found from observation tags.' + + f'Processing only {sso_name}') + + # construct pipeline from configs + pipe = Pipeline(configs["process_pipe"], logger=logger) + for pipe_component in pipe: + if pipe_component.name == 'source_flags': + pipe_component.calc_cfgs['center_on'] = sso_name + + # Other parameters + force_zero_roll = configs.get('force_zero_roll') + ds_factor = configs.get('ds_factor', 20) + mask_deg = configs.get('mask_deg', 3.0) + fit_func_name = configs.get('fit_func_name', 'gaussian2d_nonlin') + max_non_linear_order = configs.get('max_non_linear_order', 2) + fwhm_init_deg = configs.get('fwhm_init_deg', 0.5) + error_estimation_method = configs.get('error_estimation_method', 'force_one_redchi2') + flag_name_rms_calc = configs.get('flag_name_rms_calc', 'source') + flag_rms_calc_exclusive = configs.get('flag_rms_calc_exclusive', True) + + + # Load data + logger.info('loading data') + meta = ctx.get_meta(obs_id, dets={'wafer_slot': wafer_slot}) + if restrict_dets_for_debug is not False: + try: + restrict_dets_for_debug = int(restrict_dets_for_debug) + meta.restrict('dets', meta.dets.vals[:restrict_dets_for_debug]) + except ValueError: + _testdets = restrict_dets_for_debug.split(',') + restrict_list = [det.split('\'')[1].strip() for det in _testdets] + meta.restrict('dets', restrict_list) + + tod = ctx.get_obs(meta) + + # get pointing + focal_plane_rset = update_xieta( tod, + sso_name=sso_name, + fp_hdf_file=fp_hdf_file, + force_zero_roll=force_zero_roll, + pipe=pipe, + ds_factor=ds_factor, + mask_deg=mask_deg, + fit_func_name = fit_func_name, + max_non_linear_order = max_non_linear_order, + fwhm_init_deg = fwhm_init_deg, + error_estimation_method=error_estimation_method, + flag_name_rms_calc = flag_name_rms_calc, + flag_rms_calc_exclusive = flag_rms_calc_exclusive, + ) + + os.makedirs(result_dir, exist_ok=True) + write_dataset(focal_plane_rset, + filename=os.path.join(result_dir, f'focal_plane_{obs_id}_{wafer_slot}.hdf'), + address='focal_plane', + overwrite=True) + + return f"Finished processing {obs_id}, {wafer_slot}" + +def main_one_wafer_dummy(configs, obs_id, wafer_slot, restrict_dets_for_debug=False): + if type(configs) == str: + configs = yaml.safe_load(open(configs, "r")) + ctx = core.Context(configs.get('context_file')) + result_dir = configs.get('result_dir') + force_zero_roll = configs.get('force_zero_roll', True) + if force_zero_roll: + result_dir = result_dir + '_force_zero_roll' + + meta = ctx.get_meta(obs_id, dets={'wafer_slot': wafer_slot}) + if restrict_dets_for_debug is not False: + try: + restrict_dets_for_debug = int(restrict_dets_for_debug) + meta.restrict('dets', meta.dets.vals[:restrict_dets_for_debug]) + except ValueError: + _testdets = restrict_dets_for_debug.split(',') + restrict_list = [det.split('\'')[1].strip() for det in _testdets] + meta.restrict('dets', restrict_list) + result_filename = f'focal_plane_{obs_id}_{wafer_slot}.hdf' + + fp_rset_dummy = metadata.ResultSet(keys=['dets:readout_id', 'xi', 'eta', 'gamma', + 'xi_err', 'eta_err', 'R2', 'redchi2', 'az', 'el', 'roll']) + for det in meta.dets.vals: + fp_rset_dummy.rows.append((det, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan)) + + os.makedirs(result_dir, exist_ok=True) + write_dataset(fp_rset_dummy, + filename=os.path.join(result_dir, result_filename), + address='focal_plane', + overwrite=True) + return + +def combine_pointings(pointing_result_files): + combined_dict = {} + for file in pointing_result_files: + rset = read_dataset(file, 'focal_plane') + for row in rset[:]: + if row['dets:readout_id'] not in combined_dict.keys(): + combined_dict[row['dets:readout_id']] = {} + combined_dict[row['dets:readout_id']]['xi'] = row['xi'] + combined_dict[row['dets:readout_id']]['eta'] = row['eta'] + combined_dict[row['dets:readout_id']]['gamma'] = row['gamma'] + combined_dict[row['dets:readout_id']]['xi_err'] = row['xi_err'] + combined_dict[row['dets:readout_id']]['eta_err'] = row['eta_err'] + combined_dict[row['dets:readout_id']]['R2'] = row['R2'] + combined_dict[row['dets:readout_id']]['redchi2'] = row['redchi2'] + combined_dict[row['dets:readout_id']]['az'] = row['az'] + combined_dict[row['dets:readout_id']]['el'] = row['el'] + combined_dict[row['dets:readout_id']]['roll'] = row['roll'] + + focal_plane = metadata.ResultSet(keys=['dets:readout_id', 'xi', 'eta', 'gamma', 'xi_err', 'eta_err', 'R2', 'redchi2', 'az', 'el', 'roll']) + + for det, val in combined_dict.items(): + focal_plane.rows.append((det, val['xi'], val['eta'], val['gamma'], val['xi_err'], val['eta_err'], val['R2'], val['redchi2'],val['az'], val['el'], val['roll'])) + return focal_plane + +def main_one_obs(configs, obs_id, sso_name=None, + restrict_dets_for_debug=False): + if type(configs) == str: + configs = yaml.safe_load(open(configs, "r")) + ctx = core.Context(configs.get('context_file')) + result_dir = configs.get('result_dir') + force_zero_roll = configs.get('force_zero_roll', True) + if force_zero_roll: + result_dir = result_dir + '_force_zero_roll' + optics_config_fn = configs.get('optics_config_fn') + + hit_time_threshold = configs.get('hit_time_threshold', 600) + hit_circle_r_deg = configs.get('hit_circle_r_deg', 7.0) + + if sso_name is None: + logger.info('deriving sso_name from observation tag') + obs_tags = ctx.obsdb.get(obs_id, tags=True)['tags'] + sso_names = _get_sso_names_from_tags(ctx, obs_id) + sso_name = sso_names[0] + if len(sso_names) >= 2: + logger.info(f'sso_names of {sso_names} are found from observation tags.' + + f'Processing only {sso_name}') + + tod = ctx.get_obs(obs_id, no_signal=True) + streamed_wafer_slots = ['ws{}'.format(index) for index, bit in enumerate(obs_id.split('_')[-1]) if bit == '1'] + processed_wafer_slots = [] + finished_wafer_slots = [] + skipped_wafer_slots = [] + check_dir = result_dir + '_force_zero_roll' if force_zero_roll else result_dir + + for ws in streamed_wafer_slots: + hit_time = bsp.get_rough_hit_time(tod, + wafer_slot=ws, + sso_name=sso_name, + circle_r_deg=hit_circle_r_deg, + optics_config_fn=optics_config_fn) + logger.info(f'hit_time for {ws} is {hit_time:.1f} [sec]') + if hit_time >= hit_time_threshold: + if os.path.exists(os.path.join(check_dir, f'focal_plane_{obs_id}_{ws}.hdf')): + finished_wafer_slots.append(ws) + else: + processed_wafer_slots.append(ws) + else: + skipped_wafer_slots.append(ws) + + logger.info(f'Found saved data for these wafer_slots: {finished_wafer_slots}') + logger.info(f'Will continue for these wafer_slots: {processed_wafer_slots}') + + if configs.get('parallel_job'): + logger.info('Continuing with parallel job') + try: + n_jobs = int(os.environ.get('SLURM_CPUS_PER_TASK', 1)) + except: + n_jobs = -1 + rank, executor, as_completed_callable = get_exec_env(nprocs=n_jobs) + futures = [executor.submit( + main_one_wafer, + configs=configs, + obs_id=obs_id, + wafer_slot=wafer_slot, + sso_name=sso_name, + restrict_dets_for_debug=restrict_dets_for_debug, + ) + for wafer_slot in processed_wafer_slots] + for future in as_completed_callable(futures): + logger.info(future.result()) + else: + logger.info('Continuing with serial processing of wafers.') + for wafer_slot in processed_wafer_slots: + logger.info(f'Processing {obs_id}, {wafer_slot}') + main_one_wafer(configs=configs, + obs_id=obs_id, + wafer_slot=wafer_slot, + sso_name=sso_name, + restrict_dets_for_debug=restrict_dets_for_debug) + + logger.info(f'create dummy hdf for non-hitting wafer: {skipped_wafer_slots}') + for wafer_slot in skipped_wafer_slots: + main_one_wafer_dummy(configs=configs, + obs_id=obs_id, + wafer_slot=wafer_slot, + restrict_dets_for_debug=restrict_dets_for_debug) + + logger.info('making combined result') + pointing_result_files = glob.glob(os.path.join(result_dir, f'focal_plane_{obs_id}_ws[0-6].hdf')) + fp_rset_full = combine_pointings(pointing_result_files) + fp_rset_full_file = os.path.join(os.path.join(result_dir, f'focal_plane_{obs_id}_all.hdf')) + write_dataset(fp_rset_full, filename=fp_rset_full_file, + address='focal_plane', overwrite=True) + logger.info(f'ta da! Finsihed with {obs_id}') + +def main(configs, min_ctime=None, max_ctime=None, update_delay=None, + obs_id=None, wafer_slot=None, sso_name=None, restrict_dets_for_debug=False): + if (min_ctime is None) and (update_delay is not None): + # If min_ctime is provided it will use that.. + # Otherwise it will use update_delay to set min_ctime. + min_ctime = int(time.time()) - update_delay*86400 + + if type(configs) == str: + configs = yaml.safe_load(open(configs, "r")) + ctx = core.Context(configs.get('context_file')) + + if obs_id is None: + query_text = configs.get('query_text', None) + query_tags = configs.get('query_tags', None) + tot_query = "and " + if query_text is not None: + tot_query += f"{query_text} and " + if min_ctime is not None: + tot_query += f"timestamp>={min_ctime} and " + if max_ctime is not None: + tot_query += f"timestamp<={max_ctime} and " + tot_query = tot_query[4:-4] + if tot_query == "": + tot_query = "1" + + logger.info(f'tot_query: {tot_query}') + obs_list= ctx.obsdb.query(tot_query, query_tags) + + for obs in obs_list: + obs_id = obs['obs_id'] + logger.info(f'Processing {obs_id}') + main_one_obs(configs=configs, obs_id=obs_id, + restrict_dets_for_debug=restrict_dets_for_debug) + + elif obs_id is not None: + logger.info(f'Processing {obs_id}') + if wafer_slot is None: + main_one_obs(configs=configs, obs_id=obs_id, sso_name=sso_name, + restrict_dets_for_debug=restrict_dets_for_debug) + else: + main_one_wafer(configs=configs, obs_id=obs_id, wafer_slot=wafer_slot, sso_name=sso_name, + restrict_dets_for_debug=restrict_dets_for_debug) + return + + +def get_parser(): + parser = argparse.ArgumentParser(description="Get updated result of pointings with tod-based results") + parser.add_argument("configs", type=str, help="Path to the configuration file") + parser.add_argument('--min-ctime', type=int, help="Minimum timestamp for the beginning of an observation list") + parser.add_argument('--max-ctime', type=int, help="Maximum timestamp for the beginning of an observation list") + parser.add_argument('--update-delay', type=int, help="Number of days (unit is days) in the past to start observation list.") + parser.add_argument("--obs-id", type=str, + help="Specific observation obs_id to process. If provided, overrides other filtering parameters.") + + parser.add_argument("--wafer-slot", type=str, default=None, + help="Wafer slot to be processed (e.g., 'ws0', 'ws3'). Valid only when obs_id is specified.") + + parser.add_argument("--sso-name", type=str, default=None, + help="Name of solar system object (e.g., 'moon', 'jupiter'). If not specified, get sso_name from observation tags. "\ + + "Valid only when obs_id is specified") + parser.add_argument("--restrict-dets-for-debug", type=str, default=False) + return parser + +if __name__ == '__main__': + main_launcher(main, get_parser)