Skip to content

Commit 944fdcc

Browse files
committed
2 parents 34eeddc + 0f1d727 commit 944fdcc

File tree

4 files changed

+85
-52
lines changed

4 files changed

+85
-52
lines changed

backend/persister.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import io
1818
import os
1919
import shutil
20-
from pprint import pprint
20+
from pprint import pprint
2121
import json
2222

2323
import pandas as pd
@@ -37,6 +37,7 @@ class BasePersister(Loggable):
3737
Class to persist the data to a file or cloud storage.
3838
If persisting to a file, the output directory is created by config._make_output_path()
3939
"""
40+
4041
add_extension: str = "csv"
4142

4243
def __init__(self):
@@ -107,13 +108,14 @@ def add_extension(self, path: str):
107108

108109
def _write(self, path: str, records):
109110
raise NotImplementedError
110-
111+
111112
def _dump_timeseries(self, path: str, timeseries: list):
112113
raise NotImplementedError
113114

114115
def _make_output_directory(self, output_directory: str):
115116
os.mkdir(output_directory)
116117

118+
117119
def write_csv_file(path, func, records):
118120
with open(path, "w", newline="") as f:
119121
func(csv.writer(f), records)
@@ -221,19 +223,25 @@ def _write(self, path: str, records: list):
221223
"type": "Feature",
222224
"geometry": {
223225
"type": "Point",
224-
"coordinates": [record.get("longitude"), record.get("latitude"), record.get("elevation")],
226+
"coordinates": [
227+
record.get("longitude"),
228+
record.get("latitude"),
229+
record.get("elevation"),
230+
],
231+
},
232+
"properties": {
233+
k: record.get(k)
234+
for k in record.keys
235+
if k not in ["latitude", "longitude", "elevation"]
225236
},
226-
"properties": {k: record.get(k) for k in record.keys if k not in ["latitude", "longitude", "elevation"]},
227237
}
228238
for record in records
229239
]
230240
feature_collection["features"].extend(features)
231241

232-
233242
with open(path, "w") as f:
234243
json.dump(feature_collection, f, indent=4)
235244

236-
237245
def _get_gdal_type(self, dtype):
238246
"""
239247
Map pandas dtypes to GDAL-compatible types for the schema.
@@ -249,6 +257,7 @@ def _get_gdal_type(self, dtype):
249257
else:
250258
return "str" # Default to string for unsupported types
251259

260+
252261
# class ST2Persister(BasePersister):
253262
# extension = "st2"
254263
#

backend/record.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,37 +30,37 @@ def to_csv(self):
3030

3131
def __init__(self, payload):
3232
self._payload = payload
33-
33+
3434
def get(self, attr):
35-
# v = self._payload.get(attr)
36-
# if v is None and self.defaults:
37-
# v = self.defaults.get(attr)
38-
v = self.__getattr__(attr)
39-
40-
field_sigfigs = [
41-
("elevation", 2),
42-
("well_depth", 2),
43-
("latitude", 6),
44-
("longitude", 6),
45-
("min", 2),
46-
("max", 2),
47-
("mean", 2),
48-
]
49-
50-
# both analyte and water level tables have the same fields, but the
51-
# rounding should only occur for water level tables
52-
if isinstance(self, WaterLevelRecord):
53-
field_sigfigs.append((PARAMETER_VALUE, 2))
54-
55-
for field, sigfigs in field_sigfigs:
56-
if v is not None and field == attr:
57-
try:
58-
v = round(v, sigfigs)
59-
except TypeError as e:
60-
print(field, attr)
61-
raise e
62-
break
63-
return v
35+
# v = self._payload.get(attr)
36+
# if v is None and self.defaults:
37+
# v = self.defaults.get(attr)
38+
v = self.__getattr__(attr)
39+
40+
field_sigfigs = [
41+
("elevation", 2),
42+
("well_depth", 2),
43+
("latitude", 6),
44+
("longitude", 6),
45+
("min", 2),
46+
("max", 2),
47+
("mean", 2),
48+
]
49+
50+
# both analyte and water level tables have the same fields, but the
51+
# rounding should only occur for water level tables
52+
if isinstance(self, WaterLevelRecord):
53+
field_sigfigs.append((PARAMETER_VALUE, 2))
54+
55+
for field, sigfigs in field_sigfigs:
56+
if v is not None and field == attr:
57+
try:
58+
v = round(v, sigfigs)
59+
except TypeError as e:
60+
print(field, attr)
61+
raise e
62+
break
63+
return v
6464

6565
def to_row(self):
6666
return [self.get(k) for k in self.keys]

backend/unifier.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ def _perister_factory(config):
114114
# persister.save(config.output_path)
115115

116116

117-
def _site_wrapper(site_source, parameter_source, sites_summary_persister, timeseries_persister, config):
117+
def _site_wrapper(
118+
site_source, parameter_source, sites_summary_persister, timeseries_persister, config
119+
):
118120

