From 552c2d90916fbede22bc083390aab75d09e548b6 Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Fri, 10 Apr 2026 16:36:40 -0700 Subject: [PATCH 1/5] Add post-processing SQL over semantic query results Support arbitrary SQL (CASE, window functions, arithmetic, etc.) on top of semantic query results via subquery wrapping. The rewriter now walks the query tree recursively so nested subqueries and JOIN subqueries that reference semantic models are compiled correctly. Also adds a post_process parameter to compile() and query() for the Python API path, with automatic CTE hoisting. --- sidemantic/core/semantic_layer.py | 39 +++- sidemantic/sql/query_rewriter.py | 69 ++++--- tests/queries/test_sql_rewriter.py | 288 ++++++++++++++++++++++++++++- 3 files changed, 364 insertions(+), 32 deletions(-) diff --git a/sidemantic/core/semantic_layer.py b/sidemantic/core/semantic_layer.py index 1372944..a543164 100644 --- a/sidemantic/core/semantic_layer.py +++ b/sidemantic/core/semantic_layer.py @@ -435,6 +435,7 @@ def query( ungrouped: bool = False, parameters: dict[str, any] | None = None, use_preaggregations: bool | None = None, + post_process: str | None = None, ): """Execute a query against the semantic layer. @@ -448,6 +449,9 @@ def query( ungrouped: If True, return raw rows without aggregation (no GROUP BY) parameters: Template parameters for Jinja2 rendering use_preaggregations: Override pre-aggregation routing setting for this query + post_process: Optional SQL to wrap around the semantic query result. + Use {inner} as a placeholder for the compiled semantic query, e.g.: + "SELECT *, revenue / count AS avg_value FROM ({inner})" Returns: DuckDB relation object (can convert to DataFrame with .df() or .to_df()) @@ -462,6 +466,7 @@ def query( ungrouped=ungrouped, parameters=parameters, use_preaggregations=use_preaggregations, + post_process=post_process, ) return self.adapter.execute(sql) @@ -479,6 +484,7 @@ def compile( ungrouped: bool = False, parameters: dict[str, any] | None = None, use_preaggregations: bool | None = None, + post_process: str | None = None, ) -> str: """Compile a query to SQL without executing. @@ -493,6 +499,9 @@ def compile( dialect: SQL dialect override (defaults to layer's dialect) ungrouped: If True, return raw rows without aggregation (no GROUP BY) use_preaggregations: Override pre-aggregation routing setting for this query + post_process: Optional SQL to wrap around the semantic query result. + Use {inner} as a placeholder for the compiled semantic query, e.g.: + "SELECT *, revenue / count AS avg_value FROM ({inner})" Returns: SQL query string @@ -520,7 +529,7 @@ def compile( preagg_schema=self.preagg_schema, ) - return generator.generate( + inner_sql = generator.generate( metrics=metrics, dimensions=dimensions, filters=filters, @@ -533,6 +542,34 @@ def compile( use_preaggregations=use_preaggs, ) + if post_process is not None: + if "{inner}" not in post_process: + raise ValueError("post_process must contain a {inner} placeholder") + + # Strip sidemantic instrumentation comment + stripped = inner_sql.rstrip() + last_line = stripped.split("\n")[-1].strip() + if last_line.startswith("-- sidemantic:"): + stripped = "\n".join(stripped.split("\n")[:-1]) + + # If inner SQL starts with WITH (CTEs), hoist them outside + # the subquery position so the SQL is valid. + if stripped.lstrip().upper().startswith("WITH "): + import sqlglot + + target_dialect = dialect or self.dialect + parsed_inner = sqlglot.parse_one(stripped, dialect=target_dialect) + with_clause = parsed_inner.args.get("with") + if with_clause: + parsed_inner.set("with", None) + body = parsed_inner.sql(dialect=target_dialect) + ctes = with_clause.sql(dialect=target_dialect) + return ctes + "\n" + post_process.replace("{inner}", body) + + return post_process.replace("{inner}", stripped) + + return inner_sql + def explain( self, metrics: list[str] | None = None, diff --git a/sidemantic/sql/query_rewriter.py b/sidemantic/sql/query_rewriter.py index 457636f..7efc424 100644 --- a/sidemantic/sql/query_rewriter.py +++ b/sidemantic/sql/query_rewriter.py @@ -118,8 +118,9 @@ def rewrite(self, sql: str, strict: bool = True) -> str: # Check if this is a CTE-based query or has subqueries has_ctes = parsed.args.get("with") is not None has_subquery_in_from = self._has_subquery_in_from(parsed) + has_subquery_in_joins = any(isinstance(join.this, exp.Subquery) for join in (parsed.args.get("joins") or [])) - if has_ctes or has_subquery_in_from: + if has_ctes or has_subquery_in_from or has_subquery_in_joins: # Handle CTEs and subqueries return self._rewrite_with_ctes_or_subqueries(parsed) @@ -1851,41 +1852,51 @@ def _has_subquery_in_from(self, select: exp.Select) -> bool: def _rewrite_with_ctes_or_subqueries(self, parsed: exp.Select) -> str: """Rewrite query that contains CTEs or subqueries. - Strategy: - 1. Rewrite each CTE that references semantic models - 2. Rewrite subqueries in FROM clause - 3. Return the modified SQL + Recursively walks the query tree bottom-up, rewriting any + SELECT whose FROM target resolves to a semantic model. + Outer queries are left as plain SQL, so post-processing + (CASE, window functions, arithmetic, etc.) works naturally. """ - # Handle CTEs - if parsed.args.get("with"): - with_clause = parsed.args["with"] - for cte in with_clause.expressions: - # Each CTE has a name (alias) and a query (this) + self._rewrite_select_tree(parsed) + return parsed.sql(dialect=self.dialect) + + def _rewrite_select_tree(self, select: exp.Select): + """Recursively rewrite semantic subqueries and CTEs (bottom-up). + + At each level: recurse into children first, then rewrite this + node if it directly references a semantic model. + """ + # Recurse into CTEs + if select.args.get("with"): + for cte in select.args["with"].expressions: cte_query = cte.this if isinstance(cte_query, exp.Select): - # Check if this CTE references a semantic model + self._rewrite_select_tree(cte_query) if self._references_semantic_model(cte_query): - # Rewrite the CTE query - rewritten_cte_sql = self._rewrite_simple_query(cte_query) - # Parse the rewritten SQL and replace the CTE query - rewritten_cte = sqlglot.parse_one(rewritten_cte_sql, dialect=self.dialect) - cte.set("this", rewritten_cte) - - # Handle subquery in FROM - from_clause = parsed.args.get("from") + rewritten_sql = self._rewrite_simple_query(cte_query) + cte.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect)) + + # Recurse into FROM subquery + from_clause = select.args.get("from") if from_clause and isinstance(from_clause.this, exp.Subquery): subquery = from_clause.this subquery_select = subquery.this - if isinstance(subquery_select, exp.Select) and self._references_semantic_model(subquery_select): - # Rewrite the subquery - rewritten_subquery_sql = self._rewrite_simple_query(subquery_select) - rewritten_subquery = sqlglot.parse_one(rewritten_subquery_sql, dialect=self.dialect) - subquery.set("this", rewritten_subquery) - - # Return the modified SQL - # Note: Individual CTEs/subqueries are already instrumented by _rewrite_simple_query -> generator - # The outer query wrapper doesn't need separate instrumentation - return parsed.sql(dialect=self.dialect) + if isinstance(subquery_select, exp.Select): + self._rewrite_select_tree(subquery_select) + if self._references_semantic_model(subquery_select): + rewritten_sql = self._rewrite_simple_query(subquery_select) + subquery.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect)) + + # Recurse into JOIN subqueries + for join in select.args.get("joins") or []: + join_expr = join.this + if isinstance(join_expr, exp.Subquery): + join_select = join_expr.this + if isinstance(join_select, exp.Select): + self._rewrite_select_tree(join_select) + if self._references_semantic_model(join_select): + rewritten_sql = self._rewrite_simple_query(join_select) + join_expr.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect)) def _references_semantic_model(self, select: exp.Select) -> bool: """Check if a SELECT statement references any semantic models.""" diff --git a/tests/queries/test_sql_rewriter.py b/tests/queries/test_sql_rewriter.py index 9a707ae..04c16c3 100644 --- a/tests/queries/test_sql_rewriter.py +++ b/tests/queries/test_sql_rewriter.py @@ -684,8 +684,6 @@ def test_cte_with_limit_in_inner_query(semantic_layer): def test_nested_subquery(semantic_layer): """Test filtering semantic query results in outer query.""" - # Note: Deep nesting of subqueries (subquery within subquery) is not currently supported - # This test demonstrates single-level subquery with filtering sql = """ SELECT * FROM ( SELECT revenue, status FROM orders @@ -1080,3 +1078,289 @@ def test_granularity_on_non_time_dimension(semantic_layer): # Should work - status is a valid categorical dimension assert len(rows) == 2 # Two status groups + + +# --- Post-processing SQL over semantic query results --- + + +def test_postprocess_case_expression(semantic_layer): + """Test CASE expression in outer query over semantic subquery.""" + sql = """ + SELECT + status, + revenue, + CASE WHEN revenue > 200 THEN 'high' ELSE 'low' END AS tier + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + columns = _columns(result) + + assert "tier" in columns + assert len(rows) == 2 + for row in rows: + if row["revenue"] > 200: + assert row["tier"] == "high" + else: + assert row["tier"] == "low" + + +def test_postprocess_arithmetic(semantic_layer): + """Test arithmetic between metrics in outer query.""" + sql = """ + SELECT + status, + revenue, + count, + revenue / count AS avg_order_value + FROM ( + SELECT orders.revenue, orders.count, orders.status FROM orders + ) AS sq + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + columns = _columns(result) + + assert "avg_order_value" in columns + for row in rows: + assert row["avg_order_value"] == row["revenue"] / row["count"] + + +def test_postprocess_window_function(semantic_layer): + """Test window functions in outer query over semantic results.""" + sql = """ + SELECT + status, + revenue, + LAG(revenue) OVER (ORDER BY revenue DESC) AS next_lower_revenue + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + columns = _columns(result) + + assert "next_lower_revenue" in columns + assert len(rows) == 2 + + +def test_postprocess_coalesce(semantic_layer): + """Test COALESCE in outer query.""" + sql = """ + SELECT + status, + COALESCE(revenue, 0) AS safe_revenue + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + columns = _columns(result) + + assert "safe_revenue" in columns + assert all(row["safe_revenue"] is not None for row in rows) + + +def test_postprocess_having(semantic_layer): + """Test filtering with WHERE in outer query (equivalent to HAVING on aggregated results).""" + sql = """ + SELECT status, revenue + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + WHERE revenue > 200 + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + + assert len(rows) == 1 + assert rows[0]["revenue"] > 200 + + +def test_postprocess_order_by_in_outer(semantic_layer): + """Test ORDER BY in outer query over semantic results.""" + sql = """ + SELECT status, revenue + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + ORDER BY revenue DESC + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + + assert len(rows) == 2 + assert rows[0]["revenue"] >= rows[1]["revenue"] + + +def test_postprocess_limit_in_outer(semantic_layer): + """Test LIMIT in outer query over semantic results.""" + sql = """ + SELECT status, revenue + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + ORDER BY revenue DESC + LIMIT 1 + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + + assert len(rows) == 1 + assert rows[0]["revenue"] == 250.00 + + +def test_postprocess_cross_model_subquery(semantic_layer): + """Test post-processing over cross-model semantic subquery.""" + sql = """ + SELECT + region, + revenue, + CASE WHEN revenue > 200 THEN 'big' ELSE 'small' END AS market_size + FROM ( + SELECT orders.revenue, customers.region FROM orders + ) AS sq + """ + + result = semantic_layer.sql(sql) + _rows(result) + columns = _columns(result) + + assert "market_size" in columns + assert "region" in columns + assert "revenue" in columns + + +def test_deeply_nested_subquery(semantic_layer): + """Test double-wrapped subquery: outer(plain) -> middle(plain) -> inner(semantic).""" + sql = """ + SELECT + status, + revenue, + tier + FROM ( + SELECT + status, + revenue, + CASE WHEN revenue > 200 THEN 'high' ELSE 'low' END AS tier + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS inner_sq + ) AS outer_sq + WHERE tier = 'high' + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + + assert len(rows) == 1 + assert rows[0]["tier"] == "high" + assert rows[0]["revenue"] > 200 + + +def test_subquery_in_join(semantic_layer): + """Test semantic subquery used in a JOIN.""" + semantic_layer.conn.execute(""" + CREATE TABLE IF NOT EXISTS targets AS + SELECT 'completed' as status, 200 as target + UNION ALL + SELECT 'pending', 150 + """) + + sql = """ + SELECT + sq.status, + sq.revenue, + t.target, + sq.revenue - t.target AS delta + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + JOIN targets t ON sq.status = t.status + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + columns = _columns(result) + + assert "delta" in columns + assert len(rows) == 2 + for row in rows: + assert row["delta"] == row["revenue"] - row["target"] + + +def test_compile_post_process(semantic_layer): + """Test post_process parameter on compile().""" + outer_sql = semantic_layer.compile( + metrics=["orders.revenue"], + dimensions=["orders.status"], + post_process="SELECT *, CASE WHEN revenue > 200 THEN 'high' ELSE 'low' END AS tier FROM ({inner})", + ) + + # CTEs are hoisted; outer query wraps only the SELECT body + assert "CASE" in outer_sql + assert "tier" in outer_sql + assert "orders_cte" in outer_sql + # Should not have double WITH + assert "WITH WITH" not in outer_sql + + +def test_query_post_process(semantic_layer): + """Test post_process parameter on query().""" + result = semantic_layer.query( + metrics=["orders.revenue"], + dimensions=["orders.status"], + post_process="SELECT *, CASE WHEN revenue > 200 THEN 'high' ELSE 'low' END AS tier FROM ({inner})", + ) + + rows = _rows(result) + columns = _columns(result) + + assert "tier" in columns + for row in rows: + if row["revenue"] > 200: + assert row["tier"] == "high" + else: + assert row["tier"] == "low" + + +def test_post_process_missing_placeholder(semantic_layer): + """Test that post_process without {inner} raises ValueError.""" + with pytest.raises(ValueError, match="\\{inner\\}"): + semantic_layer.compile( + metrics=["orders.revenue"], + post_process="SELECT * FROM results", + ) + + +def test_dry_run_with_postprocess_subquery(semantic_layer): + """Test that rewriter.rewrite() returns composed SQL for subquery wrapping.""" + rewriter = QueryRewriter(semantic_layer.graph) + sql = """ + SELECT + status, + revenue, + CASE WHEN revenue > 200 THEN 'high' ELSE 'low' END AS tier + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + """ + + rewritten = rewriter.rewrite(sql) + + # The rewritten SQL should contain the semantic layer compilation + assert "CASE" in rewritten + assert "tier" in rewritten + # Inner query should be compiled (has CTE-style SQL from generator) + assert "AS" in rewritten From c3565e5730fab5d81fa421623e889fc96326c09d Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Fri, 10 Apr 2026 17:12:51 -0700 Subject: [PATCH 2/5] Preserve root semantic rewrite when JOIN has subquery When the root FROM references a semantic model and a JOIN contains a subquery, the routing into _rewrite_with_ctes_or_subqueries must still apply _rewrite_simple_query to the root SELECT so the explicit JOIN guard and semantic rewriting are not bypassed. --- sidemantic/sql/query_rewriter.py | 7 +++++++ tests/queries/test_sql_rewriter.py | 14 ++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/sidemantic/sql/query_rewriter.py b/sidemantic/sql/query_rewriter.py index 7efc424..27f3c3e 100644 --- a/sidemantic/sql/query_rewriter.py +++ b/sidemantic/sql/query_rewriter.py @@ -1858,6 +1858,13 @@ def _rewrite_with_ctes_or_subqueries(self, parsed: exp.Select) -> str: (CASE, window functions, arithmetic, etc.) works naturally. """ self._rewrite_select_tree(parsed) + + # If the root SELECT itself references a semantic model, it must + # still go through _rewrite_simple_query (which enforces the + # explicit JOIN guard and performs semantic rewriting). + if self._references_semantic_model(parsed): + return self._rewrite_simple_query(parsed) + return parsed.sql(dialect=self.dialect) def _rewrite_select_tree(self, select: exp.Select): diff --git a/tests/queries/test_sql_rewriter.py b/tests/queries/test_sql_rewriter.py index 04c16c3..a4da442 100644 --- a/tests/queries/test_sql_rewriter.py +++ b/tests/queries/test_sql_rewriter.py @@ -1364,3 +1364,17 @@ def test_dry_run_with_postprocess_subquery(semantic_layer): assert "tier" in rewritten # Inner query should be compiled (has CTE-style SQL from generator) assert "AS" in rewritten + + +def test_semantic_root_with_join_subquery_rejected(semantic_layer): + """Explicit JOINs on semantic models are rejected even when JOIN has a subquery.""" + semantic_layer.conn.execute(""" + CREATE TABLE IF NOT EXISTS lookup AS SELECT 1 AS id, 'x' AS val + """) + sql = """ + SELECT orders.revenue + FROM orders + JOIN (SELECT * FROM lookup) AS lk ON 1 = 1 + """ + with pytest.raises(ValueError, match="Explicit JOIN syntax is not supported"): + semantic_layer.sql(sql) From 02194ec3a607a1f8f78faf781598378618913c8f Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sat, 11 Apr 2026 07:32:58 -0700 Subject: [PATCH 3/5] Keep CTE scope in root semantic rewrite and merge post_process CTEs Fix two issues from PR review: 1. When the root SELECT references a semantic model and has user-defined CTEs, _rewrite_simple_query was discarding them. Now user CTEs are saved before rewriting and merged back into the generated SQL so filter references like IN (SELECT ... FROM user_cte) remain valid. 2. When post_process SQL has its own WITH clause and the inner semantic query also produces CTEs, the two WITH clauses are now merged into one instead of producing invalid double-WITH SQL. --- sidemantic/core/semantic_layer.py | 18 +++++++++++---- sidemantic/sql/query_rewriter.py | 21 ++++++++++++++++- tests/queries/test_sql_rewriter.py | 36 ++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 5 deletions(-) diff --git a/sidemantic/core/semantic_layer.py b/sidemantic/core/semantic_layer.py index a543164..b5c570a 100644 --- a/sidemantic/core/semantic_layer.py +++ b/sidemantic/core/semantic_layer.py @@ -559,12 +559,22 @@ def compile( target_dialect = dialect or self.dialect parsed_inner = sqlglot.parse_one(stripped, dialect=target_dialect) - with_clause = parsed_inner.args.get("with") - if with_clause: + inner_with = parsed_inner.args.get("with") + if inner_with: parsed_inner.set("with", None) body = parsed_inner.sql(dialect=target_dialect) - ctes = with_clause.sql(dialect=target_dialect) - return ctes + "\n" + post_process.replace("{inner}", body) + + # Substitute body into post_process, then merge CTEs + outer_sql = post_process.replace("{inner}", body) + outer_parsed = sqlglot.parse_one(outer_sql, dialect=target_dialect) + outer_with = outer_parsed.args.get("with") + if outer_with: + # Prepend inner CTEs before outer CTEs + merged = list(inner_with.expressions) + list(outer_with.expressions) + outer_with.set("expressions", merged) + else: + outer_parsed.set("with", inner_with) + return outer_parsed.sql(dialect=target_dialect) return post_process.replace("{inner}", stripped) diff --git a/sidemantic/sql/query_rewriter.py b/sidemantic/sql/query_rewriter.py index 27f3c3e..301ba16 100644 --- a/sidemantic/sql/query_rewriter.py +++ b/sidemantic/sql/query_rewriter.py @@ -1863,7 +1863,26 @@ def _rewrite_with_ctes_or_subqueries(self, parsed: exp.Select) -> str: # still go through _rewrite_simple_query (which enforces the # explicit JOIN guard and performs semantic rewriting). if self._references_semantic_model(parsed): - return self._rewrite_simple_query(parsed) + # Save user-defined CTEs before _rewrite_simple_query replaces + # the entire query with fresh generator output. + original_with = parsed.args.get("with") + + rewritten_sql = self._rewrite_simple_query(parsed) + + if original_with: + # Merge user CTEs into the generated SQL so references + # from filters/expressions (e.g. IN (SELECT ... FROM cte)) + # remain valid. + rewritten = sqlglot.parse_one(rewritten_sql, dialect=self.dialect) + gen_with = rewritten.args.get("with") + if gen_with: + user_ctes = [cte.copy() for cte in original_with.expressions] + gen_with.set("expressions", user_ctes + list(gen_with.expressions)) + else: + rewritten.set("with", original_with.copy()) + return rewritten.sql(dialect=self.dialect) + + return rewritten_sql return parsed.sql(dialect=self.dialect) diff --git a/tests/queries/test_sql_rewriter.py b/tests/queries/test_sql_rewriter.py index a4da442..7c48699 100644 --- a/tests/queries/test_sql_rewriter.py +++ b/tests/queries/test_sql_rewriter.py @@ -1378,3 +1378,39 @@ def test_semantic_root_with_join_subquery_rejected(semantic_layer): """ with pytest.raises(ValueError, match="Explicit JOIN syntax is not supported"): semantic_layer.sql(sql) + + +def test_semantic_root_with_user_cte_preserved(semantic_layer): + """User-defined CTEs are preserved when root query is semantic.""" + sql = """ + WITH allowed_statuses AS ( + SELECT 'completed' AS status + ) + SELECT orders.revenue + FROM orders + WHERE orders.status IN (SELECT status FROM allowed_statuses) + """ + result = semantic_layer.sql(sql) + rows = _rows(result) + + assert len(rows) == 1 + assert rows[0]["revenue"] == 250.00 + + +def test_post_process_with_own_ctes(semantic_layer): + """post_process SQL with its own CTEs merges with inner CTEs.""" + result = semantic_layer.query( + metrics=["orders.revenue"], + dimensions=["orders.status"], + post_process=""" + WITH thresholds AS (SELECT 200 AS min_rev) + SELECT sq.*, t.min_rev + FROM ({inner}) sq + CROSS JOIN thresholds t + WHERE sq.revenue >= t.min_rev + """, + ) + rows = _rows(result) + + assert len(rows) >= 1 + assert all(row["revenue"] >= 200 for row in rows) From 6cc395828904304f86608f9a3fc0160a297d250d Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sat, 11 Apr 2026 07:40:16 -0700 Subject: [PATCH 4/5] Preserve WITH RECURSIVE when merging root CTEs Propagate the recursive flag from user-defined CTEs to the merged WITH clause so self-referencing CTEs execute correctly. --- sidemantic/sql/query_rewriter.py | 3 +++ tests/queries/test_sql_rewriter.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/sidemantic/sql/query_rewriter.py b/sidemantic/sql/query_rewriter.py index 301ba16..f09d416 100644 --- a/sidemantic/sql/query_rewriter.py +++ b/sidemantic/sql/query_rewriter.py @@ -1878,6 +1878,9 @@ def _rewrite_with_ctes_or_subqueries(self, parsed: exp.Select) -> str: if gen_with: user_ctes = [cte.copy() for cte in original_with.expressions] gen_with.set("expressions", user_ctes + list(gen_with.expressions)) + # Preserve WITH RECURSIVE from the original query + if original_with.args.get("recursive"): + gen_with.set("recursive", True) else: rewritten.set("with", original_with.copy()) return rewritten.sql(dialect=self.dialect) diff --git a/tests/queries/test_sql_rewriter.py b/tests/queries/test_sql_rewriter.py index 7c48699..a9c031c 100644 --- a/tests/queries/test_sql_rewriter.py +++ b/tests/queries/test_sql_rewriter.py @@ -1397,6 +1397,25 @@ def test_semantic_root_with_user_cte_preserved(semantic_layer): assert rows[0]["revenue"] == 250.00 +def test_semantic_root_with_recursive_cte_preserved(semantic_layer): + """WITH RECURSIVE flag is preserved when merging user CTEs.""" + sql = """ + WITH RECURSIVE status_chain(status, depth) AS ( + SELECT 'completed', 1 + UNION ALL + SELECT 'pending', depth + 1 FROM status_chain WHERE depth < 2 + ) + SELECT orders.revenue, orders.status + FROM orders + WHERE orders.status IN (SELECT status FROM status_chain) + """ + result = semantic_layer.sql(sql) + rows = _rows(result) + + # Both completed ($250) and pending ($200) statuses are in the recursive CTE + assert len(rows) == 2 + + def test_post_process_with_own_ctes(semantic_layer): """post_process SQL with its own CTEs merges with inner CTEs.""" result = semantic_layer.query( From b542ae720bd8af11451b6395681f0440e911381f Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Sat, 11 Apr 2026 07:54:51 -0700 Subject: [PATCH 5/5] Handle CTE name collisions between user and generated CTEs Two changes: 1. post_process path: remove CTE hoisting entirely. Inner SQL (with CTEs) is placed directly in the subquery position. CTEs inside subqueries are valid in all target databases and naturally scoped, so name collisions with post_process CTEs cannot occur. 2. Root semantic + user CTEs: detect name collisions between user CTEs and generated CTEs, raising a clear error instead of producing invalid SQL. Walk-based renaming was too aggressive (renamed user CTE references inside filter subqueries). --- sidemantic/core/semantic_layer.py | 28 ++++---------------------- sidemantic/sql/query_rewriter.py | 9 +++++++++ tests/queries/test_sql_rewriter.py | 32 ++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 24 deletions(-) diff --git a/sidemantic/core/semantic_layer.py b/sidemantic/core/semantic_layer.py index b5c570a..e27a69f 100644 --- a/sidemantic/core/semantic_layer.py +++ b/sidemantic/core/semantic_layer.py @@ -552,30 +552,10 @@ def compile( if last_line.startswith("-- sidemantic:"): stripped = "\n".join(stripped.split("\n")[:-1]) - # If inner SQL starts with WITH (CTEs), hoist them outside - # the subquery position so the SQL is valid. - if stripped.lstrip().upper().startswith("WITH "): - import sqlglot - - target_dialect = dialect or self.dialect - parsed_inner = sqlglot.parse_one(stripped, dialect=target_dialect) - inner_with = parsed_inner.args.get("with") - if inner_with: - parsed_inner.set("with", None) - body = parsed_inner.sql(dialect=target_dialect) - - # Substitute body into post_process, then merge CTEs - outer_sql = post_process.replace("{inner}", body) - outer_parsed = sqlglot.parse_one(outer_sql, dialect=target_dialect) - outer_with = outer_parsed.args.get("with") - if outer_with: - # Prepend inner CTEs before outer CTEs - merged = list(inner_with.expressions) + list(outer_with.expressions) - outer_with.set("expressions", merged) - else: - outer_parsed.set("with", inner_with) - return outer_parsed.sql(dialect=target_dialect) - + # Inner SQL (including any CTEs) is placed directly in the + # subquery position. CTEs inside subqueries are valid SQL in + # all target databases and naturally scoped, avoiding name + # collisions with CTEs in the post_process SQL. return post_process.replace("{inner}", stripped) return inner_sql diff --git a/sidemantic/sql/query_rewriter.py b/sidemantic/sql/query_rewriter.py index f09d416..f791e52 100644 --- a/sidemantic/sql/query_rewriter.py +++ b/sidemantic/sql/query_rewriter.py @@ -1876,6 +1876,15 @@ def _rewrite_with_ctes_or_subqueries(self, parsed: exp.Select) -> str: rewritten = sqlglot.parse_one(rewritten_sql, dialect=self.dialect) gen_with = rewritten.args.get("with") if gen_with: + # Check for CTE name collisions between user and generated CTEs + user_names = {cte.alias for cte in original_with.expressions} + for gen_cte in gen_with.expressions: + if gen_cte.alias in user_names: + raise ValueError( + f"CTE name '{gen_cte.alias}' conflicts with an internally " + f"generated name. Please choose a different CTE name." + ) + user_ctes = [cte.copy() for cte in original_with.expressions] gen_with.set("expressions", user_ctes + list(gen_with.expressions)) # Preserve WITH RECURSIVE from the original query diff --git a/tests/queries/test_sql_rewriter.py b/tests/queries/test_sql_rewriter.py index a9c031c..121fd3f 100644 --- a/tests/queries/test_sql_rewriter.py +++ b/tests/queries/test_sql_rewriter.py @@ -1433,3 +1433,35 @@ def test_post_process_with_own_ctes(semantic_layer): assert len(rows) >= 1 assert all(row["revenue"] >= 200 for row in rows) + + +def test_post_process_cte_name_collision(semantic_layer): + """post_process CTE with same name as generated CTE doesn't collide.""" + result = semantic_layer.query( + metrics=["orders.revenue"], + dimensions=["orders.status"], + post_process=""" + WITH orders_cte AS (SELECT 'custom' AS source) + SELECT sq.*, oc.source + FROM ({inner}) sq + CROSS JOIN orders_cte oc + """, + ) + rows = _rows(result) + + assert len(rows) >= 1 + assert all(row["source"] == "custom" for row in rows) + + +def test_root_semantic_cte_name_collision(semantic_layer): + """User CTE with same name as generated CTE raises a clear error.""" + sql = """ + WITH orders_cte AS ( + SELECT 'completed' AS status + ) + SELECT orders.revenue + FROM orders + WHERE orders.status IN (SELECT status FROM orders_cte) + """ + with pytest.raises(ValueError, match="conflicts with an internally generated name"): + semantic_layer.sql(sql)