Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions dripline/implementations/postgres_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,20 @@ class SQLTable(Endpoint):
'''
A class for making calls to _insert_with_return
'''
def __init__(self, table_name,
def __init__(self,
table_name,
schema=None,
required_insert_names=[],
return_col_names=[],
optional_insert_names=[],
default_insert_values={},
required_insert_names=(),
insert_return_col_names=(),
optional_insert_names=(),
default_insert_values=None,
*args,
**kwargs):
**kwargs):
'''
table_name (str): name of the table within the database
schema (str): name of the schema where the table is located
required_insert_names (list): list of names (str||dict) of the table columns which must be included on every requested insert (if dict: keys are 'column' and 'payload_key', if string it is assumed that both are that value)
return_col_names (list): list of names (str) of columns whose values should be returned on completion of the insert
insert_return_col_names (list): list of names (str) of columns whose values should be returned on completion of the insert
optional_insert_names (list): list of names (str||dict) of columns which the user may specify on an insert request, but which may be omitted (if dict: keys are 'column' and 'payload_key', if string it is assumed that both are that value)
default_insert_values (dict): dictionary of {column_names: values} to serve as defaults when inserting, any values provided explicitly on the insert request will override these values
'''
Expand All @@ -98,11 +99,14 @@ def __init__(self, table_name,
self.table = None
self.table_name = table_name
self.schema = schema
self._return_names = return_col_names
self._return_names = insert_return_col_names
self._column_map = {}
self._required_insert_names = self._ensure_col_key_map(required_insert_names)
self._optional_insert_names = self._ensure_col_key_map(optional_insert_names)
self._default_insert_dict = default_insert_values
if default_insert_values is None:
self._default_insert_dict = {}
else:
self._default_insert_dict = default_insert_values

def _ensure_col_key_map(self, column_list):
to_return = []
Expand All @@ -111,15 +115,15 @@ def _ensure_col_key_map(self, column_list):
to_return.append({'column': a_col, 'payload_key': a_col})
self._column_map[a_col] = a_col
elif isinstance(a_col, dict):
if not 'column' in a_col or not 'payload_key' in a_col:
if not 'column' in a_col and not 'payload_key' in a_col:
raise KeyError(f"column insert map <{a_col}> does not contain the required keys, ['column', 'payload_key']")
to_return.append(a_col)
self._column_map[a_col['payload_key']] = a_col['column']
else:
raise TypeError(f"column info <{a_col}> is not of an expected type")
return to_return

def do_select(self, return_cols=[], where_eq_dict={}, where_lt_dict={}, where_gt_dict={}):
def do_select(self, return_cols=(), where_eq_dict=None, where_lt_dict=None, where_gt_dict=None):
'''
return_cols (list of str): string names of columns, internally converted to sql reference; if evaluates as false, all columns are returned
where_eq_dict (dict): keys are column names (str), and values are tested with '=='
Expand All @@ -134,12 +138,15 @@ def do_select(self, return_cols=[], where_eq_dict={}, where_lt_dict={}, where_gt
this_select = sqlalchemy.select(self.table)
else:
this_select = sqlalchemy.select(*[getattr(self.table.c,col) for col in return_cols])
for c,v in where_eq_dict.items():
this_select = this_select.where(getattr(self.table.c,c)==v)
for c,v in where_lt_dict.items():
this_select = this_select.where(getattr(self.table.c,c)<v)
for c,v in where_gt_dict.items():
this_select = this_select.where(getattr(self.table.c,c)>v)
if where_eq_dict is not None:
for c,v in where_eq_dict.items():
this_select = this_select.where(getattr(self.table.c,c)==v)
if where_lt_dict is not None:
for c,v in where_lt_dict.items():
this_select = this_select.where(getattr(self.table.c,c)<v)
if where_gt_dict is not None:
for c,v in where_gt_dict.items():
this_select = this_select.where(getattr(self.table.c,c)>v)
with self.service.engine.connect() as conn:
result = conn.execute(this_select)
return (result.keys(), [i for i in result])
Expand All @@ -157,7 +164,7 @@ def _insert_with_return(self, insert_kv_dict, return_col_names_list):
return_values = []
return dict(zip(return_col_names_list, return_values))

def do_insert(self, *args, **kwargs):
def do_insert(self, **kwargs):
'''
'''
# make sure that all provided insert values are expected
Expand Down
56 changes: 52 additions & 4 deletions dripline/implementations/postgres_sensor_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, insertion_table_endpoint_name, **kwargs):
AlertConsumer.__init__(self, add_endpoints_now=False, **kwargs)
PostgreSQLInterface.__init__(self, **kwargs)

self.insertion_table_endpoint_name = insertion_table_endpoint_name
self.default_insertion_table = insertion_table_endpoint_name

self.connect_to_db(self.auth)

Expand All @@ -42,14 +42,16 @@ def add_child(self, endpoint):
AlertConsumer.add_child(self, endpoint)
self.add_child_table(endpoint)

def process_payload(self, a_payload, a_routing_key_data, a_message_timestamp):
def process_payload(self, a_payload, a_routing_key_data, a_message_timestamp, target_table=None):
try:
this_data_table = self.sync_children[self.insertion_table_endpoint_name]
if target_table is None:
target_table = self.default_insertion_table
this_data_table = self.sync_children[target_table]
# combine data sources
insert_data = {'timestamp': a_message_timestamp}
insert_data.update(a_routing_key_data)
insert_data.update(a_payload.to_python())
logger.info(f"Inserting from endpoint {self.insertion_table_endpoint_name}; data are:\n{insert_data}")
logger.info(f"Inserting to endpoint {target_table}; data are:\n{insert_data}")
# do the insert
insert_return = this_data_table.do_insert(**insert_data)
logger.debug(f"Return from insertion: {insert_return}")
Expand All @@ -58,3 +60,49 @@ def process_payload(self, a_payload, a_routing_key_data, a_message_timestamp):
logger.critical(f'Received SQL error while doing insert: {err}')
except Exception as err:
logger.critical(f'An exception was raised while processing a payload to insert: {err}')



__all__.append('PostgresMappedSensorLogger')
class PostgresMappedSensorLogger(PostgresSensorLogger):
'''
Add-on to PostgresSensorLogger using traditional database structure with an endpoint_map
'''
def __init__(self, sensor_type_map_table, data_tables_dict, **kwargs):
'''
Args:
sensor_type_map_table (str): name of endpoint (table) mapping endpoint names to types
data_tables (dict): mapping of data type to endpoint (table) names
* All endpoint names should be accounted between the above
'''
# map table supercedes need for insertion table
if 'insertion_table_endpoint_name' not in kwargs:
kwargs.update( {'insertion_table_endpoint_name' : None} )
PostgresSensorLogger.__init__(self, **kwargs)

# verify table values map to endpoints
if sensor_type_map_table not in self.sync_children:
raise ValueError(f'sensor_type_map_table ({sensor_type_map_table}) not in endpoint tables ({self.sync_children.keys()})')
self._sensor_type_map_table = sensor_type_map_table
for typekey, data_table in data_tables_dict.items():
if data_table not in self.sync_children:
raise ValueError(f'data table target ({data_table}) not in endpoint tables ({self.sync_children.keys()})')
self._data_tables = data_tables_dict

def process_payload(self, a_payload, a_routing_key_data, a_message_timestamp):
'''
method is wrapped to map data insert into correct table
'''
# get the type and table for the sensor
this_type = self.sync_children[self._sensor_type_map_table].do_select(return_cols=["type"],
where_eq_dict=a_routing_key_data)
logger.debug(f'Map query returned {this_type}')
# if the key is not contained in the table, generate meaningful error message
try:
table_name = self._data_tables[this_type[1][0][0]]
except IndexError:
logger.critical(f"{a_routing_key_data} is not in database, see {this_type}")
return
logger.info(f'Found {a_routing_key_data} in table {table_name}')

super().process_payload(a_payload, a_routing_key_data, a_message_timestamp, target_table=table_name)