1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515# ===============================================================================
16+ import csv
17+ from collections import defaultdict
18+ from io import StringIO
1619from typing import Any
1720
1821import pandas as pd
1922from pandas import Timestamp
20- from pydantic import ValidationError
21- from sqlalchemy import insert
2223from sqlalchemy .exc import DatabaseError
2324from sqlalchemy .orm import Session
2425
2526from db import Thing , Deployment , Sensor
2627from db .transducer import TransducerObservation , TransducerObservationBlock
27- from schemas .transducer import CreateTransducerObservation
2828from transfers .logger import logger
2929from transfers .transferer import Transferer
3030from transfers .util import (
@@ -43,6 +43,11 @@ def __init__(self, *args, **kw):
4343 self .groundwater_parameter_id = get_groundwater_parameter_id ()
4444 self ._itertuples_field_map = {}
4545 self ._df_columns = set ()
46+ self ._deployment_lookup_chunk_size = int (
47+ self .flags .get ("DEPLOYMENT_LOOKUP_CHUNK_SIZE" , 2000 )
48+ )
49+ self ._copy_chunk_size = int (self .flags .get ("COPY_CHUNK_SIZE" , 10000 ))
50+ self ._use_copy_insert = bool (self .flags .get ("USE_COPY_INSERT" , True ))
4651 self ._observation_columns = {
4752 column .key for column in TransducerObservation .__table__ .columns
4853 }
@@ -68,23 +73,16 @@ def _get_dfs(self):
6873 return input_df , cleaned_df
6974
7075 def _transfer_hook (self , session : Session ) -> None :
71- gwd = self .cleaned_df .groupby (["PointID" ])
72- n = len (gwd )
76+ gwd = self .cleaned_df .groupby ("PointID" , sort = False )
77+ n = gwd .ngroups
78+ deployments_by_pointid = self ._prefetch_deployments (session )
7379 nodeployments = {}
74- for i , (index , group ) in enumerate (gwd ):
75- pointid = index [0 ]
80+ for i , (pointid , group ) in enumerate (gwd ):
7681 logger .info (
7782 f"Processing PointID: { pointid } . { i + 1 } /{ n } ({ 100 * (i + 1 )/ n :0.2f} ) completed."
7883 )
7984
80- deployments = (
81- session .query (Deployment )
82- .join (Thing )
83- .join (Sensor )
84- .where (Sensor .sensor_type .in_ (self ._sensor_types ))
85- .where (Thing .name == pointid )
86- .all ()
87- )
85+ deployments = deployments_by_pointid .get (pointid , [])
8886
8987 # sort rows by date measured
9088 group = group .sort_values (by = "DateMeasured" )
@@ -103,6 +101,7 @@ def _transfer_hook(self, session: Session) -> None:
103101
104102 # Get thing_id from the first deployment
105103 thing_id = deployments [0 ].thing_id
104+ deps_sorted = deployments
106105
107106 qced_block = TransducerObservationBlock (
108107 thing_id = thing_id ,
@@ -119,54 +118,46 @@ def _transfer_hook(self, session: Session) -> None:
119118 (qced_block , qced , "public" ),
120119 (notqced_block , notqced , "private" ),
121120 ):
122- block .start_datetime = rows .DateMeasured .min ()
123- block .end_datetime = rows .DateMeasured .max ()
124-
125121 if rows .empty :
126122 logger .info (f"no { release_status } records for pointid { pointid } " )
127123 continue
128124
129- def _install_ts (value ):
130- if isinstance (value , Timestamp ):
131- return value
132- if hasattr (value , "date" ):
133- return Timestamp (value )
134- return Timestamp (pd .to_datetime (value , errors = "coerce" ))
135-
136- deps_sorted = sorted (
137- deployments , key = lambda d : _install_ts (d .installation_date )
138- )
139-
140- observations = [
141- self ._make_observation (
142- pointid , row , release_status , deps_sorted , nodeployments
125+ block .start_datetime = rows .DateMeasured .iloc [0 ]
126+ block .end_datetime = rows .DateMeasured .iloc [- 1 ]
127+ if block .end_datetime <= block .start_datetime :
128+ # DB check constraint requires end > start, even for singleton blocks.
129+ block .end_datetime = block .start_datetime + pd .Timedelta (
130+ microseconds = 1
143131 )
144- for row in rows .itertuples ()
145- ]
146-
147- observations = [obs for obs in observations if obs is not None ]
148- if observations :
149- filtered_observations = [
132+ deployment_matcher = _DeploymentMatcher (deps_sorted )
133+
134+ observations = []
135+ for row in rows .itertuples ():
136+ obs = self ._make_observation (
137+ pointid ,
138+ row ,
139+ release_status ,
140+ deployment_matcher ,
141+ nodeployments ,
142+ )
143+ if obs is None :
144+ continue
145+ observations .append (
150146 {k : v for k , v in obs .items () if k in self ._observation_columns }
151- for obs in observations
152- ]
153- session .execute (
154- insert (TransducerObservation ),
155- filtered_observations ,
156147 )
148+ if observations :
149+ self ._insert_observations (session , observations )
157150 block = self ._get_or_create_block (session , block )
158151 logger .info (
159152 f"Added { len (observations )} water levels { release_status } block"
160153 )
161- try :
162- session .commit ()
163- except DatabaseError as e :
164- session .rollback ()
165- logger .critical (
166- f"Error committing water levels { release_status } block: { e } "
167- )
168- self ._capture_database_error (pointid , e )
169- continue
154+ try :
155+ session .commit ()
156+ except DatabaseError as e :
157+ session .rollback ()
158+ logger .critical (f"Error committing water levels for { pointid } : { e } " )
159+ self ._capture_database_error (pointid , e )
160+ continue
170161
171162 # convert nodeployments to errors
172163 for pointid , (min_date , max_date ) in nodeployments .items ():
@@ -176,15 +167,42 @@ def _install_ts(value):
176167 "DateMeasured" ,
177168 )
178169
170+ def _prefetch_deployments (self , session : Session ) -> dict [str , list [Deployment ]]:
171+ pointids = self .cleaned_df ["PointID" ].dropna ().unique ().tolist ()
172+ deployments_by_pointid : dict [str , list [Deployment ]] = defaultdict (list )
173+ if not pointids :
174+ return {}
175+
176+ for i in range (0 , len (pointids ), self ._deployment_lookup_chunk_size ):
177+ chunk = pointids [i : i + self ._deployment_lookup_chunk_size ]
178+ deployment_rows = (
179+ session .query (Thing .name , Deployment )
180+ .join (Deployment , Deployment .thing_id == Thing .id )
181+ .join (Sensor , Sensor .id == Deployment .sensor_id )
182+ .where (Thing .name .in_ (chunk ))
183+ .where (Sensor .sensor_type .in_ (self ._sensor_types ))
184+ .all ()
185+ )
186+ for pointid , deployment in deployment_rows :
187+ deployments_by_pointid [pointid ].append (deployment )
188+
189+ for pointid in deployments_by_pointid :
190+ deployments_by_pointid [pointid ].sort (
191+ key = lambda deployment : _installation_timestamp (
192+ deployment .installation_date
193+ )
194+ )
195+ return dict (deployments_by_pointid )
196+
179197 def _make_observation (
180198 self ,
181199 pointid : str ,
182200 row : pd .Series ,
183201 release_status : str ,
184- deps_sorted : list ,
202+ deployment_matcher : "_DeploymentMatcher" ,
185203 nodeployments : dict ,
186204 ) -> dict | None :
187- deployment = _find_deployment (row .DateMeasured , deps_sorted )
205+ deployment = deployment_matcher . find (row .DateMeasured )
188206
189207 if deployment is None :
190208 if pointid not in nodeployments :
@@ -210,15 +228,58 @@ def _make_observation(
210228 value = row .DepthToWaterBGS ,
211229 release_status = release_status ,
212230 )
213- obspayload = CreateTransducerObservation .model_validate (
214- payload
215- ).model_dump ()
231+ if payload ["value" ] is None or pd .isna (payload ["value" ]):
232+ self ._capture_error (
233+ pointid ,
234+ "DepthToWaterBGS is NULL" ,
235+ "DepthToWaterBGS" ,
236+ )
237+ return None
238+ payload ["value" ] = float (payload ["value" ])
216239 legacy_payload = self ._legacy_payload (row )
217- return {** obspayload , ** legacy_payload }
240+ return {** payload , ** legacy_payload }
241+
242+ except (TypeError , ValueError ) as e :
243+ logger .critical (f"Observation build error: { e } " )
244+ self ._capture_error (pointid , str (e ), "DepthToWaterBGS" )
218245
219- except ValidationError as e :
220- logger .critical (f"Observation validation error: { e .errors ()} " )
221- self ._capture_validation_error (pointid , e )
246+ def _insert_observations (
247+ self , session : Session , observations : list [dict [str , Any ]]
248+ ) -> None :
249+ if not observations :
250+ return
251+
252+ if not self ._use_copy_insert :
253+ raise RuntimeError (
254+ "USE_COPY_INSERT=False is not supported; transducer observations now require COPY inserts."
255+ )
256+ self ._copy_insert_observations (session , observations )
257+
258+ def _copy_insert_observations (
259+ self , session : Session , observations : list [dict [str , Any ]]
260+ ) -> None :
261+ raw_connection = session .connection ().connection
262+ cursor = raw_connection .cursor ()
263+ table_name = TransducerObservation .__table__ .name
264+ columns = [
265+ key for key in observations [0 ].keys () if key in self ._observation_columns
266+ ]
267+ if not columns :
268+ return
269+
270+ copy_sql = (
271+ f"COPY { table_name } ({ ', ' .join (columns )} ) "
272+ "FROM STDIN WITH (FORMAT csv, NULL '\\ N')"
273+ )
274+
275+ for i in range (0 , len (observations ), self ._copy_chunk_size ):
276+ chunk = observations [i : i + self ._copy_chunk_size ]
277+ stream = StringIO ()
278+ writer = csv .writer (stream , lineterminator = "\n " )
279+ for row in chunk :
280+ writer .writerow ([_copy_cell (row .get (column )) for column in columns ])
281+ stream .seek (0 )
282+ cursor .execute (copy_sql , stream = stream )
222283
223284 def _legacy_payload (self , row : pd .Series ) -> dict :
224285 return {}
@@ -356,13 +417,71 @@ def _legacy_payload(self, row: pd.Series) -> dict:
356417 }
357418
358419
359- def _find_deployment (ts , deployments ):
420+ def _installation_timestamp (value : Any ) -> Timestamp :
421+ if value is None :
422+ return Timestamp .min
423+ if isinstance (value , Timestamp ):
424+ return value
425+ if hasattr (value , "date" ):
426+ return Timestamp (value )
427+ return Timestamp (pd .to_datetime (value , errors = "coerce" ))
428+
429+
430+ def _copy_cell (value : Any ) -> Any :
431+ if value is None :
432+ return r"\N"
433+ if isinstance (value , Timestamp ):
434+ if pd .isna (value ):
435+ return r"\N"
436+ return value .to_pydatetime ().isoformat (sep = " " )
437+ try :
438+ if pd .isna (value ):
439+ return r"\N"
440+ except TypeError :
441+ pass
442+ if isinstance (value , bool ):
443+ return "t" if value else "f"
444+ if hasattr (value , "isoformat" ):
445+ return value .isoformat ()
446+ return value
447+
448+
449+ class _DeploymentMatcher :
450+ """
451+ Cursor-based matcher for monotonic time-series rows.
452+ Assumes rows are processed in ascending DateMeasured order.
453+ """
454+
455+ def __init__ (self , deployments : list [Deployment ]):
456+ self ._deployments = deployments
457+ self ._cursor = 0
458+
459+ def find (self , ts : Any ) -> Deployment | None :
460+ date = _to_date (ts )
461+ n = len (self ._deployments )
462+ while self ._cursor < n :
463+ deployment = self ._deployments [self ._cursor ]
464+ start = deployment .installation_date or Timestamp .min .date ()
465+ end = deployment .removal_date or Timestamp .max .date ()
466+ if date < start :
467+ return None
468+ if date <= end :
469+ return deployment
470+ self ._cursor += 1
471+ return None
472+
473+
474+ def _to_date (ts : Any ):
360475 if hasattr (ts , "date" ):
361- date = ts .date ()
362- else :
363- date = pd .Timestamp (ts ).date ()
476+ return ts .date ()
477+ return pd .Timestamp (ts ).date ()
478+
479+
480+ def _find_deployment (ts , deployments ):
481+ date = _to_date (ts )
364482 for d in deployments :
365- if d .installation_date > date :
483+ start = d .installation_date or Timestamp .min .date ()
484+ if start > date :
366485 break # because sorted by start
367486 end = d .removal_date if d .removal_date else Timestamp .max .date ()
368487 if end >= date :
0 commit comments