Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,23 @@ jobs:
- name: Install dependencies
run: uv sync --extra snowflake --extra dev

- name: Install fakesnow
id: fakesnow
continue-on-error: true
run: uv pip install "fakesnow>=0.9.0"

- name: Run Snowflake integration tests
if: steps.fakesnow.outcome == 'success'
env:
SNOWFLAKE_TEST: "1"
run: uv run pytest -m integration tests/db/test_snowflake_integration.py -v

- name: Report fakesnow incompatibility
if: steps.fakesnow.outcome == 'failure'
run: |
echo "::warning::fakesnow could not be installed (likely sqlglot version mismatch). Snowflake integration tests were skipped."
exit 0

clickhouse-integration:
runs-on: ubuntu-latest

Expand Down
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ license = {file = "LICENSE"}
requires-python = ">=3.11"
dependencies = [
"antlr4-python3-runtime>=4.13.2",
"sqlglot==27.12.0",
"sqlglot>=30.1.0",
"pyyaml>=6.0",
"pydantic>=2.0.0",
"jinja2>=3.1.0",
Expand All @@ -29,10 +29,12 @@ dev = [
"httpx>=0.28.0",
"pyarrow>=14.0.0",
"uvicorn>=0.34.0",
"fakesnow>=0.9.0", # For Snowflake integration tests
# fakesnow pinned out until it supports sqlglot 30.x
# "fakesnow>=0.9.0", # For Snowflake integration tests
"lkml>=1.3.7",
"inflect>=7.0.0",
"antlr4-python3-runtime>=4.13.2",
"sqlglot[c]>=30.1.0",
]
workbench = [
"textual[syntax]>=1.0.0",
Expand All @@ -49,6 +51,9 @@ charts = [
"altair>=5.0.0",
"vl-convert-python>=1.0.0",
]
fast = [
"sqlglot[c]>=30.1.0",
]
serve = [
"riffq>=0.1.0",
"pyarrow>=14.0.0",
Expand Down
114 changes: 60 additions & 54 deletions sidemantic/adapters/yardstick.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
"""Yardstick adapter for importing SQL models with AS MEASURE semantics."""
"""Yardstick adapter for importing SQL models with AS MEASURE semantics.

Compatible with sqlglot's mypyc C extension. Uses the tokenizer to
identify ``AS MEASURE <alias>`` sequences, strips the ``MEASURE``
keyword, parses with the configured dialect, then tags the
corresponding alias nodes.
"""

from functools import lru_cache
from pathlib import Path
from typing import Literal, get_args, get_origin

import sqlglot
from sqlglot import Dialect, exp
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect
from sqlglot.tokens import TokenType

from sidemantic.adapters.base import BaseAdapter
from sidemantic.core.dimension import Dimension
Expand All @@ -30,55 +38,40 @@ def _supported_metric_aggs() -> set[str]:
return _extract_literal_strings(annotation)


@lru_cache(maxsize=8)
def _yardstick_dialect(base_dialect_name: str = "duckdb") -> type:
"""Create a Yardstick dialect class extending any sqlglot dialect.
def _strip_measure_tokens(sql: str, dialect: str = "duckdb") -> tuple[str, set[str]]:
"""Remove MEASURE keyword from ``AS MEASURE <alias>`` sequences.

The returned class adds ``AS MEASURE`` alias recognition to the base
dialect's parser. Results are cached so repeated calls with the same
dialect name return the same class object.
Uses sqlglot's tokenizer so string literals and comments are handled
correctly. Returns the cleaned SQL and the set of measure alias names.
"""
base_instance = Dialect.get_or_raise(base_dialect_name)
base_cls = type(base_instance) if not isinstance(base_instance, type) else base_instance

class _YardstickParser(base_cls.Parser):
"""Parser extension for Yardstick's measure alias syntax.

Delegates to the base dialect's ``_parse_alias`` so that
dialect-specific alias behaviour (e.g. ClickHouse ``APPLY``)
is preserved. After the base parser runs, we detect the
``AS MEASURE <name>`` pattern: the base parser will have
consumed ``MEASURE`` as the alias identifier, so we replace
it with the real alias name that follows.
"""

def _parse_alias(self, this: exp.Expression | None, explicit: bool = False) -> exp.Expression | None:
result = super()._parse_alias(this, explicit)

if (
isinstance(result, exp.Alias)
and isinstance(result.args.get("alias"), exp.Identifier)
and not result.args["alias"].quoted
and result.args["alias"].name.upper() == "MEASURE"
):
actual_alias = self._parse_id_var(True, tokens=self.ALIAS_TOKENS) or (
self.STRING_ALIASES and self._parse_string_as_identifier()
)
if actual_alias:
result.set("alias", actual_alias)
result.set("yardstick_measure", True)

return result

class _YardstickDialect(base_cls):
class Parser(_YardstickParser):
pass
dialect_instance = Dialect.get_or_raise(dialect)
tokens = list(dialect_instance.tokenize(sql))
measure_names: set[str] = set()
remove_indices: set[int] = set()

return _YardstickDialect
for i in range(len(tokens) - 2):
if (
tokens[i].token_type == TokenType.ALIAS
and tokens[i + 1].token_type == TokenType.VAR
and tokens[i + 1].text.upper() == "MEASURE"
and tokens[i + 2].token_type in (TokenType.VAR, TokenType.STRING)
):
measure_names.add(tokens[i + 2].text.strip('"'))
remove_indices.add(i + 1)

if not remove_indices:
return sql, set()

# Backward-compatible alias: the default DuckDB-based dialect.
YardstickDialect = _yardstick_dialect("duckdb")
# Rebuild SQL by replacing MEASURE token spans with whitespace
# to preserve character positions for error messages.
result = list(sql)
for idx in remove_indices:
tok = tokens[idx]
start = tok.start
end = tok.end + 1
for j in range(start, min(end, len(result))):
result[j] = " "
return "".join(result), measure_names


class YardstickAdapter(BaseAdapter):
Expand All @@ -88,9 +81,6 @@ class YardstickAdapter(BaseAdapter):
`AGG(expr) AS MEASURE measure_name`.
"""

def __init__(self, dialect: str = "duckdb"):
self.dialect = dialect

_SIMPLE_AGGREGATIONS: dict[type[exp.Expression], str] = {
exp.Sum: "sum",
exp.Avg: "avg",
Expand All @@ -104,6 +94,9 @@ def __init__(self, dialect: str = "duckdb"):
}
_ANONYMOUS_AGGREGATIONS: set[str] = {"mode"}

def __init__(self, dialect: str = "duckdb"):
self.dialect = dialect

def parse(self, source: str | Path) -> SemanticGraph:
"""Parse Yardstick SQL files into a semantic graph."""
source_path = Path(source)
Expand Down Expand Up @@ -144,7 +137,20 @@ def _parse_sql_file(self, path: Path, graph: SemanticGraph) -> None:
graph.add_model(model)

def _parse_statements(self, sql: str) -> list[exp.Expression | None]:
return sqlglot.parse(sql, read=_yardstick_dialect(self.dialect))
cleaned, measure_names = _strip_measure_tokens(sql, dialect=self.dialect)
statements = sqlglot.parse(cleaned, read=self.dialect)

if measure_names:
for stmt in statements:
if stmt:
# Only tag aliases in SELECT projections of CREATE VIEW
select = stmt.expression if isinstance(stmt, exp.Create) else stmt
if isinstance(select, exp.Select):
for proj in select.expressions:
if isinstance(proj, exp.Alias) and proj.output_name in measure_names:
proj.set("yardstick_measure", True)

return statements

def _model_from_create_view(self, create_stmt: exp.Create, select: exp.Select) -> Model | None:
measure_aliases = {
Expand Down Expand Up @@ -237,10 +243,10 @@ def _metric_from_expression(self, name: str, expression: exp.Expression, all_mea
return metric

def _extract_model_source(self, select: exp.Select) -> tuple[str | None, str | None]:
from_clause = select.args.get("from")
from_clause = select.args.get("from_")
joins = select.args.get("joins") or []
where_clause = select.args.get("where")
with_clause = select.args.get("with")
with_clause = select.args.get("with_")

if (
isinstance(from_clause, exp.From)
Expand All @@ -259,8 +265,8 @@ def _extract_model_source(self, select: exp.Select) -> tuple[str | None, str | N

base_relation = exp.select("*")
if with_clause is not None:
base_relation.set("with", with_clause.copy())
base_relation.set("from", from_clause.copy())
base_relation.set("with_", with_clause.copy())
base_relation.set("from_", from_clause.copy())
if joins:
base_relation.set("joins", [join.copy() for join in joins])
if where_clause is not None:
Expand Down
Loading
Loading