Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 107 additions & 23 deletions core/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@
from pathlib import Path

from fastapi_pagination import add_pagination
from sqlalchemy import text
from sqlalchemy import text, select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import DatabaseError

from db import Base
from db.engine import session_ctx
from db.lexicon import (
LexiconCategory,
LexiconTerm,
LexiconTermCategoryAssociation,
)
from db.parameter import Parameter
from services.lexicon_helper import add_lexicon_term, add_lexicon_category


def init_parameter(path: str = None) -> None:
Expand Down Expand Up @@ -77,33 +82,112 @@ def init_lexicon(path: str = None) -> None:

default_lexicon = json.load(f)

# populate lexicon

with session_ctx() as session:
terms = default_lexicon["terms"]
categories = default_lexicon["categories"]
for category in categories:
try:
add_lexicon_category(session, category["name"], category["description"])
except DatabaseError as e:
print(f"Failed to add category {category['name']}: error: {e}")
session.rollback()
continue

for term_dict in terms:
try:
add_lexicon_term(
session,
term_dict["term"],
term_dict["definition"],
term_dict["categories"],
category_names = [category["name"] for category in categories]
existing_categories = dict(
session.execute(
select(LexiconCategory.name, LexiconCategory.id).where(
LexiconCategory.name.in_(category_names)
)
except DatabaseError as e:
print(
f"Failed to add term {term_dict['term']}: {term_dict['definition']} error: {e}"
).all()
)
category_rows = [
{"name": category["name"], "description": category["description"]}
for category in categories
if category["name"] not in existing_categories
]
if category_rows:
session.execute(
insert(LexiconCategory)
.values(category_rows)
.on_conflict_do_nothing(index_elements=["name"])
)
session.commit()
existing_categories = dict(
session.execute(
select(LexiconCategory.name, LexiconCategory.id).where(
LexiconCategory.name.in_(category_names)
)
).all()
)

term_names = [term_dict["term"] for term_dict in terms]
existing_terms = dict(
session.execute(
select(LexiconTerm.term, LexiconTerm.id).where(
LexiconTerm.term.in_(term_names)
)
).all()
)
term_rows = [
{"term": term_dict["term"], "definition": term_dict["definition"]}
for term_dict in terms
if term_dict["term"] not in existing_terms
]
if term_rows:
session.execute(
insert(LexiconTerm)
.values(term_rows)
.on_conflict_do_nothing(index_elements=["term"])
)
session.commit()
existing_terms = dict(
session.execute(
select(LexiconTerm.term, LexiconTerm.id).where(
LexiconTerm.term.in_(term_names)
)
).all()
)

term_ids = [existing_terms.get(term_name) for term_name in term_names]
category_ids = [
existing_categories.get(category_name) for category_name in category_names
]
existing_links = set()
if term_ids and category_ids:
existing_links = set(
session.execute(
select(
LexiconTermCategoryAssociation.term_id,
LexiconTermCategoryAssociation.category_id,
).where(
LexiconTermCategoryAssociation.term_id.in_(
[term_id for term_id in term_ids if term_id is not None]
),
LexiconTermCategoryAssociation.category_id.in_(
[
category_id
for category_id in category_ids
if category_id is not None
]
),
)
).all()
)

association_rows = []
for term_dict in terms:
term_id = existing_terms.get(term_dict["term"])
if term_id is None:
continue
for category in term_dict["categories"]:
category_id = existing_categories.get(category)
if category_id is None:
continue
key = (term_id, category_id)
if key in existing_links:
continue
association_rows.append(
{"term_id": term_id, "category_id": category_id}
)

session.rollback()
if association_rows:
session.execute(
insert(LexiconTermCategoryAssociation).values(association_rows)
)
session.commit()
Comment on lines +170 to +190
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

association_rows can contain duplicate (term_id, category_id) pairs within the same run (e.g., if the JSON repeats a category for a term, or terms are duplicated), and this insert has no conflict handling—potentially causing a unique-constraint failure. Consider deduping in-memory with a seen set while building association_rows, or using an ON CONFLICT DO NOTHING strategy if the association table has a unique constraint.

Copilot uses AI. Check for mistakes.


def register_routes(app):
Expand Down
12 changes: 0 additions & 12 deletions transfers/associated_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,6 @@ def _normalize_point_id(value: str) -> str:
def _normalize_location_id(value: str) -> str:
return value.strip().lower()

def _dedupe_rows(
self, rows: list[dict[str, Any]], key: str
) -> list[dict[str, Any]]:
"""Dedupe rows by unique key to avoid ON CONFLICT loops. Later rows win."""
deduped = {}
for row in rows:
assoc_id = row.get(key)
if assoc_id is None:
continue
deduped[assoc_id] = row
return list(deduped.values())

def _uuid_val(self, value: Any) -> Optional[UUID]:
if value is None or pd.isna(value):
return None
Expand Down
15 changes: 0 additions & 15 deletions transfers/chemistry_sampleinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,21 +361,6 @@ def bool_val(key: str) -> Optional[bool]:
"SampleNotes": str_val("SampleNotes"),
}

def _dedupe_rows(
self, rows: list[dict[str, Any]], key: str
) -> list[dict[str, Any]]:
"""
Deduplicate rows within a batch by the given key to avoid ON CONFLICT loops.
Later rows win.
"""
deduped = {}
for row in rows:
oid = row.get(key)
if oid is None:
continue
deduped[oid] = row
return list(deduped.values())


def run(batch_size: int = 1000) -> None:
"""Entrypoint to execute the transfer."""
Expand Down
111 changes: 3 additions & 108 deletions transfers/field_parameters_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,17 @@
from __future__ import annotations

from typing import Any, Optional
from uuid import UUID

import pandas as pd
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session

from db import NMA_Chemistry_SampleInfo, NMA_FieldParameters
from db.engine import session_ctx
from db import NMA_FieldParameters
from transfers.logger import logger
from transfers.transferer import Transferer
from transfers.util import read_csv
from transfers.transferer import ChemistryTransferer


class FieldParametersTransferer(Transferer):
class FieldParametersTransferer(ChemistryTransferer):
"""
Transfer FieldParameters records to NMA_FieldParameters.

Expand All @@ -54,59 +51,6 @@ class FieldParametersTransferer(Transferer):

source_table = "FieldParameters"

def __init__(self, *args, batch_size: int = 1000, **kwargs):
super().__init__(*args, **kwargs)
self.batch_size = batch_size
# Cache: legacy UUID -> Integer id
self._sample_info_cache: dict[UUID, int] = {}
self._build_sample_info_cache()

def _build_sample_info_cache(self) -> None:
"""Build cache of nma_sample_pt_id -> id for FK lookups."""
with session_ctx() as session:
sample_infos = (
session.query(
NMA_Chemistry_SampleInfo.nma_sample_pt_id,
NMA_Chemistry_SampleInfo.id,
)
.filter(NMA_Chemistry_SampleInfo.nma_sample_pt_id.isnot(None))
.all()
)
self._sample_info_cache = {
nma_sample_pt_id: csi_id for nma_sample_pt_id, csi_id in sample_infos
}
logger.info(
f"Built ChemistrySampleInfo cache with {len(self._sample_info_cache)} entries"
)

def _get_dfs(self) -> tuple[pd.DataFrame, pd.DataFrame]:
input_df = read_csv(self.source_table)
cleaned_df = self._filter_to_valid_sample_infos(input_df)
return input_df, cleaned_df

def _filter_to_valid_sample_infos(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Filter to only include rows where SamplePtID matches a ChemistrySampleInfo.

This prevents orphan records and ensures the FK constraint will be satisfied.
"""
valid_sample_pt_ids = set(self._sample_info_cache.keys())
before_count = len(df)
mask = df["SamplePtID"].apply(
lambda value: self._uuid_val(value) in valid_sample_pt_ids
)
filtered_df = df[mask].copy()
after_count = len(filtered_df)

if before_count > after_count:
skipped = before_count - after_count
logger.warning(
f"Filtered out {skipped} FieldParameters records without matching "
f"ChemistrySampleInfo ({after_count} valid, {skipped} orphan records prevented)"
)

return filtered_df

def _transfer_hook(self, session: Session) -> None:
"""
Override transfer hook to use batch upsert for idempotent transfers.
Expand Down Expand Up @@ -206,55 +150,6 @@ def _row_to_dict(self, row) -> Optional[dict[str, Any]]:
"AnalysesAgency": self._safe_str(row, "AnalysesAgency"),
}

def _dedupe_rows(self, rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Dedupe rows by unique key to avoid ON CONFLICT loops. Later rows win."""
deduped = {}
for row in rows:
key = row.get("nma_GlobalID")
if key is None:
continue
deduped[key] = row
return list(deduped.values())

def _safe_str(self, row, attr: str) -> Optional[str]:
"""Safely get a string value, returning None for NaN."""
val = getattr(row, attr, None)
if val is None or pd.isna(val):
return None
return str(val)

def _safe_float(self, row, attr: str) -> Optional[float]:
"""Safely get a float value, returning None for NaN."""
val = getattr(row, attr, None)
if val is None or pd.isna(val):
return None
try:
return float(val)
except (TypeError, ValueError):
return None

def _safe_int(self, row, attr: str) -> Optional[int]:
"""Safely get an int value, returning None for NaN."""
val = getattr(row, attr, None)
if val is None or pd.isna(val):
return None
try:
return int(val)
except (TypeError, ValueError):
return None

def _uuid_val(self, value: Any) -> Optional[UUID]:
if value is None or pd.isna(value):
return None
if isinstance(value, UUID):
return value
if isinstance(value, str):
try:
return UUID(value)
except ValueError:
return None
return None


def run(flags: dict = None) -> tuple[pd.DataFrame, pd.DataFrame, list]:
"""Entrypoint to execute the transfer."""
Expand Down
17 changes: 1 addition & 16 deletions transfers/hydraulicsdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _transfer_hook(self, session: Session) -> None:
f"(orphan prevention)"
)

rows = self._dedupe_rows(row_dicts, key="nma_GlobalID")
rows = self._dedupe_rows(row_dicts)

insert_stmt = insert(NMA_HydraulicsData)
excluded = insert_stmt.excluded
Expand Down Expand Up @@ -198,21 +198,6 @@ def as_int(key: str) -> Optional[int]:
"Data Source": val("Data Source"),
}

def _dedupe_rows(
self, rows: list[dict[str, Any]], key: str
) -> list[dict[str, Any]]:
"""
Deduplicate rows within a batch by the given key to avoid ON CONFLICT loops.
Later rows win.
"""
deduped = {}
for row in rows:
gid = row.get(key)
if gid is None:
continue
deduped[gid] = row
return list(deduped.values())


def run(batch_size: int = 1000) -> None:
"""Entrypoint to execute the transfer."""
Expand Down
Loading
Loading