diff --git a/transfers/transfer.py b/transfers/transfer.py index 5bca4378..1e50accb 100644 --- a/transfers/transfer.py +++ b/transfers/transfer.py @@ -20,7 +20,6 @@ from dataclasses import dataclass from dotenv import load_dotenv - from transfers.thing_transfer import ( transfer_rock_sample_locations, transfer_springs, @@ -216,13 +215,16 @@ def transfer_context(name: str, *, pad: int = 10): logger.info("Finished %s", name) -def _execute_transfer(klass, flags: dict = None): - """Execute a single transfer class. Thread-safe since each creates its own session.""" +def _get_test_pointids(): pointids = None if os.getenv("TRANSFER_TEST_POINTIDS"): pointids = os.getenv("TRANSFER_TEST_POINTIDS").split(",") + return pointids - transferer = klass(flags=flags, pointids=pointids) + +def _execute_transfer(klass, flags: dict = None): + """Execute a single transfer class. Thread-safe since each creates its own session.""" + transferer = klass(flags=flags, pointids=_get_test_pointids()) transferer.transfer() return transferer.input_df, transferer.cleaned_df, transferer.errors @@ -372,7 +374,7 @@ def transfer_all(metrics: Metrics) -> list[ProfileArtifact]: use_parallel_wells = get_bool_env("TRANSFER_PARALLEL_WELLS", True) if use_parallel_wells: logger.info("Using PARALLEL wells transfer") - transferer = WellTransferer(flags=flags) + transferer = WellTransferer(flags=flags, pointids=_get_test_pointids()) transferer.transfer_parallel() results = (transferer.input_df, transferer.cleaned_df, transferer.errors) else: diff --git a/transfers/waterlevels_transfer.py b/transfers/waterlevels_transfer.py index 6697b344..dedd72a9 100644 --- a/transfers/waterlevels_transfer.py +++ b/transfers/waterlevels_transfer.py @@ -16,10 +16,9 @@ import json import uuid from datetime import datetime, timezone, timedelta +from typing import Any import pandas as pd -from sqlalchemy.orm import Session - from db import ( Thing, Sample, @@ -31,6 +30,8 @@ Parameter, ) from db.engine import session_ctx +from sqlalchemy.exc import DatabaseError, SQLAlchemyError +from sqlalchemy.orm import Session from transfers.transferer import Transferer from transfers.util import ( filter_to_valid_point_ids, @@ -72,9 +73,10 @@ def get_contacts_info( class WaterLevelTransferer(Transferer): + source_table = "WaterLevels" + def __init__(self, *args, **kw): super().__init__(*args, **kw) - self.source_table = "WaterLevels" with session_ctx() as session: groundwater_parameter_id = ( session.query(Parameter) @@ -94,23 +96,79 @@ def _get_dfs(self) -> tuple[pd.DataFrame, pd.DataFrame]: input_df = read_csv(self.source_table, dtype={"MeasuredBy": str}) cleaned_df = filter_to_valid_point_ids(input_df) cleaned_df = filter_by_valid_measuring_agency(cleaned_df) + logger.info( + "Prepared %s rows for %s after filtering (%s -> %s)", + len(cleaned_df), + self.source_table, + len(input_df), + len(cleaned_df), + ) return input_df, cleaned_df def _transfer_hook(self, session: Session) -> None: + stats: dict[str, int] = { + "groups_total": 0, + "groups_processed": 0, + "groups_skipped_missing_thing": 0, + "groups_failed_commit": 0, + "rows_total": 0, + "rows_created": 0, + "rows_skipped_dt": 0, + "rows_skipped_reason": 0, + "rows_skipped_contacts": 0, + "rows_well_destroyed": 0, + "field_events_created": 0, + "field_activities_created": 0, + "samples_created": 0, + "observations_created": 0, + "contacts_created": 0, + "contacts_reused": 0, + } + gwd = self.cleaned_df.groupby(["PointID"]) - for index, group in gwd: + total_groups = len(gwd) + for gi, (index, group) in enumerate(gwd, start=1): + stats["groups_total"] += 1 pointid = index[0] - thing = session.query(Thing).where(Thing.name == pointid).first() + logger.info( + "Processing WaterLevels group %s/%s for PointID=%s (%s rows)", + gi, + total_groups, + pointid, + len(group), + ) + + thing = session.query(Thing).where(Thing.name == pointid).one_or_none() + if thing is None: + stats["groups_skipped_missing_thing"] += 1 + logger.warning( + "Skipping PointID=%s because Thing was not found", pointid + ) + self._capture_error(pointid, "Thing not found", "PointID") + continue for i, row in enumerate(group.itertuples()): + stats["rows_total"] += 1 dt_utc = self._get_dt_utc(row) if dt_utc is None: + stats["rows_skipped_dt"] += 1 continue - # reasons + # reasons try: glv = self._get_groundwater_level_reason(row) - except KeyError as e: + except (KeyError, ValueError) as e: + stats["rows_skipped_reason"] += 1 + logger.warning( + "Skipping %s due to invalid groundwater level reason: %s", + self._row_context(row), + e, + ) + self._capture_error( + row.PointID, + f"invalid groundwater level reason: {e}", + "LevelStatus", + ) continue release_status = "public" if row.PublicRelease else "private" @@ -122,9 +180,25 @@ def _transfer_hook(self, session: Session) -> None: release_status=release_status, ) session.add(field_event) + stats["field_events_created"] += 1 field_event_participants = self._get_field_event_participants( session, row, thing ) + stats["contacts_created"] += getattr( + self, "_last_contacts_created_count", 0 + ) + stats["contacts_reused"] += getattr( + self, "_last_contacts_reused_count", 0 + ) + + if not field_event_participants: + stats["rows_skipped_contacts"] += 1 + logger.warning( + "Skipping %s because no field event participants were found", + self._row_context(row), + ) + continue + sampler = None for i, participant in enumerate(field_event_participants): field_event_participant = FieldEventParticipant( @@ -143,8 +217,10 @@ def _transfer_hook(self, session: Session) -> None: == "Well was destroyed (no subsequent water levels should be recorded)" ): logger.warning( - "Well is destroyed - no field activity/sample/observation will be made" + "Well is destroyed for %s - no field activity/sample/observation will be made", + self._row_context(row), ) + stats["rows_well_destroyed"] += 1 field_event.notes = glv continue @@ -156,16 +232,52 @@ def _transfer_hook(self, session: Session) -> None: release_status=release_status, ) session.add(field_activity) + stats["field_activities_created"] += 1 # Sample sample = self._make_sample(row, field_activity, dt_utc, sampler) session.add(sample) + stats["samples_created"] += 1 # Observation observation = self._make_observation(row, sample, dt_utc, glv) session.add(observation) + stats["observations_created"] += 1 + stats["rows_created"] += 1 + + try: + session.commit() + session.expunge_all() + stats["groups_processed"] += 1 + except DatabaseError as e: + stats["groups_failed_commit"] += 1 + logger.exception( + "Failed committing WaterLevels group for PointID=%s: %s", + pointid, + e, + ) + session.rollback() + self._capture_database_error(pointid, e) + except SQLAlchemyError as e: + stats["groups_failed_commit"] += 1 + logger.exception( + "SQLAlchemy failure committing WaterLevels group for PointID=%s: %s", + pointid, + e, + ) + session.rollback() + self._capture_error(pointid, str(e), "UnknownField") + except Exception as e: + stats["groups_failed_commit"] += 1 + logger.exception( + "Unexpected failure committing WaterLevels group for PointID=%s: %s", + pointid, + e, + ) + session.rollback() + self._capture_error(pointid, str(e), "UnknownField") - session.commit() + self._log_transfer_summary(stats) def _make_observation( self, row: pd.Series, sample: Sample, dt_utc: datetime, glv: str @@ -265,6 +377,8 @@ def _get_groundwater_level_reason(self, row) -> str: return glv def _get_field_event_participants(self, session, row, thing) -> list[Contact]: + self._last_contacts_created_count = 0 + self._last_contacts_reused_count = 0 field_event_participants = [] measured_by = None if pd.isna(row.MeasuredBy) else row.MeasuredBy @@ -277,6 +391,7 @@ def _get_field_event_participants(self, session, row, thing) -> list[Contact]: for name, organization, role in contact_info: if (name, organization) in self._created_contacts: contact = self._created_contacts[(name, organization)] + self._last_contacts_reused_count += 1 else: try: # create new contact if not already created @@ -294,6 +409,7 @@ def _get_field_event_participants(self, session, row, thing) -> list[Contact]: ) self._created_contacts[(name, organization)] = contact + self._last_contacts_created_count += 1 except Exception as e: logger.critical( f"Contact cannot be created: Name {name} | Role {role} | Organization {organization} because of the following: {str(e)}" @@ -302,8 +418,21 @@ def _get_field_event_participants(self, session, row, thing) -> list[Contact]: field_event_participants.append(contact) else: - contact = thing.contacts[0] - field_event_participants.append(contact) + if thing.contacts: + contact = thing.contacts[0] + field_event_participants.append(contact) + self._last_contacts_reused_count += 1 + else: + logger.warning( + "Thing for PointID=%s has no contacts; cannot use owner fallback for %s", + row.PointID, + self._row_context(row), + ) + self._capture_error( + row.PointID, + "Thing has no contacts for owner fallback", + "MeasuredBy", + ) if len(field_event_participants) == 0: logger.critical( @@ -313,6 +442,36 @@ def _get_field_event_participants(self, session, row, thing) -> list[Contact]: return field_event_participants + def _row_context(self, row: Any) -> str: + return ( + f"PointID={getattr(row, 'PointID', None)}, " + f"OBJECTID={getattr(row, 'OBJECTID', None)}, " + f"GlobalID={getattr(row, 'GlobalID', None)}" + ) + + def _log_transfer_summary(self, stats: dict[str, int]) -> None: + logger.info( + "WaterLevels summary: groups total=%s processed=%s skipped_missing_thing=%s failed_commit=%s " + "rows total=%s created=%s skipped_dt=%s skipped_reason=%s skipped_contacts=%s well_destroyed=%s " + "field_events=%s activities=%s samples=%s observations=%s contacts_created=%s contacts_reused=%s", + stats["groups_total"], + stats["groups_processed"], + stats["groups_skipped_missing_thing"], + stats["groups_failed_commit"], + stats["rows_total"], + stats["rows_created"], + stats["rows_skipped_dt"], + stats["rows_skipped_reason"], + stats["rows_skipped_contacts"], + stats["rows_well_destroyed"], + stats["field_events_created"], + stats["field_activities_created"], + stats["samples_created"], + stats["observations_created"], + stats["contacts_created"], + stats["contacts_reused"], + ) + def _get_dt_utc(self, row) -> datetime | None: if pd.isna(row.DateMeasured): logger.critical(