Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 192 additions & 28 deletions core/query_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from core.ast.node import (
Node, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode,
LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode,
OrderByNode, OrderByItemNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, JoinNode
LiteralNode, DataTypeNode, TimeUnitNode, IntervalNode,
CaseNode, WhenThenNode,
OperatorNode, FunctionNode, GroupByNode, HavingNode,
OrderByNode, OrderByItemNode, LimitNode, OffsetNode, SubqueryNode,
VarNode, VarSetNode, JoinNode, ListNode
)
# TODO: implement VarNode, VarSetNode
from core.ast.enums import JoinType, SortOrder
Expand All @@ -14,6 +17,8 @@ class QueryParser:
'eq': '=', 'neq': '!=', 'ne': '!=',
'gt': '>', 'gte': '>=', 'lt': '<', 'lte': '<=',
'and': 'AND', 'or': 'OR', 'in': 'IN',
'add': '+', 'sub': '-', 'mul': '*', 'div': '/',
'is': 'IS', 'missing': 'MISSING',
}
_LIST_OPERATOR_KEYS = frozenset(_OPERATOR_KEY_TO_NAME.keys())

Expand Down Expand Up @@ -45,7 +50,7 @@ def parse(self, query: str) -> QueryNode:
mosql_ast = mosql.parse(query)
return self.parse_query_dict(mosql_ast, aliases={})

def parse_select(self, select_list: list, aliases: dict) -> SelectNode:
def parse_select(self, select_list: list, aliases: dict, distinct: bool = False, distinct_on_expr = None) -> SelectNode:
items = []
for item in select_list:
if isinstance(item, dict) and 'value' in item:
Expand All @@ -71,8 +76,16 @@ def parse_select(self, select_list: list, aliases: dict) -> SelectNode:
# Handle direct expression (string, int, etc.)
expression = self.parse_expression(item, aliases)
items.append(expression)

return SelectNode(items)

# Handle DISTINCT ON (PostgreSQL-style)
distinct_on_node = None
if distinct_on_expr is not None:
# mo_sql_parsing gives a single expression in 'distinct_on'.
expr = self.parse_expression(distinct_on_expr, aliases)
# Wrap in ListNode so it matches _distinct_on=ListNode([...]) in expected ASTs.
distinct_on_node = ListNode([expr])

return SelectNode(items, _distinct=distinct, _distinct_on=distinct_on_node)

def parse_from(self, from_list: list, aliases: dict) -> FromNode:
sources = []
Expand Down Expand Up @@ -161,6 +174,17 @@ def parse_from(self, from_list: list, aliases: dict) -> FromNode:
else:
# Multiple tables without explicit JOIN (cross join)
sources.append(table_node)
# Subquery in FROM specified directly as a query dict (with 'select').
elif 'select' in item:
alias = item.get('name')
subquery_query = self.parse_query_dict(item, aliases={})
subquery_node = SubqueryNode(subquery_query, alias)
if alias:
aliases[alias] = subquery_node
if left_source is None:
left_source = subquery_node
else:
sources.append(subquery_node)
elif isinstance(item, str):
# Simple string table name
table_node = TableNode(item)
Expand All @@ -177,22 +201,32 @@ def parse_from(self, from_list: list, aliases: dict) -> FromNode:

def parse_where(self, where_dict: dict, aliases: dict) -> WhereNode:
predicates = []
# For WHERE, we follow the original SQL text exactly:
# - If the text used a base column/expression, we keep a fresh node.
# - If it used an alias token, mo_sql_parsing will already give us that
# as a simple string and it will be resolved during parse_expression.
predicates.append(self.parse_expression(where_dict, aliases))
return WhereNode(predicates)

def parse_group_by(self, group_by_list: list, aliases: dict) -> GroupByNode:
items = []
for item in group_by_list:
if isinstance(item, dict) and 'value' in item:
expr = self.parse_expression(item['value'], aliases)
# Resolve aliases
expr = self.resolve_aliases(expr, aliases)
items.append(expr)
value = item['value']
# If GROUP BY refers to a bare alias name (e.g. 'dept_name'),
# reuse the aliased node from SELECT; otherwise parse literally.
if isinstance(value, str) and value in aliases:
items.append(aliases[value])
else:
expr = self.parse_expression(value, aliases)
items.append(expr)
else:
# Handle direct expression (string, int, etc.)
expr = self.parse_expression(item, aliases)
expr = self.resolve_aliases(expr, aliases)
items.append(expr)
if isinstance(item, str) and item in aliases:
items.append(aliases[item])
else:
expr = self.parse_expression(item, aliases)
items.append(expr)

return GroupByNode(items)