119121
try:
120122
# TODO: fully develop checks/discoveries below
@@ -203,16 +205,25 @@ def _site_wrapper(site_source, parameter_source, sites_summary_persister, timese
203205
# num_sites_to_remove from the length of the list
204206
# to remove the last num_sites_to_remove sites
205207
if use_summarize:
206-
sites_summary_persister.records = sites_summary_persister.records[
207-
: len(sites_summary_persister.records) - num_sites_to_remove
208-
]
208+
sites_summary_persister.records = (
209+
sites_summary_persister.records[
210+
: len(sites_summary_persister.records)
211+
- num_sites_to_remove
212+
]
213+
)
209214
else:
210-
timeseries_persister.timeseries = timeseries_persister.timeseries[
211-
: len(timeseries_persister.timeseries) - num_sites_to_remove
212-
]
213-
sites_summary_persister.sites = sites_summary_persister.sites[
214-
: len(sites_summary_persister.sites) - num_sites_to_remove
215-
]
215+
timeseries_persister.timeseries = (
216+
timeseries_persister.timeseries[
217+
: len(timeseries_persister.timeseries)
218+
- num_sites_to_remove
219+
]
220+
)
221+
sites_summary_persister.sites = (
222+
sites_summary_persister.sites[
223+
: len(sites_summary_persister.sites)
224+
- num_sites_to_remove
225+
]
226+
)
216227
break
217228

218229
except BaseException:
@@ -230,7 +241,13 @@ def _unify_parameter(
230241
sites_summary_persister = _perister_factory(config)
231242
timeseries_persister = CSVPersister()
232243
for site_source, parameter_source in sources:
233-
_site_wrapper(site_source, parameter_source, sites_summary_persister, timeseries_persister, config)
244+
_site_wrapper(
245+
site_source,
246+
parameter_source,
247+
sites_summary_persister,
248+
timeseries_persister,
249+
config,
250+
)
234251

235252
if config.output_summary:
236253
sites_summary_persister.dump_summary(config.output_path)

tests/test_sources/__init__.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
EXCLUDED_GEOJSON_KEYS = ["latitude", "longitude", "elevation"]
1515

1616
SUMMARY_RECORD_CSV_HEADERS = list(SummaryRecord.keys)
17-
SUMMARY_RECORD_GEOJSON_KEYS = [k for k in SUMMARY_RECORD_CSV_HEADERS if k not in EXCLUDED_GEOJSON_KEYS]
17+
SUMMARY_RECORD_GEOJSON_KEYS = [
18+
k for k in SUMMARY_RECORD_CSV_HEADERS if k not in EXCLUDED_GEOJSON_KEYS
19+
]
1820

1921
SITE_RECORD_CSV_HEADERS = list(SiteRecord.keys)
20-
SITE_RECORD_GEOJSON_KEYS = [k for k in SITE_RECORD_CSV_HEADERS if k not in EXCLUDED_GEOJSON_KEYS]
22+
SITE_RECORD_GEOJSON_KEYS = [
23+
k for k in SITE_RECORD_CSV_HEADERS if k not in EXCLUDED_GEOJSON_KEYS
24+
]
2125

2226
PARAMETER_RECORD_HEADERS = list(ParameterRecord.keys)
2327

@@ -96,7 +100,9 @@ def _check_summary_file(self, extension: str):
96100
for feature in summary["features"]:
97101
assert feature["geometry"]["type"] == "Point"
98102
assert len(feature["geometry"]["coordinates"]) == 3
99-
assert sorted(feature["properties"].keys()) == sorted(SUMMARY_RECORD_GEOJSON_KEYS)
103+
assert sorted(feature["properties"].keys()) == sorted(
104+
SUMMARY_RECORD_GEOJSON_KEYS
105+
)
100106
assert summary["features"][0]["type"] == "Feature"
101107
else:
102108
raise ValueError(f"Unsupported file extension: {extension}")
@@ -122,7 +128,9 @@ def _check_sites_file(self, extension: str):
122128
for feature in sites["features"]:
123129
assert feature["geometry"]["type"] == "Point"
124130
assert len(feature["geometry"]["coordinates"]) == 3
125-
assert sorted(feature["properties"].keys()) == sorted(SITE_RECORD_GEOJSON_KEYS)
131+
assert sorted(feature["properties"].keys()) == sorted(
132+
SITE_RECORD_GEOJSON_KEYS
133+
)
126134
assert sites["features"][0]["type"] == "Feature"
127135
else:
128136
raise ValueError(f"Unsupported file extension: {extension}")
@@ -236,7 +244,6 @@ def test_timeseries_separated_geojson(self):
236244

237245
for timeseries_file in timeseries_dir.iterdir():
238246
self._check_timeseries_file(timeseries_dir, timeseries_file.name)
239-
240247

241248
@pytest.mark.skip(reason="test_date_range not implemented yet")
242249
def test_date_range(self):

0 commit comments

Comments
 (0)