From 2e5135357277ae6b624fff6c28d3e42370e179f9 Mon Sep 17 00:00:00 2001 From: Colin Harrison Date: Fri, 20 Mar 2026 13:32:25 -0700 Subject: [PATCH] formatter test cases fixes --- core/query_formatter.py | 110 ++++++++++++++++++++++++++++------ tests/test_query_formatter.py | 84 ++++++++++++-------------- 2 files changed, 130 insertions(+), 64 deletions(-) diff --git a/core/query_formatter.py b/core/query_formatter.py index ee6c610..b27ed1b 100644 --- a/core/query_formatter.py +++ b/core/query_formatter.py @@ -1,3 +1,4 @@ +import re import mo_sql_parsing as mosql from core.ast.node import ( QueryNode, SelectNode, FromNode, WhereNode, TableNode, GroupByNode, HavingNode, @@ -13,7 +14,10 @@ def format(self, query: QueryNode) -> str: # [2] Any (JSON) -> str sql = mosql.format(json_query) - + + # Fixes edge case where formatting json with INTERVAL '0' SECOND into SQL adds quotes + sql = re.sub(r"INTERVAL '(\d+)'", r'INTERVAL \1', sql) + return sql def ast_to_json(node: QueryNode) -> dict: @@ -23,7 +27,8 @@ def ast_to_json(node: QueryNode) -> dict: # process each clause in the query for child in node.children: if child.type == NodeType.SELECT: - result['select'] = format_select(child) + select_result = format_select(child) + result.update(select_result) elif child.type == NodeType.FROM: result['from'] = format_from(child) elif child.type == NodeType.WHERE: @@ -42,11 +47,21 @@ def ast_to_json(node: QueryNode) -> dict: return result -def format_select(select_node: SelectNode) -> list: - """Format SELECT clause""" - items = [] +def format_select(select_node: SelectNode) -> dict: + """Format SELECT clause, returning dict with select/select_distinct and optional distinct_on keys""" + result = {} - for child in select_node.children: + children = list(select_node.children) + if select_node.distinct_on is not None: + children = children[:-1] + distinct_on_items = [format_expression(item) for item in select_node.distinct_on.children] + if len(distinct_on_items) == 1: + result['distinct_on'] = {'value': distinct_on_items[0]} + else: + result['distinct_on'] = [{'value': item} for item in distinct_on_items] + + items = [] + for child in children: if child.type == NodeType.COLUMN: if child.alias: items.append({'name': child.alias, 'value': format_expression(child)}) @@ -61,7 +76,9 @@ def format_select(select_node: SelectNode) -> list: else: items.append({'value': format_expression(child)}) - return items + select_key = 'select_distinct' if select_node.distinct else 'select' + result[select_key] = items + return result def format_from(from_node: FromNode) -> list: @@ -81,8 +98,8 @@ def format_from(from_node: FromNode) -> list: sources.extend(join_sources) else: sources.append(join_sources) - elif child.type == NodeType.TABLE: - sources.append(format_table(child)) + else: + sources.append(format_source(child)) return sources @@ -104,9 +121,9 @@ def format_join(join_node: JoinNode) -> list: if left_node.type == NodeType.JOIN: # Nested join - recursively format result.extend(format_join(left_node)) - elif left_node.type == NodeType.TABLE: + else: # Simple table - this becomes the FROM table - result.append(format_table(left_node)) + result.append(format_source(left_node)) # Format the join itself join_dict = {} @@ -122,7 +139,7 @@ def format_join(join_node: JoinNode) -> list: } join_key = join_type_map.get(join_node.join_type, 'join') - join_dict[join_key] = format_table(right_node) + join_dict[join_key] = format_source(right_node) # Add join condition if it exists if join_condition: @@ -133,6 +150,19 @@ def format_join(join_node: JoinNode) -> list: return result +def format_source(node: Node) -> dict: + """Format a table or subquery reference for use in FROM/JOIN""" + if node.type == NodeType.TABLE: + return format_table(node) + elif node.type == NodeType.SUBQUERY: + subquery_child = list(node.children)[0] + result = {'value': ast_to_json(subquery_child)} + if node.alias: + result['name'] = node.alias + return result + raise ValueError(f"Unsupported source type: {node.type}") + + def format_table(table_node: TableNode) -> dict: """Format a table reference""" result = {'value': table_node.name} @@ -202,15 +232,27 @@ def format_expression(node: Node): return node.name elif node.type == NodeType.LITERAL: + if node.value is None: + return {'null': {}} if isinstance(node.value, str): return {'literal': node.value} - return node.value elif node.type == NodeType.FUNCTION: # format: {'function_name': args} func_name = node.name.lower() - args = [format_expression(arg) for arg in node.children] + children = list(node.children) + + if len(children) == 1 and children[0].type == NodeType.FUNCTION and children[0].name.upper() == 'DISTINCT': + distinct_args = [format_expression(a) for a in children[0].children] + return {'distinct': True, func_name: distinct_args[0] if len(distinct_args) == 1 else distinct_args} + + if func_name == 'extract': + keyword = children[0].value.lower() if hasattr(children[0], 'value') else format_expression(children[0]) + from_expr = format_expression(children[1]) + return {'extract': [keyword, from_expr]} + + args = [format_expression(arg) for arg in children] return {func_name: args[0] if len(args) == 1 else args} elif node.type == NodeType.SUBQUERY: @@ -225,9 +267,15 @@ def format_expression(node: Node): '>=': 'gte', '<=': 'lte', '=': 'eq', - '!=': 'ne', + '!=': 'neq', 'AND': 'and', 'OR': 'or', + 'IN': 'in', + 'LIKE': 'like', + '+': 'add', + '-': 'sub', + '*': 'mul', + '/': 'div', } children = list(node.children) @@ -244,8 +292,17 @@ def format_expression(node: Node): } op_name = unary_op_map.get(node.name.upper(), node.name.lower()) return {op_name: operand} - - op_name = op_map.get(node.name.upper(), node.name.lower()) + + if node.name.upper() == 'IS' and len(children) == 2: + right = children[1] + if right.type == NodeType.LITERAL and right.value is None: + return {'missing': format_expression(children[0])} + + op_name = op_map.get(node.name, op_map.get(node.name.upper(), node.name.lower())) + + if op_name == 'sub' and len(children) == 2 and children[0].type == NodeType.LITERAL and children[0].value == 0: + return {'neg': format_expression(children[1])} + left = format_expression(children[0]) if len(children) == 2: @@ -259,5 +316,24 @@ def format_expression(node: Node): elif node.type == NodeType.TABLE: return format_table(node) + elif node.type == NodeType.DATA_TYPE: + return {node.name.lower(): {}} + + elif node.type == NodeType.LIST: + return [format_expression(item) for item in node.children] + + elif node.type == NodeType.CASE: + case_list = [] + for wt in node.whens: + case_list.append({'when': format_expression(wt.when), 'then': format_expression(wt.then)}) + if node.else_val is not None: + case_list.append(format_expression(node.else_val)) + return {'case': case_list} + + elif node.type == NodeType.INTERVAL: + value = format_expression(node.value) if isinstance(node.value, Node) else node.value + unit = node.unit.name.lower() + return {'interval': [value, unit]} + else: raise ValueError(f"Unsupported node type in expression: {node.type}") \ No newline at end of file diff --git a/tests/test_query_formatter.py b/tests/test_query_formatter.py index 3544402..ca7a205 100644 --- a/tests/test_query_formatter.py +++ b/tests/test_query_formatter.py @@ -38,15 +38,15 @@ def test_subquery_format(): def test_query_1(): """Query 1: Remove Cast Date Match Twice.""" query = get_query(1) - #sql = formatter.format(get_ast(1)) - #assert parse(sql) == parse(query["pattern"]) + sql = formatter.format(get_ast(1)) + assert parse(sql) == parse(query["pattern"]) def test_query_2(): """Query 2: Remove Cast Date Match Once.""" query = get_query(2) - #sql = formatter.format(get_ast(2)) - #assert parse(sql) == parse(query["rewrite"]) + sql = formatter.format(get_ast(2)) + assert parse(sql) == parse(query["rewrite"]) # query 3 has the exact same query as query 2, so I skipped it @@ -55,8 +55,8 @@ def test_query_2(): def test_query_4(): """Query 4.""" query = get_query(4) - #sql = formatter.format(get_ast(4)) - #assert parse(sql) == parse(query["rewrite"]) + sql = formatter.format(get_ast(4)) + assert parse(sql) == parse(query["rewrite"]) # query 5 has the exact same query as query 4, so I skipped it @@ -97,29 +97,28 @@ def test_query_11(): """Query 11: Subquery to Join Match 3.""" query = get_query(11) sql = formatter.format(get_ast(11)) - # TODO: Rewrite has SELECT DISTINCT (not supported by parser yet) - #assert parse(sql) == parse(query["rewrite"]) + assert parse(sql) == parse(query["rewrite"]) def test_query_12(): """Query 12: Join to Filter Match 1.""" query = get_query(12) - #sql = formatter.format(get_ast(12)) - #assert parse(sql) == parse(query["pattern"]) + sql = formatter.format(get_ast(12)) + assert parse(sql) == parse(query["pattern"]) def test_query_13(): """Query 13: Join to Filter Match 2.""" query = get_query(13) - #sql = formatter.format(get_ast(13)) - #assert parse(sql) == parse(query["pattern"]) + sql = formatter.format(get_ast(13)) + assert parse(sql) == parse(query["pattern"]) def test_query_14(): """Query 14: Test Rule Wetune 90 Match.""" query = get_query(14) - #sql = formatter.format(get_ast(14)) - #assert parse(sql) == parse(query["pattern"]) + sql = formatter.format(get_ast(14)) + assert parse(sql) == parse(query["pattern"]) # TODO: Query 15 uses UNION, which is not supported by parser yet @@ -128,24 +127,22 @@ def test_query_14(): def test_query_16(): """Query 16: Remove Max Distinct.""" query = get_query(16) - #sql = formatter.format(get_ast(16)) - # TODO: DISTINCT is not supported by parser yet - #assert parse(sql) == parse(query["pattern"]) + sql = formatter.format(get_ast(16)) + assert parse(sql) == parse(query["pattern"]) def test_query_17(): """Query 17.""" query = get_query(17) - #sql = formatter.format(get_ast(17)) - #assert parse(sql) == parse(query["pattern"]) + sql = formatter.format(get_ast(17)) + assert parse(sql) == parse(query["pattern"]) def test_query_18(): """Query 18 (parser drops SELECT for SELECT DISTINCT with comma join).""" query = get_query(18) - #sql = formatter.format(get_ast(18)) - # TODO: DISTINCT is not supported by parser yet - #assert parse(sql) == parse(query["pattern"]) + sql = formatter.format(get_ast(18)) + assert parse(sql) == parse(query["pattern"]) def test_query_19(): @@ -158,8 +155,8 @@ def test_query_19(): def test_query_20(): """Query 20: Partial Matching Base Case 2.""" query = get_query(20) - #sql = formatter.format(get_ast(20)) - #assert parse(sql) == parse(query["pattern"]) + sql = formatter.format(get_ast(20)) + assert parse(sql) == parse(query["pattern"]) def test_query_21(): @@ -208,8 +205,7 @@ def test_query_27(): """Query 27: Remove Where True.""" query = get_query(27) sql = formatter.format(get_ast(27)) - # TODO: parser does not support arithmetic expressions yet - #assert parse(sql) == parse(query["pattern"]) + assert parse(sql) == parse(query["pattern"]) def test_query_28(): @@ -232,9 +228,8 @@ def test_query_30(): def test_query_31(): """Query 31: Aggregation to Subquery.""" query = get_query(31) - #sql = formatter.format(get_ast(31)) - # TODO: CASE not cleanly supported yet - #assert parse(sql) == parse(query["pattern"]) + sql = formatter.format(get_ast(31)) + assert parse(sql) == parse(query["pattern"]) # TODO: Query 32: UNION not supported by parser @@ -258,8 +253,7 @@ def test_query_35(): """Query 35: Spreadsheet ID 9.""" query = get_query(35) sql = formatter.format(get_ast(35)) - # TODO: DISTINCT is not supported by parser yet - #assert parse(sql) == parse(query["pattern"]) + assert parse(sql) == parse(query["pattern"]) def test_query_36(): @@ -279,44 +273,40 @@ def test_query_37(): def test_query_38(): """Query 38: Spreadsheet ID 12.""" query = get_query(38) - #sql = formatter.format(get_ast(38)) - #assert parse(sql) == parse(query["pattern"]) + sql = formatter.format(get_ast(38)) + assert parse(sql) == parse(query["pattern"]) def test_query_39(): """Query 39: Spreadsheet ID 15.""" query = get_query(39) - #sql = formatter.format(get_ast(39)) - #assert parse(sql) == parse(query["pattern"]) + sql = formatter.format(get_ast(39)) + assert parse(sql) == parse(query["pattern"]) def test_query_40(): """Query 40.""" query = get_query(40) - #sql = formatter.format(get_ast(40)) - # TODO: DISTINCT ON is not supported by parser yet - #assert parse(sql) == parse(query["pattern"]) + sql = formatter.format(get_ast(40)) + assert parse(sql) == parse(query["pattern"]) def test_query_41(): """Query 41: Spreadsheet ID 20.""" query = get_query(41) - #sql = formatter.format(get_ast(41)) - # TODO: NULL keyword and IS NULL not fully supported yet - #assert parse(sql) == parse(query["pattern"]) + sql = formatter.format(get_ast(41)) + assert parse(sql) == parse(query["pattern"]) def test_query_42(): """Query 42: PostgreSQL Test.""" query = get_query(42) - #sql = formatter.format(get_ast(42)) - # TODO: Special query, please double check the AST - #assert parse(sql) == parse(query["pattern"]) + sql = formatter.format(get_ast(42)) + assert parse(sql) == parse(query["pattern"]) def test_query_43(): """Query 43: MySQL Test.""" query = get_query(43) - #sql = formatter.format(get_ast(43)) - # TODO: INTERVAL unit keyword not fully supported - #assert parse(sql) == parse(query["pattern"]) \ No newline at end of file + sql = formatter.format(get_ast(43)) + assert parse(sql) == parse(query["pattern"]) \ No newline at end of file