Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 93 additions & 17 deletions core/query_formatter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import mo_sql_parsing as mosql
from core.ast.node import (
QueryNode, SelectNode, FromNode, WhereNode, TableNode, GroupByNode, HavingNode,
Expand All @@ -13,7 +14,10 @@ def format(self, query: QueryNode) -> str:

# [2] Any (JSON) -> str
sql = mosql.format(json_query)


# Fixes edge case where formatting json with INTERVAL '0' SECOND into SQL adds quotes
sql = re.sub(r"INTERVAL '(\d+)'", r'INTERVAL \1', sql)

return sql

def ast_to_json(node: QueryNode) -> dict:
Expand All @@ -23,7 +27,8 @@ def ast_to_json(node: QueryNode) -> dict:
# process each clause in the query
for child in node.children:
if child.type == NodeType.SELECT:
result['select'] = format_select(child)
select_result = format_select(child)
result.update(select_result)
elif child.type == NodeType.FROM:
result['from'] = format_from(child)
elif child.type == NodeType.WHERE:
Expand All @@ -42,11 +47,21 @@ def ast_to_json(node: QueryNode) -> dict:
return result


def format_select(select_node: SelectNode) -> list:
"""Format SELECT clause"""
items = []
def format_select(select_node: SelectNode) -> dict:
"""Format SELECT clause, returning dict with select/select_distinct and optional distinct_on keys"""
result = {}

for child in select_node.children:
children = list(select_node.children)
if select_node.distinct_on is not None:
children = children[:-1]
distinct_on_items = [format_expression(item) for item in select_node.distinct_on.children]
if len(distinct_on_items) == 1:
result['distinct_on'] = {'value': distinct_on_items[0]}
else:
result['distinct_on'] = [{'value': item} for item in distinct_on_items]

items = []
for child in children:
if child.type == NodeType.COLUMN:
if child.alias:
items.append({'name': child.alias, 'value': format_expression(child)})
Expand All @@ -61,7 +76,9 @@ def format_select(select_node: SelectNode) -> list:
else:
items.append({'value': format_expression(child)})

return items
select_key = 'select_distinct' if select_node.distinct else 'select'
result[select_key] = items
return result


def format_from(from_node: FromNode) -> list:
Expand All @@ -81,8 +98,8 @@ def format_from(from_node: FromNode) -> list:
sources.extend(join_sources)
else:
sources.append(join_sources)
elif child.type == NodeType.TABLE:
sources.append(format_table(child))
else:
sources.append(format_source(child))

return sources

Expand All @@ -104,9 +121,9 @@ def format_join(join_node: JoinNode) -> list:
if left_node.type == NodeType.JOIN:
# Nested join - recursively format
result.extend(format_join(left_node))
elif left_node.type == NodeType.TABLE:
else:
# Simple table - this becomes the FROM table
result.append(format_table(left_node))
result.append(format_source(left_node))

# Format the join itself
join_dict = {}
Expand All @@ -122,7 +139,7 @@ def format_join(join_node: JoinNode) -> list:
}

join_key = join_type_map.get(join_node.join_type, 'join')
join_dict[join_key] = format_table(right_node)
join_dict[join_key] = format_source(right_node)

# Add join condition if it exists
if join_condition:
Expand All @@ -133,6 +150,19 @@ def format_join(join_node: JoinNode) -> list:
return result


def format_source(node: Node) -> dict:
"""Format a table or subquery reference for use in FROM/JOIN"""
if node.type == NodeType.TABLE:
return format_table(node)
elif node.type == NodeType.SUBQUERY:
subquery_child = list(node.children)[0]
result = {'value': ast_to_json(subquery_child)}
if node.alias:
result['name'] = node.alias
return result
raise ValueError(f"Unsupported source type: {node.type}")


def format_table(table_node: TableNode) -> dict:
"""Format a table reference"""
result = {'value': table_node.name}
Expand Down Expand Up @@ -202,15 +232,27 @@ def format_expression(node: Node):
return node.name

elif node.type == NodeType.LITERAL:
if node.value is None:
return {'null': {}}
if isinstance(node.value, str):
return {'literal': node.value}

return node.value

elif node.type == NodeType.FUNCTION:
# format: {'function_name': args}
func_name = node.name.lower()
args = [format_expression(arg) for arg in node.children]
children = list(node.children)

if len(children) == 1 and children[0].type == NodeType.FUNCTION and children[0].name.upper() == 'DISTINCT':
distinct_args = [format_expression(a) for a in children[0].children]
return {'distinct': True, func_name: distinct_args[0] if len(distinct_args) == 1 else distinct_args}

if func_name == 'extract':
keyword = children[0].value.lower() if hasattr(children[0], 'value') else format_expression(children[0])
from_expr = format_expression(children[1])
return {'extract': [keyword, from_expr]}

args = [format_expression(arg) for arg in children]
return {func_name: args[0] if len(args) == 1 else args}

elif node.type == NodeType.SUBQUERY:
Expand All @@ -225,9 +267,15 @@ def format_expression(node: Node):
'>=': 'gte',
'<=': 'lte',
'=': 'eq',
'!=': 'ne',
'!=': 'neq',
'AND': 'and',
'OR': 'or',
'IN': 'in',
'LIKE': 'like',
'+': 'add',
'-': 'sub',
'*': 'mul',
'/': 'div',
}

children = list(node.children)
Expand All @@ -244,8 +292,17 @@ def format_expression(node: Node):
}
op_name = unary_op_map.get(node.name.upper(), node.name.lower())
return {op_name: operand}

op_name = op_map.get(node.name.upper(), node.name.lower())

if node.name.upper() == 'IS' and len(children) == 2:
right = children[1]
if right.type == NodeType.LITERAL and right.value is None:
return {'missing': format_expression(children[0])}

op_name = op_map.get(node.name, op_map.get(node.name.upper(), node.name.lower()))

if op_name == 'sub' and len(children) == 2 and children[0].type == NodeType.LITERAL and children[0].value == 0:
return {'neg': format_expression(children[1])}

left = format_expression(children[0])

if len(children) == 2:
Expand All @@ -259,5 +316,24 @@ def format_expression(node: Node):
elif node.type == NodeType.TABLE:
return format_table(node)

elif node.type == NodeType.DATA_TYPE:
return {node.name.lower(): {}}

elif node.type == NodeType.LIST:
return [format_expression(item) for item in node.children]

elif node.type == NodeType.CASE:
case_list = []
for wt in node.whens:
case_list.append({'when': format_expression(wt.when), 'then': format_expression(wt.then)})
if node.else_val is not None:
case_list.append(format_expression(node.else_val))
return {'case': case_list}

elif node.type == NodeType.INTERVAL:
value = format_expression(node.value) if isinstance(node.value, Node) else node.value
unit = node.unit.name.lower()
return {'interval': [value, unit]}

else:
raise ValueError(f"Unsupported node type in expression: {node.type}")
Loading
Loading