Expand Down Expand Up @@ -250,28 +284,22 @@ def resolve_aliases(self, expr: Node, aliases: dict) -> Node:
else:
raise ValueError(f"OperatorNode has {len(expr.children)} children, expected 2 for binary operators or 1 for unary operators")
elif isinstance(expr, FunctionNode):
# Check if this function matches an aliased function from SELECT
# Resolve HAVING aggregates like COUNT(*), SUM(...) to their
# aliased counterparts from SELECT when they are structurally equal.
if expr.alias is None:
for alias, aliased_expr in aliases.items():
if isinstance(aliased_expr, FunctionNode):
if (expr.name == aliased_expr.name and
if (expr.name == aliased_expr.name and
len(expr.children) == len(aliased_expr.children) and
all(expr.children[i] == aliased_expr.children[i]
all(expr.children[i] == aliased_expr.children[i]
for i in range(len(expr.children)))):
# This function matches an aliased one, use the alias
expr.alias = alias
break
return expr
elif isinstance(expr, ColumnNode):
# Check if this column matches an aliased column from SELECT
if expr.alias is None:
for alias, aliased_expr in aliases.items():
if isinstance(aliased_expr, ColumnNode):
if (expr.name == aliased_expr.name and
expr.parent_alias == aliased_expr.parent_alias):
# This column matches an aliased one, use the alias
expr.alias = alias
break
# Do not propagate column aliases into other clauses based solely
# on structural equality; if the SQL text used an alias token,
# parse_order_by/parse_expression will already resolve it.
return expr
else:
return expr
Expand All @@ -281,6 +309,10 @@ def parse_expression(self, expr, aliases: dict = None) -> Node:
aliases = {}

if isinstance(expr, str):
# Alias reference: if a later clause uses a SELECT alias token
# (e.g. ORDER BY dept_name), reuse the aliased expression node.
if expr in aliases:
return aliases[expr]
# Column reference
if '.' in expr:
parts = expr.split('.', 1)
Expand All @@ -307,10 +339,41 @@ def parse_expression(self, expr, aliases: dict = None) -> Node:
if 'all_columns' in expr:
return ColumnNode('*')
if 'literal' in expr:
return LiteralNode(expr['literal'])
# mo_sql_parsing uses {'literal': [..]} for IN literal lists and
# {'literal': 'value'} for scalar literals.
value = expr['literal']
if isinstance(value, list):
# List of simple literals -> ListNode of LiteralNode
items = [LiteralNode(v) for v in value]
return ListNode(items)
return LiteralNode(value)
# mo_sql_parsing represents NULL keyword as a function-like dict.
# Normalize it to a LiteralNode(None) so it matches expected ASTs.
if 'null' in expr and not expr['null']:
return LiteralNode(None)

# Data type nodes used in CAST expressions (e.g. TEXT, DATE).
if len(expr) == 1:
only_key = next(iter(expr.keys()))
only_val = expr[only_key]
key_lower_single = only_key.lower()
if key_lower_single in ('text', 'date') and only_val == {}:
type_name = key_lower_single.upper()
return DataTypeNode(type_name)

# Skip metadata keys
skip_keys = {'value', 'name', 'on', 'sort'}

# Handle DISTINCT aggregates like {'distinct': True, 'max': {...}}
if expr.get('distinct') is True:
agg_keys = [k for k in expr.keys() if k not in skip_keys and k != 'distinct']
if len(agg_keys) == 1:
agg_key = agg_keys[0]
agg_value = expr[agg_key]
agg_name = self.normalize_operator_name(agg_key)
arg = self.parse_expression(agg_value, aliases)
distinct_arg = FunctionNode("DISTINCT", _args=[arg])
return FunctionNode(agg_name, _args=[distinct_arg])

# Find the operator/function key
for key in expr.keys():
Expand All @@ -321,6 +384,96 @@ def parse_expression(self, expr, aliases: dict = None) -> Node:
op_name = self.normalize_operator_name(key)
key_lower = key.lower()

# Pattern 0: IS NULL / MISSING operator
# mo_sql_parsing can emit {"missing": <expr>} for "expr IS NULL".
if key_lower == 'missing':
# Value may be a single expression or a list containing it.
target_expr = value
if isinstance(value, list) and value:
target_expr = value[0]
target = self.parse_expression(target_expr, aliases)
return OperatorNode(target, 'IS', LiteralNode(None))

# Special handling for IN so that the right-hand side becomes either
# a SubqueryNode or a ListNode, matching expected ASTs.
if key_lower == 'in':
# mo_sql_parsing patterns:
# {'in': [lhs, rhs]}
# {'in': ['col', {'literal': [...]}]}
# {'in': ['col', [v1, v2, ...]]}
if isinstance(value, list) and len(value) >= 2:
left_raw = value[0]
right_raw = value[1]
else:
# Fallback: treat as binary with parsed children
operands = [self.parse_expression(v, aliases) for v in self.normalize_to_list(value)]
if len(operands) == 1:
return operands[0]
if len(operands) == 2:
return OperatorNode(operands[0], op_name, operands[1])
result = operands[0]
for operand in operands[1:]:
result = OperatorNode(result, op_name, operand)
return result

left = self.parse_expression(left_raw, aliases)

# Subquery RHS
if isinstance(right_raw, dict) and 'select' in right_raw:
right = self.parse_expression(right_raw, aliases)
# Literal-list RHS
elif isinstance(right_raw, dict) and 'literal' in right_raw:
# parse_expression on this dict will return a ListNode or LiteralNode
right = self.parse_expression(right_raw, aliases)
elif isinstance(right_raw, list):
items = [self.parse_expression(item, aliases) for item in right_raw]
right = ListNode(items)
else:
right = self.parse_expression(right_raw, aliases)

return OperatorNode(left, op_name, right)

# Unary negation (e.g. -EXTRACT(...)) -> model as 0 - expr.
# TODO: double check on this (query 42)
if key_lower == 'neg':
inner = self.parse_expression(value, aliases)
return OperatorNode(LiteralNode(0), '-', inner)

# CASE expressions: {'case': [{'when': ..., 'then': ...}, else_expr]}
if key_lower == 'case':
if not isinstance(value, list) or len(value) == 0:
return LiteralNode(None)
*when_parts, else_part = value if len(value) > 1 else (value, None)
whens: list[WhenThenNode] = []
for branch in when_parts:
when_expr = self.parse_expression(branch['when'], aliases)
then_expr = self.parse_expression(branch['then'], aliases)
whens.append(WhenThenNode(when_expr, then_expr))
else_node = self.parse_expression(else_part, aliases) if else_part is not None else None
return CaseNode(whens, else_node)

# INTERVAL literals: {'interval': [value, 'unit']}
if key_lower == 'interval':
if isinstance(value, list) and len(value) == 2:
num_raw, unit_raw = value
num_node = self.parse_expression(num_raw, aliases)
unit_name = str(unit_raw).upper()
unit_node = TimeUnitNode(unit_name)
return IntervalNode(num_node, unit_node)

# DATE(...) function: {'date': 't1.data'}
if key_lower == 'date':
arg = self.parse_expression(value, aliases)
return FunctionNode('DATE', _args=[arg])

# EXTRACT(field FROM expr): represented as {'extract': ['dow', 'tweets.created_at']}
if key_lower == 'extract':
if isinstance(value, list) and len(value) == 2:
field_raw, expr_raw = value
field_node = LiteralNode(str(field_raw).upper())
expr_node = self.parse_expression(expr_raw, aliases)
return FunctionNode('EXTRACT', _args=[field_node, expr_node])

# Pattern 1: List value (either n-ary operator or multi-arg function)
if isinstance(value, list):
if len(value) == 0:
Expand Down Expand Up @@ -376,9 +529,20 @@ def parse_query_dict(self, query_dict: dict, aliases: dict) -> QueryNode:
order_by_clause = None
limit_clause = None
offset_clause = None
# DISTINCT and DISTINCT ON
distinct = False
distinct_on_expr = None
if 'select_distinct' in query_dict:
distinct = True
select_source = query_dict['select_distinct']
else:
select_source = query_dict.get('select')
if 'distinct_on' in query_dict:
# mo_sql_parsing uses a single expression under 'distinct_on'.
distinct_on_expr = query_dict['distinct_on'].get('value', query_dict['distinct_on'])

if 'select' in query_dict:
select_clause = self.parse_select(self.normalize_to_list(query_dict['select']), aliases)
if select_source is not None:
select_clause = self.parse_select(self.normalize_to_list(select_source), aliases, distinct=distinct, distinct_on_expr=distinct_on_expr)
if 'from' in query_dict:
from_clause = self.parse_from(self.normalize_to_list(query_dict['from']), aliases)
if 'where' in query_dict:
Expand Down
1 change: 1 addition & 0 deletions data/asts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,7 @@ def _ast_query_42() -> QueryNode:
where_condition = OperatorNode(where_123, "AND", strpos_cond)
where_clause = WhereNode([where_condition])
# GROUP BY 1, 2
# TODO: replace it by actual column nodes
group_by_clause = GroupByNode([LiteralNode(1), LiteralNode(2)])
return QueryNode(
_select=select_clause,
Expand Down
Loading
Loading