Skip to content

Commit 982a63c

Browse files
committed
feat: optimize water level data transfer by implementing chunked deployment prefetching and COPY insert method
1 parent b710200 commit 982a63c

1 file changed

Lines changed: 185 additions & 66 deletions

File tree

transfers/waterlevels_transducer_transfer.py

Lines changed: 185 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
# ===============================================================================
16+
import csv
17+
from collections import defaultdict
18+
from io import StringIO
1619
from typing import Any
1720

1821
import pandas as pd
1922
from pandas import Timestamp
20-
from pydantic import ValidationError
21-
from sqlalchemy import insert
2223
from sqlalchemy.exc import DatabaseError
2324
from sqlalchemy.orm import Session
2425

2526
from db import Thing, Deployment, Sensor
2627
from db.transducer import TransducerObservation, TransducerObservationBlock
27-
from schemas.transducer import CreateTransducerObservation
2828
from transfers.logger import logger
2929
from transfers.transferer import Transferer
3030
from transfers.util import (
@@ -43,6 +43,11 @@ def __init__(self, *args, **kw):
4343
self.groundwater_parameter_id = get_groundwater_parameter_id()
4444
self._itertuples_field_map = {}
4545
self._df_columns = set()
46+
self._deployment_lookup_chunk_size = int(
47+
self.flags.get("DEPLOYMENT_LOOKUP_CHUNK_SIZE", 2000)
48+
)
49+
self._copy_chunk_size = int(self.flags.get("COPY_CHUNK_SIZE", 10000))
50+
self._use_copy_insert = bool(self.flags.get("USE_COPY_INSERT", True))
4651
self._observation_columns = {
4752
column.key for column in TransducerObservation.__table__.columns
4853
}
@@ -68,23 +73,16 @@ def _get_dfs(self):
6873
return input_df, cleaned_df
6974

7075
def _transfer_hook(self, session: Session) -> None:
71-
gwd = self.cleaned_df.groupby(["PointID"])
72-
n = len(gwd)
76+
gwd = self.cleaned_df.groupby("PointID", sort=False)
77+
n = gwd.ngroups
78+
deployments_by_pointid = self._prefetch_deployments(session)
7379
nodeployments = {}
74-
for i, (index, group) in enumerate(gwd):
75-
pointid = index[0]
80+
for i, (pointid, group) in enumerate(gwd):
7681
logger.info(
7782
f"Processing PointID: {pointid}. {i + 1}/{n} ({100*(i+1)/n:0.2f}) completed."
7883
)
7984

80-
deployments = (
81-
session.query(Deployment)
82-
.join(Thing)
83-
.join(Sensor)
84-
.where(Sensor.sensor_type.in_(self._sensor_types))
85-
.where(Thing.name == pointid)
86-
.all()
87-
)
85+
deployments = deployments_by_pointid.get(pointid, [])
8886

8987
# sort rows by date measured
9088
group = group.sort_values(by="DateMeasured")
@@ -103,6 +101,7 @@ def _transfer_hook(self, session: Session) -> None:
103101

104102
# Get thing_id from the first deployment
105103
thing_id = deployments[0].thing_id
104+
deps_sorted = deployments
106105

107106
qced_block = TransducerObservationBlock(
108107
thing_id=thing_id,
@@ -119,54 +118,46 @@ def _transfer_hook(self, session: Session) -> None:
119118
(qced_block, qced, "public"),
120119
(notqced_block, notqced, "private"),
121120
):
122-
block.start_datetime = rows.DateMeasured.min()
123-
block.end_datetime = rows.DateMeasured.max()
124-
125121
if rows.empty:
126122
logger.info(f"no {release_status} records for pointid {pointid}")
127123
continue
128124

129-
def _install_ts(value):
130-
if isinstance(value, Timestamp):
131-
return value
132-
if hasattr(value, "date"):
133-
return Timestamp(value)
134-
return Timestamp(pd.to_datetime(value, errors="coerce"))
135-
136-
deps_sorted = sorted(
137-
deployments, key=lambda d: _install_ts(d.installation_date)
138-
)
139-
140-
observations = [
141-
self._make_observation(
142-
pointid, row, release_status, deps_sorted, nodeployments
125+
block.start_datetime = rows.DateMeasured.iloc[0]
126+
block.end_datetime = rows.DateMeasured.iloc[-1]
127+
if block.end_datetime <= block.start_datetime:
128+
# DB check constraint requires end > start, even for singleton blocks.
129+
block.end_datetime = block.start_datetime + pd.Timedelta(
130+
microseconds=1
143131
)
144-
for row in rows.itertuples()
145-
]
146-
147-
observations = [obs for obs in observations if obs is not None]
148-
if observations:
149-
filtered_observations = [
132+
deployment_matcher = _DeploymentMatcher(deps_sorted)
133+
134+
observations = []
135+
for row in rows.itertuples():
136+
obs = self._make_observation(
137+
pointid,
138+
row,
139+
release_status,
140+
deployment_matcher,
141+
nodeployments,
142+
)
143+
if obs is None:
144+
continue
145+
observations.append(
150146
{k: v for k, v in obs.items() if k in self._observation_columns}
151-
for obs in observations
152-
]
153-
session.execute(
154-
insert(TransducerObservation),
155-
filtered_observations,
156147
)
148+
if observations:
149+
self._insert_observations(session, observations)
157150
block = self._get_or_create_block(session, block)
158151
logger.info(
159152
f"Added {len(observations)} water levels {release_status} block"
160153
)
161-
try:
162-
session.commit()
163-
except DatabaseError as e:
164-
session.rollback()
165-
logger.critical(
166-
f"Error committing water levels {release_status} block: {e}"
167-
)
168-
self._capture_database_error(pointid, e)
169-
continue
154+
try:
155+
session.commit()
156+
except DatabaseError as e:
157+
session.rollback()
158+
logger.critical(f"Error committing water levels for {pointid}: {e}")
159+
self._capture_database_error(pointid, e)
160+
continue
170161

171162
# convert nodeployments to errors
172163
for pointid, (min_date, max_date) in nodeployments.items():
@@ -176,15 +167,42 @@ def _install_ts(value):
176167
"DateMeasured",
177168
)
178169

170+
def _prefetch_deployments(self, session: Session) -> dict[str, list[Deployment]]:
171+
pointids = self.cleaned_df["PointID"].dropna().unique().tolist()
172+
deployments_by_pointid: dict[str, list[Deployment]] = defaultdict(list)
173+
if not pointids:
174+
return {}
175+
176+
for i in range(0, len(pointids), self._deployment_lookup_chunk_size):
177+
chunk = pointids[i : i + self._deployment_lookup_chunk_size]
178+
deployment_rows = (
179+
session.query(Thing.name, Deployment)
180+
.join(Deployment, Deployment.thing_id == Thing.id)
181+
.join(Sensor, Sensor.id == Deployment.sensor_id)
182+
.where(Thing.name.in_(chunk))
183+
.where(Sensor.sensor_type.in_(self._sensor_types))
184+
.all()
185+
)
186+
for pointid, deployment in deployment_rows:
187+
deployments_by_pointid[pointid].append(deployment)
188+
189+
for pointid in deployments_by_pointid:
190+
deployments_by_pointid[pointid].sort(
191+
key=lambda deployment: _installation_timestamp(
192+
deployment.installation_date
193+
)
194+
)
195+
return dict(deployments_by_pointid)
196+
179197
def _make_observation(
180198
self,
181199
pointid: str,
182200
row: pd.Series,
183201
release_status: str,
184-
deps_sorted: list,
202+
deployment_matcher: "_DeploymentMatcher",
185203
nodeployments: dict,
186204
) -> dict | None:
187-
deployment = _find_deployment(row.DateMeasured, deps_sorted)
205+
deployment = deployment_matcher.find(row.DateMeasured)
188206

189207
if deployment is None:
190208
if pointid not in nodeployments:
@@ -210,15 +228,58 @@ def _make_observation(
210228
value=row.DepthToWaterBGS,
211229
release_status=release_status,
212230
)
213-
obspayload = CreateTransducerObservation.model_validate(
214-
payload
215-
).model_dump()
231+
if payload["value"] is None or pd.isna(payload["value"]):
232+
self._capture_error(
233+
pointid,
234+
"DepthToWaterBGS is NULL",
235+
"DepthToWaterBGS",
236+
)
237+
return None
238+
payload["value"] = float(payload["value"])
216239
legacy_payload = self._legacy_payload(row)
217-
return {**obspayload, **legacy_payload}
240+
return {**payload, **legacy_payload}
241+
242+
except (TypeError, ValueError) as e:
243+
logger.critical(f"Observation build error: {e}")
244+
self._capture_error(pointid, str(e), "DepthToWaterBGS")
218245

219-
except ValidationError as e:
220-
logger.critical(f"Observation validation error: {e.errors()}")
221-
self._capture_validation_error(pointid, e)
246+
def _insert_observations(
247+
self, session: Session, observations: list[dict[str, Any]]
248+
) -> None:
249+
if not observations:
250+
return
251+
252+
if not self._use_copy_insert:
253+
raise RuntimeError(
254+
"USE_COPY_INSERT=False is not supported; transducer observations now require COPY inserts."
255+
)
256+
self._copy_insert_observations(session, observations)
257+
258+
def _copy_insert_observations(
259+
self, session: Session, observations: list[dict[str, Any]]
260+
) -> None:
261+
raw_connection = session.connection().connection
262+
cursor = raw_connection.cursor()
263+
table_name = TransducerObservation.__table__.name
264+
columns = [
265+
key for key in observations[0].keys() if key in self._observation_columns
266+
]
267+
if not columns:
268+
return
269+
270+
copy_sql = (
271+
f"COPY {table_name} ({', '.join(columns)}) "
272+
"FROM STDIN WITH (FORMAT csv, NULL '\\N')"
273+
)
274+
275+
for i in range(0, len(observations), self._copy_chunk_size):
276+
chunk = observations[i : i + self._copy_chunk_size]
277+
stream = StringIO()
278+
writer = csv.writer(stream, lineterminator="\n")
279+
for row in chunk:
280+
writer.writerow([_copy_cell(row.get(column)) for column in columns])
281+
stream.seek(0)
282+
cursor.execute(copy_sql, stream=stream)
222283

223284
def _legacy_payload(self, row: pd.Series) -> dict:
224285
return {}
@@ -356,13 +417,71 @@ def _legacy_payload(self, row: pd.Series) -> dict:
356417
}
357418

358419

359-
def _find_deployment(ts, deployments):
420+
def _installation_timestamp(value: Any) -> Timestamp:
421+
if value is None:
422+
return Timestamp.min
423+
if isinstance(value, Timestamp):
424+
return value
425+
if hasattr(value, "date"):
426+
return Timestamp(value)
427+
return Timestamp(pd.to_datetime(value, errors="coerce"))
428+
429+
430+
def _copy_cell(value: Any) -> Any:
431+
if value is None:
432+
return r"\N"
433+
if isinstance(value, Timestamp):
434+
if pd.isna(value):
435+
return r"\N"
436+
return value.to_pydatetime().isoformat(sep=" ")
437+
try:
438+
if pd.isna(value):
439+
return r"\N"
440+
except TypeError:
441+
pass
442+
if isinstance(value, bool):
443+
return "t" if value else "f"
444+
if hasattr(value, "isoformat"):
445+
return value.isoformat()
446+
return value
447+
448+
449+
class _DeploymentMatcher:
450+
"""
451+
Cursor-based matcher for monotonic time-series rows.
452+
Assumes rows are processed in ascending DateMeasured order.
453+
"""
454+
455+
def __init__(self, deployments: list[Deployment]):
456+
self._deployments = deployments
457+
self._cursor = 0
458+
459+
def find(self, ts: Any) -> Deployment | None:
460+
date = _to_date(ts)
461+
n = len(self._deployments)
462+
while self._cursor < n:
463+
deployment = self._deployments[self._cursor]
464+
start = deployment.installation_date or Timestamp.min.date()
465+
end = deployment.removal_date or Timestamp.max.date()
466+
if date < start:
467+
return None
468+
if date <= end:
469+
return deployment
470+
self._cursor += 1
471+
return None
472+
473+
474+
def _to_date(ts: Any):
360475
if hasattr(ts, "date"):
361-
date = ts.date()
362-
else:
363-
date = pd.Timestamp(ts).date()
476+
return ts.date()
477+
return pd.Timestamp(ts).date()
478+
479+
480+
def _find_deployment(ts, deployments):
481+
date = _to_date(ts)
364482
for d in deployments:
365-
if d.installation_date > date:
483+
start = d.installation_date or Timestamp.min.date()
484+
if start > date:
366485
break # because sorted by start
367486
end = d.removal_date if d.removal_date else Timestamp.max.date()
368487
if end >= date:

0 commit comments

Comments
 (0)