From 35b15e4af5880a17015fec5d8dc8d161a17bb5e2 Mon Sep 17 00:00:00 2001 From: jakeross Date: Sat, 14 Feb 2026 21:31:17 -0700 Subject: [PATCH] chore: build caches for Thing and owner contacts in water levels transfer --- transfers/waterlevels_transfer.py | 349 +++++++++++++++++++++++------- 1 file changed, 274 insertions(+), 75 deletions(-) diff --git a/transfers/waterlevels_transfer.py b/transfers/waterlevels_transfer.py index 43b66020..3b664e4c 100644 --- a/transfers/waterlevels_transfer.py +++ b/transfers/waterlevels_transfer.py @@ -21,6 +21,7 @@ import pandas as pd from db import ( Thing, + ThingContactAssociation, Sample, Observation, FieldEvent, @@ -30,6 +31,7 @@ Parameter, ) from db.engine import session_ctx +from sqlalchemy import insert from sqlalchemy.exc import DatabaseError, SQLAlchemyError from sqlalchemy.orm import Session from transfers.transferer import Transferer @@ -92,6 +94,36 @@ def __init__(self, *args, **kw): self._measured_by_mapper = json.load(f) self._created_contacts = {} + self._thing_id_by_pointid: dict[str, int] = {} + self._owner_contact_id_by_pointid: dict[str, int] = {} + self._build_caches() + + def _build_caches(self) -> None: + with session_ctx() as session: + self._thing_id_by_pointid = { + name: thing_id + for name, thing_id in session.query(Thing.name, Thing.id).all() + } + + owner_rows = ( + session.query(Thing.name, ThingContactAssociation.contact_id) + .join( + ThingContactAssociation, + Thing.id == ThingContactAssociation.thing_id, + ) + .order_by(Thing.name, ThingContactAssociation.id.asc()) + .all() + ) + owner_contact_cache: dict[str, int] = {} + for pointid, contact_id in owner_rows: + owner_contact_cache.setdefault(pointid, contact_id) + self._owner_contact_id_by_pointid = owner_contact_cache + + logger.info( + "Built WaterLevels caches: %s Things, %s owner contacts", + len(self._thing_id_by_pointid), + len(self._owner_contact_id_by_pointid), + ) def _get_dfs(self) -> tuple[pd.DataFrame, pd.DataFrame]: input_df = read_csv(self.source_table, dtype={"MeasuredBy": str}) @@ -140,8 +172,8 @@ def _transfer_hook(self, session: Session) -> None: len(group), ) - thing = session.query(Thing).where(Thing.name == pointid).one_or_none() - if thing is None: + thing_id = self._thing_id_by_pointid.get(pointid) + if thing_id is None: stats["groups_skipped_missing_thing"] += 1 logger.warning( "Skipping PointID=%s because Thing was not found", pointid @@ -149,6 +181,7 @@ def _transfer_hook(self, session: Session) -> None: self._capture_error(pointid, "Thing not found", "PointID") continue + prepared_rows: list[dict[str, Any]] = [] for i, row in enumerate(group.itertuples()): stats["rows_total"] += 1 dt_utc = self._get_dt_utc(row) @@ -175,16 +208,8 @@ def _transfer_hook(self, session: Session) -> None: release_status = "public" if row.PublicRelease else "private" - # field event - field_event = FieldEvent( - thing=thing, - event_date=dt_utc, - release_status=release_status, - ) - session.add(field_event) - stats["field_events_created"] += 1 field_event_participants = self._get_field_event_participants( - session, row, thing + session, row ) stats["contacts_created"] += getattr( self, "_last_contacts_created_count", 0 @@ -201,53 +226,181 @@ def _transfer_hook(self, session: Session) -> None: ) continue - sampler = None - for i, participant in enumerate(field_event_participants): - field_event_participant = FieldEventParticipant( - field_event=field_event, participant=participant - ) - if i == 0: - field_event_participant.participant_role = "Lead" - sampler = field_event_participant - else: - field_event_participant.participant_role = "Participant" - - session.add(field_event_participant) - - if ( + is_destroyed = ( glv == "Well was destroyed (no subsequent water levels should be recorded)" - ): + ) + if is_destroyed: logger.warning( "Well is destroyed for %s - no field activity/sample/observation will be made", self._row_context(row), ) stats["rows_well_destroyed"] += 1 - field_event.notes = glv - continue - # Field Activity - # TODO: use create schema to validate data - field_activity = FieldActivity( - field_event=field_event, - activity_type="groundwater level", - release_status=release_status, + prepared_rows.append( + { + "row": row, + "dt_utc": dt_utc, + "glv": glv, + "release_status": release_status, + "participants": field_event_participants, + "is_destroyed": is_destroyed, + } ) - session.add(field_activity) - stats["field_activities_created"] += 1 - - # Sample - sample = self._make_sample(row, field_activity, dt_utc, sampler) - session.add(sample) - stats["samples_created"] += 1 - - # Observation - observation = self._make_observation(row, sample, dt_utc, glv) - session.add(observation) - stats["observations_created"] += 1 stats["rows_created"] += 1 + if not prepared_rows: + stats["groups_processed"] += 1 + continue + try: + session.flush() + + # FieldEvent batch + field_event_rows = [ + { + "thing_id": thing_id, + "event_date": prep["dt_utc"], + "release_status": prep["release_status"], + "notes": prep["glv"] if prep["is_destroyed"] else None, + } + for prep in prepared_rows + ] + field_event_ids = ( + session.execute( + insert(FieldEvent).returning(FieldEvent.id), + field_event_rows, + ) + .scalars() + .all() + ) + stats["field_events_created"] += len(field_event_rows) + + # FieldEventParticipant batch + lead participant id map + participant_rows: list[dict[str, Any]] = [] + lead_row_pos_by_prepared_idx: dict[int, int] = {} + for prepared_idx, prep in enumerate(prepared_rows): + for participant_idx, participant in enumerate(prep["participants"]): + participant_rows.append( + { + "field_event_id": field_event_ids[prepared_idx], + "contact_id": participant.id, + "participant_role": ( + "Lead" if participant_idx == 0 else "Participant" + ), + "release_status": prep["release_status"], + } + ) + if participant_idx == 0: + lead_row_pos_by_prepared_idx[prepared_idx] = ( + len(participant_rows) - 1 + ) + + lead_participant_id_by_prepared_idx: dict[int, int] = {} + if participant_rows: + participant_ids = ( + session.execute( + insert(FieldEventParticipant).returning( + FieldEventParticipant.id + ), + participant_rows, + ) + .scalars() + .all() + ) + for prepared_idx, pos in lead_row_pos_by_prepared_idx.items(): + lead_participant_id_by_prepared_idx[prepared_idx] = ( + participant_ids[pos] + ) + + # FieldActivity batch (non-destroyed rows) + field_activity_rows: list[dict[str, Any]] = [] + activity_row_pos_by_prepared_idx: dict[int, int] = {} + for prepared_idx, prep in enumerate(prepared_rows): + if prep["is_destroyed"]: + continue + activity_row_pos_by_prepared_idx[prepared_idx] = len( + field_activity_rows + ) + field_activity_rows.append( + { + "field_event_id": field_event_ids[prepared_idx], + "activity_type": "groundwater level", + "release_status": prep["release_status"], + } + ) + + field_activity_ids: list[int] = [] + if field_activity_rows: + field_activity_ids = ( + session.execute( + insert(FieldActivity).returning(FieldActivity.id), + field_activity_rows, + ) + .scalars() + .all() + ) + stats["field_activities_created"] += len(field_activity_rows) + + # Sample batch (non-destroyed rows) + sample_rows: list[dict[str, Any]] = [] + sample_row_pos_by_prepared_idx: dict[int, int] = {} + for prepared_idx, prep in enumerate(prepared_rows): + if prep["is_destroyed"]: + continue + sample_row_pos_by_prepared_idx[prepared_idx] = len(sample_rows) + sample_rows.append( + { + "nma_pk_waterlevels": prep["row"].GlobalID, + "field_activity_id": field_activity_ids[ + activity_row_pos_by_prepared_idx[prepared_idx] + ], + "field_event_participant_id": lead_participant_id_by_prepared_idx.get( + prepared_idx + ), + "sample_date": prep["dt_utc"], + "sample_matrix": "water", + "sample_name": str(uuid.uuid4()), + "sample_method": self._get_sample_method(prep["row"]), + "qc_type": "Normal", + "depth_top": None, + "depth_bottom": None, + "release_status": prep["release_status"], + } + ) + + sample_ids: list[int] = [] + if sample_rows: + sample_ids = ( + session.execute( + insert(Sample).returning(Sample.id), + sample_rows, + ) + .scalars() + .all() + ) + stats["samples_created"] += len(sample_rows) + + # Observation batch (non-destroyed rows) + observation_rows: list[dict[str, Any]] = [] + for prepared_idx, prep in enumerate(prepared_rows): + if prep["is_destroyed"]: + continue + sample_id = sample_ids[sample_row_pos_by_prepared_idx[prepared_idx]] + observation_rows.append( + self._make_observation_insert_row( + prep["row"], + sample_id, + prep["dt_utc"], + prep["glv"], + prep["release_status"], + ) + ) + + if observation_rows: + session.execute(insert(Observation), observation_rows) + stats["observations_created"] += len(observation_rows) + session.commit() session.expunge_all() stats["groups_processed"] += 1 @@ -284,6 +437,25 @@ def _transfer_hook(self, session: Session) -> None: def _make_observation( self, row: pd.Series, sample: Sample, dt_utc: datetime, glv: str ) -> Observation: + value, measuring_point_height, data_quality = self._get_observation_parts(row) + observation = Observation( + nma_pk_waterlevels=row.GlobalID, + sample=sample, + sensor_id=None, + analysis_method_id=None, + observation_datetime=dt_utc, + parameter_id=self.groundwater_parameter_id, + value=value, + unit="ft", + measuring_point_height=measuring_point_height, + groundwater_level_reason=glv, + nma_data_quality=data_quality, + ) + return observation + + def _get_observation_parts( + self, row: pd.Series + ) -> tuple[float | None, float | None, str | None]: if pd.isna(row.MPHeight): if pd.notna(row.DepthToWater) and pd.notna(row.DepthToWaterBGS): logger.warning( @@ -359,30 +531,34 @@ def _make_observation( "DataQuality", ) - # TODO: after sensors have been added to the database update sensor_id (or sensor) for waterlevels that come from db sensors (like e probes?) - observation = Observation( - nma_pk_waterlevels=row.GlobalID, - sample=sample, - sensor_id=None, - analysis_method_id=None, - observation_datetime=dt_utc, - parameter_id=self.groundwater_parameter_id, - value=value, - unit="ft", - measuring_point_height=measuring_point_height, - groundwater_level_reason=glv, - nma_data_quality=data_quality, - ) - return observation + return value, measuring_point_height, data_quality + + def _make_observation_insert_row( + self, + row: pd.Series, + sample_id: int, + dt_utc: datetime, + glv: str, + release_status: str, + ) -> dict[str, Any]: + value, measuring_point_height, data_quality = self._get_observation_parts(row) + return { + "nma_pk_waterlevels": row.GlobalID, + "sample_id": sample_id, + "sensor_id": None, + "analysis_method_id": None, + "observation_datetime": dt_utc, + "parameter_id": self.groundwater_parameter_id, + "value": value, + "unit": "ft", + "measuring_point_height": measuring_point_height, + "groundwater_level_reason": glv, + "nma_data_quality": data_quality, + "release_status": release_status, + } def _make_sample(self, row, field_activity, dt_utc, sampler) -> Sample: - sample_method = ( - "null placeholder" - if pd.isna(row.MeasurementMethod) - else lexicon_mapper.map_value( - f"LU_MeasurementMethod:{row.MeasurementMethod}", "null placeholder" - ) - ) + sample_method = self._get_sample_method(row) sample = Sample( nma_pk_waterlevels=row.GlobalID, @@ -398,6 +574,15 @@ def _make_sample(self, row, field_activity, dt_utc, sampler) -> Sample: ) return sample + def _get_sample_method(self, row) -> str: + return ( + "null placeholder" + if pd.isna(row.MeasurementMethod) + else lexicon_mapper.map_value( + f"LU_MeasurementMethod:{row.MeasurementMethod}", "null placeholder" + ) + ) + def _get_groundwater_level_reason(self, row) -> str: glv = row.LevelStatus if pd.isna(glv): @@ -415,7 +600,7 @@ def _get_groundwater_level_reason(self, row) -> str: raise ValueError(f"Unknown groundwater level reason: {glv}") return glv - def _get_field_event_participants(self, session, row, thing) -> list[Contact]: + def _get_field_event_participants(self, session, row) -> list[Contact]: self._last_contacts_created_count = 0 self._last_contacts_reused_count = 0 field_event_participants = [] @@ -457,13 +642,10 @@ def _get_field_event_participants(self, session, row, thing) -> list[Contact]: field_event_participants.append(contact) else: - if thing.contacts: - contact = thing.contacts[0] - field_event_participants.append(contact) - self._last_contacts_reused_count += 1 - else: + owner_contact_id = self._owner_contact_id_by_pointid.get(row.PointID) + if owner_contact_id is None: logger.warning( - "Thing for PointID=%s has no contacts; cannot use owner fallback for %s", + "Thing for PointID=%s has no owner contact; cannot use owner fallback for %s", row.PointID, self._row_context(row), ) @@ -472,6 +654,23 @@ def _get_field_event_participants(self, session, row, thing) -> list[Contact]: "Thing has no contacts for owner fallback", "MeasuredBy", ) + else: + contact = session.get(Contact, owner_contact_id) + if contact is None: + logger.warning( + "Owner contact id=%s not found for PointID=%s; cannot use owner fallback for %s", + owner_contact_id, + row.PointID, + self._row_context(row), + ) + self._capture_error( + row.PointID, + f"owner contact id {owner_contact_id} not found", + "MeasuredBy", + ) + else: + field_event_participants.append(contact) + self._last_contacts_reused_count += 1 if len(field_event_participants) == 0: logger.critical(