Skip to content

Commit fbdc18b

Browse files
committed
fix(well-transfer): preload shared elevation cache before parallel workers
Load cached elevations once before starting parallel well transfer workers. This avoids multiple workers loading the same cache at the same time and makes sure they all use the same shared in-memory data during the transfer.
1 parent 623206e commit fbdc18b

2 files changed

Lines changed: 64 additions & 0 deletions

File tree

tests/test_well_transfer.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import threading
2+
from contextlib import contextmanager
23
from types import SimpleNamespace
34

5+
import pandas as pd
46
import pytest
57
from sqlalchemy.exc import IntegrityError
68

@@ -191,3 +193,61 @@ def fake_map_value(value):
191193

192194
assert session.begin_nested_calls == 1
193195
assert session.rollback_calls == 0
196+
197+
198+
def test_transfer_parallel_preloads_cached_elevations_before_worker_submission(
199+
monkeypatch,
200+
):
201+
class FakePreloadSession:
202+
def query(self, _model):
203+
return self
204+
205+
def all(self):
206+
return []
207+
208+
def expunge_all(self):
209+
pass
210+
211+
class FakeFuture:
212+
def result(self):
213+
return {"errors": []}
214+
215+
class FakeExecutor:
216+
def __enter__(self):
217+
return self
218+
219+
def __exit__(self, exc_type, exc, tb):
220+
return False
221+
222+
def submit(self, fn, idx, batch):
223+
assert transferer._cached_elevations == {"source": "preloaded"}
224+
return FakeFuture()
225+
226+
@contextmanager
227+
def fake_session_ctx():
228+
yield FakePreloadSession()
229+
230+
load_calls = []
231+
dumped = []
232+
233+
def fake_get_cached_elevations():
234+
load_calls.append("load")
235+
return {"source": "preloaded"}
236+
237+
def fake_dump_cached_elevations(lut):
238+
dumped.append(lut)
239+
240+
transferer = wt.WellTransferer()
241+
df = pd.DataFrame([{"PointID": "AR0001"}])
242+
243+
monkeypatch.setattr(wt, "session_ctx", fake_session_ctx)
244+
monkeypatch.setattr(wt, "get_cached_elevations", fake_get_cached_elevations)
245+
monkeypatch.setattr(wt, "dump_cached_elevations", fake_dump_cached_elevations)
246+
monkeypatch.setattr(wt, "ThreadPoolExecutor", lambda max_workers: FakeExecutor())
247+
monkeypatch.setattr(wt, "as_completed", lambda futures: list(futures))
248+
monkeypatch.setattr(transferer, "_get_dfs", lambda: (df, df.copy()))
249+
250+
transferer.transfer_parallel(num_workers=2)
251+
252+
assert load_calls == ["load"]
253+
assert dumped == [{"source": "preloaded"}]

transfers/well_transfer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ def transfer_parallel(self, num_workers: int = None) -> None:
183183
logger.info("No wells to transfer")
184184
return
185185

186+
# Pre-load shared cached elevations on the main thread so workers
187+
# mutate a single cache instance instead of racing lazy initialization.
188+
self._get_cached_elevations()
189+
186190
# Calculate batch size
187191
batch_size = max(100, n // num_workers)
188192
batches = [df.iloc[i : i + batch_size] for i in range(0, n, batch_size)]

0 commit comments

Comments
 (0)