diff --git a/core/initializers.py b/core/initializers.py index 330ade9f..4ffbfb74 100644 --- a/core/initializers.py +++ b/core/initializers.py @@ -16,13 +16,18 @@ from pathlib import Path from fastapi_pagination import add_pagination -from sqlalchemy import text +from sqlalchemy import text, select +from sqlalchemy.dialects.postgresql import insert from sqlalchemy.exc import DatabaseError from db import Base from db.engine import session_ctx +from db.lexicon import ( + LexiconCategory, + LexiconTerm, + LexiconTermCategoryAssociation, +) from db.parameter import Parameter -from services.lexicon_helper import add_lexicon_term, add_lexicon_category def init_parameter(path: str = None) -> None: @@ -77,33 +82,112 @@ def init_lexicon(path: str = None) -> None: default_lexicon = json.load(f) - # populate lexicon - with session_ctx() as session: terms = default_lexicon["terms"] categories = default_lexicon["categories"] - for category in categories: - try: - add_lexicon_category(session, category["name"], category["description"]) - except DatabaseError as e: - print(f"Failed to add category {category['name']}: error: {e}") - session.rollback() - continue - - for term_dict in terms: - try: - add_lexicon_term( - session, - term_dict["term"], - term_dict["definition"], - term_dict["categories"], + category_names = [category["name"] for category in categories] + existing_categories = dict( + session.execute( + select(LexiconCategory.name, LexiconCategory.id).where( + LexiconCategory.name.in_(category_names) ) - except DatabaseError as e: - print( - f"Failed to add term {term_dict['term']}: {term_dict['definition']} error: {e}" + ).all() + ) + category_rows = [ + {"name": category["name"], "description": category["description"]} + for category in categories + if category["name"] not in existing_categories + ] + if category_rows: + session.execute( + insert(LexiconCategory) + .values(category_rows) + .on_conflict_do_nothing(index_elements=["name"]) + ) + session.commit() + existing_categories = dict( + session.execute( + select(LexiconCategory.name, LexiconCategory.id).where( + LexiconCategory.name.in_(category_names) + ) + ).all() + ) + + term_names = [term_dict["term"] for term_dict in terms] + existing_terms = dict( + session.execute( + select(LexiconTerm.term, LexiconTerm.id).where( + LexiconTerm.term.in_(term_names) + ) + ).all() + ) + term_rows = [ + {"term": term_dict["term"], "definition": term_dict["definition"]} + for term_dict in terms + if term_dict["term"] not in existing_terms + ] + if term_rows: + session.execute( + insert(LexiconTerm) + .values(term_rows) + .on_conflict_do_nothing(index_elements=["term"]) + ) + session.commit() + existing_terms = dict( + session.execute( + select(LexiconTerm.term, LexiconTerm.id).where( + LexiconTerm.term.in_(term_names) + ) + ).all() + ) + + term_ids = [existing_terms.get(term_name) for term_name in term_names] + category_ids = [ + existing_categories.get(category_name) for category_name in category_names + ] + existing_links = set() + if term_ids and category_ids: + existing_links = set( + session.execute( + select( + LexiconTermCategoryAssociation.term_id, + LexiconTermCategoryAssociation.category_id, + ).where( + LexiconTermCategoryAssociation.term_id.in_( + [term_id for term_id in term_ids if term_id is not None] + ), + LexiconTermCategoryAssociation.category_id.in_( + [ + category_id + for category_id in category_ids + if category_id is not None + ] + ), + ) + ).all() + ) + + association_rows = [] + for term_dict in terms: + term_id = existing_terms.get(term_dict["term"]) + if term_id is None: + continue + for category in term_dict["categories"]: + category_id = existing_categories.get(category) + if category_id is None: + continue + key = (term_id, category_id) + if key in existing_links: + continue + association_rows.append( + {"term_id": term_id, "category_id": category_id} ) - session.rollback() + if association_rows: + session.execute( + insert(LexiconTermCategoryAssociation).values(association_rows) + ) + session.commit() def register_routes(app): diff --git a/transfers/associated_data.py b/transfers/associated_data.py index 6c667aca..ebe1cebe 100644 --- a/transfers/associated_data.py +++ b/transfers/associated_data.py @@ -169,18 +169,6 @@ def _normalize_point_id(value: str) -> str: def _normalize_location_id(value: str) -> str: return value.strip().lower() - def _dedupe_rows( - self, rows: list[dict[str, Any]], key: str - ) -> list[dict[str, Any]]: - """Dedupe rows by unique key to avoid ON CONFLICT loops. Later rows win.""" - deduped = {} - for row in rows: - assoc_id = row.get(key) - if assoc_id is None: - continue - deduped[assoc_id] = row - return list(deduped.values()) - def _uuid_val(self, value: Any) -> Optional[UUID]: if value is None or pd.isna(value): return None diff --git a/transfers/chemistry_sampleinfo.py b/transfers/chemistry_sampleinfo.py index 395c063f..ce867436 100644 --- a/transfers/chemistry_sampleinfo.py +++ b/transfers/chemistry_sampleinfo.py @@ -361,21 +361,6 @@ def bool_val(key: str) -> Optional[bool]: "SampleNotes": str_val("SampleNotes"), } - def _dedupe_rows( - self, rows: list[dict[str, Any]], key: str - ) -> list[dict[str, Any]]: - """ - Deduplicate rows within a batch by the given key to avoid ON CONFLICT loops. - Later rows win. - """ - deduped = {} - for row in rows: - oid = row.get(key) - if oid is None: - continue - deduped[oid] = row - return list(deduped.values()) - def run(batch_size: int = 1000) -> None: """Entrypoint to execute the transfer.""" diff --git a/transfers/field_parameters_transfer.py b/transfers/field_parameters_transfer.py index d7dc77d7..3a894222 100644 --- a/transfers/field_parameters_transfer.py +++ b/transfers/field_parameters_transfer.py @@ -31,20 +31,17 @@ from __future__ import annotations from typing import Any, Optional -from uuid import UUID import pandas as pd from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session -from db import NMA_Chemistry_SampleInfo, NMA_FieldParameters -from db.engine import session_ctx +from db import NMA_FieldParameters from transfers.logger import logger -from transfers.transferer import Transferer -from transfers.util import read_csv +from transfers.transferer import ChemistryTransferer -class FieldParametersTransferer(Transferer): +class FieldParametersTransferer(ChemistryTransferer): """ Transfer FieldParameters records to NMA_FieldParameters. @@ -54,59 +51,6 @@ class FieldParametersTransferer(Transferer): source_table = "FieldParameters" - def __init__(self, *args, batch_size: int = 1000, **kwargs): - super().__init__(*args, **kwargs) - self.batch_size = batch_size - # Cache: legacy UUID -> Integer id - self._sample_info_cache: dict[UUID, int] = {} - self._build_sample_info_cache() - - def _build_sample_info_cache(self) -> None: - """Build cache of nma_sample_pt_id -> id for FK lookups.""" - with session_ctx() as session: - sample_infos = ( - session.query( - NMA_Chemistry_SampleInfo.nma_sample_pt_id, - NMA_Chemistry_SampleInfo.id, - ) - .filter(NMA_Chemistry_SampleInfo.nma_sample_pt_id.isnot(None)) - .all() - ) - self._sample_info_cache = { - nma_sample_pt_id: csi_id for nma_sample_pt_id, csi_id in sample_infos - } - logger.info( - f"Built ChemistrySampleInfo cache with {len(self._sample_info_cache)} entries" - ) - - def _get_dfs(self) -> tuple[pd.DataFrame, pd.DataFrame]: - input_df = read_csv(self.source_table) - cleaned_df = self._filter_to_valid_sample_infos(input_df) - return input_df, cleaned_df - - def _filter_to_valid_sample_infos(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Filter to only include rows where SamplePtID matches a ChemistrySampleInfo. - - This prevents orphan records and ensures the FK constraint will be satisfied. - """ - valid_sample_pt_ids = set(self._sample_info_cache.keys()) - before_count = len(df) - mask = df["SamplePtID"].apply( - lambda value: self._uuid_val(value) in valid_sample_pt_ids - ) - filtered_df = df[mask].copy() - after_count = len(filtered_df) - - if before_count > after_count: - skipped = before_count - after_count - logger.warning( - f"Filtered out {skipped} FieldParameters records without matching " - f"ChemistrySampleInfo ({after_count} valid, {skipped} orphan records prevented)" - ) - - return filtered_df - def _transfer_hook(self, session: Session) -> None: """ Override transfer hook to use batch upsert for idempotent transfers. @@ -206,55 +150,6 @@ def _row_to_dict(self, row) -> Optional[dict[str, Any]]: "AnalysesAgency": self._safe_str(row, "AnalysesAgency"), } - def _dedupe_rows(self, rows: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Dedupe rows by unique key to avoid ON CONFLICT loops. Later rows win.""" - deduped = {} - for row in rows: - key = row.get("nma_GlobalID") - if key is None: - continue - deduped[key] = row - return list(deduped.values()) - - def _safe_str(self, row, attr: str) -> Optional[str]: - """Safely get a string value, returning None for NaN.""" - val = getattr(row, attr, None) - if val is None or pd.isna(val): - return None - return str(val) - - def _safe_float(self, row, attr: str) -> Optional[float]: - """Safely get a float value, returning None for NaN.""" - val = getattr(row, attr, None) - if val is None or pd.isna(val): - return None - try: - return float(val) - except (TypeError, ValueError): - return None - - def _safe_int(self, row, attr: str) -> Optional[int]: - """Safely get an int value, returning None for NaN.""" - val = getattr(row, attr, None) - if val is None or pd.isna(val): - return None - try: - return int(val) - except (TypeError, ValueError): - return None - - def _uuid_val(self, value: Any) -> Optional[UUID]: - if value is None or pd.isna(value): - return None - if isinstance(value, UUID): - return value - if isinstance(value, str): - try: - return UUID(value) - except ValueError: - return None - return None - def run(flags: dict = None) -> tuple[pd.DataFrame, pd.DataFrame, list]: """Entrypoint to execute the transfer.""" diff --git a/transfers/hydraulicsdata.py b/transfers/hydraulicsdata.py index bfaee00f..d5a2b180 100644 --- a/transfers/hydraulicsdata.py +++ b/transfers/hydraulicsdata.py @@ -100,7 +100,7 @@ def _transfer_hook(self, session: Session) -> None: f"(orphan prevention)" ) - rows = self._dedupe_rows(row_dicts, key="nma_GlobalID") + rows = self._dedupe_rows(row_dicts) insert_stmt = insert(NMA_HydraulicsData) excluded = insert_stmt.excluded @@ -198,21 +198,6 @@ def as_int(key: str) -> Optional[int]: "Data Source": val("Data Source"), } - def _dedupe_rows( - self, rows: list[dict[str, Any]], key: str - ) -> list[dict[str, Any]]: - """ - Deduplicate rows within a batch by the given key to avoid ON CONFLICT loops. - Later rows win. - """ - deduped = {} - for row in rows: - gid = row.get(key) - if gid is None: - continue - deduped[gid] = row - return list(deduped.values()) - def run(batch_size: int = 1000) -> None: """Entrypoint to execute the transfer.""" diff --git a/transfers/major_chemistry.py b/transfers/major_chemistry.py index 1aab8da7..e6acf023 100644 --- a/transfers/major_chemistry.py +++ b/transfers/major_chemistry.py @@ -30,20 +30,17 @@ from datetime import datetime from typing import Any, Optional -from uuid import UUID import pandas as pd from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session -from db import NMA_Chemistry_SampleInfo, NMA_MajorChemistry -from db.engine import session_ctx +from db import NMA_MajorChemistry from transfers.logger import logger -from transfers.transferer import Transferer -from transfers.util import read_csv +from transfers.transferer import ChemistryTransferer -class MajorChemistryTransferer(Transferer): +class MajorChemistryTransferer(ChemistryTransferer): """ Transfer for the legacy MajorChemistry table. @@ -52,59 +49,15 @@ class MajorChemistryTransferer(Transferer): source_table = "MajorChemistry" - def __init__(self, *args, batch_size: int = 1000, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.batch_size = batch_size - # Cache: legacy UUID -> Integer id - self._sample_info_cache: dict[UUID, int] = {} - self._build_sample_info_cache() - - def _build_sample_info_cache(self) -> None: - """Build cache of nma_sample_pt_id -> id for FK lookups.""" - with session_ctx() as session: - sample_infos = ( - session.query( - NMA_Chemistry_SampleInfo.nma_sample_pt_id, - NMA_Chemistry_SampleInfo.id, - ) - .filter(NMA_Chemistry_SampleInfo.nma_sample_pt_id.isnot(None)) - .all() - ) - self._sample_info_cache = { - nma_sample_pt_id: csi_id for nma_sample_pt_id, csi_id in sample_infos - } - logger.info( - f"Built ChemistrySampleInfo cache with {len(self._sample_info_cache)} entries" - ) - - def _get_dfs(self) -> tuple[pd.DataFrame, pd.DataFrame]: - input_df = read_csv(self.source_table, parse_dates=["AnalysisDate"]) - cleaned_df = self._filter_to_valid_sample_infos(input_df) - return input_df, cleaned_df - - def _filter_to_valid_sample_infos(self, df: pd.DataFrame) -> pd.DataFrame: - valid_sample_pt_ids = set(self._sample_info_cache.keys()) - mask = df["SamplePtID"].apply( - lambda value: self._uuid_val(value) in valid_sample_pt_ids - ) - before_count = len(df) - filtered_df = df[mask].copy() - after_count = len(filtered_df) - - if before_count > after_count: - skipped = before_count - after_count - logger.warning( - f"Filtered out {skipped} MajorChemistry records without matching " - f"ChemistrySampleInfo ({after_count} valid, {skipped} orphan records prevented)" - ) - - return filtered_df + self._parse_dates = ["AnalysisDate"] def _transfer_hook(self, session: Session) -> None: row_dicts = [] skipped_global_id = 0 skipped_csi_id = 0 - for row in self.cleaned_df.to_dict("records"): + for row in self.cleaned_df.itertuples(): row_dict = self._row_dict(row) if row_dict is None: continue @@ -135,7 +88,7 @@ def _transfer_hook(self, session: Session) -> None: skipped_csi_id, ) - rows = self._dedupe_rows(row_dicts, key="nma_GlobalID") + rows = self._dedupe_rows(row_dicts) insert_stmt = insert(NMA_MajorChemistry) excluded = insert_stmt.excluded @@ -170,43 +123,22 @@ def _transfer_hook(self, session: Session) -> None: session.commit() session.expunge_all() - def _row_dict(self, row: dict[str, Any]) -> Optional[dict[str, Any]]: - def val(key: str) -> Optional[Any]: - v = row.get(key) - if pd.isna(v): - return None - return v - - def float_val(key: str) -> Optional[float]: - v = val(key) - if v is None: - return None - try: - return float(v) - except (TypeError, ValueError): - return None - - def int_val(key: str) -> Optional[int]: - v = val(key) - if v is None: - return None - try: - return int(v) - except (TypeError, ValueError): - return None - - analysis_date = val("AnalysisDate") + def _row_dict(self, row: Any) -> Optional[dict[str, Any]]: + analysis_date = getattr(row, "AnalysisDate", None) + if analysis_date is None or pd.isna(analysis_date): + analysis_date = None if hasattr(analysis_date, "to_pydatetime"): analysis_date = analysis_date.to_pydatetime() if isinstance(analysis_date, datetime): analysis_date = analysis_date.replace(tzinfo=None) # Get legacy UUID FK - legacy_sample_pt_id = self._uuid_val(val("SamplePtID")) + sample_pt_raw = getattr(row, "SamplePtID", None) + legacy_sample_pt_id = self._uuid_val(sample_pt_raw) if legacy_sample_pt_id is None: self._capture_error( - val("SamplePtID"), - f"Invalid SamplePtID: {val('SamplePtID')}", + sample_pt_raw, + f"Invalid SamplePtID: {sample_pt_raw}", "SamplePtID", ) return None @@ -214,7 +146,8 @@ def int_val(key: str) -> Optional[int]: # Look up Integer FK from cache chemistry_sample_info_id = self._sample_info_cache.get(legacy_sample_pt_id) - nma_global_id = self._uuid_val(val("GlobalID")) + global_id_raw = getattr(row, "GlobalID", None) + nma_global_id = self._uuid_val(global_id_raw) return { # Legacy UUID PK -> nma_global_id (unique audit column) @@ -223,47 +156,23 @@ def int_val(key: str) -> Optional[int]: "chemistry_sample_info_id": chemistry_sample_info_id, # Legacy ID columns (renamed with nma_ prefix) "nma_SamplePtID": legacy_sample_pt_id, - "nma_SamplePointID": val("SamplePointID"), - "nma_OBJECTID": val("OBJECTID"), - "nma_WCLab_ID": val("WCLab_ID"), + "nma_SamplePointID": self._safe_str(row, "SamplePointID"), + "nma_OBJECTID": self._safe_int(row, "OBJECTID"), + "nma_WCLab_ID": self._safe_str(row, "WCLab_ID"), # Data columns - "Analyte": val("Analyte"), - "Symbol": val("Symbol"), - "SampleValue": float_val("SampleValue"), - "Units": val("Units"), - "Uncertainty": float_val("Uncertainty"), - "AnalysisMethod": val("AnalysisMethod"), + "Analyte": self._safe_str(row, "Analyte"), + "Symbol": self._safe_str(row, "Symbol"), + "SampleValue": self._safe_float(row, "SampleValue"), + "Units": self._safe_str(row, "Units"), + "Uncertainty": self._safe_float(row, "Uncertainty"), + "AnalysisMethod": self._safe_str(row, "AnalysisMethod"), "AnalysisDate": analysis_date, - "Notes": val("Notes"), - "Volume": int_val("Volume"), - "VolumeUnit": val("VolumeUnit"), - "AnalysesAgency": val("AnalysesAgency"), + "Notes": self._safe_str(row, "Notes"), + "Volume": self._safe_int(row, "Volume"), + "VolumeUnit": self._safe_str(row, "VolumeUnit"), + "AnalysesAgency": self._safe_str(row, "AnalysesAgency"), } - def _dedupe_rows( - self, rows: list[dict[str, Any]], key: str - ) -> list[dict[str, Any]]: - """Dedupe rows by unique key to avoid ON CONFLICT loops. Later rows win.""" - deduped = {} - for row in rows: - gid = row.get(key) - if gid is None: - continue - deduped[gid] = row - return list(deduped.values()) - - def _uuid_val(self, value: Any) -> Optional[UUID]: - if value is None or pd.isna(value): - return None - if isinstance(value, UUID): - return value - if isinstance(value, str): - try: - return UUID(value) - except ValueError: - return None - return None - def run(batch_size: int = 1000) -> None: """Entrypoint to execute the transfer.""" diff --git a/transfers/minor_trace_chemistry_transfer.py b/transfers/minor_trace_chemistry_transfer.py index 5f84bfda..ed1d16da 100644 --- a/transfers/minor_trace_chemistry_transfer.py +++ b/transfers/minor_trace_chemistry_transfer.py @@ -219,16 +219,6 @@ def _row_to_dict(self, row) -> Optional[dict[str, Any]]: } return row_dict - def _dedupe_rows(self, rows: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Dedupe rows by unique key to avoid ON CONFLICT loops. Later rows win.""" - deduped = {} - for row in rows: - key = row.get("nma_GlobalID") - if key is None: - continue - deduped[key] = row - return list(deduped.values()) - def _safe_str(self, row, attr: str) -> Optional[str]: """Safely get a string value, returning None for NaN.""" val = getattr(row, attr, None) diff --git a/transfers/ngwmn_views.py b/transfers/ngwmn_views.py index 7470f602..ffad1139 100644 --- a/transfers/ngwmn_views.py +++ b/transfers/ngwmn_views.py @@ -50,7 +50,9 @@ def _get_dfs(self) -> tuple[pd.DataFrame, pd.DataFrame]: def _transfer_hook(self, session: Session) -> None: rows = self._dedupe_rows( - [self._row_dict(row) for row in self.cleaned_df.to_dict("records")] + [self._row_dict(row) for row in self.cleaned_df.to_dict("records")], + key=self._conflict_columns(), + include_missing=True, ) for i in range(0, len(rows), self.batch_size): @@ -103,25 +105,6 @@ def _conflict_columns(self) -> list[str]: def _upsert_set_clause(self) -> dict[str, Any]: raise NotImplementedError("_upsert_set_clause must be implemented") - def _dedupe_rows(self, rows: list[dict[str, Any]]) -> list[dict[str, Any]]: - """ - Deduplicate rows within a batch on conflict columns to avoid ON CONFLICT loops. - Later rows win. - """ - keys = self._conflict_columns() - deduped: dict[tuple, dict[str, Any]] = {} - passthrough: list[dict[str, Any]] = [] - - for row in rows: - key_tuple = tuple(row.get(k) for k in keys) - # If any part of the conflict key is missing, don't dedupe—let it pass through. - if any(k is None for k in key_tuple): - passthrough.append(row) - else: - deduped[key_tuple] = row - - return list(deduped.values()) + passthrough - class NGWMNWellConstructionTransferer(_BaseNGWMNTransferer): source_table = "view_NGWMN_WellConstruction" diff --git a/transfers/radionuclides.py b/transfers/radionuclides.py index 24723508..8b4ad9df 100644 --- a/transfers/radionuclides.py +++ b/transfers/radionuclides.py @@ -30,20 +30,18 @@ from datetime import datetime from typing import Any, Optional -from uuid import UUID import pandas as pd from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session -from db import NMA_Chemistry_SampleInfo, NMA_Radionuclides -from db.engine import session_ctx +from db import NMA_Radionuclides from transfers.logger import logger -from transfers.transferer import Transferer +from transfers.transferer import ChemistryTransferer from transfers.util import read_csv -class RadionuclidesTransferer(Transferer): +class RadionuclidesTransferer(ChemistryTransferer): """ Transfer for the legacy Radionuclides table. @@ -54,56 +52,17 @@ class RadionuclidesTransferer(Transferer): def __init__(self, *args, batch_size: int = 1000, **kwargs): super().__init__(*args, **kwargs) - self.batch_size = batch_size - # Cache: legacy UUID -> Integer chemistry_sample_info_id - self._sample_info_cache: dict[UUID, int] = {} - self._build_sample_info_cache() - - def _build_sample_info_cache(self) -> None: - """Build cache of nma_sample_pt_id -> chemistry_sample_info_id for FK lookups.""" - with session_ctx() as session: - sample_infos = ( - session.query( - NMA_Chemistry_SampleInfo.nma_sample_pt_id, - NMA_Chemistry_SampleInfo.id, - ) - .filter(NMA_Chemistry_SampleInfo.nma_sample_pt_id.isnot(None)) - .all() - ) - self._sample_info_cache = { - nma_sample_pt_id: csi_id for nma_sample_pt_id, csi_id in sample_infos - } - logger.info( - f"Built ChemistrySampleInfo cache with {len(self._sample_info_cache)} entries" - ) + self._parse_dates = ["AnalysisDate"] def _get_dfs(self) -> tuple[pd.DataFrame, pd.DataFrame]: - input_df = read_csv(self.source_table, parse_dates=["AnalysisDate"]) + input_df = read_csv(self.source_table, parse_dates=self._parse_dates) cleaned_df = self._filter_to_valid_sample_infos(input_df) return input_df, cleaned_df - def _filter_to_valid_sample_infos(self, df: pd.DataFrame) -> pd.DataFrame: - valid_sample_pt_ids = set(self._sample_info_cache.keys()) - mask = df["SamplePtID"].apply( - lambda value: self._uuid_val(value) in valid_sample_pt_ids - ) - before_count = len(df) - filtered_df = df[mask].copy() - after_count = len(filtered_df) - - if before_count > after_count: - skipped = before_count - after_count - logger.warning( - f"Filtered out {skipped} Radionuclides records without matching " - f"ChemistrySampleInfo ({after_count} valid, {skipped} orphan records prevented)" - ) - - return filtered_df - def _transfer_hook(self, session: Session) -> None: row_dicts = [] skipped_global_id = 0 - for row in self.cleaned_df.to_dict("records"): + for row in self.cleaned_df.itertuples(): row_dict = self._row_dict(row) if row_dict is None: continue @@ -162,43 +121,22 @@ def _transfer_hook(self, session: Session) -> None: session.commit() session.expunge_all() - def _row_dict(self, row: dict[str, Any]) -> Optional[dict[str, Any]]: - def val(key: str) -> Optional[Any]: - v = row.get(key) - if pd.isna(v): - return None - return v - - def float_val(key: str) -> Optional[float]: - v = val(key) - if v is None: - return None - try: - return float(v) - except (TypeError, ValueError): - return None - - def int_val(key: str) -> Optional[int]: - v = val(key) - if v is None: - return None - try: - return int(v) - except (TypeError, ValueError): - return None - - analysis_date = val("AnalysisDate") + def _row_dict(self, row: Any) -> Optional[dict[str, Any]]: + analysis_date = getattr(row, "AnalysisDate", None) + if analysis_date is None or pd.isna(analysis_date): + analysis_date = None if hasattr(analysis_date, "to_pydatetime"): analysis_date = analysis_date.to_pydatetime() if isinstance(analysis_date, datetime): analysis_date = analysis_date.replace(tzinfo=None) # Get legacy UUID FK - legacy_sample_pt_id = self._uuid_val(val("SamplePtID")) + sample_pt_raw = getattr(row, "SamplePtID", None) + legacy_sample_pt_id = self._uuid_val(sample_pt_raw) if legacy_sample_pt_id is None: self._capture_error( - val("SamplePtID"), - f"Invalid SamplePtID: {val('SamplePtID')}", + sample_pt_raw, + f"Invalid SamplePtID: {sample_pt_raw}", "SamplePtID", ) return None @@ -206,7 +144,8 @@ def int_val(key: str) -> Optional[int]: # Look up Integer FK from cache chemistry_sample_info_id = self._sample_info_cache.get(legacy_sample_pt_id) - nma_global_id = self._uuid_val(val("GlobalID")) + global_id_raw = getattr(row, "GlobalID", None) + nma_global_id = self._uuid_val(global_id_raw) return { # Legacy UUID PK -> nma_global_id (unique audit column) @@ -215,50 +154,23 @@ def int_val(key: str) -> Optional[int]: "chemistry_sample_info_id": chemistry_sample_info_id, # Legacy ID columns (renamed with nma_ prefix) "nma_SamplePtID": legacy_sample_pt_id, - "nma_SamplePointID": val("SamplePointID"), - "nma_OBJECTID": val("OBJECTID"), - "nma_WCLab_ID": val("WCLab_ID"), + "nma_SamplePointID": self._safe_str(row, "SamplePointID"), + "nma_OBJECTID": self._safe_int(row, "OBJECTID"), + "nma_WCLab_ID": self._safe_str(row, "WCLab_ID"), # Data columns - "Analyte": val("Analyte"), - "Symbol": val("Symbol"), - "SampleValue": float_val("SampleValue"), - "Units": val("Units"), - "Uncertainty": float_val("Uncertainty"), - "AnalysisMethod": val("AnalysisMethod"), + "Analyte": self._safe_str(row, "Analyte"), + "Symbol": self._safe_str(row, "Symbol"), + "SampleValue": self._safe_float(row, "SampleValue"), + "Units": self._safe_str(row, "Units"), + "Uncertainty": self._safe_float(row, "Uncertainty"), + "AnalysisMethod": self._safe_str(row, "AnalysisMethod"), "AnalysisDate": analysis_date, - "Notes": val("Notes"), - "Volume": int_val("Volume"), - "VolumeUnit": val("VolumeUnit"), - "AnalysesAgency": val("AnalysesAgency"), + "Notes": self._safe_str(row, "Notes"), + "Volume": self._safe_int(row, "Volume"), + "VolumeUnit": self._safe_str(row, "VolumeUnit"), + "AnalysesAgency": self._safe_str(row, "AnalysesAgency"), } - def _uuid_val(self, value: Any) -> Optional[UUID]: - if value is None or pd.isna(value): - return None - if isinstance(value, UUID): - return value - if isinstance(value, str): - try: - return UUID(value) - except ValueError: - return None - return None - - def _dedupe_rows( - self, rows: list[dict[str, Any]], key: str - ) -> list[dict[str, Any]]: - """ - Deduplicate rows within a batch by the given key to avoid ON CONFLICT loops. - Later rows win. - """ - deduped = {} - for row in rows: - row_key = row.get(key) - if row_key is None: - continue - deduped[row_key] = row - return list(deduped.values()) - def run(batch_size: int = 1000) -> None: """Entrypoint to execute the transfer.""" diff --git a/transfers/surface_water_data.py b/transfers/surface_water_data.py index 9821bf41..9b4a6e32 100644 --- a/transfers/surface_water_data.py +++ b/transfers/surface_water_data.py @@ -70,7 +70,7 @@ def _transfer_hook(self, session: Session) -> None: continue rows.append(record) - rows = self._dedupe_rows(rows, key="OBJECTID") + rows = self._dedupe_rows(rows, key="OBJECTID", include_missing=True) if skipped_missing_thing: logger.warning( @@ -160,23 +160,6 @@ def to_uuid(v: Any) -> Optional[uuid.UUID]: "thing_id": thing_id, } - def _dedupe_rows( - self, rows: list[dict[str, Any]], key: str - ) -> list[dict[str, Any]]: - """ - Deduplicate rows within a batch by the given key to avoid ON CONFLICT loops. - Later rows win. - """ - deduped: dict[Any, dict[str, Any]] = {} - passthrough: list[dict[str, Any]] = [] - for row in rows: - row_key = row.get(key) - if row_key is None: - passthrough.append(row) - else: - deduped[row_key] = row - return list(deduped.values()) + passthrough - def _resolve_thing_id(self, location_id: Optional[uuid.UUID]) -> Optional[int]: if location_id is None: return None diff --git a/transfers/surface_water_photos.py b/transfers/surface_water_photos.py index 43f11581..12d9c589 100644 --- a/transfers/surface_water_photos.py +++ b/transfers/surface_water_photos.py @@ -83,18 +83,6 @@ def _row_dict(self, row: dict[str, Any]) -> dict[str, Any]: "GlobalID": self._uuid_val(row.get("GlobalID")), } - def _dedupe_rows( - self, rows: list[dict[str, Any]], key: str - ) -> list[dict[str, Any]]: - """Dedupe rows by unique key to avoid ON CONFLICT loops. Later rows win.""" - deduped = {} - for row in rows: - global_id = row.get(key) - if global_id is None: - continue - deduped[global_id] = row - return list(deduped.values()) - def _uuid_val(self, value: Any) -> Optional[UUID]: if value is None or pd.isna(value): return None diff --git a/transfers/transferer.py b/transfers/transferer.py index 47826b0f..e6fe93e3 100644 --- a/transfers/transferer.py +++ b/transfers/transferer.py @@ -14,6 +14,8 @@ # limitations under the License. # =============================================================================== import time +from typing import Any, Optional +from uuid import UUID import pandas as pd from pandas import DataFrame @@ -21,7 +23,7 @@ from sqlalchemy.exc import DatabaseError from sqlalchemy.orm import Session -from db import Thing, Base +from db import Thing, Base, NMA_Chemistry_SampleInfo from db.engine import session_ctx from transfers.logger import logger from transfers.util import chunk_by_size, read_csv @@ -141,6 +143,40 @@ def _read_csv(self, name: str, dtype: dict | None = None, **kw) -> pd.DataFrame: return pd.read_csv(csv_path, **kw) return read_csv(name, dtype=dtype, **kw) + def _dedupe_rows( + self, + rows: list[dict[str, Any]], + key: str | list[str] = "nma_GlobalID", + include_missing: bool = False, + ) -> list[dict[str, Any]]: + """Dedupe rows by unique key(s) to avoid ON CONFLICT loops. Later rows win.""" + deduped: dict[Any, dict[str, Any]] = {} + passthrough: list[dict[str, Any]] = [] + key_list = key if isinstance(key, list) else [key] + + for row in rows: + if len(key_list) == 1: + row_key = row.get(key_list[0]) + else: + row_key = tuple(row.get(k) for k in key_list) + + # Treat None and any pd.isna(...) value (e.g., NaN) as missing keys + if isinstance(row_key, tuple): + is_missing = any(pd.isna(k) for k in row_key) + else: + is_missing = pd.isna(row_key) + + if is_missing: + if include_missing: + passthrough.append(row) + continue + + deduped[row_key] = row + + if include_missing: + return list(deduped.values()) + passthrough + return list(deduped.values()) + class ChunkTransferer(Transferer): def __init__(self, *args, **kwargs): @@ -250,4 +286,108 @@ def _get_db_item(self, session, index) -> Thing: return session.query(Thing).filter(Thing.name == pointid).first() +class ChemistryTransferer(Transferer): + def __init__(self, *args, batch_size: int = 1000, **kwargs): + super().__init__(*args, **kwargs) + self.batch_size = batch_size + # Cache: legacy UUID -> Integer id + self._sample_info_cache: dict[UUID, int] = {} + self._build_sample_info_cache() + self._parse_dates = None + + def _build_sample_info_cache(self) -> None: + """Build cache of nma_sample_pt_id -> id for FK lookups.""" + with session_ctx() as session: + sample_infos = ( + session.query( + NMA_Chemistry_SampleInfo.nma_sample_pt_id, + NMA_Chemistry_SampleInfo.id, + ) + .filter(NMA_Chemistry_SampleInfo.nma_sample_pt_id.isnot(None)) + .all() + ) + self._sample_info_cache = { + nma_sample_pt_id: csi_id for nma_sample_pt_id, csi_id in sample_infos + } + logger.info( + f"Built ChemistrySampleInfo cache with {len(self._sample_info_cache)} entries" + ) + + def _get_dfs(self) -> tuple[pd.DataFrame, pd.DataFrame]: + input_df = read_csv(self.source_table, parse_dates=self._parse_dates) + cleaned_df = self._filter_to_valid_sample_infos(input_df) + return input_df, cleaned_df + + def _filter_to_valid_sample_infos(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Filter to only include rows where SamplePtID matches a ChemistrySampleInfo. + + This prevents orphan records and ensures the FK constraint will be satisfied. + """ + valid_sample_pt_ids = set(self._sample_info_cache.keys()) + before_count = len(df) + parsed_sample_pt_ids = df["SamplePtID"].map(self._uuid_val) + mask = parsed_sample_pt_ids.isin(valid_sample_pt_ids) + filtered_df = df[mask].copy() + inverted_df = df[~mask].copy() + if not inverted_df.empty: + for _, row in inverted_df.iterrows(): + pointid = row["SamplePointID"] + self._capture_error( + pointid, + f"No matching ChemistrySampleInfo for SamplePtID: {pointid}", + "SamplePtID", + ) + + after_count = len(filtered_df) + + if before_count > after_count: + skipped = before_count - after_count + logger.warning( + f"Filtered out {skipped} FieldParameters records without matching " + f"ChemistrySampleInfo ({after_count} valid, {skipped} orphan records prevented)" + ) + + return filtered_df + + def _safe_str(self, row, attr: str) -> Optional[str]: + """Safely get a string value, returning None for NaN.""" + val = getattr(row, attr, None) + if val is None or pd.isna(val): + return None + return str(val) + + def _safe_float(self, row, attr: str) -> Optional[float]: + """Safely get a float value, returning None for NaN.""" + val = getattr(row, attr, None) + if val is None or pd.isna(val): + return None + try: + return float(val) + except (TypeError, ValueError): + return None + + def _safe_int(self, row, attr: str) -> Optional[int]: + """Safely get an int value, returning None for NaN.""" + val = getattr(row, attr, None) + if val is None or pd.isna(val): + return None + try: + return int(val) + except (TypeError, ValueError): + return None + + def _uuid_val(self, value: Any) -> Optional[UUID]: + if value is None or pd.isna(value): + return None + if isinstance(value, UUID): + return value + if isinstance(value, str): + try: + return UUID(value) + except ValueError: + return None + return None + + # ============= EOF ============================================= diff --git a/transfers/waterlevelscontinuous_pressure_daily.py b/transfers/waterlevelscontinuous_pressure_daily.py index 6caa348c..0c364697 100644 --- a/transfers/waterlevelscontinuous_pressure_daily.py +++ b/transfers/waterlevelscontinuous_pressure_daily.py @@ -148,21 +148,6 @@ def val(key: str) -> Optional[Any]: "CONDDL (mS/cm)": val("CONDDL (mS/cm)"), } - def _dedupe_rows( - self, rows: list[dict[str, Any]], key: str - ) -> list[dict[str, Any]]: - """ - Deduplicate rows within a batch by the given key to avoid ON CONFLICT loops. - Later rows win. - """ - deduped = {} - for row in rows: - gid = row.get(key) - if gid is None: - continue - deduped[gid] = row - return list(deduped.values()) - def run(batch_size: int = 1000) -> None: """Entrypoint to execute the transfer.""" diff --git a/transfers/weather_data.py b/transfers/weather_data.py index 4d75d1b4..9be3f157 100644 --- a/transfers/weather_data.py +++ b/transfers/weather_data.py @@ -48,6 +48,7 @@ def _transfer_hook(self, session: Session) -> None: rows = self._dedupe_rows( [self._row_dict(row) for row in self.cleaned_df.to_dict("records")], key="OBJECTID", + include_missing=True, ) insert_stmt = insert(NMA_WeatherData) @@ -94,23 +95,6 @@ def to_uuid(v: Any) -> Optional[uuid.UUID]: "OBJECTID": val("OBJECTID"), } - def _dedupe_rows( - self, rows: list[dict[str, Any]], key: str - ) -> list[dict[str, Any]]: - """ - Deduplicate rows within a batch by the given key to avoid ON CONFLICT loops. - Later rows win. - """ - deduped: dict[Any, dict[str, Any]] = {} - passthrough: list[dict[str, Any]] = [] - for row in rows: - row_key = row.get(key) - if row_key is None: - passthrough.append(row) - else: - deduped[row_key] = row - return list(deduped.values()) + passthrough - def run(batch_size: int = 1000) -> None: """Entrypoint to execute the transfer.""" diff --git a/transfers/weather_photos.py b/transfers/weather_photos.py index a223c42a..1a204f8a 100644 --- a/transfers/weather_photos.py +++ b/transfers/weather_photos.py @@ -83,18 +83,6 @@ def _row_dict(self, row: dict[str, Any]) -> dict[str, Any]: "GlobalID": self._uuid_val(row.get("GlobalID")), } - def _dedupe_rows( - self, rows: list[dict[str, Any]], key: str - ) -> list[dict[str, Any]]: - """Dedupe rows by unique key to avoid ON CONFLICT loops. Later rows win.""" - deduped = {} - for row in rows: - global_id = row.get(key) - if global_id is None: - continue - deduped[global_id] = row - return list(deduped.values()) - def _uuid_val(self, value: Any) -> Optional[UUID]: if value is None or pd.isna(value): return None