Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion sidemantic/core/semantic_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def query(
ungrouped: bool = False,
parameters: dict[str, any] | None = None,
use_preaggregations: bool | None = None,
post_process: str | None = None,
):
"""Execute a query against the semantic layer.

Expand All @@ -448,6 +449,9 @@ def query(
ungrouped: If True, return raw rows without aggregation (no GROUP BY)
parameters: Template parameters for Jinja2 rendering
use_preaggregations: Override pre-aggregation routing setting for this query
post_process: Optional SQL to wrap around the semantic query result.
Use {inner} as a placeholder for the compiled semantic query, e.g.:
"SELECT *, revenue / count AS avg_value FROM ({inner})"

Returns:
DuckDB relation object (can convert to DataFrame with .df() or .to_df())
Expand All @@ -462,6 +466,7 @@ def query(
ungrouped=ungrouped,
parameters=parameters,
use_preaggregations=use_preaggregations,
post_process=post_process,
)

return self.adapter.execute(sql)
Expand All @@ -479,6 +484,7 @@ def compile(
ungrouped: bool = False,
parameters: dict[str, any] | None = None,
use_preaggregations: bool | None = None,
post_process: str | None = None,
) -> str:
"""Compile a query to SQL without executing.

Expand All @@ -493,6 +499,9 @@ def compile(
dialect: SQL dialect override (defaults to layer's dialect)
ungrouped: If True, return raw rows without aggregation (no GROUP BY)
use_preaggregations: Override pre-aggregation routing setting for this query
post_process: Optional SQL to wrap around the semantic query result.
Use {inner} as a placeholder for the compiled semantic query, e.g.:
"SELECT *, revenue / count AS avg_value FROM ({inner})"

Returns:
SQL query string
Expand Down Expand Up @@ -520,7 +529,7 @@ def compile(
preagg_schema=self.preagg_schema,
)

return generator.generate(
inner_sql = generator.generate(
metrics=metrics,
dimensions=dimensions,
filters=filters,
Expand All @@ -533,6 +542,24 @@ def compile(
use_preaggregations=use_preaggs,
)

if post_process is not None:
if "{inner}" not in post_process:
raise ValueError("post_process must contain a {inner} placeholder")

# Strip sidemantic instrumentation comment
stripped = inner_sql.rstrip()
last_line = stripped.split("\n")[-1].strip()
if last_line.startswith("-- sidemantic:"):
stripped = "\n".join(stripped.split("\n")[:-1])

# Inner SQL (including any CTEs) is placed directly in the
# subquery position. CTEs inside subqueries are valid SQL in
# all target databases and naturally scoped, avoiding name
# collisions with CTEs in the post_process SQL.
return post_process.replace("{inner}", stripped)

return inner_sql

def explain(
self,
metrics: list[str] | None = None,
Expand Down
107 changes: 78 additions & 29 deletions sidemantic/sql/query_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ def rewrite(self, sql: str, strict: bool = True) -> str:
# Check if this is a CTE-based query or has subqueries
has_ctes = parsed.args.get("with") is not None
has_subquery_in_from = self._has_subquery_in_from(parsed)
has_subquery_in_joins = any(isinstance(join.this, exp.Subquery) for join in (parsed.args.get("joins") or []))

if has_ctes or has_subquery_in_from:
if has_ctes or has_subquery_in_from or has_subquery_in_joins:
Comment on lines +121 to +123
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve root semantic rewrite when JOIN has subquery

Routing every query with a JOIN subquery through _rewrite_with_ctes_or_subqueries skips _rewrite_simple_query for the root SELECT, because _rewrite_select_tree only rewrites child scopes. For queries whose root FROM is a semantic model (for example, FROM orders o JOIN (SELECT ...) s), metric references like o.revenue are no longer rewritten through the semantic pipeline or rejected by the explicit JOIN guard, so execution can now fail with binder errors or silently return raw row-level columns instead of semantic aggregates.

Useful? React with 👍 / 👎.

# Handle CTEs and subqueries
return self._rewrite_with_ctes_or_subqueries(parsed)

Expand Down Expand Up @@ -1851,41 +1852,89 @@ def _has_subquery_in_from(self, select: exp.Select) -> bool:
def _rewrite_with_ctes_or_subqueries(self, parsed: exp.Select) -> str:
"""Rewrite query that contains CTEs or subqueries.

Strategy:
1. Rewrite each CTE that references semantic models
2. Rewrite subqueries in FROM clause
3. Return the modified SQL
Recursively walks the query tree bottom-up, rewriting any
SELECT whose FROM target resolves to a semantic model.
Outer queries are left as plain SQL, so post-processing
(CASE, window functions, arithmetic, etc.) works naturally.
"""
# Handle CTEs
if parsed.args.get("with"):
with_clause = parsed.args["with"]
for cte in with_clause.expressions:
# Each CTE has a name (alias) and a query (this)
self._rewrite_select_tree(parsed)

# If the root SELECT itself references a semantic model, it must
# still go through _rewrite_simple_query (which enforces the
# explicit JOIN guard and performs semantic rewriting).
if self._references_semantic_model(parsed):
# Save user-defined CTEs before _rewrite_simple_query replaces
# the entire query with fresh generator output.
original_with = parsed.args.get("with")

rewritten_sql = self._rewrite_simple_query(parsed)

if original_with:
# Merge user CTEs into the generated SQL so references
# from filters/expressions (e.g. IN (SELECT ... FROM cte))
# remain valid.
rewritten = sqlglot.parse_one(rewritten_sql, dialect=self.dialect)
gen_with = rewritten.args.get("with")
if gen_with:
# Check for CTE name collisions between user and generated CTEs
user_names = {cte.alias for cte in original_with.expressions}
for gen_cte in gen_with.expressions:
if gen_cte.alias in user_names:
raise ValueError(
f"CTE name '{gen_cte.alias}' conflicts with an internally "
f"generated name. Please choose a different CTE name."
)

user_ctes = [cte.copy() for cte in original_with.expressions]
gen_with.set("expressions", user_ctes + list(gen_with.expressions))
Comment on lines +1888 to +1889
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve WITH RECURSIVE when merging root CTEs

When a root semantic query includes WITH RECURSIVE, this merge path copies only original_with.expressions into the generated WITH and drops the recursive flag from the original clause. The rewritten SQL becomes WITH ... instead of WITH RECURSIVE ..., so self-referencing CTEs fail at execution (e.g., DuckDB reports a circular CTE reference). I reproduced this with a root FROM orders query filtered by a recursive CTE; rewrite output omitted RECURSIVE.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid duplicate CTE aliases when merging root WITH clauses

When a root semantic query already has a user CTE name that matches an internally generated CTE (for example orders_cte), this merge creates duplicate aliases and the rewritten SQL fails to parse. I reproduced this with WITH orders_cte AS (...) SELECT orders.revenue FROM orders ..., which rewrote to two orders_cte entries and raised a parser error. Since generated CTE names are implementation details, users can hit this unintentionally; the merge needs collision handling (rename or namespace).

Useful? React with 👍 / 👎.

# Preserve WITH RECURSIVE from the original query
if original_with.args.get("recursive"):
gen_with.set("recursive", True)
else:
rewritten.set("with", original_with.copy())
return rewritten.sql(dialect=self.dialect)

return rewritten_sql

return parsed.sql(dialect=self.dialect)

def _rewrite_select_tree(self, select: exp.Select):
"""Recursively rewrite semantic subqueries and CTEs (bottom-up).

At each level: recurse into children first, then rewrite this
node if it directly references a semantic model.
"""
# Recurse into CTEs
if select.args.get("with"):
for cte in select.args["with"].expressions:
cte_query = cte.this
if isinstance(cte_query, exp.Select):
# Check if this CTE references a semantic model
self._rewrite_select_tree(cte_query)
if self._references_semantic_model(cte_query):
# Rewrite the CTE query
rewritten_cte_sql = self._rewrite_simple_query(cte_query)
# Parse the rewritten SQL and replace the CTE query
rewritten_cte = sqlglot.parse_one(rewritten_cte_sql, dialect=self.dialect)
cte.set("this", rewritten_cte)

# Handle subquery in FROM
from_clause = parsed.args.get("from")
rewritten_sql = self._rewrite_simple_query(cte_query)
cte.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect))

# Recurse into FROM subquery
from_clause = select.args.get("from")
if from_clause and isinstance(from_clause.this, exp.Subquery):
subquery = from_clause.this
subquery_select = subquery.this
if isinstance(subquery_select, exp.Select) and self._references_semantic_model(subquery_select):
# Rewrite the subquery
rewritten_subquery_sql = self._rewrite_simple_query(subquery_select)
rewritten_subquery = sqlglot.parse_one(rewritten_subquery_sql, dialect=self.dialect)
subquery.set("this", rewritten_subquery)

# Return the modified SQL
# Note: Individual CTEs/subqueries are already instrumented by _rewrite_simple_query -> generator
# The outer query wrapper doesn't need separate instrumentation
return parsed.sql(dialect=self.dialect)
if isinstance(subquery_select, exp.Select):
self._rewrite_select_tree(subquery_select)
if self._references_semantic_model(subquery_select):
rewritten_sql = self._rewrite_simple_query(subquery_select)
subquery.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect))

# Recurse into JOIN subqueries
for join in select.args.get("joins") or []:
join_expr = join.this
if isinstance(join_expr, exp.Subquery):
join_select = join_expr.this
if isinstance(join_select, exp.Select):
self._rewrite_select_tree(join_select)
if self._references_semantic_model(join_select):
rewritten_sql = self._rewrite_simple_query(join_select)
join_expr.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect))

def _references_semantic_model(self, select: exp.Select) -> bool:
"""Check if a SELECT statement references any semantic models."""
Expand Down
Loading
Loading