diff --git a/sidemantic/core/semantic_layer.py b/sidemantic/core/semantic_layer.py index 1372944..e27a69f 100644 --- a/sidemantic/core/semantic_layer.py +++ b/sidemantic/core/semantic_layer.py @@ -435,6 +435,7 @@ def query( ungrouped: bool = False, parameters: dict[str, any] | None = None, use_preaggregations: bool | None = None, + post_process: str | None = None, ): """Execute a query against the semantic layer. @@ -448,6 +449,9 @@ def query( ungrouped: If True, return raw rows without aggregation (no GROUP BY) parameters: Template parameters for Jinja2 rendering use_preaggregations: Override pre-aggregation routing setting for this query + post_process: Optional SQL to wrap around the semantic query result. + Use {inner} as a placeholder for the compiled semantic query, e.g.: + "SELECT *, revenue / count AS avg_value FROM ({inner})" Returns: DuckDB relation object (can convert to DataFrame with .df() or .to_df()) @@ -462,6 +466,7 @@ def query( ungrouped=ungrouped, parameters=parameters, use_preaggregations=use_preaggregations, + post_process=post_process, ) return self.adapter.execute(sql) @@ -479,6 +484,7 @@ def compile( ungrouped: bool = False, parameters: dict[str, any] | None = None, use_preaggregations: bool | None = None, + post_process: str | None = None, ) -> str: """Compile a query to SQL without executing. @@ -493,6 +499,9 @@ def compile( dialect: SQL dialect override (defaults to layer's dialect) ungrouped: If True, return raw rows without aggregation (no GROUP BY) use_preaggregations: Override pre-aggregation routing setting for this query + post_process: Optional SQL to wrap around the semantic query result. + Use {inner} as a placeholder for the compiled semantic query, e.g.: + "SELECT *, revenue / count AS avg_value FROM ({inner})" Returns: SQL query string @@ -520,7 +529,7 @@ def compile( preagg_schema=self.preagg_schema, ) - return generator.generate( + inner_sql = generator.generate( metrics=metrics, dimensions=dimensions, filters=filters, @@ -533,6 +542,24 @@ def compile( use_preaggregations=use_preaggs, ) + if post_process is not None: + if "{inner}" not in post_process: + raise ValueError("post_process must contain a {inner} placeholder") + + # Strip sidemantic instrumentation comment + stripped = inner_sql.rstrip() + last_line = stripped.split("\n")[-1].strip() + if last_line.startswith("-- sidemantic:"): + stripped = "\n".join(stripped.split("\n")[:-1]) + + # Inner SQL (including any CTEs) is placed directly in the + # subquery position. CTEs inside subqueries are valid SQL in + # all target databases and naturally scoped, avoiding name + # collisions with CTEs in the post_process SQL. + return post_process.replace("{inner}", stripped) + + return inner_sql + def explain( self, metrics: list[str] | None = None, diff --git a/sidemantic/sql/query_rewriter.py b/sidemantic/sql/query_rewriter.py index 457636f..f791e52 100644 --- a/sidemantic/sql/query_rewriter.py +++ b/sidemantic/sql/query_rewriter.py @@ -118,8 +118,9 @@ def rewrite(self, sql: str, strict: bool = True) -> str: # Check if this is a CTE-based query or has subqueries has_ctes = parsed.args.get("with") is not None has_subquery_in_from = self._has_subquery_in_from(parsed) + has_subquery_in_joins = any(isinstance(join.this, exp.Subquery) for join in (parsed.args.get("joins") or [])) - if has_ctes or has_subquery_in_from: + if has_ctes or has_subquery_in_from or has_subquery_in_joins: # Handle CTEs and subqueries return self._rewrite_with_ctes_or_subqueries(parsed) @@ -1851,41 +1852,89 @@ def _has_subquery_in_from(self, select: exp.Select) -> bool: def _rewrite_with_ctes_or_subqueries(self, parsed: exp.Select) -> str: """Rewrite query that contains CTEs or subqueries. - Strategy: - 1. Rewrite each CTE that references semantic models - 2. Rewrite subqueries in FROM clause - 3. Return the modified SQL + Recursively walks the query tree bottom-up, rewriting any + SELECT whose FROM target resolves to a semantic model. + Outer queries are left as plain SQL, so post-processing + (CASE, window functions, arithmetic, etc.) works naturally. """ - # Handle CTEs - if parsed.args.get("with"): - with_clause = parsed.args["with"] - for cte in with_clause.expressions: - # Each CTE has a name (alias) and a query (this) + self._rewrite_select_tree(parsed) + + # If the root SELECT itself references a semantic model, it must + # still go through _rewrite_simple_query (which enforces the + # explicit JOIN guard and performs semantic rewriting). + if self._references_semantic_model(parsed): + # Save user-defined CTEs before _rewrite_simple_query replaces + # the entire query with fresh generator output. + original_with = parsed.args.get("with") + + rewritten_sql = self._rewrite_simple_query(parsed) + + if original_with: + # Merge user CTEs into the generated SQL so references + # from filters/expressions (e.g. IN (SELECT ... FROM cte)) + # remain valid. + rewritten = sqlglot.parse_one(rewritten_sql, dialect=self.dialect) + gen_with = rewritten.args.get("with") + if gen_with: + # Check for CTE name collisions between user and generated CTEs + user_names = {cte.alias for cte in original_with.expressions} + for gen_cte in gen_with.expressions: + if gen_cte.alias in user_names: + raise ValueError( + f"CTE name '{gen_cte.alias}' conflicts with an internally " + f"generated name. Please choose a different CTE name." + ) + + user_ctes = [cte.copy() for cte in original_with.expressions] + gen_with.set("expressions", user_ctes + list(gen_with.expressions)) + # Preserve WITH RECURSIVE from the original query + if original_with.args.get("recursive"): + gen_with.set("recursive", True) + else: + rewritten.set("with", original_with.copy()) + return rewritten.sql(dialect=self.dialect) + + return rewritten_sql + + return parsed.sql(dialect=self.dialect) + + def _rewrite_select_tree(self, select: exp.Select): + """Recursively rewrite semantic subqueries and CTEs (bottom-up). + + At each level: recurse into children first, then rewrite this + node if it directly references a semantic model. + """ + # Recurse into CTEs + if select.args.get("with"): + for cte in select.args["with"].expressions: cte_query = cte.this if isinstance(cte_query, exp.Select): - # Check if this CTE references a semantic model + self._rewrite_select_tree(cte_query) if self._references_semantic_model(cte_query): - # Rewrite the CTE query - rewritten_cte_sql = self._rewrite_simple_query(cte_query) - # Parse the rewritten SQL and replace the CTE query - rewritten_cte = sqlglot.parse_one(rewritten_cte_sql, dialect=self.dialect) - cte.set("this", rewritten_cte) - - # Handle subquery in FROM - from_clause = parsed.args.get("from") + rewritten_sql = self._rewrite_simple_query(cte_query) + cte.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect)) + + # Recurse into FROM subquery + from_clause = select.args.get("from") if from_clause and isinstance(from_clause.this, exp.Subquery): subquery = from_clause.this subquery_select = subquery.this - if isinstance(subquery_select, exp.Select) and self._references_semantic_model(subquery_select): - # Rewrite the subquery - rewritten_subquery_sql = self._rewrite_simple_query(subquery_select) - rewritten_subquery = sqlglot.parse_one(rewritten_subquery_sql, dialect=self.dialect) - subquery.set("this", rewritten_subquery) - - # Return the modified SQL - # Note: Individual CTEs/subqueries are already instrumented by _rewrite_simple_query -> generator - # The outer query wrapper doesn't need separate instrumentation - return parsed.sql(dialect=self.dialect) + if isinstance(subquery_select, exp.Select): + self._rewrite_select_tree(subquery_select) + if self._references_semantic_model(subquery_select): + rewritten_sql = self._rewrite_simple_query(subquery_select) + subquery.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect)) + + # Recurse into JOIN subqueries + for join in select.args.get("joins") or []: + join_expr = join.this + if isinstance(join_expr, exp.Subquery): + join_select = join_expr.this + if isinstance(join_select, exp.Select): + self._rewrite_select_tree(join_select) + if self._references_semantic_model(join_select): + rewritten_sql = self._rewrite_simple_query(join_select) + join_expr.set("this", sqlglot.parse_one(rewritten_sql, dialect=self.dialect)) def _references_semantic_model(self, select: exp.Select) -> bool: """Check if a SELECT statement references any semantic models.""" diff --git a/tests/queries/test_sql_rewriter.py b/tests/queries/test_sql_rewriter.py index 9a707ae..121fd3f 100644 --- a/tests/queries/test_sql_rewriter.py +++ b/tests/queries/test_sql_rewriter.py @@ -684,8 +684,6 @@ def test_cte_with_limit_in_inner_query(semantic_layer): def test_nested_subquery(semantic_layer): """Test filtering semantic query results in outer query.""" - # Note: Deep nesting of subqueries (subquery within subquery) is not currently supported - # This test demonstrates single-level subquery with filtering sql = """ SELECT * FROM ( SELECT revenue, status FROM orders @@ -1080,3 +1078,390 @@ def test_granularity_on_non_time_dimension(semantic_layer): # Should work - status is a valid categorical dimension assert len(rows) == 2 # Two status groups + + +# --- Post-processing SQL over semantic query results --- + + +def test_postprocess_case_expression(semantic_layer): + """Test CASE expression in outer query over semantic subquery.""" + sql = """ + SELECT + status, + revenue, + CASE WHEN revenue > 200 THEN 'high' ELSE 'low' END AS tier + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + columns = _columns(result) + + assert "tier" in columns + assert len(rows) == 2 + for row in rows: + if row["revenue"] > 200: + assert row["tier"] == "high" + else: + assert row["tier"] == "low" + + +def test_postprocess_arithmetic(semantic_layer): + """Test arithmetic between metrics in outer query.""" + sql = """ + SELECT + status, + revenue, + count, + revenue / count AS avg_order_value + FROM ( + SELECT orders.revenue, orders.count, orders.status FROM orders + ) AS sq + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + columns = _columns(result) + + assert "avg_order_value" in columns + for row in rows: + assert row["avg_order_value"] == row["revenue"] / row["count"] + + +def test_postprocess_window_function(semantic_layer): + """Test window functions in outer query over semantic results.""" + sql = """ + SELECT + status, + revenue, + LAG(revenue) OVER (ORDER BY revenue DESC) AS next_lower_revenue + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + columns = _columns(result) + + assert "next_lower_revenue" in columns + assert len(rows) == 2 + + +def test_postprocess_coalesce(semantic_layer): + """Test COALESCE in outer query.""" + sql = """ + SELECT + status, + COALESCE(revenue, 0) AS safe_revenue + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + columns = _columns(result) + + assert "safe_revenue" in columns + assert all(row["safe_revenue"] is not None for row in rows) + + +def test_postprocess_having(semantic_layer): + """Test filtering with WHERE in outer query (equivalent to HAVING on aggregated results).""" + sql = """ + SELECT status, revenue + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + WHERE revenue > 200 + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + + assert len(rows) == 1 + assert rows[0]["revenue"] > 200 + + +def test_postprocess_order_by_in_outer(semantic_layer): + """Test ORDER BY in outer query over semantic results.""" + sql = """ + SELECT status, revenue + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + ORDER BY revenue DESC + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + + assert len(rows) == 2 + assert rows[0]["revenue"] >= rows[1]["revenue"] + + +def test_postprocess_limit_in_outer(semantic_layer): + """Test LIMIT in outer query over semantic results.""" + sql = """ + SELECT status, revenue + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + ORDER BY revenue DESC + LIMIT 1 + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + + assert len(rows) == 1 + assert rows[0]["revenue"] == 250.00 + + +def test_postprocess_cross_model_subquery(semantic_layer): + """Test post-processing over cross-model semantic subquery.""" + sql = """ + SELECT + region, + revenue, + CASE WHEN revenue > 200 THEN 'big' ELSE 'small' END AS market_size + FROM ( + SELECT orders.revenue, customers.region FROM orders + ) AS sq + """ + + result = semantic_layer.sql(sql) + _rows(result) + columns = _columns(result) + + assert "market_size" in columns + assert "region" in columns + assert "revenue" in columns + + +def test_deeply_nested_subquery(semantic_layer): + """Test double-wrapped subquery: outer(plain) -> middle(plain) -> inner(semantic).""" + sql = """ + SELECT + status, + revenue, + tier + FROM ( + SELECT + status, + revenue, + CASE WHEN revenue > 200 THEN 'high' ELSE 'low' END AS tier + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS inner_sq + ) AS outer_sq + WHERE tier = 'high' + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + + assert len(rows) == 1 + assert rows[0]["tier"] == "high" + assert rows[0]["revenue"] > 200 + + +def test_subquery_in_join(semantic_layer): + """Test semantic subquery used in a JOIN.""" + semantic_layer.conn.execute(""" + CREATE TABLE IF NOT EXISTS targets AS + SELECT 'completed' as status, 200 as target + UNION ALL + SELECT 'pending', 150 + """) + + sql = """ + SELECT + sq.status, + sq.revenue, + t.target, + sq.revenue - t.target AS delta + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + JOIN targets t ON sq.status = t.status + """ + + result = semantic_layer.sql(sql) + rows = _rows(result) + columns = _columns(result) + + assert "delta" in columns + assert len(rows) == 2 + for row in rows: + assert row["delta"] == row["revenue"] - row["target"] + + +def test_compile_post_process(semantic_layer): + """Test post_process parameter on compile().""" + outer_sql = semantic_layer.compile( + metrics=["orders.revenue"], + dimensions=["orders.status"], + post_process="SELECT *, CASE WHEN revenue > 200 THEN 'high' ELSE 'low' END AS tier FROM ({inner})", + ) + + # CTEs are hoisted; outer query wraps only the SELECT body + assert "CASE" in outer_sql + assert "tier" in outer_sql + assert "orders_cte" in outer_sql + # Should not have double WITH + assert "WITH WITH" not in outer_sql + + +def test_query_post_process(semantic_layer): + """Test post_process parameter on query().""" + result = semantic_layer.query( + metrics=["orders.revenue"], + dimensions=["orders.status"], + post_process="SELECT *, CASE WHEN revenue > 200 THEN 'high' ELSE 'low' END AS tier FROM ({inner})", + ) + + rows = _rows(result) + columns = _columns(result) + + assert "tier" in columns + for row in rows: + if row["revenue"] > 200: + assert row["tier"] == "high" + else: + assert row["tier"] == "low" + + +def test_post_process_missing_placeholder(semantic_layer): + """Test that post_process without {inner} raises ValueError.""" + with pytest.raises(ValueError, match="\\{inner\\}"): + semantic_layer.compile( + metrics=["orders.revenue"], + post_process="SELECT * FROM results", + ) + + +def test_dry_run_with_postprocess_subquery(semantic_layer): + """Test that rewriter.rewrite() returns composed SQL for subquery wrapping.""" + rewriter = QueryRewriter(semantic_layer.graph) + sql = """ + SELECT + status, + revenue, + CASE WHEN revenue > 200 THEN 'high' ELSE 'low' END AS tier + FROM ( + SELECT orders.revenue, orders.status FROM orders + ) AS sq + """ + + rewritten = rewriter.rewrite(sql) + + # The rewritten SQL should contain the semantic layer compilation + assert "CASE" in rewritten + assert "tier" in rewritten + # Inner query should be compiled (has CTE-style SQL from generator) + assert "AS" in rewritten + + +def test_semantic_root_with_join_subquery_rejected(semantic_layer): + """Explicit JOINs on semantic models are rejected even when JOIN has a subquery.""" + semantic_layer.conn.execute(""" + CREATE TABLE IF NOT EXISTS lookup AS SELECT 1 AS id, 'x' AS val + """) + sql = """ + SELECT orders.revenue + FROM orders + JOIN (SELECT * FROM lookup) AS lk ON 1 = 1 + """ + with pytest.raises(ValueError, match="Explicit JOIN syntax is not supported"): + semantic_layer.sql(sql) + + +def test_semantic_root_with_user_cte_preserved(semantic_layer): + """User-defined CTEs are preserved when root query is semantic.""" + sql = """ + WITH allowed_statuses AS ( + SELECT 'completed' AS status + ) + SELECT orders.revenue + FROM orders + WHERE orders.status IN (SELECT status FROM allowed_statuses) + """ + result = semantic_layer.sql(sql) + rows = _rows(result) + + assert len(rows) == 1 + assert rows[0]["revenue"] == 250.00 + + +def test_semantic_root_with_recursive_cte_preserved(semantic_layer): + """WITH RECURSIVE flag is preserved when merging user CTEs.""" + sql = """ + WITH RECURSIVE status_chain(status, depth) AS ( + SELECT 'completed', 1 + UNION ALL + SELECT 'pending', depth + 1 FROM status_chain WHERE depth < 2 + ) + SELECT orders.revenue, orders.status + FROM orders + WHERE orders.status IN (SELECT status FROM status_chain) + """ + result = semantic_layer.sql(sql) + rows = _rows(result) + + # Both completed ($250) and pending ($200) statuses are in the recursive CTE + assert len(rows) == 2 + + +def test_post_process_with_own_ctes(semantic_layer): + """post_process SQL with its own CTEs merges with inner CTEs.""" + result = semantic_layer.query( + metrics=["orders.revenue"], + dimensions=["orders.status"], + post_process=""" + WITH thresholds AS (SELECT 200 AS min_rev) + SELECT sq.*, t.min_rev + FROM ({inner}) sq + CROSS JOIN thresholds t + WHERE sq.revenue >= t.min_rev + """, + ) + rows = _rows(result) + + assert len(rows) >= 1 + assert all(row["revenue"] >= 200 for row in rows) + + +def test_post_process_cte_name_collision(semantic_layer): + """post_process CTE with same name as generated CTE doesn't collide.""" + result = semantic_layer.query( + metrics=["orders.revenue"], + dimensions=["orders.status"], + post_process=""" + WITH orders_cte AS (SELECT 'custom' AS source) + SELECT sq.*, oc.source + FROM ({inner}) sq + CROSS JOIN orders_cte oc + """, + ) + rows = _rows(result) + + assert len(rows) >= 1 + assert all(row["source"] == "custom" for row in rows) + + +def test_root_semantic_cte_name_collision(semantic_layer): + """User CTE with same name as generated CTE raises a clear error.""" + sql = """ + WITH orders_cte AS ( + SELECT 'completed' AS status + ) + SELECT orders.revenue + FROM orders + WHERE orders.status IN (SELECT status FROM orders_cte) + """ + with pytest.raises(ValueError, match="conflicts with an internally generated name"): + semantic_layer.sql(sql)