diff --git a/.ruff.toml b/.ruff.toml new file mode 100644 index 0000000..489184b --- /dev/null +++ b/.ruff.toml @@ -0,0 +1,3 @@ +[format] +# Like Black, use double quotes for strings. +quote-style = "single" \ No newline at end of file diff --git a/src/queryparser/common/common.py b/src/queryparser/common/common.py index 0564b61..c2343b8 100644 --- a/src/queryparser/common/common.py +++ b/src/queryparser/common/common.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # All listeners that are with minor modifications shared between PostgreSQL # and MySQL. -from __future__ import (absolute_import, print_function) +from __future__ import absolute_import, print_function import logging import re @@ -31,7 +31,7 @@ def parse_alias(alias, quote_char): def process_column_name(column_name_listener, walker, ctx, quote_char): - ''' + """ A helper function that strips the quote characters from the column names. The returned list includes: @@ -44,7 +44,7 @@ def process_column_name(column_name_listener, walker, ctx, quote_char): column_name_listener object :param walker: - antlr walker object + antlr walker object :param ctx: antlr context to walk through @@ -52,7 +52,7 @@ def process_column_name(column_name_listener, walker, ctx, quote_char): :param quote_char: which quote character are we expecting? - ''' + """ cn = [] column_name_listener.column_name = [] walker.walk(column_name_listener, ctx) @@ -81,12 +81,12 @@ def process_column_name(column_name_listener, walker, ctx, quote_char): def get_column_name_listener(base): - class ColumnNameListener(base): """ Get all column names. """ + def __init__(self): self.column_name = [] self.column_as_array = [] @@ -105,12 +105,12 @@ def enterColumn_spec(self, ctx): def get_table_name_listener(base, quote_char): - class TableNameListener(base): """ Get table names. """ + def __init__(self): self.table_names = [] self.table_aliases = [] @@ -126,9 +126,7 @@ def enterAlias(self, ctx): def get_schema_name_listener(base, quote_char): - class SchemaNameListener(base): - def __init__(self, replace_schema_name): self.replace_schema_name = replace_schema_name @@ -141,9 +139,13 @@ def enterSchema_name(self, ctx): nsn = unicode(nsn, 'utf-8') except NameError: pass - nsn = re.sub(r'(|{})(?!{})[\S]*[^{}](|{})'.format( - quote_char, quote_char, quote_char, quote_char), - r'\1{}\2'.format(nsn), sn) + nsn = re.sub( + r'(|{})(?!{})[\S]*[^{}](|{})'.format( + quote_char, quote_char, quote_char, quote_char + ), + r'\1{}\2'.format(nsn), + sn, + ) ctx.getTokens(ttype)[0].getSymbol().text = nsn except KeyError: pass @@ -152,42 +154,46 @@ def enterSchema_name(self, ctx): def get_remove_subqueries_listener(base, base_parser): - class RemoveSubqueriesListener(base): """ Remove nested select_expressions. """ + def __init__(self, depth): self.depth = depth def enterSelect_expression(self, ctx): parent = ctx.parentCtx.parentCtx - if isinstance(parent, base_parser.SubqueryContext) and \ - ctx.depth() > self.depth: + if ( + isinstance(parent, base_parser.SubqueryContext) + and ctx.depth() > self.depth + ): # we need to remove all Select_expression instances, not # just the last one so we loop over until we get all of them # out - seinstances = [isinstance(i, - base_parser.Select_expressionContext) - for i in ctx.parentCtx.children] + seinstances = [ + isinstance(i, base_parser.Select_expressionContext) + for i in ctx.parentCtx.children + ] while True in seinstances: ctx.parentCtx.removeLastChild() - seinstances = [isinstance(i, - base_parser.Select_expressionContext) - for i in ctx.parentCtx.children] + seinstances = [ + isinstance(i, base_parser.Select_expressionContext) + for i in ctx.parentCtx.children + ] return RemoveSubqueriesListener def get_query_listener(base, base_parser, quote_char): - class QueryListener(base): """ Extract all select_expressions. """ + def __init__(self): self.select_expressions = [] self.select_list = None @@ -219,12 +225,12 @@ def enterSelect_list(self, ctx): def get_column_keyword_function_listener(base, quote_char): - class ColumnKeywordFunctionListener(base): """ Extract columns, keywords and functions. """ + def __init__(self): self.tables = [] self.columns = [] @@ -232,8 +238,7 @@ def __init__(self): self.keywords = [] self.functions = [] self.column_name_listener = get_column_name_listener(base)() - self.table_name_listener = get_table_name_listener( - base, quote_char)() + self.table_name_listener = get_table_name_listener(base, quote_char)() self.walker = antlr4.ParseTreeWalker() self.data = [] @@ -247,8 +252,9 @@ def _process_alias(self, ctx): return alias def _extract_column(self, ctx, append=True, join_columns=False): - cn = process_column_name(self.column_name_listener, self.walker, - ctx, quote_char) + cn = process_column_name( + self.column_name_listener, self.walker, ctx, quote_char + ) alias = self._process_alias(ctx) if len(cn) > 1: @@ -293,15 +299,20 @@ def enterTable_atom(self, ctx): tn[1] = ts.table_name().getText().replace(quote_char, '') self.tables.append((alias, tn, ctx.depth())) - logging.info((ctx.depth(), ctx.__class__.__name__, - [tn, alias])) + logging.info((ctx.depth(), ctx.__class__.__name__, [tn, alias])) self.data.append([ctx.depth(), ctx, [tn, alias]]) def enterDisplayed_column(self, ctx): - logging.info((ctx.depth(), ctx.__class__.__name__, - self._extract_column(ctx, append=False)[1])) - self.data.append([ctx.depth(), ctx, - self._extract_column(ctx, append=False)[1]]) + logging.info( + ( + ctx.depth(), + ctx.__class__.__name__, + self._extract_column(ctx, append=False)[1], + ) + ) + self.data.append( + [ctx.depth(), ctx, self._extract_column(ctx, append=False)[1]] + ) self._extract_column(ctx) if ctx.ASTERISK(): self.keywords.append('*') @@ -312,10 +323,10 @@ def enterSelect_expression(self, ctx): def enterSelect_list(self, ctx): if ctx.ASTERISK(): - logging.info((ctx.depth(), ctx.__class__.__name__, - [[None, None, '*'], None])) - self.data.append([ctx.depth(), ctx, [[[None, None, '*'], - None]]]) + logging.info( + (ctx.depth(), ctx.__class__.__name__, [[None, None, '*'], None]) + ) + self.data.append([ctx.depth(), ctx, [[[None, None, '*'], None]]]) self.columns.append(('*', None)) self.keywords.append('*') @@ -330,36 +341,60 @@ def enterGroupby_clause(self, ctx): col = self._extract_column(ctx, append=False) if col[1][0][0][2] not in self.column_aliases: self._extract_column(ctx) - logging.info((ctx.depth(), ctx.__class__.__name__, - self._extract_column(ctx, append=False)[1])) - self.data.append([ctx.depth(), ctx, - self._extract_column(ctx, append=False)[1]]) + logging.info( + ( + ctx.depth(), + ctx.__class__.__name__, + self._extract_column(ctx, append=False)[1], + ) + ) + self.data.append( + [ctx.depth(), ctx, self._extract_column(ctx, append=False)[1]] + ) def enterWhere_clause(self, ctx): self.keywords.append('where') self._extract_column(ctx) - logging.info((ctx.depth(), ctx.__class__.__name__, - self._extract_column(ctx, append=False)[1])) - self.data.append([ctx.depth(), ctx, - self._extract_column(ctx, append=False)[1]]) + logging.info( + ( + ctx.depth(), + ctx.__class__.__name__, + self._extract_column(ctx, append=False)[1], + ) + ) + self.data.append( + [ctx.depth(), ctx, self._extract_column(ctx, append=False)[1]] + ) def enterHaving_clause(self, ctx): self.keywords.append('having') self._extract_column(ctx) - logging.info((ctx.depth(), ctx.__class__.__name__, - self._extract_column(ctx, append=False)[1])) - self.data.append([ctx.depth(), ctx, - self._extract_column(ctx, append=False)[1]]) + logging.info( + ( + ctx.depth(), + ctx.__class__.__name__, + self._extract_column(ctx, append=False)[1], + ) + ) + self.data.append( + [ctx.depth(), ctx, self._extract_column(ctx, append=False)[1]] + ) def enterOrderby_clause(self, ctx): self.keywords.append('order by') col = self._extract_column(ctx, append=False) if col[1][0][0][2] not in self.column_aliases: self._extract_column(ctx) - logging.info((ctx.depth(), ctx.__class__.__name__, - self._extract_column(ctx, append=False)[1])) - self.data.append([ctx.depth(), ctx, - self._extract_column(ctx, append=False)[1]]) + logging.info( + ( + ctx.depth(), + ctx.__class__.__name__, + self._extract_column(ctx, append=False)[1], + ) + ) + self.data.append( + [ctx.depth(), ctx, self._extract_column(ctx, append=False)[1]] + ) def enterLimit_clause(self, ctx): self.keywords.append('limit') @@ -367,10 +402,16 @@ def enterLimit_clause(self, ctx): def enterJoin_condition(self, ctx): self.keywords.append('join') self._extract_column(ctx, join_columns=ctx) - logging.info((ctx.depth(), ctx.__class__.__name__, - self._extract_column(ctx, append=False)[1])) - self.data.append([ctx.depth(), ctx, - self._extract_column(ctx, append=False)[1]]) + logging.info( + ( + ctx.depth(), + ctx.__class__.__name__, + self._extract_column(ctx, append=False)[1], + ) + ) + self.data.append( + [ctx.depth(), ctx, self._extract_column(ctx, append=False)[1]] + ) def enterSpoint(self, ctx): self.functions.append('spoint') @@ -437,8 +478,16 @@ class SQLQueryProcessor(object): other types of listeners can be added. """ - def __init__(self, base_lexer, base_parser, base_parser_listener, - quote_char, query=None, base_sphere_listener=None): + + def __init__( + self, + base_lexer, + base_parser, + base_parser_listener, + quote_char, + query=None, + base_sphere_listener=None, + ): self.lexer = base_lexer self.parser = base_parser self.parser_listener = base_parser_listener @@ -495,12 +544,12 @@ def _extract_instances(self, column_keyword_function_listener): if isinstance(i[1], self.parser.Select_listContext): if len(i) == 3: - select_list_columns.append([[i[2][0][0] + [i[1]], - i[2][0][1]]]) + select_list_columns.append([[i[2][0][0] + [i[1]], i[2][0][1]]]) ctx_stack.append(i) - if isinstance(i[1], self.parser.Where_clauseContext) or\ - isinstance(i[1], self.parser.Having_clauseContext): + if isinstance(i[1], self.parser.Where_clauseContext) or isinstance( + i[1], self.parser.Having_clauseContext + ): if len(i[2]) > 1: for j in i[2]: other_columns.append([j]) @@ -514,15 +563,23 @@ def _extract_instances(self, column_keyword_function_listener): if i[1].USING_SYM(): for ctx in ctx_stack[::-1]: - if not isinstance(ctx[1], - self.parser.Table_atomContext): + if not isinstance(ctx[1], self.parser.Table_atomContext): break for ju in join_using: if ju[0][1] is None: - other_columns.append([[[ctx[2][0][0], - ctx[2][0][1], - ju[0][2], - ctx[1]], None]]) + other_columns.append( + [ + [ + [ + ctx[2][0][0], + ctx[2][0][1], + ju[0][2], + ctx[1], + ], + None, + ] + ] + ) elif i[1].ON(): if len(i[2]) > 1: for j in i[2]: @@ -546,9 +603,16 @@ def _extract_instances(self, column_keyword_function_listener): go_columns.append(i[2]) ctx_stack.append(i) - return select_list_columns, select_list_tables,\ - select_list_table_references, other_columns, go_columns, join,\ - join_using, column_aliases + return ( + select_list_columns, + select_list_tables, + select_list_table_references, + other_columns, + go_columns, + join, + join_using, + column_aliases, + ) def _get_budget_column(self, c, tab, ref): cname = c[0][2] @@ -577,10 +641,17 @@ def _get_budget_column(self, c, tab, ref): return cname, cctx, calias, column_found, t - def _extract_columns(self, columns, select_list_tables, ref_dict, join, - budget, column_aliases, touched_columns=None, - subquery_contents=None): - + def _extract_columns( + self, + columns, + select_list_tables, + ref_dict, + join, + budget, + column_aliases, + touched_columns=None, + subquery_contents=None, + ): # Here we store all columns that might have references somewhere # higher up in the tree structure. We'll revisit them later. missing_columns = [] @@ -595,11 +666,11 @@ def _extract_columns(self, columns, select_list_tables, ref_dict, join, calias = c[1] # if * is selected we don't care too much - if c[0][0] is None and c[0][1] is None and c[0][2] == '*'\ - and not join: + if c[0][0] is None and c[0][1] is None and c[0][2] == '*' and not join: for slt in select_list_tables: - extra_columns.append([[slt[0][0][0], slt[0][0][1], cname, - c[0][3]], calias]) + extra_columns.append( + [[slt[0][0][0], slt[0][0][1], cname, c[0][3]], calias] + ) remove_column_idxs.append(i) continue @@ -616,15 +687,17 @@ def _extract_columns(self, columns, select_list_tables, ref_dict, join, # We have to check if we also have a join on the same level # and we are actually touching a column from the joined table - if join and c[0][2] != '*' and\ - (tab[1] != c[0][1] or - (tab[1] is None and c[0][1] is None)): - cname, cctx, calias, column_found, tab =\ - self._get_budget_column(c, tab, budget[-1][2]) + if ( + join + and c[0][2] != '*' + and (tab[1] != c[0][1] or (tab[1] is None and c[0][1] is None)) + ): + cname, cctx, calias, column_found, tab = self._get_budget_column( + c, tab, budget[-1][2] + ) # raise an ambiguous column if column_found and c[0][1] is None: - raise QueryError("Column '%s' is possibly ambiguous." - % c[0][2]) + raise QueryError("Column '%s' is possibly ambiguous." % c[0][2]) except IndexError: pass @@ -636,14 +709,18 @@ def _extract_columns(self, columns, select_list_tables, ref_dict, join, if isinstance(ref[0], int): # ref is a budget column - cname, cctx, calias, column_found, tab =\ - self._get_budget_column(c, tab, ref[2]) + cname, cctx, calias, column_found, tab = self._get_budget_column( + c, tab, ref[2] + ) ref_cols = [j[0][2] for j in ref[2]] - if not column_found and c[0][1] is not None\ - and c[0][1] != tab[0][1] and '*' not in ref_cols: - raise QueryError("Unknown column '%s.%s'." % (c[0][1], - c[0][2])) + if ( + not column_found + and c[0][1] is not None + and c[0][1] != tab[0][1] + and '*' not in ref_cols + ): + raise QueryError("Unknown column '%s.%s'." % (c[0][1], c[0][2])) else: # ref is a table @@ -662,12 +739,12 @@ def _extract_columns(self, columns, select_list_tables, ref_dict, join, if subquery_contents is not None: try: contents = subquery_contents[c[0][1]] - cname, cctx, calias, column_found, tab =\ + cname, cctx, calias, column_found, tab = ( self._get_budget_column(c, tab, contents) + ) except KeyError: - tabs = [j[0][0][:2] for j in - subquery_contents.values()] + tabs = [j[0][0][:2] for j in subquery_contents.values()] tabs += [j[0][0] for j in select_list_tables] column_found = False for t in tabs: @@ -686,18 +763,24 @@ def _extract_columns(self, columns, select_list_tables, ref_dict, join, continue else: if tab[0][1] == c[0][1]: - columns[i] = [[tab[0][0], tab[0][1], - c[0][2], c[0][3]], c[1]] + columns[i] = [ + [tab[0][0], tab[0][1], c[0][2], c[0][3]], + c[1], + ] else: - missing_columns.append(c) columns[i] = c if touched_columns is not None: touched_columns.append(c) continue - elif c[0][2] is not None and c[0][2] != '*' and c[0][1] is \ - None and len(ref_dict.keys()) > 1 and not join: + elif ( + c[0][2] is not None + and c[0][2] != '*' + and c[0][1] is None + and len(ref_dict.keys()) > 1 + and not join + ): raise QueryError("Column '%s' is ambiguous." % c[0][2]) elif len(budget) and tab[0][0] is None and tab[0][1] is None: @@ -705,18 +788,21 @@ def _extract_columns(self, columns, select_list_tables, ref_dict, join, column_found = False if isinstance(ref[0], int): - cname, cctx, calias, column_found, tab =\ - self._get_budget_column(c, tab, ref[2]) + cname, cctx, calias, column_found, tab = ( + self._get_budget_column(c, tab, ref[2]) + ) # We allow None.None columns because they are produced # by count(*) - if not column_found and c[0][2] is not None\ - and c[0][2] not in column_aliases: + if ( + not column_found + and c[0][2] is not None + and c[0][2] not in column_aliases + ): raise QueryError("Unknown column '%s'." % c[0][2]) if touched_columns is not None: - touched_columns.append([[tab[0][0], tab[0][1], cname, cctx], - calias]) + touched_columns.append([[tab[0][0], tab[0][1], cname, cctx], calias]) else: columns[i] = [[tab[0][0], tab[0][1], cname, c[0][3]], calias] @@ -726,7 +812,27 @@ def _extract_columns(self, columns, select_list_tables, ref_dict, join, columns.extend(extra_columns) return missing_columns - def process_query(self, replace_schema_name=None, indexed_objects=None): + @staticmethod + def _match_and_replace_function_name(query, function_name, i): + """ + This very roughly checks if the function name is present in the query. + We check for a space, the function name, and an opening parenthesis. + """ + pattern = r'\s' + re.escape(function_name) + r'\(' + match = re.search(pattern, query) + if match: + start, end = match.span() + # Replace the matched function name with UDF_{i} + query = query[: start + 1] + f'UDF_{i}' + query[end - 1 :] + + return match, query + + def process_query( + self, + replace_schema_name=None, + replace_function_names=None, + indexed_objects=None, + ): """ Parses and processes the query. After a successful run it fills up columns, keywords, functions and syntax_errors lists. @@ -737,6 +843,21 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): :param indexed_objects: Deprecated """ + self.replaced_functions = {} + + if replace_function_names: + if (n := len(replace_function_names)) > 10: + raise ValueError( + f'Too many function names to replace (you passed {n}). Maximum: 10' + ) + for i, function_name in enumerate(replace_function_names): + match, query = self._match_and_replace_function_name( + self.query, function_name, i + ) + if match: + self.replaced_functions[i] = function_name + self.set_query(query) + # Antlr objects inpt = antlr4.InputStream(self.query) lexer = self.lexer(inpt) @@ -752,12 +873,14 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): if replace_schema_name is not None: schema_name_listener = get_schema_name_listener( - self.parser_listener, self.quote_char)(replace_schema_name) + self.parser_listener, self.quote_char + )(replace_schema_name) self.walker.walk(schema_name_listener, tree) self._query = stream.getText() - query_listener = get_query_listener(self.parser_listener, - self.parser, self.quote_char)() + query_listener = get_query_listener( + self.parser_listener, self.parser, self.quote_char + )() subquery_aliases = [None] keywords = [] functions = [] @@ -784,10 +907,11 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): # Iterate through subqueries starting with the lowest level for ccc, ctx in enumerate(query_listener.select_expressions[::-1]): remove_subquieries_listener = get_remove_subqueries_listener( - self.parser_listener, self.parser)(ctx.depth()) - column_keyword_function_listener = \ - get_column_keyword_function_listener( - self.parser_listener, self.quote_char)() + self.parser_listener, self.parser + )(ctx.depth()) + column_keyword_function_listener = get_column_keyword_function_listener( + self.parser_listener, self.quote_char + )() # Remove nested subqueries from select_expressions self.walker.walk(remove_subquieries_listener, ctx) @@ -809,10 +933,16 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): # We get the columns from the select list along with all # other touched columns and any possible join conditions column_aliases_from_previous = [i for i in column_aliases] - select_list_columns, select_list_tables,\ - select_list_table_references, other_columns, go_columns, join,\ - join_using, column_aliases =\ - self._extract_instances(column_keyword_function_listener) + ( + select_list_columns, + select_list_tables, + select_list_table_references, + other_columns, + go_columns, + join, + join_using, + column_aliases, + ) = self._extract_instances(column_keyword_function_listener) tables.extend([i[0] for i in select_list_tables]) @@ -837,9 +967,14 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): for table in select_list_tables: ref_dict[table[0][0][1]] = table - mc = self._extract_columns(select_list_columns, select_list_tables, - ref_dict, join, budget, - column_aliases_from_previous) + mc = self._extract_columns( + select_list_columns, + select_list_tables, + ref_dict, + join, + budget, + column_aliases_from_previous, + ) missing_columns.extend([[i] for i in mc]) touched_columns.extend(select_list_columns) @@ -851,10 +986,15 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): if col[0][0][2] not in aliases: other_columns.append(col) - mc = self._extract_columns(other_columns, select_list_tables, - ref_dict, join, budget, - column_aliases_from_previous, - touched_columns) + mc = self._extract_columns( + other_columns, + select_list_tables, + ref_dict, + join, + budget, + column_aliases_from_previous, + touched_columns, + ) missing_columns.extend([[i] for i in mc]) @@ -863,8 +1003,9 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): join_columns.append(budget.pop(-1)) if len(join_using) == 1: for tab in select_list_tables: - touched_columns.append([[tab[0][0][0], tab[0][0][1], - join_using[0][0][2]], None]) + touched_columns.append( + [[tab[0][0][0], tab[0][0][1], join_using[0][0][2]], None] + ) bp = [] for b in budget[::-1]: if b[0] > current_depth: @@ -876,26 +1017,38 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): subquery_contents[subquery_alias] = current_columns if len(missing_columns): - mc = self._extract_columns(missing_columns, select_list_tables, - ref_dict, join, budget, - column_aliases_from_previous, - touched_columns, subquery_contents) + mc = self._extract_columns( + missing_columns, + select_list_tables, + ref_dict, + join, + budget, + column_aliases_from_previous, + touched_columns, + subquery_contents, + ) if len(mc): - unref_cols = "', '".join(['.'.join([j for j in i[0][:3] if j]) - for i in mc]) + unref_cols = "', '".join( + ['.'.join([j for j in i[0][:3] if j]) for i in mc] + ) raise QueryError("Unreferenced column(s): '%s'." % unref_cols) touched_columns = set([tuple(i[0]) for i in touched_columns]) # extract display_columns display_columns = [] - mc = self._extract_columns([[i] for i in budget[-1][2]], - select_list_tables, ref_dict, join, budget, - column_aliases_from_previous, - display_columns, subquery_contents) - - display_columns = [[i[1] if i[1] else i[0][2], i[0]] - for i in display_columns] + mc = self._extract_columns( + [[i] for i in budget[-1][2]], + select_list_tables, + ref_dict, + join, + budget, + column_aliases_from_previous, + display_columns, + subquery_contents, + ) + + display_columns = [[i[1] if i[1] else i[0][2], i[0]] for i in display_columns] # Let's get rid of all columns that are already covered by # db.tab.*. Figure out a better way to do it and replace the code @@ -908,23 +1061,41 @@ def process_query(self, replace_schema_name=None, indexed_objects=None): for acol in asterisk_columns: for col in touched_columns: - if acol[0] == col[0] and acol[1] == col[1] and \ - acol[2] != col[2]: + if acol[0] == col[0] and acol[1] == col[1] and acol[2] != col[2]: del_columns.append(col) columns = list(set(touched_columns).difference(del_columns)) self.columns = list(set([self._strip_column(i) for i in columns])) self.keywords = list(set(keywords)) self.functions = list(set(functions)) - self.display_columns = [(i[0].lstrip('"').rstrip('"'), - list(self._strip_column(i[1]))) - for i in display_columns] - - self.tables = list(set([tuple([i[0][0].lstrip('"').rstrip('"') - if i[0][0] is not None else i[0][0], - i[0][1].lstrip('"').rstrip('"') - if i[0][1] is not None else i[0][1]]) - for i in tables])) + self.display_columns = [ + (i[0].lstrip('"').rstrip('"'), list(self._strip_column(i[1]))) + for i in display_columns + ] + + self.tables = list( + set( + [ + tuple( + [ + i[0][0].lstrip('"').rstrip('"') + if i[0][0] is not None + else i[0][0], + i[0][1].lstrip('"').rstrip('"') + if i[0][1] is not None + else i[0][1], + ] + ) + for i in tables + ] + ) + ) + + if len(self.replaced_functions) > 0: + for i, function_name in self.replaced_functions.items(): + self._query = self.query.replace(f'UDF_{i}', function_name) + self.functions.remove(f'UDF_{i}') + self.functions.append(function_name) @property def query(self): diff --git a/src/queryparser/postgresql/PostgreSQLLexer.g4 b/src/queryparser/postgresql/PostgreSQLLexer.g4 index 2532ca5..904c543 100644 --- a/src/queryparser/postgresql/PostgreSQLLexer.g4 +++ b/src/queryparser/postgresql/PostgreSQLLexer.g4 @@ -170,6 +170,16 @@ TIME_SYM : T_ I_ M_ E_ ; TIMESTAMP : T_ I_ M_ E_ S_ T_ A_ M_ P_ ; TRUE_SYM : T_ R_ U_ E_ ; TRUNCATE : T_ R_ U_ N_ C_ A_ T_ E_ ; +UDF_0 : U_ D_ F_ '_' '0' ; +UDF_1 : U_ D_ F_ '_' '1' ; +UDF_2 : U_ D_ F_ '_' '2' ; +UDF_3 : U_ D_ F_ '_' '3' ; +UDF_4 : U_ D_ F_ '_' '4' ; +UDF_5 : U_ D_ F_ '_' '5' ; +UDF_6 : U_ D_ F_ '_' '6' ; +UDF_7 : U_ D_ F_ '_' '7' ; +UDF_8 : U_ D_ F_ '_' '8' ; +UDF_9 : U_ D_ F_ '_' '9' ; UNION_SYM : U_ N_ I_ O_ N_ ; UNSIGNED_SYM : U_ N_ S_ I_ G_ N_ E_ D_ ; UPDATE : U_ P_ D_ A_ T_ E_ ; diff --git a/src/queryparser/postgresql/PostgreSQLParser.g4 b/src/queryparser/postgresql/PostgreSQLParser.g4 index 209248c..b3aadc5 100644 --- a/src/queryparser/postgresql/PostgreSQLParser.g4 +++ b/src/queryparser/postgresql/PostgreSQLParser.g4 @@ -63,7 +63,7 @@ array_functions: ARRAY_LENGTH ; custom_functions: - GAIA_HEALPIX_INDEX | PDIST ; + GAIA_HEALPIX_INDEX | PDIST | UDF_0 | UDF_1 | UDF_2 | UDF_3 | UDF_4 | UDF_5 | UDF_6 | UDF_7 | UDF_8 | UDF_9 ; pg_sphere_functions: AREA ; diff --git a/src/queryparser/postgresql/postgresqlprocessor.py b/src/queryparser/postgresql/postgresqlprocessor.py index a15b7f6..e400745 100644 --- a/src/queryparser/postgresql/postgresqlprocessor.py +++ b/src/queryparser/postgresql/postgresqlprocessor.py @@ -6,18 +6,18 @@ """ -from __future__ import (absolute_import, print_function) +from __future__ import absolute_import, print_function __all__ = ["PostgreSQLQueryProcessor"] +from ..common import SQLQueryProcessor from .PostgreSQLLexer import PostgreSQLLexer from .PostgreSQLParser import PostgreSQLParser from .PostgreSQLParserListener import PostgreSQLParserListener -from ..common import SQLQueryProcessor - class PostgreSQLQueryProcessor(SQLQueryProcessor): def __init__(self, query=None): - super().__init__(PostgreSQLLexer, PostgreSQLParser, - PostgreSQLParserListener, '"', query) + super().__init__( + PostgreSQLLexer, PostgreSQLParser, PostgreSQLParserListener, '"', query + ) diff --git a/src/queryparser/testing/tests.yaml b/src/queryparser/testing/tests.yaml index f33812e..fb093f6 100644 --- a/src/queryparser/testing/tests.yaml +++ b/src/queryparser/testing/tests.yaml @@ -16,6 +16,7 @@ common_tests: - - ['col1: db.tab.a'] - ['db.tab'] + - - - SELECT t.a FROM db.tab1 as t, db.tab2; @@ -24,6 +25,7 @@ common_tests: - - ['a: db.tab1.a'] - ['db.tab1', 'db.tab2'] + - - - SELECT COUNT(*), a*2, b, 100 FROM db.tab; @@ -32,6 +34,7 @@ common_tests: - ['COUNT'] - ['a: db.tab.a', 'b: db.tab.b'] - ['db.tab'] + - - - SELECT (((((((1+2)*3)/4)^5)%6)&7)>>8) FROM db.tab; @@ -40,6 +43,7 @@ common_tests: - - - ['db.tab'] + - - - SELECT ABS(a),AVG(b) FROM db.tab; @@ -48,6 +52,7 @@ common_tests: - ['AVG', 'ABS'] - ['a: db.tab.a', 'b: db.tab.b'] - ['db.tab'] + - - - SELECT AVG(((((b & a) << 1) + 1) / a) ^ 4.5) FROM db.tab; @@ -56,6 +61,7 @@ common_tests: - ['AVG'] - - ['db.tab'] + - - - SELECT A.a,B.* FROM db.tab1 A,db.tab2 AS B LIMIT 10; @@ -64,6 +70,7 @@ common_tests: - - ['a: db.tab1.a', '*: db.tab2.*'] - ['db.tab1', 'db.tab2'] + - - - SELECT fofid, x, y, z, vx, vy, vz @@ -76,6 +83,7 @@ common_tests: - - ['fofid: MDR1.FOF.fofid', 'x: MDR1.FOF.x', 'y: MDR1.FOF.y', 'z: MDR1.FOF.z', 'vx: MDR1.FOF.vx', 'vy: MDR1.FOF.vy', 'vz: MDR1.FOF.vz'] - ['MDR1.FOF'] + - - - SELECT article, dealer, price @@ -86,6 +94,7 @@ common_tests: - ['MAX'] - ['article: world.shop.article', 'dealer: world.shop.dealer', 'price: world.shop.price'] - ['world.shop', 'universe.shop'] + - - - SELECT dealer, price @@ -99,6 +108,7 @@ common_tests: - ['MAX'] - ['price: db.shop.price', 'dealer: db.shop.dealer'] - ['db.shop', 'db.warehouse'] + - - - SELECT A.*, B.* @@ -110,6 +120,7 @@ common_tests: - - ['*: db1.table1.*', '*: db2.table1.*'] - ['db1.table1', 'db2.table1'] + - - - SELECT * FROM mmm.products @@ -120,6 +131,7 @@ common_tests: - - ['*: mmm.products.*'] - ['mmm.products'] + - - - SELECT t.table_name AS tname, t.description AS tdesc, @@ -153,6 +165,7 @@ common_tests: 'jcol: tap_schema.cols.column_name', 'kcol: tap_schema.cols.column_name'] - ['tap_schema.tabs', 'tap_schema.cols'] + - - - SELECT t1.a FROM d.tab t1 @@ -162,6 +175,7 @@ common_tests: - ['a: foo.tab.a'] - ['foo.tab'] - 'd': 'foo' + - - - SELECT DISTINCT t.table_name @@ -175,6 +189,7 @@ common_tests: - - ['table_name: tap_schema.tabs.table_name'] - ['tap_schema.tabs', 'tap_schema.cols'] + - - - SELECT s.* FROM db.person p INNER JOIN db.shirt s @@ -186,6 +201,7 @@ common_tests: - - ['*: db.shirt.*'] - ['db.shirt', 'db.person'] + - - - SELECT x, y, z, mass @@ -200,6 +216,7 @@ common_tests: - ['x: MDR1.FOF.x', 'y: MDR1.FOF.y', 'z: MDR1.FOF.z', 'mass: MDR1.FOF.mass'] - ['MDR1.FOF'] + - - - SELECT h.Mvir, h.spin, g.diskMassStellar, @@ -219,6 +236,7 @@ common_tests: 'diskMassStellar: MDPL2.Galacticus.diskMassStellar', 'spin: MDPL2.Rockstar.spin'] - ['MDPL2.Rockstar', 'MDPL2.Galacticus'] + - - - SELECT bdmId, Rbin, mass, dens @@ -247,6 +265,7 @@ common_tests: - ['bdmId: Bolshoi.BDMVProf.bdmId', 'Rbin: Bolshoi.BDMVProf.Rbin', 'mass: Bolshoi.BDMVProf.mass', 'dens: Bolshoi.BDMVProf.dens'] - ['Bolshoi.BDMVProf', 'Bolshoi.BDMV'] + - - - SELECT t.RAVE_OBS_ID AS c1, t.HEALPix AS c2, @@ -274,6 +293,7 @@ common_tests: 'c4: RAVEPUB_DR5.RAVE_ON.TEFF'] - ['RAVEPUB_DR5.RAVE_DR5', 'RAVEPUB_DR5.RAVE_Gravity_SC', 'RAVEPUB_DR5.RAVE_ON'] + - - - SELECT db.tab.a FROM db.tab; @@ -282,6 +302,7 @@ common_tests: - - ['a: db.tab.a'] - ['db.tab'] + - - - SELECT COUNT(*) AS n, id, mra, mlem AS qqq, blem @@ -316,6 +337,7 @@ common_tests: - ['n: None.None.None', 'id: db.bar.id', 'mra: db.tab.ra', 'qqq: db.bar.mlem', 'blem: None.None.blem'] - ['db.tab', 'db.bar', 'db.gaia'] + - - - SELECT @@ -361,6 +383,7 @@ common_tests: 'n: None.None.None'] - ['gaiadr1.tgas_source', 'gaiadr1.tmass_best_neighbour', 'gaiadr1.tmass_original_valid'] + - - - SELECT ra, sub.qqq, t1.bar @@ -379,6 +402,7 @@ common_tests: - - - ['db.tab', 'db.blem'] + - - - SELECT t1.a, t2.b, t3.c, t4.z @@ -392,6 +416,7 @@ common_tests: 'd': 'foo' 'db2': 'bar' 'foo': 'bas' + - - - SELECT *, AVG(par) as apar FROM db.tab; @@ -400,6 +425,7 @@ common_tests: - ['AVG'] - ['*: db.tab.*', 'apar: db.tab.par'] - ['db.tab'] + - - - SELECT q.ra, q.de, tab2.par @@ -414,6 +440,7 @@ common_tests: - ['MAX'] - ['ra: db.tab.ra', 'de: db.tab.de', 'par: db.tab2.par'] - ['db.tab', 'db.tab2', 'db.undef'] + - - - SELECT a, b @@ -427,6 +454,7 @@ common_tests: - - ['a: db.tab1.a', 'b: db.tab1.b'] - ['db.tab1', 'db.tab2'] + - - - SELECT a FROM db.tab HAVING b > 0 @@ -435,6 +463,7 @@ common_tests: - - ['a: db.tab.a'] - ['db.tab'] + - - - SELECT a FROM db.tab WHERE EXISTS ( @@ -445,7 +474,8 @@ common_tests: - - ['a: db.tab.a'] - ['db.tab', 'db.foo'] - + - + - - SELECT * FROM ( @@ -458,6 +488,7 @@ common_tests: - - - ['db.a', 'db.b', 'db.c', 'db.d', 'db.x', 'db.y'] + - - - SELECT * @@ -471,6 +502,7 @@ common_tests: - - - ['db.a', 'db.b', 'db.c', 'db.d', 'db.x', 'db.y'] + - - - SELECT A.*, B.* @@ -482,6 +514,7 @@ common_tests: - - ['*: db1.table1.*', '*: db2.table1.*'] - ['db1.table1', 'db2.table1'] + - common_translation_tests: @@ -493,6 +526,7 @@ common_translation_tests: - - - + - mysql_tests: @@ -508,6 +542,7 @@ mysql_tests: - - ['fi@1: db.test_table.fi@1', 'fi2: db.test_table.fi2'] - ['db.test_table', 'bd.test_table'] + - - - SELECT `fi@1`, fi2 @@ -521,6 +556,7 @@ mysql_tests: - - ['fi@1: db.test_table.fi@1', 'fi2: db.test_table.fi2'] - ['db.test_table', 'bd.test_table'] + - - - SELECT log10(mass)/sqrt(x) AS logM @@ -530,6 +566,7 @@ mysql_tests: - ['log10', 'sqrt'] - - ['MDR1.FOF'] + - - - SELECT log10(ABS(x)) AS log_x @@ -539,6 +576,7 @@ mysql_tests: - ['log10', 'ABS'] - ['log_x: MDR1.FOF.x'] - ['MDR1.FOF'] + - - - SELECT DEGREES(sdist(spoint(RADIANS(0.0), RADIANS(0.0)), @@ -550,6 +588,7 @@ mysql_tests: - ['DEGREES', 'RADIANS', 'sdist', 'spoint'] - - ['db.VII/233/xsc'] + - - - SELECT Data FROM db.Users @@ -559,6 +598,7 @@ mysql_tests: - - ['Data: db.Users.Data'] - ['db.Users'] + - - - SELECT CONVERT(ra, DECIMAL(12,9)) as ra2, ra as ra1 @@ -570,6 +610,7 @@ mysql_tests: - - ['ra1: GDR1.gaia_source.ra', 'ra2: GDR1.gaia_source.ra'] - ['GDR1.gaia_source'] + - - - SELECT DEGREES(sdist(spoint(RADIANS(ra), RADIANS(dec)), @@ -587,6 +628,7 @@ mysql_tests: 'DEGREES'] - - ['GDR1.gaia_source'] + - - - SELECT x, y, z, mass @@ -598,6 +640,7 @@ mysql_tests: - ['x: MDR1.FOF.x', 'y: MDR1.FOF.y', 'z: MDR1.FOF.z', 'mass: MDR1.FOF.mass'] - ['MDR1.FOF'] + - - - SELECT bdmId, Rbin, mass, dens FROM Bolshoi.BDMVProf @@ -618,6 +661,7 @@ mysql_tests: - ['bdmId: Bolshoi.BDMVProf.bdmId', 'Rbin: Bolshoi.BDMVProf.Rbin', 'mass: Bolshoi.BDMVProf.mass', 'dens: Bolshoi.BDMVProf.dens'] - ['Bolshoi.BDMVProf', 'Bolshoi.BDMV'] + - postgresql_tests: @@ -628,6 +672,7 @@ postgresql_tests: - ['pdist'] - - + - - - SELECT DISTINCT ON ("source"."tycho2_id") "tycho2_id", "source"."tycho2_dist" @@ -637,6 +682,7 @@ postgresql_tests: - - - + - - - SELECT ra, dec FROM gdr1.gaia_source @@ -647,6 +693,7 @@ postgresql_tests: - ['scircle', 'spoint'] - ['ra: gdr1.gaia_source.ra', 'dec: gdr1.gaia_source.dec'] - ['gdr1.gaia_source'] + - - - SELECT ra, dec FROM gdr1.gaia_source @@ -657,6 +704,7 @@ postgresql_tests: - ['scircle', 'spoint'] - ['ra: gdr1.gaia_source.ra', 'dec: gdr1.gaia_source.dec'] - ['gdr1.gaia_source'] + - - - SELECT * FROM gdr2.vari_cepheid AS v @@ -668,6 +716,7 @@ postgresql_tests: - ['scircle', 'spoint'] - ['*: gdr2.vari_cepheid.*'] - ['gdr2.gaia_source', 'gdr2.vari_cepheid'] + - - - SELECT curves.observation_time, @@ -696,6 +745,7 @@ postgresql_tests: 'observation_time: gdr1.phot_variable_time_series_gfov.observation_time', 'phase: gdr1.rrlyrae.p1'] - ['gdr1.phot_variable_time_series_gfov', 'gdr1.rrlyrae'] + - - - SELECT a @@ -706,6 +756,7 @@ postgresql_tests: - - ['a: db.tab.a'] - ['db.tab'] + - - - SELECT arr[1:3] FROM db.phot; @@ -714,6 +765,7 @@ postgresql_tests: - - ['arr: db.phot.arr'] - ['db.phot'] + - - - SELECT arr[1:3][1][2][3][4] FROM db.phot; @@ -722,6 +774,7 @@ postgresql_tests: - - ['arr: db.phot.arr'] - ['db.phot'] + - - - SELECT ra, dec FROM gdr1.gaia_source @@ -732,6 +785,7 @@ postgresql_tests: - ['scircle', 'spoint'] - ['ra: gdr1.gaia_source.ra', 'dec: gdr1.gaia_source.dec'] - ['gdr1.gaia_source'] + - - - SELECT q2.c / q1.c FROM ( @@ -749,6 +803,7 @@ postgresql_tests: - ['COUNT'] - - ['gdr1.tgas_source'] + - - - SELECT * FROM gdr2.vari_cepheid AS v @@ -760,6 +815,7 @@ postgresql_tests: - ['scircle', 'spoint'] - ['*: gdr2.vari_cepheid.*'] - ['gdr2.gaia_source', 'gdr2.vari_cepheid'] + - - - SELECT ra FROM gdr2.gaia_source AS gaia @@ -770,14 +826,19 @@ postgresql_tests: - ['RADIANS', 'spoint', 'scircle'] - ['ra: gdr2.gaia_source.ra'] - ['gdr2.gaia_source'] + - + - + - "SELECT specuid, ra, dec FROM dr1.spectrum WHERE QMOST_SPEC_IS_IN_SURVEY(specuid, '04');" + - ['dr1.spectrum.specuid', 'dr1.spectrum.ra', 'dr1.spectrum.dec'] + - ["where"] + - ["QMOST_SPEC_IS_IN_SURVEY"] + - ["specuid: dr1.spectrum.specuid", "ra: dr1.spectrum.ra", "dec: dr1.spectrum.dec"] + - ["dr1.spectrum"] + - + - ["QMOST_SPEC_IS_IN_SURVEY"] -# Each test below consists of: -# -# - ADQL query string -# - translated query string - adql_mysql_tests: - - SELECT POINT('icrs', 10, 10) AS "p" FROM "db".tab diff --git a/src/queryparser/testing/utils.py b/src/queryparser/testing/utils.py index c89ad7f..012e40a 100644 --- a/src/queryparser/testing/utils.py +++ b/src/queryparser/testing/utils.py @@ -6,12 +6,28 @@ def _test_parsing(query_processor, test, translate=False): - if len(test) == 6: - query, columns, keywords, functions, display_columns, tables = test + if len(test) == 7: + ( + query, + columns, + keywords, + functions, + display_columns, + tables, + replace_function_names, + ) = test replace_schema_name = None - elif len(test) == 7: - query, columns, keywords, functions, display_columns, tables,\ - replace_schema_name = test + elif len(test) == 8: + ( + query, + columns, + keywords, + functions, + display_columns, + tables, + replace_schema_name, + replace_function_names, + ) = test if translate: adt = ADQLQueryTranslator() @@ -21,20 +37,30 @@ def _test_parsing(query_processor, test, translate=False): elif query_processor == PostgreSQLQueryProcessor: query = adt.to_postgresql() - if replace_schema_name is None: - qp = query_processor(query) - else: - qp = query_processor() - qp.set_query(query) - qp.process_query(replace_schema_name=replace_schema_name) - - qp_columns = ['.'.join([str(j) for j in i[:3]]) for i in qp.columns - if i[0] is not None and i[1] is not None] - qp_display_columns = ['%s: %s' % (str(i[0]), - '.'.join([str(j) for j in i[1]])) - for i in qp.display_columns] - qp_tables = ['.'.join([str(j) for j in i]) for i in qp.tables - if i[0] is not None and i[1] is not None] + if replace_function_names is None: + replace_function_names = [] + + qp = query_processor() + qp.set_query(query) + qp.process_query( + replace_schema_name=replace_schema_name, + replace_function_names=replace_function_names, + ) + + qp_columns = [ + '.'.join([str(j) for j in i[:3]]) + for i in qp.columns + if i[0] is not None and i[1] is not None + ] + qp_display_columns = [ + '%s: %s' % (str(i[0]), '.'.join([str(j) for j in i[1]])) + for i in qp.display_columns + ] + qp_tables = [ + '.'.join([str(j) for j in i]) + for i in qp.tables + if i[0] is not None and i[1] is not None + ] if columns is not None: assert set(columns) == set(qp_columns) @@ -50,4 +76,3 @@ def _test_parsing(query_processor, test, translate=False): if tables is not None: assert set(tables) == set(qp_tables) -