From 29f58bd10320c4d7e429dcfea32c52ddb52b4000 Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Tue, 16 Dec 2025 16:52:15 +0100 Subject: [PATCH 01/10] Refactored snapshots from mixin to snapshots/runtime module --- pyproject.toml | 2 +- src/fastflowtransform/executors/base.py | 36 +- .../executors/bigquery/base.py | 47 ++- src/fastflowtransform/executors/common.py | 6 + .../executors/databricks_spark.py | 325 +--------------- src/fastflowtransform/executors/duckdb.py | 78 ++-- src/fastflowtransform/executors/postgres.py | 77 ++-- .../executors/snowflake_snowpark.py | 88 ++--- src/fastflowtransform/snapshots/__init__.py | 0 .../{snapshots.py => snapshots/core.py} | 0 .../snapshots/runtime/__init__.py | 15 + .../runtime/base.py} | 222 ++++++----- .../snapshots/runtime/bigquery.py | 52 +++ .../snapshots/runtime/databricks_spark.py | 357 ++++++++++++++++++ .../snapshots/runtime/duckdb.py | 48 +++ .../snapshots/runtime/postgres.py | 50 +++ .../snapshots/runtime/snowflake_snowpark.py | 54 +++ tests/common/fixtures.py | 7 + tests/common/snapshot_helpers.py | 12 +- .../test_snapshots_bigquery_integration.py | 12 +- ..._snapshots_databricks_spark_integration.py | 6 +- ...napshots_snowflake_snowpark_integration.py | 12 +- tests/unit/executors/test_duckdb_exec_unit.py | 7 +- .../unit/executors/test_postgres_exec_unit.py | 1 - .../executors/test_snowflake_snowpark_exec.py | 1 - uv.lock | 2 +- 26 files changed, 897 insertions(+), 620 deletions(-) create mode 100644 src/fastflowtransform/executors/common.py create mode 100644 src/fastflowtransform/snapshots/__init__.py rename src/fastflowtransform/{snapshots.py => snapshots/core.py} (100%) create mode 100644 src/fastflowtransform/snapshots/runtime/__init__.py rename src/fastflowtransform/{executors/_snapshot_sql_mixin.py => snapshots/runtime/base.py} (85%) create mode 100644 src/fastflowtransform/snapshots/runtime/bigquery.py create mode 100644 src/fastflowtransform/snapshots/runtime/databricks_spark.py create mode 100644 src/fastflowtransform/snapshots/runtime/duckdb.py create mode 100644 src/fastflowtransform/snapshots/runtime/postgres.py create mode 100644 src/fastflowtransform/snapshots/runtime/snowflake_snowpark.py diff --git a/pyproject.toml b/pyproject.toml index 74cfd22..9aa566d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "fastflowtransform" -version = "0.6.13" +version = "0.6.14" description = "Python framework for SQL & Python data transformation, ETL pipelines, and dbt-style data modeling" readme = "README.md" license = { text = "Apache-2.0" } diff --git a/src/fastflowtransform/executors/base.py b/src/fastflowtransform/executors/base.py index 32147a7..899c54f 100644 --- a/src/fastflowtransform/executors/base.py +++ b/src/fastflowtransform/executors/base.py @@ -54,8 +54,6 @@ def _python_incremental_merge_default( return combined combined = pd.concat([df_old, df_new], ignore_index=True) - - # Nur Update-Spalten verwenden, die es wirklich gibt update_cols = [c for c in update_cols if c in combined.columns] sort_cols = unique_key + update_cols if update_cols else unique_key @@ -128,16 +126,15 @@ class BaseExecutor[TFrame](ABC): - (optional) _frame_name """ - # Standard meta columns used by snapshot materialization. - SNAPSHOT_VALID_FROM_COL = "_ff_valid_from" - SNAPSHOT_VALID_TO_COL = "_ff_valid_to" - SNAPSHOT_IS_CURRENT_COL = "_ff_is_current" - SNAPSHOT_HASH_COL = "_ff_snapshot_hash" - SNAPSHOT_UPDATED_AT_COL = "_ff_updated_at" + ENGINE_NAME: str = "generic" _ff_contracts: Mapping[str, ContractsFileModel] | None = None _ff_project_contracts: ProjectContractsModel | None = None + @property + def engine_name(self) -> str: + return getattr(self, "ENGINE_NAME", "generic") + def configure_contracts( self, contracts: Mapping[str, ContractsFileModel] | None, @@ -1126,23 +1123,6 @@ def _meta_is_incremental(meta: Mapping[str, Any] | None) -> bool: return bool(incremental_cfg) # ── Snapshot API ────────────────────────────────────────────────── - def snapshot_prune( - self, - relation: str, - unique_key: list[str], - keep_last: int, - *, - dry_run: bool = False, - ) -> None: # pragma: no cover - abstract - """ - Prune old snapshot versions for the given relation. - - Engines may implement this in a best-effort manner. Default: not supported. - """ - raise NotImplementedError( - f"Snapshot pruning is not implemented for engine '{self.engine_name}'." - ) - @staticmethod def _meta_is_snapshot(meta: Mapping[str, Any] | None) -> bool: """ @@ -1225,9 +1205,3 @@ def load_seed( raise NotImplementedError( f"Seeding is not implemented for executor engine '{self.engine_name}'." ) - - ENGINE_NAME: str = "generic" - - @property - def engine_name(self) -> str: - return getattr(self, "ENGINE_NAME", "generic") diff --git a/src/fastflowtransform/executors/bigquery/base.py b/src/fastflowtransform/executors/bigquery/base.py index d3e1661..7d6e708 100644 --- a/src/fastflowtransform/executors/bigquery/base.py +++ b/src/fastflowtransform/executors/bigquery/base.py @@ -6,19 +6,19 @@ from fastflowtransform.core import Node, relation_for from fastflowtransform.executors._budget_runner import run_sql_with_budget -from fastflowtransform.executors._snapshot_sql_mixin import SnapshotSqlMixin from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.executors.budget import BudgetGuard from fastflowtransform.executors.query_stats import _TrackedQueryJob from fastflowtransform.meta import ensure_meta_table, upsert_meta +from fastflowtransform.snapshots.runtime.bigquery import BigQuerySnapshotRuntime from fastflowtransform.typing import BadRequest, Client, NotFound, bigquery TFrame = TypeVar("TFrame") -class BigQueryBaseExecutor(SqlIdentifierMixin, SnapshotSqlMixin, BaseExecutor[TFrame]): +class BigQueryBaseExecutor(SqlIdentifierMixin, BaseExecutor[TFrame]): """ Shared BigQuery executor logic (SQL, incremental, meta, DQ helpers). @@ -55,6 +55,7 @@ def __init__( project=self.project, location=self.location, ) + self.snapshot_runtime = BigQuerySnapshotRuntime(self) # ---- Identifier helpers ---- def _bq_quote(self, value: str) -> str: @@ -290,29 +291,6 @@ def _create_or_replace_view_from_table( self._ensure_dataset() self._execute_sql(f"CREATE OR REPLACE VIEW {view_id} AS SELECT * FROM {back_id}").result() - # ---- Snapshot mixin hooks ---- - def _snapshot_prepare_target(self) -> None: - self._ensure_dataset() - - def _snapshot_target_identifier(self, rel_name: str) -> str: - return self._qualified_identifier(rel_name) - - def _snapshot_current_timestamp(self) -> str: - return "CURRENT_TIMESTAMP()" - - def _snapshot_null_timestamp(self) -> str: - return "CAST(NULL AS TIMESTAMP)" - - def _snapshot_null_hash(self) -> str: - return "CAST(NULL AS STRING)" - - def _snapshot_hash_expr(self, check_cols: list[str], src_alias: str) -> str: - concat_expr = self._snapshot_concat_expr(check_cols, src_alias) - return f"TO_HEX(MD5({concat_expr}))" - - def _snapshot_cast_as_string(self, expr: str) -> str: - return f"CAST({expr} AS STRING)" - # ---- Meta hook ---- def on_node_built(self, node: Node, relation: str, fingerprint: str) -> None: """ @@ -456,6 +434,25 @@ def execute_hook_sql(self, sql: str) -> None: """ self._execute_sql(sql).result() + # ---- Snapshot runtime delegation (shared for pandas + BigFrames) ---- + def run_snapshot_sql(self, node: Node, env: Any) -> None: + self.snapshot_runtime.run_snapshot_sql(node, env) + + def snapshot_prune( + self, + relation: str, + unique_key: list[str], + keep_last: int, + *, + dry_run: bool = False, + ) -> None: + self.snapshot_runtime.snapshot_prune( + relation, + unique_key, + keep_last, + dry_run=dry_run, + ) + def _introspect_columns_metadata( self, table: str, diff --git a/src/fastflowtransform/executors/common.py b/src/fastflowtransform/executors/common.py new file mode 100644 index 0000000..97594ed --- /dev/null +++ b/src/fastflowtransform/executors/common.py @@ -0,0 +1,6 @@ +# fastflowtransform/executors/common.py + + +def _q_ident(ident: str) -> str: + # Simple, safe quoting for identifiers + return '"' + ident.replace('"', '""') + '"' diff --git a/src/fastflowtransform/executors/databricks_spark.py b/src/fastflowtransform/executors/databricks_spark.py index 39a1f26..27e86a6 100644 --- a/src/fastflowtransform/executors/databricks_spark.py +++ b/src/fastflowtransform/executors/databricks_spark.py @@ -3,7 +3,6 @@ from collections.abc import Callable, Iterable from contextlib import suppress -from functools import reduce from pathlib import Path from time import perf_counter from typing import Any, cast @@ -18,16 +17,12 @@ from fastflowtransform.errors import ModelExecutionError from fastflowtransform.executors._budget_runner import run_sql_with_budget from fastflowtransform.executors._query_stats_adapter import SparkDataFrameStatsAdapter -from fastflowtransform.executors._spark_imports import ( - get_spark_functions, - get_spark_window, -) from fastflowtransform.executors._test_utils import make_fetchable, rows_to_tuples from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.executors.budget import BudgetGuard -from fastflowtransform.logging import echo, echo_debug +from fastflowtransform.logging import echo_debug from fastflowtransform.meta import ensure_meta_table, upsert_meta -from fastflowtransform.snapshots import resolve_snapshot_config +from fastflowtransform.snapshots.runtime.databricks_spark import DatabricksSparkSnapshotRuntime from fastflowtransform.table_formats import get_spark_format_handler from fastflowtransform.table_formats.base import SparkFormatHandler from fastflowtransform.typing import SDF, DataType, SparkSession @@ -188,6 +183,7 @@ class DatabricksSparkExecutor(BaseExecutor[SDF]): ENGINE_NAME: str = "databricks_spark" runtime_contracts: DatabricksSparkRuntimeContracts + snapshot_runtime: DatabricksSparkSnapshotRuntime _BUDGET_GUARD = BudgetGuard( env_var="FF_SPK_MAX_BYTES", estimator_attr="_estimate_query_bytes", @@ -315,6 +311,7 @@ def __init__( self._spark_default_size = self._detect_default_size() self.runtime_contracts = DatabricksSparkRuntimeContracts(self) + self.snapshot_runtime = DatabricksSparkSnapshotRuntime(self) # ---------- Cost estimation & central execution ---------- @@ -478,6 +475,9 @@ def _exec() -> SDF: ) # ---------- Frame hooks (required) ---------- + def _quote_identifier(self, ident: str) -> str: + return self._format_handler.qualify_identifier(ident, database=self.database) + def _read_relation(self, relation: str, node: Node, deps: Iterable[str]) -> SDF: # relation may optionally be "db.table" (via source()/ref()) physical = self._format_handler.qualify_identifier(relation, database=self.database) @@ -944,252 +944,9 @@ def _spark_sql_type(dt: DataType) -> str: table_sql = self._sql_identifier(relation) self._execute_sql(f"ALTER TABLE {table_sql} ADD COLUMNS ({cols_sql})") - # ── Snapshot API ───────────────────────────────────────────────────── - + # ── Snapshot runtime delegation ────────────────────────────────────── def run_snapshot_sql(self, node: Node, env: Environment) -> None: - """ - Snapshot materialization for Spark/Databricks. - """ - F = get_spark_functions() - - meta = self._validate_snapshot_node(node) - cfg = resolve_snapshot_config(node, meta) - - strategy = cfg.strategy - unique_key = cfg.unique_key - updated_at = cfg.updated_at - check_cols = cfg.check_cols - - body, rel_name, physical = self._snapshot_sql_body(node, env) - - vf = BaseExecutor.SNAPSHOT_VALID_FROM_COL - vt = BaseExecutor.SNAPSHOT_VALID_TO_COL - is_cur = BaseExecutor.SNAPSHOT_IS_CURRENT_COL - hash_col = BaseExecutor.SNAPSHOT_HASH_COL - upd_meta = BaseExecutor.SNAPSHOT_UPDATED_AT_COL - - if not self.exists_relation(rel_name): - self._snapshot_first_run( - node=node, - rel_name=rel_name, - body=body, - strategy=strategy, - updated_at=updated_at, - check_cols=check_cols, - F=F, - vf=vf, - vt=vt, - is_cur=is_cur, - hash_col=hash_col, - upd_meta=upd_meta, - ) - return - - self._snapshot_incremental_run( - node=node, - body=body, - rel_name=rel_name, - physical=physical, - strategy=strategy, - unique_key=unique_key, - updated_at=updated_at, - check_cols=check_cols, - F=F, - vf=vf, - vt=vt, - is_cur=is_cur, - hash_col=hash_col, - upd_meta=upd_meta, - ) - - def _validate_snapshot_node(self, node: Node) -> dict[str, Any]: - if node.kind != "sql": - raise TypeError( - f"Snapshot materialization is only supported for SQL models, " - f"got kind={node.kind!r} for {node.name}." - ) - - meta = getattr(node, "meta", {}) or {} - if not self._meta_is_snapshot(meta): - raise ValueError(f"Node {node.name} is not configured with materialized='snapshot'.") - return meta - - def _snapshot_sql_body( - self, - node: Node, - env: Environment, - ) -> tuple[str, str, str]: - sql_rendered = self.render_sql( - node, - env, - ref_resolver=lambda name: self._resolve_ref(name, env), - source_resolver=self._resolve_source, - ) - sql_clean = self._strip_leading_config(sql_rendered).strip() - body = self._selectable_body(sql_clean).rstrip(" ;\n\t") - - rel_name = relation_for(node.name) - physical = self._physical_identifier(rel_name) - return body, rel_name, physical - - def _snapshot_first_run( - self, - *, - node: Node, - rel_name: str, - body: str, - strategy: str, - updated_at: str | None, - check_cols: list[str], - F: Any, - vf: str, - vt: str, - is_cur: str, - hash_col: str, - upd_meta: str, - ) -> None: - src_df = self._execute_sql(body) - - echo_debug(f"[snapshot] first run for {rel_name} (strategy={strategy})") - - if strategy == "timestamp": - assert updated_at is not None, ( - "timestamp snapshots require a non-null updated_at column" - ) - df_snap = ( - src_df.withColumn(upd_meta, F.col(updated_at)) - .withColumn(vf, F.col(updated_at)) - .withColumn(vt, F.lit(None).cast("timestamp")) - .withColumn(is_cur, F.lit(True)) - .withColumn(hash_col, F.lit(None).cast("string")) - ) - else: - cols_expr = [F.coalesce(F.col(c).cast("string"), F.lit("")) for c in check_cols] - concat_expr = F.concat_ws("||", *cols_expr) - hash_expr = F.md5(concat_expr).cast("string") - upd_expr = F.col(updated_at) if updated_at else F.current_timestamp() - - df_snap = ( - src_df.withColumn(upd_meta, upd_expr) - .withColumn(vf, F.current_timestamp()) - .withColumn(vt, F.lit(None).cast("timestamp")) - .withColumn(is_cur, F.lit(True)) - .withColumn(hash_col, hash_expr) - ) - - storage_meta = self._storage_meta(node, rel_name) - self._save_df_as_table(rel_name, df_snap, storage=storage_meta) - - def _snapshot_incremental_run( - self, - *, - node: Node, - body: str, - rel_name: str, - physical: str, - strategy: str, - unique_key: list[str], - updated_at: str | None, - check_cols: list[str], - F: Any, - vf: str, - vt: str, - is_cur: str, - hash_col: str, - upd_meta: str, - ) -> None: - echo_debug(f"[snapshot] incremental run for {rel_name} (strategy={strategy})") - - existing = self.spark.table(physical) - src_df = self._execute_sql(body) - - missing_keys_src = [k for k in unique_key if k not in src_df.columns] - missing_keys_snap = [k for k in unique_key if k not in existing.columns] - if missing_keys_src or missing_keys_snap: - raise ValueError( - f"{node.path}: snapshot unique_key columns must exist on both source and " - f"snapshot table. Missing on source={missing_keys_src}, " - f"on snapshot={missing_keys_snap}." - ) - - if strategy == "check": - cols_expr = [F.coalesce(F.col(c).cast("string"), F.lit("")) for c in check_cols] - concat_expr = F.concat_ws("||", *cols_expr) - src_df = src_df.withColumn("__ff_new_hash", F.md5(concat_expr).cast("string")) - - current_df = existing.filter(F.col(is_cur) == True) # noqa: E712 - - s_alias = src_df.alias("s") - t_alias = current_df.alias("t") - joined = s_alias.join(t_alias, on=unique_key, how="left") - - if strategy == "timestamp": - assert updated_at is not None, ( - "timestamp snapshots require a non-null updated_at column" - ) - s_upd = F.col(f"s.{updated_at}") - t_upd = F.col(f"t.{upd_meta}") - cond_new = t_upd.isNull() - cond_changed = t_upd.isNotNull() & (s_upd > t_upd) - changed_or_new = cond_new | cond_changed - else: - s_hash = F.col("s.__ff_new_hash") - t_hash = F.col(f"t.{hash_col}") - cond_new = t_hash.isNull() - cond_changed = t_hash.isNotNull() & (s_hash != F.coalesce(t_hash, F.lit(""))) - changed_or_new = cond_new | cond_changed - - changed_keys = ( - joined.filter(changed_or_new) - .select(*[F.col(f"s.{k}").alias(k) for k in unique_key]) - .dropDuplicates() - ) - - prev_noncurrent = existing.filter(F.col(is_cur) == False) # noqa: E712 - preserved_current = current_df.join(changed_keys, on=unique_key, how="left_anti") - - closed_prev = ( - current_df.join(changed_keys, on=unique_key, how="inner") - .withColumn(vt, F.current_timestamp()) - .withColumn(is_cur, F.lit(False)) - ) - - new_src = src_df.join(changed_keys, on=unique_key, how="inner") - if strategy == "timestamp": - assert updated_at is not None, ( - "timestamp snapshots require a non-null updated_at column" - ) - new_versions = ( - new_src.withColumn(upd_meta, F.col(updated_at)) - .withColumn(vf, F.col(updated_at)) - .withColumn(vt, F.lit(None).cast("timestamp")) - .withColumn(is_cur, F.lit(True)) - .withColumn(hash_col, F.lit(None).cast("string")) - ) - else: - upd_expr = F.col(updated_at) if updated_at else F.current_timestamp() - new_versions = ( - new_src.withColumn(upd_meta, upd_expr) - .withColumn(vf, F.current_timestamp()) - .withColumn(vt, F.lit(None).cast("timestamp")) - .withColumn(is_cur, F.lit(True)) - .withColumn(hash_col, F.col("__ff_new_hash")) - ) - - parts = [prev_noncurrent, preserved_current, closed_prev, new_versions] - snapshot_df = reduce(lambda a, b: a.unionByName(b, allowMissingColumns=True), parts) - if "__ff_new_hash" in snapshot_df.columns: - snapshot_df = snapshot_df.drop("__ff_new_hash") - - # Break lineage so Spark doesn't see this as "read from and overwrite the same table" - try: - snapshot_df = snapshot_df.localCheckpoint(eager=True) - except Exception: - snapshot_df = snapshot_df.cache() - snapshot_df.count() - - storage_meta = self._storage_meta(node, rel_name) - self._save_df_as_table(rel_name, snapshot_df, storage=storage_meta) + self.snapshot_runtime.run_snapshot_sql(node, env) def snapshot_prune( self, @@ -1199,64 +956,12 @@ def snapshot_prune( *, dry_run: bool = False, ) -> None: - """ - Delete older snapshot versions while keeping the most recent `keep_last` - rows per business key (including the current row), implemented as a - DataFrame overwrite (no in-place DELETE). - """ - if keep_last <= 0: - return - - Window = get_spark_window() - F = get_spark_functions() - - if not unique_key: - return - - vf = BaseExecutor.SNAPSHOT_VALID_FROM_COL - - try: - physical = self._physical_identifier(relation) - df = self.spark.table(physical) - except Exception: - return - - w = Window.partitionBy(*[F.col(k) for k in unique_key]).orderBy(F.col(vf).desc()) - ranked = df.withColumn("__ff_rn", F.row_number().over(w)) - - if dry_run: - cnt = ranked.filter(F.col("__ff_rn") > int(keep_last)).count() - - echo( - f"[DRY-RUN] snapshot_prune({relation}): would delete {cnt} row(s) " - f"(keep_last={keep_last})" - ) - return - - pruned = ranked.filter(F.col("__ff_rn") <= int(keep_last)).drop("__ff_rn") - - # Materialize before overwrite to avoid Spark's - # [UNSUPPORTED_OVERWRITE.TABLE] "target that is also being read from". - materialized: list[SDF] = [] - - def _materialize(df: SDF) -> SDF: - try: - cp = df.localCheckpoint(eager=True) - materialized.append(cp) - return cp - except Exception: - cached = df.cache() - cached.count() - materialized.append(cached) - return cached - - try: - out = _materialize(pruned) - self._save_df_as_table(relation, out) - finally: - for handle in materialized: - with suppress(Exception): - handle.unpersist() + self.snapshot_runtime.snapshot_prune( + relation, + unique_key, + keep_last, + dry_run=dry_run, + ) def execute_hook_sql(self, sql: str) -> None: """ diff --git a/src/fastflowtransform/executors/duckdb.py b/src/fastflowtransform/executors/duckdb.py index 4eaa43a..0d62175 100644 --- a/src/fastflowtransform/executors/duckdb.py +++ b/src/fastflowtransform/executors/duckdb.py @@ -4,7 +4,7 @@ import json import re import uuid -from collections.abc import Callable, Iterable +from collections.abc import Iterable from contextlib import suppress from pathlib import Path from typing import Any, ClassVar @@ -12,25 +12,24 @@ import duckdb import pandas as pd from duckdb import CatalogException +from jinja2 import Environment from fastflowtransform.contracts.runtime.duckdb import DuckRuntimeContracts from fastflowtransform.core import Node from fastflowtransform.executors._budget_runner import run_sql_with_budget -from fastflowtransform.executors._snapshot_sql_mixin import SnapshotSqlMixin from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable from fastflowtransform.executors.base import BaseExecutor, _scalar from fastflowtransform.executors.budget import BudgetGuard +from fastflowtransform.executors.common import _q_ident from fastflowtransform.meta import ensure_meta_table, upsert_meta +from fastflowtransform.snapshots.runtime.duckdb import DuckSnapshotRuntime -def _q(ident: str) -> str: - return '"' + ident.replace('"', '""') + '"' - - -class DuckExecutor(SqlIdentifierMixin, SnapshotSqlMixin, BaseExecutor[pd.DataFrame]): +class DuckExecutor(SqlIdentifierMixin, BaseExecutor[pd.DataFrame]): ENGINE_NAME: str = "duckdb" runtime_contracts: DuckRuntimeContracts + snapshot_runtime: DuckSnapshotRuntime _FIXED_TYPE_SIZES: ClassVar[dict[str, int]] = { "boolean": 1, @@ -85,11 +84,12 @@ def __init__( else: self.catalog = self._detect_catalog() if self.schema: - safe_schema = _q(self.schema) + safe_schema = _q_ident(self.schema) self._execute_sql(f"create schema if not exists {safe_schema}") self._execute_sql(f"set schema '{self.schema}'") self.runtime_contracts = DuckRuntimeContracts(self) + self.snapshot_runtime = DuckSnapshotRuntime(self) def execute_test_sql(self, stmt: Any) -> Any: """ @@ -436,8 +436,10 @@ def _apply_catalog_override(self, name: str) -> bool: if self.db_path != ":memory:": resolved = str(Path(self.db_path).resolve()) with suppress(Exception): - self._execute_sql(f"detach database {_q(alias)}") - self._execute_sql(f"attach database '{resolved}' as {_q(alias)} (READ_ONLY FALSE)") + self._execute_sql(f"detach database {_q_ident(alias)}") + self._execute_sql( + f"attach database '{resolved}' as {_q_ident(alias)} (READ_ONLY FALSE)" + ) self._execute_sql(f"set catalog '{alias}'") return True except Exception: @@ -471,7 +473,7 @@ def _exec_many(self, sql: str) -> None: # ---- Frame hooks ---- def _quote_identifier(self, ident: str) -> str: - return _q(ident) + return _q_ident(ident) def _should_include_catalog( self, catalog: str | None, schema: str | None, *, explicit: bool @@ -630,7 +632,7 @@ def alter_table_sync_schema( } add = [c for c in cols if c not in existing] for c in add: - col = _q(c) + col = _q_ident(c) target = self._qualified(relation) try: self._execute_sql(f"alter table {target} add column {col} varchar") @@ -645,37 +647,27 @@ def execute_hook_sql(self, sql: str) -> None: """ self._exec_many(sql) - # ---- Snapshot mixin hooks ---- - def _snapshot_target_identifier(self, rel_name: str) -> str: - return self._qualified(rel_name) - - def _snapshot_current_timestamp(self) -> str: - return "current_timestamp" - - def _snapshot_null_timestamp(self) -> str: - return "cast(null as timestamp)" - - def _snapshot_null_hash(self) -> str: - return "cast(null as varchar)" - - def _snapshot_hash_expr(self, check_cols: list[str], src_alias: str) -> str: - concat_expr = self._snapshot_concat_expr(check_cols, src_alias) - return f"cast(md5({concat_expr}) as varchar)" - - def _snapshot_cast_as_string(self, expr: str) -> str: - return f"cast({expr} as varchar)" - - def _snapshot_source_ref( - self, rel_name: str, select_body: str - ) -> tuple[str, Callable[[], None]]: - src_view_name = f"__ff_snapshot_src_{rel_name}".replace(".", "_") - src_quoted = _q(src_view_name) - self._execute_sql(f"create or replace temp view {src_quoted} as {select_body}") - - def _cleanup() -> None: - self._execute_sql(f"drop view if exists {src_quoted}") + # ---- Snapshot runtime delegation ---- + def run_snapshot_sql(self, node: Node, env: Environment) -> None: + """ + Delegate snapshot materialization to the DuckDB snapshot runtime. + """ + self.snapshot_runtime.run_snapshot_sql(node, env) - return src_quoted, _cleanup + def snapshot_prune( + self, + relation: str, + unique_key: list[str], + keep_last: int, + *, + dry_run: bool = False, + ) -> None: + self.snapshot_runtime.snapshot_prune( + relation, + unique_key, + keep_last, + dry_run=dry_run, + ) # ---- Unit-test helpers ------------------------------------------------- @@ -783,7 +775,7 @@ def load_seed( qualified = self._qualify_identifier(table, schema=target_schema, catalog=self.catalog) if target_schema and "." not in table: - safe_schema = _q(target_schema) + safe_schema = _q_ident(target_schema) self._execute_sql(f"create schema if not exists {safe_schema}") created_schema = True diff --git a/src/fastflowtransform/executors/postgres.py b/src/fastflowtransform/executors/postgres.py index c6f80bd..6341430 100644 --- a/src/fastflowtransform/executors/postgres.py +++ b/src/fastflowtransform/executors/postgres.py @@ -1,11 +1,12 @@ # fastflowtransform/executors/postgres.py import json import re -from collections.abc import Callable, Iterable +from collections.abc import Iterable from time import perf_counter from typing import Any, cast import pandas as pd +from jinja2 import Environment from sqlalchemy import create_engine, text from sqlalchemy.engine import Connection, Engine from sqlalchemy.exc import ProgrammingError, SQLAlchemyError @@ -16,13 +17,14 @@ from fastflowtransform.core import Node from fastflowtransform.errors import ModelExecutionError, ProfileConfigError from fastflowtransform.executors._budget_runner import run_sql_with_budget -from fastflowtransform.executors._snapshot_sql_mixin import SnapshotSqlMixin from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable from fastflowtransform.executors.base import BaseExecutor, _scalar from fastflowtransform.executors.budget import BudgetGuard +from fastflowtransform.executors.common import _q_ident from fastflowtransform.executors.query_stats import QueryStats from fastflowtransform.meta import ensure_meta_table, upsert_meta +from fastflowtransform.snapshots.runtime.postgres import PostgresSnapshotRuntime def _base_type(t: str) -> str: @@ -32,9 +34,10 @@ def _base_type(t: str) -> str: return s -class PostgresExecutor(SqlIdentifierMixin, SnapshotSqlMixin, BaseExecutor[pd.DataFrame]): +class PostgresExecutor(SqlIdentifierMixin, BaseExecutor[pd.DataFrame]): ENGINE_NAME: str = "postgres" runtime_contracts: PostgresRuntimeContracts + snapshot_runtime: PostgresSnapshotRuntime _DEFAULT_PG_ROW_WIDTH = 128 _BUDGET_GUARD = BudgetGuard( env_var="FF_PG_MAX_BYTES", @@ -60,7 +63,7 @@ def __init__(self, dsn: str, schema: str | None = None): if self.schema: try: with self.engine.begin() as conn: - conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {self._q_ident(self.schema)}")) + conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {_q_ident(self.schema)}")) except SQLAlchemyError as exc: raise ProfileConfigError( f"Failed to ensure schema '{self.schema}' exists: {exc}" @@ -68,6 +71,7 @@ def __init__(self, dsn: str, schema: str | None = None): # Enable runtime contracts (cast/verify) for SQL and pandas models. self.runtime_contracts = PostgresRuntimeContracts(self) + self.snapshot_runtime = PostgresSnapshotRuntime(self) def execute_test_sql(self, stmt: Any) -> Any: """ @@ -310,19 +314,15 @@ def _to_int(node: dict[str, Any], keys: tuple[str, ...]) -> int | None: return int(candidate) # --- Helpers --------------------------------------------------------- - def _q_ident(self, ident: str) -> str: - # Simple, safe quoting for identifiers - return '"' + ident.replace('"', '""') + '"' - def _quote_identifier(self, ident: str) -> str: - return self._q_ident(ident) + return _q_ident(ident) - def _qualified(self, relname: str, schema: str | None = None) -> str: - return self._format_identifier(relname, purpose="physical", schema=schema) + def _qualified(self, relation: str, schema: str | None = None, *, quoted: bool = True) -> str: + return self._format_identifier(relation, purpose="physical", schema=schema, quote=quoted) def _set_search_path(self, conn: Connection) -> None: if self.schema: - conn.execute(text(f"SET LOCAL search_path = {self._q_ident(self.schema)}")) + conn.execute(text(f"SET LOCAL search_path = {_q_ident(self.schema)}")) def _extract_select_like(self, sql_or_body: str) -> str: """ @@ -546,37 +546,26 @@ def alter_table_sync_schema( self._execute_sql(f'alter table {qrel} add column "{c}" text', conn=conn) # ── Snapshot API ────────────────────────────────────────────────────── - def _snapshot_target_identifier(self, rel_name: str) -> str: - return self._qualified(rel_name) - - def _snapshot_current_timestamp(self) -> str: - return "current_timestamp" - - def _snapshot_null_timestamp(self) -> str: - return "cast(null as timestamp)" - - def _snapshot_null_hash(self) -> str: - return "cast(null as text)" - - def _snapshot_hash_expr(self, check_cols: list[str], src_alias: str) -> str: - concat_expr = self._snapshot_concat_expr(check_cols, src_alias) - return f"md5({concat_expr})" - - def _snapshot_cast_as_string(self, expr: str) -> str: - return f"cast({expr} as text)" - - def _snapshot_source_ref( - self, rel_name: str, select_body: str - ) -> tuple[str, Callable[[], None]]: - src_name = f"__ff_snapshot_src_{rel_name}".replace(".", "_") - src_q = self._q_ident(src_name) - self._execute_sql(f"drop table if exists {src_q}") - self._execute_sql(f"create temporary table {src_q} as {select_body}") - - def _cleanup() -> None: - self._execute_sql(f"drop table if exists {src_q}") + def run_snapshot_sql(self, node: Node, env: Environment) -> None: + """ + Delegate snapshot materialization to the Postgres snapshot runtime. + """ + self.snapshot_runtime.run_snapshot_sql(node, env) - return src_q, _cleanup + def snapshot_prune( + self, + relation: str, + unique_key: list[str], + keep_last: int, + *, + dry_run: bool = False, + ) -> None: + self.snapshot_runtime.snapshot_prune( + relation, + unique_key, + keep_last, + dry_run=dry_run, + ) def execute_hook_sql(self, sql: str) -> None: """ @@ -624,8 +613,8 @@ def utest_load_relation_from_rows(self, relation: str, rows: list[dict]) -> None ) cols = list(first.keys()) - col_list_sql = ", ".join(self._q_ident(c) for c in cols) - select_exprs = ", ".join(f":{c} AS {self._q_ident(c)}" for c in cols) + col_list_sql = ", ".join(_q_ident(c) for c in cols) + select_exprs = ", ".join(f":{c} AS {_q_ident(c)}" for c in cols) insert_values_sql = ", ".join(f":{c}" for c in cols) try: diff --git a/src/fastflowtransform/executors/snowflake_snowpark.py b/src/fastflowtransform/executors/snowflake_snowpark.py index 0bc9b7f..c3e963e 100644 --- a/src/fastflowtransform/executors/snowflake_snowpark.py +++ b/src/fastflowtransform/executors/snowflake_snowpark.py @@ -2,7 +2,7 @@ from __future__ import annotations import json -from collections.abc import Callable, Iterable +from collections.abc import Iterable from contextlib import suppress from time import perf_counter from typing import Any, cast @@ -12,19 +12,23 @@ from fastflowtransform.contracts.runtime.snowflake_snowpark import SnowflakeSnowparkRuntimeContracts from fastflowtransform.core import Node, relation_for from fastflowtransform.executors._budget_runner import run_sql_with_budget -from fastflowtransform.executors._snapshot_sql_mixin import SnapshotSqlMixin from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable, rows_to_tuples from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.executors.budget import BudgetGuard +from fastflowtransform.executors.common import _q_ident from fastflowtransform.executors.query_stats import QueryStats from fastflowtransform.meta import ensure_meta_table, upsert_meta +from fastflowtransform.snapshots.runtime.snowflake_snowpark import ( + SnowflakeSnowparkSnapshotRuntime, +) from fastflowtransform.typing import SNDF, SnowparkSession as Session -class SnowflakeSnowparkExecutor(SqlIdentifierMixin, SnapshotSqlMixin, BaseExecutor[SNDF]): +class SnowflakeSnowparkExecutor(SqlIdentifierMixin, BaseExecutor[SNDF]): ENGINE_NAME: str = "snowflake_snowpark" runtime_contracts: SnowflakeSnowparkRuntimeContracts + snapshot_runtime: SnowflakeSnowparkSnapshotRuntime """Snowflake executor operating on Snowpark DataFrames (no pandas).""" _BUDGET_GUARD = BudgetGuard( env_var="FF_SF_MAX_BYTES", @@ -42,6 +46,7 @@ def __init__(self, cfg: dict): self.allow_create_schema: bool = bool(cfg["allow_create_schema"]) self._ensure_schema() self.runtime_contracts = SnowflakeSnowparkRuntimeContracts(self) + self.snapshot_runtime = SnowflakeSnowparkSnapshotRuntime(self) def execute_test_sql(self, stmt: Any) -> Any: """ @@ -174,9 +179,6 @@ def _exec_many(self, sql: str) -> None: self._execute_sql(stmt).collect() # ---------- Helpers ---------- - def _q(self, s: str) -> str: - return '"' + s.replace('"', '""') + '"' - def _quote_identifier(self, ident: str) -> str: # Keep identifiers unquoted to match legacy Snowflake behaviour. return ident @@ -193,9 +195,9 @@ def _should_include_catalog( # Always include database when present; Snowflake expects DB.SCHEMA.TABLE. return bool(catalog) - def _qualified(self, rel: str) -> str: + def _qualified(self, relation: str, *, quoted: bool = False) -> str: # DATABASE.SCHEMA.TABLE (no quotes) - return self._format_identifier(rel, purpose="physical", quote=False) + return self._format_identifier(relation, purpose="physical", quote=quoted) def _ensure_schema(self) -> None: """ @@ -211,8 +213,8 @@ def _ensure_schema(self) -> None: # Misconfigured; let downstream errors surface naturally. return - db = self._q(self.database) - sch = self._q(self.schema) + db = _q_ident(self.database) + sch = _q_ident(self.schema) with suppress(Exception): # Fully qualified CREATE SCHEMA is allowed in Snowflake. self.session.sql(f"CREATE SCHEMA IF NOT EXISTS {db}.{sch}").collect() @@ -389,8 +391,8 @@ def load_seed( created_schema = False if target_db and target_schema and getattr(self, "allow_create_schema", False): - db_ident = self._q(target_db) - schema_ident = self._q(target_schema) + db_ident = _q_ident(target_db) + schema_ident = _q_ident(target_schema) try: self.session.sql(f"CREATE SCHEMA IF NOT EXISTS {db_ident}.{schema_ident}").collect() created_schema = True @@ -481,7 +483,7 @@ def on_node_built(self, node: Node, relation: str, fingerprint: str) -> None: # ── Incremental API (parity with DuckDB/PG) ────────────────────────── def exists_relation(self, relation: str) -> bool: """Check existence via information_schema.tables.""" - db = self._q(self.database) + db = _q_ident(self.database) schema_lit = f"'{self.schema.upper()}'" rel_lit = f"'{relation.upper()}'" q = f""" @@ -546,7 +548,7 @@ def alter_table_sync_schema( qrel = self._qualified(relation) # Use identifiers in FROM, but *string literals* in WHERE - db_ident = self._q(self.database) + db_ident = _q_ident(self.database) schema_lit = self.schema.replace("'", "''") rel_lit = relation.replace("'", "''") @@ -575,45 +577,27 @@ def alter_table_sync_schema( return # Column names are identifiers → _q is correct here - cols_sql = ", ".join(f"{self._q(c)} STRING" for c in to_add) + cols_sql = ", ".join(f"{_q_ident(c)} STRING" for c in to_add) self._execute_sql(f"ALTER TABLE {qrel} ADD COLUMN {cols_sql}").collect() - # ---- Snapshot API (mixin hooks) -------------------------------------- - def _snapshot_target_identifier(self, rel_name: str) -> str: - return self._qualified(rel_name) - - def _snapshot_current_timestamp(self) -> str: - return "CURRENT_TIMESTAMP()" - - def _snapshot_create_keyword(self) -> str: - return "CREATE OR REPLACE TABLE" - - def _snapshot_null_timestamp(self) -> str: - return "CAST(NULL AS TIMESTAMP)" - - def _snapshot_null_hash(self) -> str: - return "CAST(NULL AS VARCHAR)" - - def _snapshot_hash_expr(self, check_cols: list[str], src_alias: str) -> str: - concat_expr = self._snapshot_concat_expr(check_cols, src_alias) - return f"CAST(MD5({concat_expr}) AS VARCHAR)" - - def _snapshot_cast_as_string(self, expr: str) -> str: - return f"CAST({expr} AS VARCHAR)" + # ---- Snapshot runtime delegation -------------------------------------- + def run_snapshot_sql(self, node: Node, env: Any) -> None: + self.snapshot_runtime.run_snapshot_sql(node, env) - def _snapshot_source_ref( - self, rel_name: str, select_body: str - ) -> tuple[str, Callable[[], None]]: - src_name = f"__ff_snapshot_src_{rel_name}".replace(".", "_") - src_quoted = self._q(src_name) - self._execute_sql( - f"CREATE OR REPLACE TEMPORARY VIEW {src_quoted} AS {select_body}" - ).collect() - - def _cleanup() -> None: - self._execute_sql(f"DROP VIEW IF EXISTS {src_quoted}").collect() - - return src_quoted, _cleanup + def snapshot_prune( + self, + relation: str, + unique_key: list[str], + keep_last: int, + *, + dry_run: bool = False, + ) -> None: + self.snapshot_runtime.snapshot_prune( + relation, + unique_key, + keep_last, + dry_run=dry_run, + ) def execute_hook_sql(self, sql: str) -> None: """ @@ -778,7 +762,7 @@ def introspect_table_physical_schema(self, table: str) -> dict[str, str]: """ db, sch, tbl = self._normalize_table_parts_for_introspection(table) - db_ident = self._q(db) + db_ident = _q_ident(db) schema_lit = sch.replace("'", "''").upper() table_lit = tbl.replace("'", "''").upper() @@ -814,7 +798,7 @@ def introspect_column_physical_type(self, table: str, column: str) -> str | None """ db, sch, tbl = self._normalize_table_parts_for_introspection(table) - db_ident = self._q(db) + db_ident = _q_ident(db) schema_lit = sch.replace("'", "''").upper() table_lit = tbl.replace("'", "''").upper() col_lit = (column or "").replace("'", "''").upper() diff --git a/src/fastflowtransform/snapshots/__init__.py b/src/fastflowtransform/snapshots/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/fastflowtransform/snapshots.py b/src/fastflowtransform/snapshots/core.py similarity index 100% rename from src/fastflowtransform/snapshots.py rename to src/fastflowtransform/snapshots/core.py diff --git a/src/fastflowtransform/snapshots/runtime/__init__.py b/src/fastflowtransform/snapshots/runtime/__init__.py new file mode 100644 index 0000000..e3cb377 --- /dev/null +++ b/src/fastflowtransform/snapshots/runtime/__init__.py @@ -0,0 +1,15 @@ +from fastflowtransform.snapshots.runtime.base import BaseSnapshotRuntime +from fastflowtransform.snapshots.runtime.bigquery import BigQuerySnapshotRuntime +from fastflowtransform.snapshots.runtime.databricks_spark import DatabricksSparkSnapshotRuntime +from fastflowtransform.snapshots.runtime.duckdb import DuckSnapshotRuntime +from fastflowtransform.snapshots.runtime.postgres import PostgresSnapshotRuntime +from fastflowtransform.snapshots.runtime.snowflake_snowpark import SnowflakeSnowparkSnapshotRuntime + +__all__ = [ + "BaseSnapshotRuntime", + "BigQuerySnapshotRuntime", + "DatabricksSparkSnapshotRuntime", + "DuckSnapshotRuntime", + "PostgresSnapshotRuntime", + "SnowflakeSnowparkSnapshotRuntime", +] diff --git a/src/fastflowtransform/executors/_snapshot_sql_mixin.py b/src/fastflowtransform/snapshots/runtime/base.py similarity index 85% rename from src/fastflowtransform/executors/_snapshot_sql_mixin.py rename to src/fastflowtransform/snapshots/runtime/base.py index 0e15da5..cb08ca5 100644 --- a/src/fastflowtransform/executors/_snapshot_sql_mixin.py +++ b/src/fastflowtransform/snapshots/runtime/base.py @@ -2,29 +2,71 @@ from collections.abc import Callable, Iterable from contextlib import suppress -from typing import TYPE_CHECKING, Any, cast +from typing import Any, Protocol, TypeVar from jinja2 import Environment from fastflowtransform.core import Node, relation_for from fastflowtransform.logging import echo -from fastflowtransform.snapshots import resolve_snapshot_config +from fastflowtransform.snapshots.core import resolve_snapshot_config -if TYPE_CHECKING: - # Adjust this import to your actual path - from fastflowtransform.executors.base import BaseExecutor +class SnapshotExecutor(Protocol): + """ + Minimal surface required by the snapshot runtime. + """ + + def render_sql( + self, + node: Node, + env: Environment, + ref_resolver: Callable[[str], str] | None = None, + source_resolver: Callable[[str, str], str] | None = None, + ) -> str: ... + + def _resolve_ref(self, name: str, env: Environment) -> str: ... + + def _resolve_source(self, source_name: str, table_name: str) -> str: ... + + def _strip_leading_config(self, sql: str) -> str: ... + + def _selectable_body(self, sql: str) -> str: ... + + def exists_relation(self, relation: str) -> bool: ... + + def _execute_sql(self, sql: str, *args: Any, **kwargs: Any) -> Any: ... -class SnapshotSqlMixin: + def _meta_is_snapshot(self, meta: dict[str, Any] | None) -> bool: ... + + def _quote_identifier(self, ident: str) -> str: ... + + +E = TypeVar("E", bound=SnapshotExecutor) + + +class BaseSnapshotRuntime[E: SnapshotExecutor]: """ - Shared SQL snapshot materialization (timestamp + check strategies). + Base snapshot runtime mirroring the contracts runtime pattern. Engines provide small hooks for identifier qualification, expressions, - staging, and execution. All column names come from BaseExecutor constants. + staging, and execution. All column names come from the executor constants. """ + # Standard snapshot metadata column names (single source of truth for runtimes). + SNAPSHOT_VALID_FROM_COL = "_ff_valid_from" + SNAPSHOT_VALID_TO_COL = "_ff_valid_to" + SNAPSHOT_IS_CURRENT_COL = "_ff_is_current" + SNAPSHOT_HASH_COL = "_ff_snapshot_hash" + SNAPSHOT_UPDATED_AT_COL = "_ff_updated_at" + + executor: E + + def __init__(self, executor: E): + self.executor = executor + + # ---- Public entrypoints ------------------------------------------------- def run_snapshot_sql(self, node: Node, env: Environment) -> None: - ex = cast("BaseExecutor[Any]", self) + ex = self.executor meta = self._snapshot_validate_node(node) cfg = resolve_snapshot_config(node, meta) @@ -35,11 +77,11 @@ def run_snapshot_sql(self, node: Node, env: Environment) -> None: if not cfg.unique_key: raise ValueError(f"{node.path}: snapshot models require a non-empty unique_key list.") - vf = self.SNAPSHOT_VALID_FROM_COL # type: ignore[attr-defined] - vt = self.SNAPSHOT_VALID_TO_COL # type: ignore[attr-defined] - is_cur = self.SNAPSHOT_IS_CURRENT_COL # type: ignore[attr-defined] - hash_col = self.SNAPSHOT_HASH_COL # type: ignore[attr-defined] - upd_meta = self.SNAPSHOT_UPDATED_AT_COL # type: ignore[attr-defined] + vf = self.SNAPSHOT_VALID_FROM_COL + vt = self.SNAPSHOT_VALID_TO_COL + is_cur = self.SNAPSHOT_IS_CURRENT_COL + hash_col = self.SNAPSHOT_HASH_COL + upd_meta = self.SNAPSHOT_UPDATED_AT_COL self._snapshot_prepare_target() @@ -116,6 +158,74 @@ def run_snapshot_sql(self, node: Node, env: Environment) -> None: with suppress(Exception): cleanup() + def snapshot_prune( + self, + relation: str, + unique_key: list[str], + keep_last: int, + *, + dry_run: bool = False, + ) -> None: + """ + Delete older snapshot versions while keeping the most recent `keep_last` + rows per business key (including the current row). + """ + ex = self.executor + + if keep_last <= 0: + return + + keys = [k for k in unique_key if k] + if not keys: + return + + target = self._snapshot_target_identifier(relation) + vf = self.SNAPSHOT_VALID_FROM_COL + + key_select = ", ".join(keys) + part_by = ", ".join(keys) + + ranked_sql = f""" +SELECT + {key_select}, + {vf}, + ROW_NUMBER() OVER ( + PARTITION BY {part_by} + ORDER BY {vf} DESC + ) AS rn +FROM {target} +""" + + if dry_run: + sql = f""" +WITH ranked AS ( + {ranked_sql} +) +SELECT COUNT(*) AS rows_to_delete +FROM ranked +WHERE rn > {int(keep_last)} +""" + res = ex._execute_sql(sql) + count = self._snapshot_fetch_count(res) + echo( + f"[DRY-RUN] snapshot_prune({relation}): would delete {count} row(s) " + f"(keep_last={keep_last})" + ) + return + + join_pred = " AND ".join([f"t.{k} = r.{k}" for k in keys]) + delete_sql = f""" +DELETE FROM {target} t +USING ( + {ranked_sql} +) r +WHERE + r.rn > {int(keep_last)} + AND {join_pred} + AND t.{vf} = r.{vf} +""" + ex._execute_sql(delete_sql) + # ---- Core SQL builders ------------------------------------------------- def _snapshot_first_run_sql( self, @@ -228,78 +338,9 @@ def _snapshot_insert_sql( OR {change_condition} """ - # ---- Pruning ----------------------------------------------------------- - def snapshot_prune( - self, - relation: str, - unique_key: list[str], - keep_last: int, - *, - dry_run: bool = False, - ) -> None: - """ - Delete older snapshot versions while keeping the most recent `keep_last` - rows per business key (including the current row). - """ - ex = cast("BaseExecutor[Any]", self) - - if keep_last <= 0: - return - - keys = [k for k in unique_key if k] - if not keys: - return - - target = self._snapshot_target_identifier(relation) - vf = self.SNAPSHOT_VALID_FROM_COL # type: ignore[attr-defined] - - key_select = ", ".join(keys) - part_by = ", ".join(keys) - - ranked_sql = f""" -SELECT - {key_select}, - {vf}, - ROW_NUMBER() OVER ( - PARTITION BY {part_by} - ORDER BY {vf} DESC - ) AS rn -FROM {target} -""" - - if dry_run: - sql = f""" -WITH ranked AS ( - {ranked_sql} -) -SELECT COUNT(*) AS rows_to_delete -FROM ranked -WHERE rn > {int(keep_last)} -""" - res = ex._execute_sql(sql) - count = self._snapshot_fetch_count(res) - echo( - f"[DRY-RUN] snapshot_prune({relation}): would delete {count} row(s) " - f"(keep_last={keep_last})" - ) - return - - join_pred = " AND ".join([f"t.{k} = r.{k}" for k in keys]) - delete_sql = f""" -DELETE FROM {target} t -USING ( - {ranked_sql} -) r -WHERE - r.rn > {int(keep_last)} - AND {join_pred} - AND t.{vf} = r.{vf} -""" - ex._execute_sql(delete_sql) - # ---- Rendering helpers ------------------------------------------------- def _snapshot_render_body(self, node: Node, env: Environment) -> str: - ex = cast("BaseExecutor[Any]", self) + ex = self.executor sql_rendered = ex.render_sql( node, @@ -311,6 +352,8 @@ def _snapshot_render_body(self, node: Node, env: Environment) -> str: return ex._selectable_body(sql_clean).rstrip(" ;\n\t") def _snapshot_validate_node(self, node: Node) -> dict[str, Any]: + ex = self.executor + if node.kind != "sql": raise TypeError( f"Snapshot materialization is only supported for SQL models, " @@ -318,7 +361,7 @@ def _snapshot_validate_node(self, node: Node) -> dict[str, Any]: ) meta = getattr(node, "meta", {}) or {} - if not self._meta_is_snapshot(meta): # type: ignore[attr-defined] + if not ex._meta_is_snapshot(meta): raise ValueError(f"Node {node.name} is not configured with materialized='snapshot'.") return meta @@ -348,6 +391,10 @@ def _snapshot_null_hash(self) -> str: # pragma: no cover - abstract def _snapshot_hash_expr(self, check_cols: list[str], src_alias: str) -> str: # pragma: no cover raise NotImplementedError + # ---- Optional overrides ----------------------------------------------- + def _snapshot_cast_as_string(self, expr: str) -> str: + return f"CAST({expr} AS STRING)" + def _snapshot_updated_at_expr(self, updated_at: str, src_alias: str) -> str: return f"{src_alias}.{updated_at}" @@ -359,7 +406,7 @@ def _snapshot_exec_and_wait(self, sql: str) -> None: """ Execute SQL and, if necessary, wait for completion (jobs, lazy DataFrames). """ - res = self._execute_sql(sql) # type: ignore[attr-defined] + res = self.executor._execute_sql(sql) if res is None: return for attr in ("result", "collect"): @@ -377,9 +424,6 @@ def _snapshot_concat_expr(self, columns: list[str], src_alias: str) -> str: ] return " || '||' || ".join(parts) if parts else "''" - def _snapshot_cast_as_string(self, expr: str) -> str: - return f"CAST({expr} AS STRING)" - def _snapshot_coalesce(self, expr: str, default: str) -> str: return f"COALESCE({expr}, {default})" diff --git a/src/fastflowtransform/snapshots/runtime/bigquery.py b/src/fastflowtransform/snapshots/runtime/bigquery.py new file mode 100644 index 0000000..280f4ad --- /dev/null +++ b/src/fastflowtransform/snapshots/runtime/bigquery.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import Protocol + +from fastflowtransform.snapshots.runtime.base import BaseSnapshotRuntime, SnapshotExecutor + + +class BigQuerySnapshotExecutor(SnapshotExecutor, Protocol): + project: str + dataset: str + location: str | None + + def _ensure_dataset(self) -> None: ... + + def _qualified_identifier( + self, relation: str, project: str | None = None, dataset: str | None = None + ) -> str: ... + + +class BigQuerySnapshotRuntime(BaseSnapshotRuntime[BigQuerySnapshotExecutor]): + """ + Snapshot runtime for BigQuery, matching the legacy mixin hooks. + """ + + # ---- Engine hooks ----------------------------------------------------- + def _snapshot_prepare_target(self) -> None: + self.executor._ensure_dataset() + + def _snapshot_target_identifier(self, rel_name: str) -> str: + return self.executor._qualified_identifier( + rel_name, + project=getattr(self.executor, "project", None), + dataset=getattr(self.executor, "dataset", None), + ) + + def _snapshot_current_timestamp(self) -> str: + return "CURRENT_TIMESTAMP()" + + def _snapshot_null_timestamp(self) -> str: + return "CAST(NULL AS TIMESTAMP)" + + def _snapshot_null_hash(self) -> str: + return "CAST(NULL AS STRING)" + + def _snapshot_hash_expr(self, check_cols: list[str], src_alias: str) -> str: + concat_expr = self._snapshot_concat_expr(check_cols, src_alias) + return f"TO_HEX(MD5({concat_expr}))" + + def _snapshot_cast_as_string(self, expr: str) -> str: + return f"CAST({expr} AS STRING)" + + # BigQuery uses inline source (default), so no override for _snapshot_source_ref diff --git a/src/fastflowtransform/snapshots/runtime/databricks_spark.py b/src/fastflowtransform/snapshots/runtime/databricks_spark.py new file mode 100644 index 0000000..c11462b --- /dev/null +++ b/src/fastflowtransform/snapshots/runtime/databricks_spark.py @@ -0,0 +1,357 @@ +from __future__ import annotations + +from contextlib import suppress +from functools import reduce +from typing import Any, Protocol + +from fastflowtransform.core import Node, relation_for +from fastflowtransform.executors._spark_imports import get_spark_functions, get_spark_window +from fastflowtransform.logging import echo, echo_debug +from fastflowtransform.snapshots.core import resolve_snapshot_config +from fastflowtransform.snapshots.runtime.base import SnapshotExecutor +from fastflowtransform.typing import SDF, SparkSession + + +class DatabricksSnapshotExecutor(SnapshotExecutor, Protocol): + spark: SparkSession + + def _physical_identifier(self, identifier: str, *, database: str | None = None) -> str: ... + + def _storage_meta(self, node: Node | None, relation: str) -> dict[str, Any]: ... + + def _save_df_as_table( + self, identifier: str, df: SDF, *, storage: dict[str, Any] | None = None + ) -> None: ... + + +class DatabricksSparkSnapshotRuntime: + """ + Snapshot runtime for Databricks/Spark (Delta/Parquet/Iceberg), extracted + from the executor. Uses Spark DataFrame operations instead of SQL strings. + """ + + SNAPSHOT_VALID_FROM_COL = "_ff_valid_from" + SNAPSHOT_VALID_TO_COL = "_ff_valid_to" + SNAPSHOT_IS_CURRENT_COL = "_ff_is_current" + SNAPSHOT_HASH_COL = "_ff_snapshot_hash" + SNAPSHOT_UPDATED_AT_COL = "_ff_updated_at" + + executor: DatabricksSnapshotExecutor + + def __init__(self, executor: DatabricksSnapshotExecutor): + self.executor = executor + + def run_snapshot_sql(self, node: Node, env: Any) -> None: + ex = self.executor + F = get_spark_functions() + + meta = self._validate_snapshot_node(node) + cfg = resolve_snapshot_config(node, meta) + + strategy = cfg.strategy + unique_key = cfg.unique_key + updated_at = cfg.updated_at + check_cols = cfg.check_cols + + body, rel_name, physical = self._snapshot_sql_body(node, env) + + vf = self.SNAPSHOT_VALID_FROM_COL + vt = self.SNAPSHOT_VALID_TO_COL + is_cur = self.SNAPSHOT_IS_CURRENT_COL + hash_col = self.SNAPSHOT_HASH_COL + upd_meta = self.SNAPSHOT_UPDATED_AT_COL + + if not ex.exists_relation(rel_name): + self._snapshot_first_run( + node=node, + rel_name=rel_name, + body=body, + strategy=strategy, + updated_at=updated_at, + check_cols=check_cols, + F=F, + vf=vf, + vt=vt, + is_cur=is_cur, + hash_col=hash_col, + upd_meta=upd_meta, + ) + return + + self._snapshot_incremental_run( + node=node, + body=body, + rel_name=rel_name, + physical=physical, + strategy=strategy, + unique_key=unique_key, + updated_at=updated_at, + check_cols=check_cols, + F=F, + vf=vf, + vt=vt, + is_cur=is_cur, + hash_col=hash_col, + upd_meta=upd_meta, + ) + + def snapshot_prune( + self, + relation: str, + unique_key: list[str], + keep_last: int, + *, + dry_run: bool = False, + ) -> None: + """ + Delete older snapshot versions while keeping the most recent `keep_last` + rows per business key (including the current row), implemented as a + DataFrame overwrite (no in-place DELETE). + """ + if keep_last <= 0: + return + + Window = get_spark_window() + F = get_spark_functions() + ex = self.executor + + if not unique_key: + return + + vf = self.SNAPSHOT_VALID_FROM_COL + + try: + physical = ex._physical_identifier(relation) + df = ex.spark.table(physical) + except Exception: + return + + w = Window.partitionBy(*[F.col(k) for k in unique_key]).orderBy(F.col(vf).desc()) + ranked = df.withColumn("__ff_rn", F.row_number().over(w)) + + if dry_run: + cnt = ranked.filter(F.col("__ff_rn") > int(keep_last)).count() + + echo( + f"[DRY-RUN] snapshot_prune({relation}): would delete {cnt} row(s) " + f"(keep_last={keep_last})" + ) + return + + pruned = ranked.filter(F.col("__ff_rn") <= int(keep_last)).drop("__ff_rn") + + # Materialize before overwrite to avoid Spark's self-read/overwrite issues. + materialized: list[Any] = [] + + def _materialize(df_any: Any) -> Any: + try: + cp = df_any.localCheckpoint(eager=True) + materialized.append(cp) + return cp + except Exception: + cached = df_any.cache() + cached.count() + materialized.append(cached) + return cached + + try: + out = _materialize(pruned) + ex._save_df_as_table(relation, out) + finally: + for handle in materialized: + with suppress(Exception): + handle.unpersist() + + # ---- Helpers --------------------------------------------------------- + def _validate_snapshot_node(self, node: Node) -> dict[str, Any]: + ex = self.executor + if node.kind != "sql": + raise TypeError( + f"Snapshot materialization is only supported for SQL models, " + f"got kind={node.kind!r} for {node.name}." + ) + + meta = getattr(node, "meta", {}) or {} + if not ex._meta_is_snapshot(meta): + raise ValueError(f"Node {node.name} is not configured with materialized='snapshot'.") + return meta + + def _snapshot_sql_body( + self, + node: Node, + env: Any, + ) -> tuple[str, str, str]: + ex = self.executor + sql_rendered = ex.render_sql( + node, + env, + ref_resolver=lambda name: ex._resolve_ref(name, env), + source_resolver=ex._resolve_source, + ) + sql_clean = ex._strip_leading_config(sql_rendered).strip() + body = ex._selectable_body(sql_clean).rstrip(" ;\n\t") + + rel_name = relation_for(node.name) + physical = ex._physical_identifier(rel_name) + return body, rel_name, physical + + def _snapshot_first_run( + self, + *, + node: Node, + rel_name: str, + body: str, + strategy: str, + updated_at: str | None, + check_cols: list[str], + F: Any, + vf: str, + vt: str, + is_cur: str, + hash_col: str, + upd_meta: str, + ) -> None: + ex = self.executor + src_df = ex._execute_sql(body) + + echo_debug(f"[snapshot] first run for {rel_name} (strategy={strategy})") + + if strategy == "timestamp": + assert updated_at is not None, ( + "timestamp snapshots require a non-null updated_at column" + ) + df_snap = ( + src_df.withColumn(upd_meta, F.col(updated_at)) + .withColumn(vf, F.col(updated_at)) + .withColumn(vt, F.lit(None).cast("timestamp")) + .withColumn(is_cur, F.lit(True)) + .withColumn(hash_col, F.lit(None).cast("string")) + ) + else: + cols_expr = [F.coalesce(F.col(c).cast("string"), F.lit("")) for c in check_cols] + concat_expr = F.concat_ws("||", *cols_expr) + hash_expr = F.md5(concat_expr).cast("string") + upd_expr = F.col(updated_at) if updated_at else F.current_timestamp() + + df_snap = ( + src_df.withColumn(upd_meta, upd_expr) + .withColumn(vf, F.current_timestamp()) + .withColumn(vt, F.lit(None).cast("timestamp")) + .withColumn(is_cur, F.lit(True)) + .withColumn(hash_col, hash_expr) + ) + + storage_meta = ex._storage_meta(node, rel_name) + ex._save_df_as_table(rel_name, df_snap, storage=storage_meta) + + def _snapshot_incremental_run( + self, + *, + node: Node, + body: str, + rel_name: str, + physical: str, + strategy: str, + unique_key: list[str], + updated_at: str | None, + check_cols: list[str], + F: Any, + vf: str, + vt: str, + is_cur: str, + hash_col: str, + upd_meta: str, + ) -> None: + ex = self.executor + echo_debug(f"[snapshot] incremental run for {rel_name} (strategy={strategy})") + + existing = ex.spark.table(physical) + src_df = ex._execute_sql(body) + + missing_keys_src = [k for k in unique_key if k not in src_df.columns] + missing_keys_snap = [k for k in unique_key if k not in existing.columns] + if missing_keys_src or missing_keys_snap: + raise ValueError( + f"{node.path}: snapshot unique_key columns must exist on both source and " + f"snapshot table. Missing on source={missing_keys_src}, " + f"on snapshot={missing_keys_snap}." + ) + + if strategy == "check": + cols_expr = [F.coalesce(F.col(c).cast("string"), F.lit("")) for c in check_cols] + concat_expr = F.concat_ws("||", *cols_expr) + src_df = src_df.withColumn("__ff_new_hash", F.md5(concat_expr).cast("string")) + + current_df = existing.filter(F.col(is_cur) == True) # noqa: E712 + + s_alias = src_df.alias("s") + t_alias = current_df.alias("t") + joined = s_alias.join(t_alias, on=unique_key, how="left") + + if strategy == "timestamp": + assert updated_at is not None, ( + "timestamp snapshots require a non-null updated_at column" + ) + s_upd = F.col(f"s.{updated_at}") + t_upd = F.col(f"t.{upd_meta}") + cond_new = t_upd.isNull() + cond_changed = t_upd.isNotNull() & (s_upd > t_upd) + changed_or_new = cond_new | cond_changed + else: + s_hash = F.col("s.__ff_new_hash") + t_hash = F.col(f"t.{hash_col}") + cond_new = t_hash.isNull() + cond_changed = t_hash.isNotNull() & (s_hash != F.coalesce(t_hash, F.lit(""))) + changed_or_new = cond_new | cond_changed + + changed_keys = ( + joined.filter(changed_or_new) + .select(*[F.col(f"s.{k}").alias(k) for k in unique_key]) + .dropDuplicates() + ) + + prev_noncurrent = existing.filter(F.col(is_cur) == False) # noqa: E712 + preserved_current = current_df.join(changed_keys, on=unique_key, how="left_anti") + + closed_prev = ( + current_df.join(changed_keys, on=unique_key, how="inner") + .withColumn(vt, F.current_timestamp()) + .withColumn(is_cur, F.lit(False)) + ) + + new_src = src_df.join(changed_keys, on=unique_key, how="inner") + if strategy == "timestamp": + assert updated_at is not None, ( + "timestamp snapshots require a non-null updated_at column" + ) + new_versions = ( + new_src.withColumn(upd_meta, F.col(updated_at)) + .withColumn(vf, F.col(updated_at)) + .withColumn(vt, F.lit(None).cast("timestamp")) + .withColumn(is_cur, F.lit(True)) + .withColumn(hash_col, F.lit(None).cast("string")) + ) + else: + upd_expr = F.col(updated_at) if updated_at else F.current_timestamp() + new_versions = ( + new_src.withColumn(upd_meta, upd_expr) + .withColumn(vf, F.current_timestamp()) + .withColumn(vt, F.lit(None).cast("timestamp")) + .withColumn(is_cur, F.lit(True)) + .withColumn(hash_col, F.col("__ff_new_hash")) + ) + + parts = [prev_noncurrent, preserved_current, closed_prev, new_versions] + snapshot_df = reduce(lambda a, b: a.unionByName(b, allowMissingColumns=True), parts) + if "__ff_new_hash" in snapshot_df.columns: + snapshot_df = snapshot_df.drop("__ff_new_hash") + + # Break lineage so Spark doesn't see this as "read from and overwrite the same table" + try: + snapshot_df = snapshot_df.localCheckpoint(eager=True) + except Exception: + snapshot_df = snapshot_df.cache() + snapshot_df.count() + + storage_meta = ex._storage_meta(node, rel_name) + ex._save_df_as_table(rel_name, snapshot_df, storage=storage_meta) diff --git a/src/fastflowtransform/snapshots/runtime/duckdb.py b/src/fastflowtransform/snapshots/runtime/duckdb.py new file mode 100644 index 0000000..de0b931 --- /dev/null +++ b/src/fastflowtransform/snapshots/runtime/duckdb.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Protocol + +from fastflowtransform.snapshots.runtime.base import BaseSnapshotRuntime, SnapshotExecutor + + +class DuckSnapshotExecutor(SnapshotExecutor, Protocol): + def _qualified(self, relation: str, *, quoted: bool = True) -> str: ... + + +class DuckSnapshotRuntime(BaseSnapshotRuntime[DuckSnapshotExecutor]): + """ + Snapshot runtime for DuckDB, extracted from the old SnapshotSqlMixin. + """ + + # ---- Engine hooks ----------------------------------------------------- + def _snapshot_target_identifier(self, rel_name: str) -> str: + return self.executor._qualified(rel_name) + + def _snapshot_current_timestamp(self) -> str: + return "current_timestamp" + + def _snapshot_null_timestamp(self) -> str: + return "cast(null as timestamp)" + + def _snapshot_null_hash(self) -> str: + return "cast(null as varchar)" + + def _snapshot_hash_expr(self, check_cols: list[str], src_alias: str) -> str: + concat_expr = self._snapshot_concat_expr(check_cols, src_alias) + return f"cast(md5({concat_expr}) as varchar)" + + def _snapshot_cast_as_string(self, expr: str) -> str: + return f"cast({expr} as varchar)" + + def _snapshot_source_ref( + self, rel_name: str, select_body: str + ) -> tuple[str, Callable[[], None]]: + src_view_name = f"__ff_snapshot_src_{rel_name}".replace(".", "_") + src_quoted = self.executor._quote_identifier(src_view_name) + self.executor._execute_sql(f"create or replace temp view {src_quoted} as {select_body}") + + def _cleanup() -> None: + self.executor._execute_sql(f"drop view if exists {src_quoted}") + + return src_quoted, _cleanup diff --git a/src/fastflowtransform/snapshots/runtime/postgres.py b/src/fastflowtransform/snapshots/runtime/postgres.py new file mode 100644 index 0000000..9b086d0 --- /dev/null +++ b/src/fastflowtransform/snapshots/runtime/postgres.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Protocol + +from fastflowtransform.executors.common import _q_ident +from fastflowtransform.snapshots.runtime.base import BaseSnapshotRuntime, SnapshotExecutor + + +class PostgresSnapshotExecutor(SnapshotExecutor, Protocol): + def _qualified(self, relation: str, *, quoted: bool = True) -> str: ... + + +class PostgresSnapshotRuntime(BaseSnapshotRuntime[PostgresSnapshotExecutor]): + """ + Snapshot runtime for Postgres, extracted from the legacy mixin hooks. + """ + + # ---- Engine hooks ----------------------------------------------------- + def _snapshot_target_identifier(self, rel_name: str) -> str: + return self.executor._qualified(rel_name) + + def _snapshot_current_timestamp(self) -> str: + return "current_timestamp" + + def _snapshot_null_timestamp(self) -> str: + return "cast(null as timestamp)" + + def _snapshot_null_hash(self) -> str: + return "cast(null as text)" + + def _snapshot_hash_expr(self, check_cols: list[str], src_alias: str) -> str: + concat_expr = self._snapshot_concat_expr(check_cols, src_alias) + return f"md5({concat_expr})" + + def _snapshot_cast_as_string(self, expr: str) -> str: + return f"cast({expr} as text)" + + def _snapshot_source_ref( + self, rel_name: str, select_body: str + ) -> tuple[str, Callable[[], None]]: + src_name = f"__ff_snapshot_src_{rel_name}".replace(".", "_") + src_q = _q_ident(src_name) + self.executor._execute_sql(f"drop table if exists {src_q}") + self.executor._execute_sql(f"create temporary table {src_q} as {select_body}") + + def _cleanup() -> None: + self.executor._execute_sql(f"drop table if exists {src_q}") + + return src_q, _cleanup diff --git a/src/fastflowtransform/snapshots/runtime/snowflake_snowpark.py b/src/fastflowtransform/snapshots/runtime/snowflake_snowpark.py new file mode 100644 index 0000000..9c37f27 --- /dev/null +++ b/src/fastflowtransform/snapshots/runtime/snowflake_snowpark.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Protocol + +from fastflowtransform.executors.common import _q_ident +from fastflowtransform.snapshots.runtime.base import BaseSnapshotRuntime, SnapshotExecutor + + +class SnowflakeSnapshotExecutor(SnapshotExecutor, Protocol): + def _qualified(self, relation: str, *, quoted: bool = True) -> str: ... + + +class SnowflakeSnowparkSnapshotRuntime(BaseSnapshotRuntime[SnowflakeSnapshotExecutor]): + """ + Snapshot runtime for Snowflake Snowpark, matching legacy mixin hooks. + """ + + # ---- Engine hooks ----------------------------------------------------- + def _snapshot_target_identifier(self, rel_name: str) -> str: + return self.executor._qualified(rel_name) + + def _snapshot_current_timestamp(self) -> str: + return "CURRENT_TIMESTAMP()" + + def _snapshot_create_keyword(self) -> str: + return "CREATE OR REPLACE TABLE" + + def _snapshot_null_timestamp(self) -> str: + return "CAST(NULL AS TIMESTAMP)" + + def _snapshot_null_hash(self) -> str: + return "CAST(NULL AS VARCHAR)" + + def _snapshot_hash_expr(self, check_cols: list[str], src_alias: str) -> str: + concat_expr = self._snapshot_concat_expr(check_cols, src_alias) + return f"CAST(MD5({concat_expr}) AS VARCHAR)" + + def _snapshot_cast_as_string(self, expr: str) -> str: + return f"CAST({expr} AS VARCHAR)" + + def _snapshot_source_ref( + self, rel_name: str, select_body: str + ) -> tuple[str, Callable[[], None]]: + src_name = f"__ff_snapshot_src_{rel_name}".replace(".", "_") + src_quoted = _q_ident(src_name) + self.executor._execute_sql( + f"CREATE OR REPLACE TEMPORARY VIEW {src_quoted} AS {select_body}" + ) + + def _cleanup() -> None: + self.executor._execute_sql(f"DROP VIEW IF EXISTS {src_quoted}") + + return src_quoted, _cleanup diff --git a/tests/common/fixtures.py b/tests/common/fixtures.py index c81788f..cbb8042 100644 --- a/tests/common/fixtures.py +++ b/tests/common/fixtures.py @@ -46,8 +46,12 @@ # Snowflake try: from fastflowtransform.executors.snowflake_snowpark import SnowflakeSnowparkExecutor + from fastflowtransform.snapshots.runtime.snowflake_snowpark import ( + SnowflakeSnowparkSnapshotRuntime, + ) except ModuleNotFoundError: # pragma: no cover SnowflakeSnowparkExecutor = None # type: ignore[assignment] + SnowflakeSnowparkSnapshotRuntime = None # type: ignore[assignment] # ---- Jinja env ---------------------------------------------------------------- @@ -447,6 +451,9 @@ def snowflake_executor_fake() -> Any: session = FakeSnowflakeSession() ex.session = session + # Wire snapshot runtime to mirror real executor setup. + ex.snapshot_runtime = SnowflakeSnowparkSnapshotRuntime(ex) + return ex diff --git a/tests/common/snapshot_helpers.py b/tests/common/snapshot_helpers.py index 696eaf2..e29d829 100644 --- a/tests/common/snapshot_helpers.py +++ b/tests/common/snapshot_helpers.py @@ -8,15 +8,15 @@ import pandas as pd from fastflowtransform.core import relation_for -from fastflowtransform.executors.base import BaseExecutor +from fastflowtransform.snapshots.runtime.base import BaseSnapshotRuntime SnapshotReadFn = Callable[[Any, str], pd.DataFrame] -VF_COL = BaseExecutor.SNAPSHOT_VALID_FROM_COL -VT_COL = BaseExecutor.SNAPSHOT_VALID_TO_COL -IS_CUR_COL = BaseExecutor.SNAPSHOT_IS_CURRENT_COL -HASH_COL = BaseExecutor.SNAPSHOT_HASH_COL -UPD_META_COL = BaseExecutor.SNAPSHOT_UPDATED_AT_COL +VF_COL = BaseSnapshotRuntime.SNAPSHOT_VALID_FROM_COL +VT_COL = BaseSnapshotRuntime.SNAPSHOT_VALID_TO_COL +IS_CUR_COL = BaseSnapshotRuntime.SNAPSHOT_IS_CURRENT_COL +HASH_COL = BaseSnapshotRuntime.SNAPSHOT_HASH_COL +UPD_META_COL = BaseSnapshotRuntime.SNAPSHOT_UPDATED_AT_COL # ── Node factories ────────────────────────────────────────────────────────── diff --git a/tests/integration/executors/bigquery/test_snapshots_bigquery_integration.py b/tests/integration/executors/bigquery/test_snapshots_bigquery_integration.py index 5048795..664d570 100644 --- a/tests/integration/executors/bigquery/test_snapshots_bigquery_integration.py +++ b/tests/integration/executors/bigquery/test_snapshots_bigquery_integration.py @@ -9,13 +9,13 @@ ) from fastflowtransform.core import relation_for -from fastflowtransform.executors.base import BaseExecutor +from fastflowtransform.snapshots.runtime.base import BaseSnapshotRuntime -VF_COL = BaseExecutor.SNAPSHOT_VALID_FROM_COL -VT_COL = BaseExecutor.SNAPSHOT_VALID_TO_COL -IS_CUR_COL = BaseExecutor.SNAPSHOT_IS_CURRENT_COL -HASH_COL = BaseExecutor.SNAPSHOT_HASH_COL -UPD_META_COL = BaseExecutor.SNAPSHOT_UPDATED_AT_COL +VF_COL = BaseSnapshotRuntime.SNAPSHOT_VALID_FROM_COL +VT_COL = BaseSnapshotRuntime.SNAPSHOT_VALID_TO_COL +IS_CUR_COL = BaseSnapshotRuntime.SNAPSHOT_IS_CURRENT_COL +HASH_COL = BaseSnapshotRuntime.SNAPSHOT_HASH_COL +UPD_META_COL = BaseSnapshotRuntime.SNAPSHOT_UPDATED_AT_COL # Simple SQL bodies - they're never actually executed by a real engine, # we only inspect the resulting BigQuery SQL sent to the fake client. diff --git a/tests/integration/executors/databricks_spark/test_snapshots_databricks_spark_integration.py b/tests/integration/executors/databricks_spark/test_snapshots_databricks_spark_integration.py index 4190a14..66da747 100644 --- a/tests/integration/executors/databricks_spark/test_snapshots_databricks_spark_integration.py +++ b/tests/integration/executors/databricks_spark/test_snapshots_databricks_spark_integration.py @@ -59,7 +59,11 @@ def _reset_snapshot_table(executor, node_name: str) -> None: def _read_spark(ex: DatabricksSparkExecutor, relation: str): physical = ex._physical_identifier(relation) - return ex.spark.table(physical).toPandas().sort_values(["id", ex.SNAPSHOT_VALID_FROM_COL]) + return ( + ex.spark.table(physical) + .toPandas() + .sort_values(["id", ex.snapshot_runtime.SNAPSHOT_VALID_FROM_COL]) + ) @pytest.mark.databricks_spark diff --git a/tests/integration/executors/snowflake_snowpark/test_snapshots_snowflake_snowpark_integration.py b/tests/integration/executors/snowflake_snowpark/test_snapshots_snowflake_snowpark_integration.py index 4d8c57f..6327397 100644 --- a/tests/integration/executors/snowflake_snowpark/test_snapshots_snowflake_snowpark_integration.py +++ b/tests/integration/executors/snowflake_snowpark/test_snapshots_snowflake_snowpark_integration.py @@ -11,7 +11,7 @@ patch_render_sql, ) -from fastflowtransform.executors.base import BaseExecutor +from fastflowtransform.snapshots.runtime.base import BaseSnapshotRuntime SQL_TS_FIRST = """ select 1 as id, @@ -60,10 +60,10 @@ def test_snowflake_timestamp_snapshot_emits_create_table( sql = _all_sql(ex).upper() assert "CREATE OR REPLACE TABLE" in sql - assert BaseExecutor.SNAPSHOT_VALID_FROM_COL.upper() in sql - assert BaseExecutor.SNAPSHOT_VALID_TO_COL.upper() in sql - assert BaseExecutor.SNAPSHOT_IS_CURRENT_COL.upper() in sql - assert BaseExecutor.SNAPSHOT_UPDATED_AT_COL.upper() in sql + assert BaseSnapshotRuntime.SNAPSHOT_VALID_FROM_COL.upper() in sql + assert BaseSnapshotRuntime.SNAPSHOT_VALID_TO_COL.upper() in sql + assert BaseSnapshotRuntime.SNAPSHOT_IS_CURRENT_COL.upper() in sql + assert BaseSnapshotRuntime.SNAPSHOT_UPDATED_AT_COL.upper() in sql # timestamp strategy should not rely on a hash column # but we allow the implementation to include it if desired @@ -82,7 +82,7 @@ def test_snowflake_check_snapshot_emits_hash_column( sql = _all_sql(ex).upper() # first run should create the table with a hash column in the projection assert "CREATE OR REPLACE TABLE" in sql - assert BaseExecutor.SNAPSHOT_HASH_COL.upper() in sql + assert BaseSnapshotRuntime.SNAPSHOT_HASH_COL.upper() in sql @pytest.mark.snowflake_snowpark diff --git a/tests/unit/executors/test_duckdb_exec_unit.py b/tests/unit/executors/test_duckdb_exec_unit.py index 3f87354..d1e211c 100644 --- a/tests/unit/executors/test_duckdb_exec_unit.py +++ b/tests/unit/executors/test_duckdb_exec_unit.py @@ -7,7 +7,8 @@ import pytest from fastflowtransform.core import Node -from fastflowtransform.executors.duckdb import DuckExecutor, _q +from fastflowtransform.executors.common import _q_ident +from fastflowtransform.executors.duckdb import DuckExecutor @pytest.fixture @@ -101,14 +102,14 @@ def test_create_or_replace_view_from_table(duck_exec: DuckExecutor): def test_format_relation_for_ref(duck_exec: DuckExecutor): rel = duck_exec._format_relation_for_ref("my_model") # relation_for("my_model") → "my_model" - assert rel == _q("my_model") + assert rel == _q_ident("my_model") @pytest.mark.unit @pytest.mark.duckdb def test_format_relation_for_ref_with_schema(duck_exec_schema: DuckExecutor): rel = duck_exec_schema._format_relation_for_ref("my_model") - assert rel == f'"demo_schema".{_q("my_model")}' + assert rel == f'"demo_schema".{_q_ident("my_model")}' @pytest.mark.unit diff --git a/tests/unit/executors/test_postgres_exec_unit.py b/tests/unit/executors/test_postgres_exec_unit.py index 8b20c56..90d0326 100644 --- a/tests/unit/executors/test_postgres_exec_unit.py +++ b/tests/unit/executors/test_postgres_exec_unit.py @@ -142,7 +142,6 @@ def _fake_create_engine(dsn, future=True): @pytest.mark.postgres def test_q_ident_and_qualified(monkeypatch, fake_engine_and_conn): ex = PostgresExecutor("postgresql+psycopg://x", schema="public") - assert ex._q_ident('t"b') == '"t""b"' assert ex._qualified("tbl") == '"public"."tbl"' assert ex._qualified("tbl", schema="x") == '"x"."tbl"' # with no schema diff --git a/tests/unit/executors/test_snowflake_snowpark_exec.py b/tests/unit/executors/test_snowflake_snowpark_exec.py index 5dcc7a6..d6ad961 100644 --- a/tests/unit/executors/test_snowflake_snowpark_exec.py +++ b/tests/unit/executors/test_snowflake_snowpark_exec.py @@ -187,7 +187,6 @@ def test_init_sets_db_schema_and_con(sf_exec): @pytest.mark.unit @pytest.mark.snowflake_snowpark def test_q_and_qualified(sf_exec): - assert sf_exec._q("x") == '"x"' assert sf_exec._qualified("TBL") == "DB1.SC1.TBL" diff --git a/uv.lock b/uv.lock index 9cb4ee5..22707db 100644 --- a/uv.lock +++ b/uv.lock @@ -733,7 +733,7 @@ wheels = [ [[package]] name = "fastflowtransform" -version = "0.6.13" +version = "0.6.14" source = { editable = "." } dependencies = [ { name = "duckdb" }, From de4c91dbde3d59d49be5cc40d43bb28b3801aad1 Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Tue, 16 Dec 2025 18:58:41 +0100 Subject: [PATCH 02/10] Refactored query_stats and budgets for duckdb --- src/fastflowtransform/cli/run.py | 2 +- .../executors/_budget_runner.py | 4 +- .../executors/_query_stats_adapter.py | 2 +- src/fastflowtransform/executors/base.py | 4 +- .../executors/bigquery/base.py | 4 +- .../executors/bigquery/pandas.py | 2 +- .../executors/budget/__init__.py | 0 .../executors/{budget.py => budget/core.py} | 0 .../executors/budget/runtime/__init__.py | 4 + .../executors/budget/runtime/base.py | 87 ++++ .../executors/budget/runtime/duckdb.py | 315 ++++++++++++++ .../executors/databricks_spark.py | 2 +- src/fastflowtransform/executors/duckdb.py | 388 +++--------------- src/fastflowtransform/executors/postgres.py | 4 +- .../executors/query_stats/__init__.py | 0 .../{query_stats.py => query_stats/core.py} | 0 .../executors/query_stats/runtime/__init__.py | 4 + .../executors/query_stats/runtime/base.py | 91 ++++ .../executors/query_stats/runtime/duckdb.py | 33 ++ .../executors/snowflake_snowpark.py | 4 +- 20 files changed, 602 insertions(+), 348 deletions(-) create mode 100644 src/fastflowtransform/executors/budget/__init__.py rename src/fastflowtransform/executors/{budget.py => budget/core.py} (100%) create mode 100644 src/fastflowtransform/executors/budget/runtime/__init__.py create mode 100644 src/fastflowtransform/executors/budget/runtime/base.py create mode 100644 src/fastflowtransform/executors/budget/runtime/duckdb.py create mode 100644 src/fastflowtransform/executors/query_stats/__init__.py rename src/fastflowtransform/executors/{query_stats.py => query_stats/core.py} (100%) create mode 100644 src/fastflowtransform/executors/query_stats/runtime/__init__.py create mode 100644 src/fastflowtransform/executors/query_stats/runtime/base.py create mode 100644 src/fastflowtransform/executors/query_stats/runtime/duckdb.py diff --git a/src/fastflowtransform/cli/run.py b/src/fastflowtransform/cli/run.py index 120ff07..412f170 100644 --- a/src/fastflowtransform/cli/run.py +++ b/src/fastflowtransform/cli/run.py @@ -65,7 +65,7 @@ from fastflowtransform.core import REGISTRY, Node, relation_for from fastflowtransform.dag import levels as dag_levels from fastflowtransform.executors.base import BaseExecutor -from fastflowtransform.executors.budget import format_bytes +from fastflowtransform.executors.budget.core import format_bytes from fastflowtransform.fingerprint import ( EnvCtx, build_env_ctx, diff --git a/src/fastflowtransform/executors/_budget_runner.py b/src/fastflowtransform/executors/_budget_runner.py index 8817f68..9568f8f 100644 --- a/src/fastflowtransform/executors/_budget_runner.py +++ b/src/fastflowtransform/executors/_budget_runner.py @@ -6,8 +6,8 @@ from typing import Any from fastflowtransform.executors._query_stats_adapter import QueryStatsAdapter, RowcountStatsAdapter -from fastflowtransform.executors.budget import BudgetGuard -from fastflowtransform.executors.query_stats import QueryStats +from fastflowtransform.executors.budget.core import BudgetGuard +from fastflowtransform.executors.query_stats.core import QueryStats def run_sql_with_budget( diff --git a/src/fastflowtransform/executors/_query_stats_adapter.py b/src/fastflowtransform/executors/_query_stats_adapter.py index f2a01c1..eacc54a 100644 --- a/src/fastflowtransform/executors/_query_stats_adapter.py +++ b/src/fastflowtransform/executors/_query_stats_adapter.py @@ -4,7 +4,7 @@ from collections.abc import Callable from typing import Any, Protocol -from fastflowtransform.executors.query_stats import QueryStats +from fastflowtransform.executors.query_stats.core import QueryStats class QueryStatsAdapter(Protocol): diff --git a/src/fastflowtransform/executors/base.py b/src/fastflowtransform/executors/base.py index 899c54f..ee55e66 100644 --- a/src/fastflowtransform/executors/base.py +++ b/src/fastflowtransform/executors/base.py @@ -21,8 +21,8 @@ from fastflowtransform.core import REGISTRY, Node, relation_for from fastflowtransform.errors import ModelExecutionError from fastflowtransform.executors._query_stats_adapter import JobStatsAdapter -from fastflowtransform.executors.budget import BudgetGuard -from fastflowtransform.executors.query_stats import QueryStats +from fastflowtransform.executors.budget.core import BudgetGuard +from fastflowtransform.executors.query_stats.core import QueryStats from fastflowtransform.incremental import _normalize_unique_key from fastflowtransform.logging import echo, echo_debug from fastflowtransform.validation import validate_required_columns diff --git a/src/fastflowtransform/executors/bigquery/base.py b/src/fastflowtransform/executors/bigquery/base.py index 7d6e708..5eeade5 100644 --- a/src/fastflowtransform/executors/bigquery/base.py +++ b/src/fastflowtransform/executors/bigquery/base.py @@ -9,8 +9,8 @@ from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable from fastflowtransform.executors.base import BaseExecutor -from fastflowtransform.executors.budget import BudgetGuard -from fastflowtransform.executors.query_stats import _TrackedQueryJob +from fastflowtransform.executors.budget.core import BudgetGuard +from fastflowtransform.executors.query_stats.core import _TrackedQueryJob from fastflowtransform.meta import ensure_meta_table, upsert_meta from fastflowtransform.snapshots.runtime.bigquery import BigQuerySnapshotRuntime from fastflowtransform.typing import BadRequest, Client, NotFound, bigquery diff --git a/src/fastflowtransform/executors/bigquery/pandas.py b/src/fastflowtransform/executors/bigquery/pandas.py index 0a23e0b..c672395 100644 --- a/src/fastflowtransform/executors/bigquery/pandas.py +++ b/src/fastflowtransform/executors/bigquery/pandas.py @@ -10,7 +10,7 @@ from fastflowtransform.contracts.runtime.bigquery import BigQueryRuntimeContracts from fastflowtransform.core import Node from fastflowtransform.executors.bigquery.base import BigQueryBaseExecutor -from fastflowtransform.executors.query_stats import QueryStats +from fastflowtransform.executors.query_stats.core import QueryStats from fastflowtransform.typing import BadRequest, Client, LoadJobConfig, NotFound, bigquery diff --git a/src/fastflowtransform/executors/budget/__init__.py b/src/fastflowtransform/executors/budget/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/fastflowtransform/executors/budget.py b/src/fastflowtransform/executors/budget/core.py similarity index 100% rename from src/fastflowtransform/executors/budget.py rename to src/fastflowtransform/executors/budget/core.py diff --git a/src/fastflowtransform/executors/budget/runtime/__init__.py b/src/fastflowtransform/executors/budget/runtime/__init__.py new file mode 100644 index 0000000..35b288f --- /dev/null +++ b/src/fastflowtransform/executors/budget/runtime/__init__.py @@ -0,0 +1,4 @@ +from fastflowtransform.executors.budget.runtime.base import BaseBudgetRuntime +from fastflowtransform.executors.budget.runtime.duckdb import DuckBudgetRuntime + +__all__ = ["BaseBudgetRuntime", "DuckBudgetRuntime"] diff --git a/src/fastflowtransform/executors/budget/runtime/base.py b/src/fastflowtransform/executors/budget/runtime/base.py new file mode 100644 index 0000000..d924390 --- /dev/null +++ b/src/fastflowtransform/executors/budget/runtime/base.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from collections.abc import Callable +from contextlib import suppress +from time import perf_counter +from typing import Any, Protocol, TypeVar + +from fastflowtransform.executors._query_stats_adapter import QueryStatsAdapter, RowcountStatsAdapter +from fastflowtransform.executors.budget.core import BudgetGuard +from fastflowtransform.executors.query_stats.runtime.base import BaseQueryStatsRuntime + + +class BudgetExecutor(Protocol): + """Minimal executor surface used by budget runtimes.""" + + def _apply_budget_guard(self, guard: BudgetGuard | None, sql: str) -> int | None: ... + def _is_budget_guard_active(self) -> bool: ... + + +E = TypeVar("E", bound=BudgetExecutor) + + +class BaseBudgetRuntime[E: BudgetExecutor]: + """ + Base runtime for per-query budget enforcement. + + Executors compose this (like runtime contracts) and delegate guarded + execution through it. + """ + + executor: E + guard: BudgetGuard | None + + def __init__(self, executor: E, guard: BudgetGuard | None): + self.executor = executor + self.guard = guard + + def apply_guard(self, sql: str) -> int | None: + return self.executor._apply_budget_guard(self.guard, sql) + + def run_sql( + self, + sql: str, + *, + exec_fn: Callable[[], Any], + stats_runtime: BaseQueryStatsRuntime, + rowcount_extractor: Callable[[Any], int | None] | None = None, + extra_stats: Callable[[Any], Any] | None = None, + estimate_fn: Callable[[str], int | None] | None = None, + post_estimate_fn: Callable[[str, Any], int | None] | None = None, + record_stats: bool = True, + stats_adapter: QueryStatsAdapter | None = None, + ) -> Any: + estimated_bytes = self.apply_guard(sql) + if ( + estimated_bytes is None + and not self.executor._is_budget_guard_active() + and estimate_fn is not None + ): + with suppress(Exception): + estimated_bytes = estimate_fn(sql) + + if not record_stats: + return exec_fn() + + started = perf_counter() + result = exec_fn() + duration_ms = int((perf_counter() - started) * 1000) + + adapter = stats_adapter + if adapter is None and (rowcount_extractor or post_estimate_fn or extra_stats): + adapter = RowcountStatsAdapter( + rowcount_extractor=rowcount_extractor, + post_estimate_fn=post_estimate_fn, + extra_stats=extra_stats, + sql=sql, + ) + + stats_runtime.record_result( + result, + duration_ms=duration_ms, + estimated_bytes=estimated_bytes, + adapter=adapter, + sql=sql, + ) + + return result diff --git a/src/fastflowtransform/executors/budget/runtime/duckdb.py b/src/fastflowtransform/executors/budget/runtime/duckdb.py new file mode 100644 index 0000000..933b327 --- /dev/null +++ b/src/fastflowtransform/executors/budget/runtime/duckdb.py @@ -0,0 +1,315 @@ +from __future__ import annotations + +import json +import re +from collections.abc import Iterable +from typing import Any, ClassVar, Protocol + +from fastflowtransform.executors.budget.core import BudgetGuard +from fastflowtransform.executors.budget.runtime.base import BaseBudgetRuntime, BudgetExecutor + + +class DuckBudgetExecutor(BudgetExecutor, Protocol): + schema: str | None + + def _execute_fetchall(self, sql: str, params: Any | None = None) -> list[Any]: ... + def _selectable_body(self, sql: str) -> str: ... + + +class DuckBudgetRuntime(BaseBudgetRuntime[DuckBudgetExecutor]): + """DuckDB-specific budget runtime with plan-based estimation.""" + + _FIXED_TYPE_SIZES: ClassVar[dict[str, int]] = { + "boolean": 1, + "bool": 1, + "tinyint": 1, + "smallint": 2, + "integer": 4, + "int": 4, + "bigint": 8, + "float": 4, + "real": 4, + "double": 8, + "double precision": 8, + "decimal": 16, + "numeric": 16, + "uuid": 16, + "json": 64, + "jsonb": 64, + "timestamp": 8, + "timestamp_ntz": 8, + "timestamp_ltz": 8, + "timestamptz": 8, + "date": 4, + "time": 4, + "interval": 16, + } + _VARCHAR_DEFAULT_WIDTH = 64 + _VARCHAR_MAX_WIDTH = 1024 + _DEFAULT_ROW_WIDTH = 128 + + def __init__(self, executor: DuckBudgetExecutor, guard: BudgetGuard | None): + super().__init__(executor, guard) + self._table_row_width_cache: dict[tuple[str | None, str], int] = {} + + # ------------------------------------------------------------------ # + # Cost estimation used by BudgetGuard # + # ------------------------------------------------------------------ # + + def estimate_query_bytes(self, sql: str) -> int | None: + """ + Estimate query size via DuckDB's EXPLAIN (FORMAT JSON). + """ + # Try to normalize to a SELECT/CTE body if the executor exposes it + body = self.executor._selectable_body(sql).strip().rstrip(";\n\t ") + + lower = body.lower() + if not lower.startswith(("select", "with")): + return None + + explain_sql = f"EXPLAIN (FORMAT JSON) {body}" + try: + rows = self.executor._execute_fetchall(explain_sql) + except Exception: + return None + + if not rows: + return None + + fragments: list[str] = [] + for row in rows: + for cell in row: + if cell is None: + continue + fragments.append(str(cell)) + + if not fragments: + return None + + plan_text = "\n".join(fragments).strip() + start = plan_text.find("[") + end = plan_text.rfind("]") + if start == -1 or end == -1 or end <= start: + return None + + try: + plan_data = json.loads(plan_text[start : end + 1]) + except Exception: + return None + + estimate = self._max_cardinality(plan_data) + if estimate <= 0: + return None + + tables = self._collect_tables_from_plan( + plan_data if isinstance(plan_data, list) else [plan_data] + ) + row_width = self._row_width_for_tables(tables) + if row_width <= 0: + row_width = self._DEFAULT_ROW_WIDTH + + bytes_estimate = int(estimate * row_width) + return bytes_estimate if bytes_estimate > 0 else None + + def _max_cardinality(self, plan_data: Any) -> int: + def _to_int(value: Any) -> int | None: + if value is None: + return None + if isinstance(value, (int, float)): + try: + converted = int(value) + except Exception: + return None + return converted + text = str(value) + match = re.search(r"(\d+(?:\.\d+)?)", text) + if not match: + return None + try: + return int(float(match.group(1))) + except ValueError: + return None + + def _walk_node(node: dict[str, Any]) -> int: + best = 0 + extra = node.get("extra_info") or {} + for key in ( + "Estimated Cardinality", + "estimated_cardinality", + "Cardinality", + "cardinality", + ): + candidate = _to_int(extra.get(key)) + if candidate is not None: + best = max(best, candidate) + candidate = _to_int(node.get("cardinality")) + if candidate is not None: + best = max(best, candidate) + for child in node.get("children") or []: + if isinstance(child, dict): + best = max(best, _walk_node(child)) + return best + + nodes = plan_data if isinstance(plan_data, list) else [plan_data] + + estimate = 0 + for entry in nodes: + if isinstance(entry, dict): + estimate = max(estimate, _walk_node(entry)) + return estimate + + def _collect_tables_from_plan(self, nodes: list[dict[str, Any]]) -> set[tuple[str | None, str]]: + tables: set[tuple[str | None, str]] = set() + + def _walk(entry: dict[str, Any]) -> None: + extra = entry.get("extra_info") or {} + table_val = extra.get("Table") + schema_val = extra.get("Schema") or extra.get("Database") or extra.get("Catalog") + if isinstance(table_val, str) and table_val.strip(): + schema, table = self._split_identifier(table_val, schema_val) + if table: + tables.add((schema, table)) + for child in entry.get("children") or []: + if isinstance(child, dict): + _walk(child) + + for node in nodes: + if isinstance(node, dict): + _walk(node) + return tables + + def _split_identifier( + self, identifier: str, explicit_schema: str | None + ) -> tuple[str | None, str]: + parts = [part.strip() for part in identifier.split(".") if part.strip()] + if not parts: + return explicit_schema, identifier + if len(parts) >= 2: + schema_candidate = self._strip_quotes(parts[-2]) + table_candidate = self._strip_quotes(parts[-1]) + return schema_candidate or explicit_schema, table_candidate + return explicit_schema, self._strip_quotes(parts[-1]) + + def _strip_quotes(self, value: str) -> str: + if value.startswith('"') and value.endswith('"'): + return value[1:-1] + return value + + def _row_width_for_tables(self, tables: Iterable[tuple[str | None, str]]) -> int: + widths: list[int] = [] + for schema, table in tables: + width = self._row_width_for_table(schema, table) + if width > 0: + widths.append(width) + return max(widths) if widths else 0 + + def _row_width_for_table(self, schema: str | None, table: str) -> int: + key = (schema or "", table.lower()) + cached = self._table_row_width_cache.get(key) + if cached: + return cached + + columns = self._columns_for_table(table, schema) + width = sum(self._estimate_column_width(col) for col in columns) + if width <= 0: + width = self._DEFAULT_ROW_WIDTH + self._table_row_width_cache[key] = width + return width + + def _columns_for_table( + self, table: str, schema: str | None + ) -> list[tuple[str | None, int | None, int | None, int | None]]: + table_lower = table.lower() + columns: list[tuple[str | None, int | None, int | None, int | None]] = [] + seen_schemas: set[str | None] = set() + for candidate in self._schema_candidates(schema): + if candidate in seen_schemas: + continue + seen_schemas.add(candidate) + try: + if candidate is not None: + rows = self.executor._execute_fetchall( + """ + select lower(data_type) as dtype, + character_maximum_length, + numeric_precision, + numeric_scale + from information_schema.columns + where lower(table_name)=lower(?) + and lower(table_schema)=lower(?) + order by ordinal_position + """, + [table_lower, candidate.lower()], + ) + else: + rows = self.executor._execute_fetchall( + """ + select lower(data_type) as dtype, + character_maximum_length, + numeric_precision, + numeric_scale + from information_schema.columns + where lower(table_name)=lower(?) + order by lower(table_schema), ordinal_position + """, + [table_lower], + ) + except Exception: + continue + if rows: + return rows + return columns + + def _schema_candidates(self, schema: str | None) -> list[str | None]: + candidates: list[str | None] = [] + + def _add(value: str | None) -> None: + normalized = self._normalize_schema(value) + if normalized not in candidates: + candidates.append(normalized) + + _add(schema) + _add(getattr(self.executor, "schema", None)) + for alt in ("main", "temp"): + _add(alt) + _add(None) + return candidates + + def _normalize_schema(self, schema: str | None) -> str | None: + if not schema: + return None + stripped = schema.strip() + return stripped or None + + def _estimate_column_width( + self, column_info: tuple[str | None, int | None, int | None, int | None] + ) -> int: + dtype_raw, char_max, numeric_precision, _ = column_info + dtype = self._normalize_data_type(dtype_raw) + if dtype and dtype in self._FIXED_TYPE_SIZES: + return self._FIXED_TYPE_SIZES[dtype] + + if dtype in {"character", "varchar", "char", "text", "string"}: + if char_max and char_max > 0: + return min(char_max, self._VARCHAR_MAX_WIDTH) + return self._VARCHAR_DEFAULT_WIDTH + + if dtype in {"varbinary", "blob", "binary"}: + if char_max and char_max > 0: + return min(char_max, self._VARCHAR_MAX_WIDTH) + return self._VARCHAR_DEFAULT_WIDTH + + if dtype in {"numeric", "decimal"} and numeric_precision and numeric_precision > 0: + return min(max(int(numeric_precision), 16), 128) + + return 16 + + def _normalize_data_type(self, dtype: str | None) -> str | None: + if not dtype: + return None + stripped = dtype.strip().lower() + if "(" in stripped: + stripped = stripped.split("(", 1)[0].strip() + if stripped.endswith("[]"): + stripped = stripped[:-2] + return stripped or None diff --git a/src/fastflowtransform/executors/databricks_spark.py b/src/fastflowtransform/executors/databricks_spark.py index 27e86a6..e2ba7e2 100644 --- a/src/fastflowtransform/executors/databricks_spark.py +++ b/src/fastflowtransform/executors/databricks_spark.py @@ -19,7 +19,7 @@ from fastflowtransform.executors._query_stats_adapter import SparkDataFrameStatsAdapter from fastflowtransform.executors._test_utils import make_fetchable, rows_to_tuples from fastflowtransform.executors.base import BaseExecutor -from fastflowtransform.executors.budget import BudgetGuard +from fastflowtransform.executors.budget.core import BudgetGuard from fastflowtransform.logging import echo_debug from fastflowtransform.meta import ensure_meta_table, upsert_meta from fastflowtransform.snapshots.runtime.databricks_spark import DatabricksSparkSnapshotRuntime diff --git a/src/fastflowtransform/executors/duckdb.py b/src/fastflowtransform/executors/duckdb.py index 0d62175..c416a1d 100644 --- a/src/fastflowtransform/executors/duckdb.py +++ b/src/fastflowtransform/executors/duckdb.py @@ -1,13 +1,11 @@ # fastflowtransform/executors/duckdb.py from __future__ import annotations -import json -import re import uuid from collections.abc import Iterable from contextlib import suppress from pathlib import Path -from typing import Any, ClassVar +from typing import Any, cast import duckdb import pandas as pd @@ -16,12 +14,13 @@ from fastflowtransform.contracts.runtime.duckdb import DuckRuntimeContracts from fastflowtransform.core import Node -from fastflowtransform.executors._budget_runner import run_sql_with_budget from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable from fastflowtransform.executors.base import BaseExecutor, _scalar -from fastflowtransform.executors.budget import BudgetGuard +from fastflowtransform.executors.budget.core import BudgetGuard +from fastflowtransform.executors.budget.runtime.duckdb import DuckBudgetRuntime from fastflowtransform.executors.common import _q_ident +from fastflowtransform.executors.query_stats.runtime.duckdb import DuckQueryStatsRuntime from fastflowtransform.meta import ensure_meta_table, upsert_meta from fastflowtransform.snapshots.runtime.duckdb import DuckSnapshotRuntime @@ -29,36 +28,10 @@ class DuckExecutor(SqlIdentifierMixin, BaseExecutor[pd.DataFrame]): ENGINE_NAME: str = "duckdb" runtime_contracts: DuckRuntimeContracts + runtime_query_stats: DuckQueryStatsRuntime + runtime_budget: DuckBudgetRuntime snapshot_runtime: DuckSnapshotRuntime - _FIXED_TYPE_SIZES: ClassVar[dict[str, int]] = { - "boolean": 1, - "bool": 1, - "tinyint": 1, - "smallint": 2, - "integer": 4, - "int": 4, - "bigint": 8, - "float": 4, - "real": 4, - "double": 8, - "double precision": 8, - "decimal": 16, - "numeric": 16, - "uuid": 16, - "json": 64, - "jsonb": 64, - "timestamp": 8, - "timestamp_ntz": 8, - "timestamp_ltz": 8, - "timestamptz": 8, - "date": 4, - "time": 4, - "interval": 16, - } - _VARCHAR_DEFAULT_WIDTH = 64 - _VARCHAR_MAX_WIDTH = 1024 - _DEFAULT_ROW_WIDTH = 128 _BUDGET_GUARD = BudgetGuard( env_var="FF_DUCKDB_MAX_BYTES", estimator_attr="_estimate_query_bytes", @@ -77,19 +50,35 @@ def __init__( self.schema = schema.strip() if isinstance(schema, str) and schema.strip() else None catalog_override = catalog.strip() if isinstance(catalog, str) and catalog.strip() else None self.catalog = self._detect_catalog() - self._table_row_width_cache: dict[tuple[str | None, str], int] = {} if catalog_override: if self._apply_catalog_override(catalog_override): self.catalog = catalog_override else: self.catalog = self._detect_catalog() + self.runtime_query_stats = DuckQueryStatsRuntime(self) + self.runtime_budget = DuckBudgetRuntime(self, self._BUDGET_GUARD) + self.runtime_contracts = DuckRuntimeContracts(self) + self.snapshot_runtime = DuckSnapshotRuntime(self) + if self.schema: safe_schema = _q_ident(self.schema) - self._execute_sql(f"create schema if not exists {safe_schema}") - self._execute_sql(f"set schema '{self.schema}'") + self._execute_basic(f"create schema if not exists {safe_schema}") + self._execute_basic(f"set schema '{self.schema}'") - self.runtime_contracts = DuckRuntimeContracts(self) - self.snapshot_runtime = DuckSnapshotRuntime(self) + def _execute_basic(self, sql: str, params: Any | None = None) -> duckdb.DuckDBPyConnection: + """ + Minimal helper to execute a statement and return the DuckDB cursor. + Centralises raw connection use for test + runtime helpers. + """ + return self.con.execute(sql, params) if params is not None else self.con.execute(sql) + + def _execute_fetchall(self, sql: str, params: Any | None = None) -> list[Any]: + """ + Helper for runtimes that need full result sets without exposing cursors. + """ + res = self._execute_basic(sql, params) + fetchall = getattr(res, "fetchall", None) + return list(cast(Iterable[Any], fetchall())) if callable(fetchall) else [] def execute_test_sql(self, stmt: Any) -> Any: """ @@ -104,15 +93,15 @@ def _run_one(s: Any) -> Any: and isinstance(s[0], str) and isinstance(s[1], dict) ): - return self.con.execute(s[0], s[1]) + return self._execute_basic(s[0], s[1]) if isinstance(s, str): - return self.con.execute(s) + return self._execute_basic(s) if isinstance(s, Iterable) and not isinstance(s, (bytes, bytearray, str)): res = None for item in s: res = _run_one(item) return res - return self.con.execute(str(s)) + return self._execute_basic(str(s)) return make_fetchable(_run_one(stmt)) @@ -138,16 +127,12 @@ def _exec() -> duckdb.DuckDBPyConnection: return self.con.execute(sql, *args, **kwargs) def _rows(result: Any) -> int | None: - rc = getattr(result, "rowcount", None) - if isinstance(rc, int) and rc >= 0: - return rc - return None + return self.runtime_query_stats.rowcount_from_result(result) - return run_sql_with_budget( - self, + return self.runtime_budget.run_sql( sql, - guard=self._BUDGET_GUARD, exec_fn=_exec, + stats_runtime=self.runtime_query_stats, rowcount_extractor=_rows, estimate_fn=self._estimate_query_bytes, ) @@ -155,277 +140,12 @@ def _rows(result: Any) -> int | None: # --- Cost estimation for the shared BudgetGuard ----------------- def _estimate_query_bytes(self, sql: str) -> int | None: - """ - Estimate query size via DuckDB's EXPLAIN (FORMAT JSON). - - The JSON plan exposes an \"Estimated Cardinality\" per node. - We walk the parsed tree, take the highest non-zero estimate and - return it as a byte-estimate surrogate (row count ≈ bytes) so the - cost guard can still make a meaningful decision without executing - the query. - """ - try: - body = self._selectable_body(sql).strip().rstrip(";\n\t ") - except AttributeError: - body = sql.strip().rstrip(";\n\t ") - - lower = body.lower() - if not lower.startswith(("select", "with")): - return None - - explain_sql = f"EXPLAIN (FORMAT JSON) {body}" - try: - rows = self.con.execute(explain_sql).fetchall() - except Exception: - return None - - if not rows: - return None - - fragments: list[str] = [] - for row in rows: - for cell in row: - if cell is None: - continue - fragments.append(str(cell)) - - if not fragments: - return None - - plan_text = "\n".join(fragments).strip() - start = plan_text.find("[") - end = plan_text.rfind("]") - if start == -1 or end == -1 or end <= start: - return None - - try: - plan_data = json.loads(plan_text[start : end + 1]) - except Exception: - return None - - def _to_int(value: Any) -> int | None: - if value is None: - return None - if isinstance(value, (int, float)): - try: - converted = int(value) - except Exception: - return None - return converted - text = str(value) - match = re.search(r"(\d+(?:\.\d+)?)", text) - if not match: - return None - try: - return int(float(match.group(1))) - except ValueError: - return None - - def _walk_node(node: dict[str, Any]) -> int: - best = 0 - extra = node.get("extra_info") or {} - for key in ( - "Estimated Cardinality", - "estimated_cardinality", - "Cardinality", - "cardinality", - ): - candidate = _to_int(extra.get(key)) - if candidate is not None: - best = max(best, candidate) - candidate = _to_int(node.get("cardinality")) - if candidate is not None: - best = max(best, candidate) - for child in node.get("children") or []: - if isinstance(child, dict): - best = max(best, _walk_node(child)) - return best - - nodes: list[Any] - nodes = plan_data if isinstance(plan_data, list) else [plan_data] - - estimate = 0 - for entry in nodes: - if isinstance(entry, dict): - estimate = max(estimate, _walk_node(entry)) - - if estimate <= 0: - return None - - tables = self._collect_tables_from_plan(nodes) - row_width = self._row_width_for_tables(tables) - if row_width <= 0: - row_width = self._DEFAULT_ROW_WIDTH - - bytes_estimate = int(estimate * row_width) - return bytes_estimate if bytes_estimate > 0 else None - - def _collect_tables_from_plan(self, nodes: list[dict[str, Any]]) -> set[tuple[str | None, str]]: - tables: set[tuple[str | None, str]] = set() - - def _walk(entry: dict[str, Any]) -> None: - extra = entry.get("extra_info") or {} - table_val = extra.get("Table") - schema_val = extra.get("Schema") or extra.get("Database") or extra.get("Catalog") - if isinstance(table_val, str) and table_val.strip(): - schema, table = self._split_identifier(table_val, schema_val) - if table: - tables.add((schema, table)) - for child in entry.get("children") or []: - if isinstance(child, dict): - _walk(child) - - for node in nodes: - if isinstance(node, dict): - _walk(node) - return tables - - def _split_identifier( - self, identifier: str, explicit_schema: str | None - ) -> tuple[str | None, str]: - parts = [part.strip() for part in identifier.split(".") if part.strip()] - if not parts: - return explicit_schema, identifier - if len(parts) >= 2: - schema_candidate = self._strip_quotes(parts[-2]) - table_candidate = self._strip_quotes(parts[-1]) - return schema_candidate or explicit_schema, table_candidate - return explicit_schema, self._strip_quotes(parts[-1]) - - def _strip_quotes(self, value: str) -> str: - if value.startswith('"') and value.endswith('"'): - return value[1:-1] - return value - - def _row_width_for_tables(self, tables: Iterable[tuple[str | None, str]]) -> int: - widths: list[int] = [] - for schema, table in tables: - width = self._row_width_for_table(schema, table) - if width > 0: - widths.append(width) - return max(widths) if widths else 0 - - def _row_width_for_table(self, schema: str | None, table: str) -> int: - key = (schema or "", table.lower()) - cached = self._table_row_width_cache.get(key) - if cached: - return cached - - columns = self._columns_for_table(table, schema) - width = sum(self._estimate_column_width(col) for col in columns) - if width <= 0: - width = self._DEFAULT_ROW_WIDTH - self._table_row_width_cache[key] = width - return width - - def _columns_for_table( - self, table: str, schema: str | None - ) -> list[tuple[str | None, int | None, int | None, int | None]]: - table_lower = table.lower() - columns: list[tuple[str | None, int | None, int | None, int | None]] = [] - seen_schemas: set[str | None] = set() - for candidate in self._schema_candidates(schema): - if candidate in seen_schemas: - continue - seen_schemas.add(candidate) - if candidate is not None: - try: - rows = self.con.execute( - """ - select lower(data_type) as dtype, - character_maximum_length, - numeric_precision, - numeric_scale - from information_schema.columns - where lower(table_name)=lower(?) - and lower(table_schema)=lower(?) - order by ordinal_position - """, - [table_lower, candidate.lower()], - ).fetchall() - except Exception: - continue - else: - try: - rows = self.con.execute( - """ - select lower(data_type) as dtype, - character_maximum_length, - numeric_precision, - numeric_scale - from information_schema.columns - where lower(table_name)=lower(?) - order by lower(table_schema), ordinal_position - """, - [table_lower], - ).fetchall() - except Exception: - continue - if rows: - return rows - return columns - - def _schema_candidates(self, schema: str | None) -> list[str | None]: - candidates: list[str | None] = [] - - def _add(value: str | None) -> None: - normalized = self._normalize_schema(value) - if normalized not in candidates: - candidates.append(normalized) - - _add(schema) - _add(self.schema) - for alt in ("main", "temp"): - _add(alt) - _add(None) - return candidates - - def _normalize_schema(self, schema: str | None) -> str | None: - if not schema: - return None - stripped = schema.strip() - return stripped or None - - def _estimate_column_width( - self, column_info: tuple[str | None, int | None, int | None, int | None] - ) -> int: - dtype_raw, char_max, numeric_precision, _ = column_info - dtype = self._normalize_data_type(dtype_raw) - if dtype and dtype in self._FIXED_TYPE_SIZES: - return self._FIXED_TYPE_SIZES[dtype] - - if dtype in {"character", "varchar", "char", "text", "string"}: - if char_max and char_max > 0: - return min(char_max, self._VARCHAR_MAX_WIDTH) - return self._VARCHAR_DEFAULT_WIDTH - - if dtype in {"varbinary", "blob", "binary"}: - if char_max and char_max > 0: - return min(char_max, self._VARCHAR_MAX_WIDTH) - return self._VARCHAR_DEFAULT_WIDTH - - if dtype in {"numeric", "decimal"} and numeric_precision and numeric_precision > 0: - return min(max(int(numeric_precision), 16), 128) - - return 16 - - def _normalize_data_type(self, dtype: str | None) -> str | None: - if not dtype: - return None - stripped = dtype.strip().lower() - if "(" in stripped: - stripped = stripped.split("(", 1)[0].strip() - if stripped.endswith("[]"): - stripped = stripped[:-2] - return stripped or None + return self.runtime_budget.estimate_query_bytes(sql) def _detect_catalog(self) -> str | None: - try: - rows = self._execute_sql("PRAGMA database_list").fetchall() - if rows: - return str(rows[0][1]) - except Exception: - return None + rows = self._execute_basic("PRAGMA database_list").fetchall() + if rows: + return str(rows[0][1]) return None def _apply_catalog_override(self, name: str) -> bool: @@ -436,11 +156,11 @@ def _apply_catalog_override(self, name: str) -> bool: if self.db_path != ":memory:": resolved = str(Path(self.db_path).resolve()) with suppress(Exception): - self._execute_sql(f"detach database {_q_ident(alias)}") - self._execute_sql( + self._execute_basic(f"detach database {_q_ident(alias)}") + self._execute_basic( f"attach database '{resolved}' as {_q_ident(alias)} (READ_ONLY FALSE)" ) - self._execute_sql(f"set catalog '{alias}'") + self._execute_basic(f"set catalog '{alias}'") return True except Exception: return False @@ -513,7 +233,7 @@ def _read_relation(self, relation: str, node: Node, deps: Iterable[str]) -> pd.D except CatalogException as e: existing = [ r[0] - for r in self._execute_sql( + for r in self._execute_basic( "select table_name from information_schema.tables " "where table_schema in ('main','temp')" ).fetchall() @@ -535,7 +255,7 @@ def _materialize_relation(self, relation: str, df: pd.DataFrame, node: Node) -> self.con.unregister(tmp) except Exception: # housekeeping only; stats here are not important but harmless if recorded - self._execute_sql(f'drop view if exists "{tmp}"') + self._execute_basic(f'drop view if exists "{tmp}"') def _create_or_replace_view_from_table( self, view_name: str, backing_table: str, node: Node @@ -576,10 +296,10 @@ def exists_relation(self, relation: str) -> bool: where_tables.append("table_schema in ('main','temp')") where = " AND ".join(where_tables) sql_tables = f"select 1 from information_schema.tables where {where} limit 1" - if self._execute_sql(sql_tables, params).fetchone(): + if self._execute_basic(sql_tables, params).fetchone(): return True sql_views = f"select 1 from information_schema.views where {where} limit 1" - return bool(self._execute_sql(sql_views, params).fetchone()) + return bool(self._execute_basic(sql_views, params).fetchone()) def create_table_as(self, relation: str, select_sql: str) -> None: # Use only the SELECT body and strip trailing semicolons for safety. @@ -619,11 +339,11 @@ def alter_table_sync_schema( """ # Probe: empty projection from the SELECT (cleaned to avoid parser issues). body = self._first_select_body(select_sql).strip().rstrip(";\n\t ") - probe = self._execute_sql(f"select * from ({body}) as q limit 0") + probe = self._execute_basic(f"select * from ({body}) as q limit 0") cols = [c[0] for c in probe.description or []] existing = { r[0] - for r in self._execute_sql( + for r in self._execute_basic( "select column_name from information_schema.columns " + "where lower(table_name)=lower(?)" + (" and lower(table_schema)=lower(?)" if self.schema else ""), @@ -635,9 +355,9 @@ def alter_table_sync_schema( col = _q_ident(c) target = self._qualified(relation) try: - self._execute_sql(f"alter table {target} add column {col} varchar") + self._execute_basic(f"alter table {target} add column {col} varchar") except Exception: - self._execute_sql(f"alter table {target} add column {col} varchar") + self._execute_basic(f"alter table {target} add column {col} varchar") def execute_hook_sql(self, sql: str) -> None: """ @@ -681,13 +401,13 @@ def utest_load_relation_from_rows(self, relation: str, rows: list[dict]) -> None self.con.register(tmp, df) try: target = self._qualified(relation) - self._execute_sql(f"create or replace table {target} as select * from {tmp}") + self._execute_basic(f"create or replace table {target} as select * from {tmp}") finally: with suppress(Exception): self.con.unregister(tmp) # Fallback for older DuckDB where unregister might not exist with suppress(Exception): - self._execute_sql(f'drop view if exists "{tmp}"') + self._execute_basic(f'drop view if exists "{tmp}"') def utest_read_relation(self, relation: str) -> pd.DataFrame: """ @@ -704,9 +424,9 @@ def utest_clean_target(self, relation: str) -> None: target = self._qualified(relation) # best-effort; ignore failures with suppress(Exception): - self._execute_sql(f"drop view if exists {target}") + self._execute_basic(f"drop view if exists {target}") with suppress(Exception): - self._execute_sql(f"drop table if exists {target}") + self._execute_basic(f"drop table if exists {target}") def _introspect_columns_metadata( self, @@ -745,7 +465,7 @@ def _introspect_columns_metadata( "order by table_schema, ordinal_position" ) - rows = self._execute_sql(sql, params).fetchall() + rows = self._execute_basic(sql, params).fetchall() # Normalize to plain strings return [(str(name), str(dtype)) for (name, dtype) in rows] @@ -787,6 +507,6 @@ def load_seed( with suppress(Exception): self.con.unregister(tmp) with suppress(Exception): - self._execute_sql(f'drop view if exists "{tmp}"') + self._execute_basic(f'drop view if exists "{tmp}"') return True, qualified, created_schema diff --git a/src/fastflowtransform/executors/postgres.py b/src/fastflowtransform/executors/postgres.py index 6341430..f87f459 100644 --- a/src/fastflowtransform/executors/postgres.py +++ b/src/fastflowtransform/executors/postgres.py @@ -20,9 +20,9 @@ from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable from fastflowtransform.executors.base import BaseExecutor, _scalar -from fastflowtransform.executors.budget import BudgetGuard +from fastflowtransform.executors.budget.core import BudgetGuard from fastflowtransform.executors.common import _q_ident -from fastflowtransform.executors.query_stats import QueryStats +from fastflowtransform.executors.query_stats.core import QueryStats from fastflowtransform.meta import ensure_meta_table, upsert_meta from fastflowtransform.snapshots.runtime.postgres import PostgresSnapshotRuntime diff --git a/src/fastflowtransform/executors/query_stats/__init__.py b/src/fastflowtransform/executors/query_stats/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/fastflowtransform/executors/query_stats.py b/src/fastflowtransform/executors/query_stats/core.py similarity index 100% rename from src/fastflowtransform/executors/query_stats.py rename to src/fastflowtransform/executors/query_stats/core.py diff --git a/src/fastflowtransform/executors/query_stats/runtime/__init__.py b/src/fastflowtransform/executors/query_stats/runtime/__init__.py new file mode 100644 index 0000000..fd26221 --- /dev/null +++ b/src/fastflowtransform/executors/query_stats/runtime/__init__.py @@ -0,0 +1,4 @@ +from fastflowtransform.executors.query_stats.runtime.base import BaseQueryStatsRuntime +from fastflowtransform.executors.query_stats.runtime.duckdb import DuckQueryStatsRuntime + +__all__ = ["BaseQueryStatsRuntime", "DuckQueryStatsRuntime"] diff --git a/src/fastflowtransform/executors/query_stats/runtime/base.py b/src/fastflowtransform/executors/query_stats/runtime/base.py new file mode 100644 index 0000000..e921564 --- /dev/null +++ b/src/fastflowtransform/executors/query_stats/runtime/base.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from time import perf_counter +from typing import Any, Protocol, TypeVar + +from fastflowtransform.executors._query_stats_adapter import ( + JobStatsAdapter, + QueryStatsAdapter, + RowcountStatsAdapter, +) +from fastflowtransform.executors.query_stats.core import QueryStats + + +class QueryStatsExecutor(Protocol): + """Minimal executor surface used by query-stats runtimes.""" + + def _record_query_stats(self, stats: QueryStats) -> None: ... + + +E = TypeVar("E", bound=QueryStatsExecutor) + + +@dataclass +class QueryTimer: + started_at: float + + +class BaseQueryStatsRuntime[E: QueryStatsExecutor]: + """ + Base runtime for collecting per-query stats. + + Executors compose this (like runtime contracts) and delegate stat recording + so the run engine can aggregate per-node metrics. + """ + + executor: E + + def __init__(self, executor: E): + self.executor = executor + + def start_timer(self) -> QueryTimer: + return QueryTimer(started_at=perf_counter()) + + def record_result( + self, + result: Any, + *, + timer: QueryTimer | None = None, + duration_ms: int | None = None, + estimated_bytes: int | None = None, + adapter: QueryStatsAdapter | None = None, + sql: str | None = None, + rowcount_extractor: Callable[[Any], int | None] | None = None, + extra_stats: Callable[[Any], QueryStats | None] | None = None, + post_estimate_fn: Callable[[str, Any], int | None] | None = None, + ) -> QueryStats: + """ + Collect stats from a result object and record them on the executor. + + Either pass a timer (from start_timer) or an explicit duration_ms. + If no adapter is given, a simple QueryStats with bytes/duration is recorded. + """ + if duration_ms is None and timer is not None: + duration_ms = int((perf_counter() - timer.started_at) * 1000) + + stats_adapter = adapter + if stats_adapter is None and (rowcount_extractor or extra_stats or post_estimate_fn): + stats_adapter = RowcountStatsAdapter( + rowcount_extractor=rowcount_extractor, + extra_stats=extra_stats, + post_estimate_fn=post_estimate_fn, + sql=sql, + ) + + if stats_adapter is None: + stats = QueryStats(bytes_processed=estimated_bytes, rows=None, duration_ms=duration_ms) + else: + stats = stats_adapter.collect( + result, duration_ms=duration_ms, estimated_bytes=estimated_bytes + ) + + self.executor._record_query_stats(stats) + return stats + + def record_job(self, job: Any) -> QueryStats: + adapter = JobStatsAdapter() + stats = adapter.collect(job) + self.executor._record_query_stats(stats) + return stats diff --git a/src/fastflowtransform/executors/query_stats/runtime/duckdb.py b/src/fastflowtransform/executors/query_stats/runtime/duckdb.py new file mode 100644 index 0000000..54fd26d --- /dev/null +++ b/src/fastflowtransform/executors/query_stats/runtime/duckdb.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import Any + +import pandas as pd + +from fastflowtransform.executors.query_stats.core import QueryStats +from fastflowtransform.executors.query_stats.runtime.base import ( + BaseQueryStatsRuntime, + QueryStatsExecutor, +) + + +class DuckQueryStatsRuntime(BaseQueryStatsRuntime[QueryStatsExecutor]): + """DuckDB-specific runtime logic for stats extraction.""" + + def rowcount_from_result(self, result: Any) -> int | None: + rc = getattr(result, "rowcount", None) + if isinstance(rc, int) and rc >= 0: + return rc + return None + + def record_dataframe(self, df: pd.DataFrame, duration_ms: int) -> QueryStats: + rows = len(df) + bytes_estimate = int(df.memory_usage(deep=True).sum()) if rows > 0 else 0 + bytes_val = bytes_estimate if bytes_estimate > 0 else None + stats = QueryStats( + bytes_processed=bytes_val, + rows=rows if rows > 0 else None, + duration_ms=duration_ms, + ) + self.executor._record_query_stats(stats) + return stats diff --git a/src/fastflowtransform/executors/snowflake_snowpark.py b/src/fastflowtransform/executors/snowflake_snowpark.py index c3e963e..7cd81f7 100644 --- a/src/fastflowtransform/executors/snowflake_snowpark.py +++ b/src/fastflowtransform/executors/snowflake_snowpark.py @@ -15,9 +15,9 @@ from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable, rows_to_tuples from fastflowtransform.executors.base import BaseExecutor -from fastflowtransform.executors.budget import BudgetGuard +from fastflowtransform.executors.budget.core import BudgetGuard from fastflowtransform.executors.common import _q_ident -from fastflowtransform.executors.query_stats import QueryStats +from fastflowtransform.executors.query_stats.core import QueryStats from fastflowtransform.meta import ensure_meta_table, upsert_meta from fastflowtransform.snapshots.runtime.snowflake_snowpark import ( SnowflakeSnowparkSnapshotRuntime, From be1af66c89298eb217a0b432990a94fd8182ea0f Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Wed, 17 Dec 2025 08:22:18 +0100 Subject: [PATCH 03/10] Refactored budgets and query stats for postgres engine --- .../executors/budget/runtime/__init__.py | 3 +- .../executors/budget/runtime/postgres.py | 102 ++++++++ src/fastflowtransform/executors/postgres.py | 246 +++++------------- .../executors/query_stats/runtime/__init__.py | 3 +- .../executors/query_stats/runtime/postgres.py | 33 +++ .../unit/executors/test_postgres_exec_unit.py | 41 ++- 6 files changed, 243 insertions(+), 185 deletions(-) create mode 100644 src/fastflowtransform/executors/budget/runtime/postgres.py create mode 100644 src/fastflowtransform/executors/query_stats/runtime/postgres.py diff --git a/src/fastflowtransform/executors/budget/runtime/__init__.py b/src/fastflowtransform/executors/budget/runtime/__init__.py index 35b288f..d8509e6 100644 --- a/src/fastflowtransform/executors/budget/runtime/__init__.py +++ b/src/fastflowtransform/executors/budget/runtime/__init__.py @@ -1,4 +1,5 @@ from fastflowtransform.executors.budget.runtime.base import BaseBudgetRuntime from fastflowtransform.executors.budget.runtime.duckdb import DuckBudgetRuntime +from fastflowtransform.executors.budget.runtime.postgres import PostgresBudgetRuntime -__all__ = ["BaseBudgetRuntime", "DuckBudgetRuntime"] +__all__ = ["BaseBudgetRuntime", "DuckBudgetRuntime", "PostgresBudgetRuntime"] diff --git a/src/fastflowtransform/executors/budget/runtime/postgres.py b/src/fastflowtransform/executors/budget/runtime/postgres.py new file mode 100644 index 0000000..3cb3296 --- /dev/null +++ b/src/fastflowtransform/executors/budget/runtime/postgres.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import json +from typing import Any, Protocol + +from fastflowtransform.executors.budget.core import BudgetGuard +from fastflowtransform.executors.budget.runtime.base import BaseBudgetRuntime, BudgetExecutor + + +class PostgresBudgetExecutor(BudgetExecutor, Protocol): + schema: str | None + + def _execute_sql_maintenance( + self, + sql: str, + *args: Any, + conn: Any | None = None, + set_search_path: bool = True, + **kwargs: Any, + ) -> Any: ... + + def _set_search_path(self, conn: Any) -> None: ... + def _extract_select_like(self, sql_or_body: str) -> str: ... + + +class PostgresBudgetRuntime(BaseBudgetRuntime[PostgresBudgetExecutor]): + """Postgres-specific budget runtime with EXPLAIN-based estimation.""" + + _DEFAULT_PG_ROW_WIDTH = 128 + + def __init__(self, executor: PostgresBudgetExecutor, guard: BudgetGuard | None): + super().__init__(executor, guard) + + def estimate_query_bytes(self, sql: str) -> int | None: + body = self.executor._extract_select_like(sql) + lower = body.lstrip().lower() + if not lower.startswith(("select", "with")): + return None + + explain_sql = f"EXPLAIN (FORMAT JSON) {body}" + + try: + raw = self.executor._execute_sql_maintenance(explain_sql, set_search_path=False) + except Exception: + return None + + if raw is None: + return None + + try: + data = json.loads(raw) + except Exception: + data = raw + + # Postgres JSON format: list with a single object + if isinstance(data, list) and data: + root = data[0] + elif isinstance(data, dict): + root = data + else: + return None + + plan = root.get("Plan") + if not isinstance(plan, dict): + if isinstance(root, dict) and "Node Type" in root: + plan = root + else: + return None + + return self._estimate_bytes_from_plan(plan) + + def _estimate_bytes_from_plan(self, plan: dict[str, Any]) -> int | None: + def _to_int(node: dict[str, Any], keys: tuple[str, ...]) -> int | None: + for key in keys: + val = node.get(key) + if val is None: + continue + try: + return int(val) + except (TypeError, ValueError): + continue + return None + + rows = _to_int(plan, ("Plan Rows", "Plan_Rows", "Rows")) + width = _to_int(plan, ("Plan Width", "Plan_Width", "Width")) + + if rows is None and width is None: + return None + + candidate: int | None + + if rows is not None and width is not None: + candidate = rows * width + elif rows is not None: + candidate = rows * self._DEFAULT_PG_ROW_WIDTH + else: + candidate = width + + if candidate is None or candidate <= 0: + return None + + return int(candidate) diff --git a/src/fastflowtransform/executors/postgres.py b/src/fastflowtransform/executors/postgres.py index f87f459..4e4326d 100644 --- a/src/fastflowtransform/executors/postgres.py +++ b/src/fastflowtransform/executors/postgres.py @@ -1,5 +1,4 @@ # fastflowtransform/executors/postgres.py -import json import re from collections.abc import Iterable from time import perf_counter @@ -16,13 +15,13 @@ from fastflowtransform.contracts.runtime.postgres import PostgresRuntimeContracts from fastflowtransform.core import Node from fastflowtransform.errors import ModelExecutionError, ProfileConfigError -from fastflowtransform.executors._budget_runner import run_sql_with_budget from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable from fastflowtransform.executors.base import BaseExecutor, _scalar from fastflowtransform.executors.budget.core import BudgetGuard +from fastflowtransform.executors.budget.runtime.postgres import PostgresBudgetRuntime from fastflowtransform.executors.common import _q_ident -from fastflowtransform.executors.query_stats.core import QueryStats +from fastflowtransform.executors.query_stats.runtime.postgres import PostgresQueryStatsRuntime from fastflowtransform.meta import ensure_meta_table, upsert_meta from fastflowtransform.snapshots.runtime.postgres import PostgresSnapshotRuntime @@ -37,8 +36,9 @@ def _base_type(t: str) -> str: class PostgresExecutor(SqlIdentifierMixin, BaseExecutor[pd.DataFrame]): ENGINE_NAME: str = "postgres" runtime_contracts: PostgresRuntimeContracts + runtime_query_stats: PostgresQueryStatsRuntime + runtime_budget: PostgresBudgetRuntime snapshot_runtime: PostgresSnapshotRuntime - _DEFAULT_PG_ROW_WIDTH = 128 _BUDGET_GUARD = BudgetGuard( env_var="FF_PG_MAX_BYTES", estimator_attr="_estimate_query_bytes", @@ -62,14 +62,17 @@ def __init__(self, dsn: str, schema: str | None = None): if self.schema: try: - with self.engine.begin() as conn: - conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {_q_ident(self.schema)}")) + self._execute_sql_maintenance( + f"CREATE SCHEMA IF NOT EXISTS {_q_ident(self.schema)}", set_search_path=False + ) except SQLAlchemyError as exc: raise ProfileConfigError( f"Failed to ensure schema '{self.schema}' exists: {exc}" ) from exc - # Enable runtime contracts (cast/verify) for SQL and pandas models. + # Enable runtime helpers and contracts. + self.runtime_query_stats = PostgresQueryStatsRuntime(self) + self.runtime_budget = PostgresBudgetRuntime(self, self._BUDGET_GUARD) self.runtime_contracts = PostgresRuntimeContracts(self) self.snapshot_runtime = PostgresSnapshotRuntime(self) @@ -112,6 +115,7 @@ def _execute_sql_core( sql: str, *args: Any, conn: Connection, + set_search_path: bool = True, **kwargs: Any, ) -> Any: """ @@ -124,7 +128,8 @@ def _execute_sql_core( Used by both the high-level _execute_sql and maintenance helpers. """ - self._set_search_path(conn) + if set_search_path: + self._set_search_path(conn) return conn.execute(text(sql), *args, **kwargs) def _execute_sql_maintenance( @@ -132,6 +137,7 @@ def _execute_sql_maintenance( sql: str, *args: Any, conn: Connection | None = None, + set_search_path: bool = True, **kwargs: Any, ) -> Any: """ @@ -148,9 +154,13 @@ def _execute_sql_maintenance( """ if conn is None: with self.engine.begin() as local_conn: - return self._execute_sql_core(sql, *args, conn=local_conn, **kwargs) + return self._execute_sql_core( + sql, *args, conn=local_conn, set_search_path=set_search_path, **kwargs + ) else: - return self._execute_sql_core(sql, *args, conn=conn, **kwargs) + return self._execute_sql_core( + sql, *args, conn=conn, set_search_path=set_search_path, **kwargs + ) def _execute_sql( self, @@ -177,16 +187,12 @@ def _exec() -> Any: return self._execute_sql_core(sql, *args, conn=conn, **kwargs) def _rows(result: Any) -> int | None: - rc = getattr(result, "rowcount", None) - if isinstance(rc, int) and rc >= 0: - return rc - return None + return self.runtime_query_stats.rowcount_from_result(result) - return run_sql_with_budget( - self, + return self.runtime_budget.run_sql( sql, - guard=self._BUDGET_GUARD, exec_fn=_exec, + stats_runtime=self.runtime_query_stats, rowcount_extractor=_rows, estimate_fn=self._estimate_query_bytes, ) @@ -203,115 +209,18 @@ def _analyze_relations( - Uses passed-in conn if given, otherwise opens its own transaction. - Best-effort: logs and continues on failure. """ - owns_conn = False - if conn is None: - conn_ctx = self.engine.begin() - conn = conn_ctx.__enter__() - owns_conn = True - try: - self._set_search_path(conn) - for rel in relations: - try: - # If it already looks qualified, leave it; otherwise qualify. - qrel = self._qualified(rel) if "." not in rel else rel - conn.execute(text(f"ANALYZE {qrel}")) - except Exception: - pass - finally: - if owns_conn: - conn_ctx.__exit__(None, None, None) + for rel in relations: + try: + # If it already looks qualified, leave it; otherwise qualify. + qrel = self._qualified(rel) if "." not in rel else rel + self._execute_sql_maintenance(f"ANALYZE {qrel}", conn=conn) + except Exception: + pass # --- Cost estimation for the shared BudgetGuard ----------------- def _estimate_query_bytes(self, sql: str) -> int | None: - """ - Best-effort bytes estimate for a SELECT-ish query using - EXPLAIN (FORMAT JSON). - - Approximation: estimated_rows * avg_row_width (in bytes). - Returns None if: - - the query is not SELECT/CTE - - EXPLAIN fails - - the JSON structure is not what we expect - """ - body = self._extract_select_like(sql) - lower = body.lstrip().lower() - if not lower.startswith(("select", "with")): - # Only try to estimate for read-like queries - return None - - explain_sql = f"EXPLAIN (FORMAT JSON) {body}" - - try: - with self.engine.begin() as conn: - self._set_search_path(conn) - raw = conn.execute(text(explain_sql)).scalar() - except Exception: - return None - - if raw is None: - return None - - try: - data = json.loads(raw) - except Exception: - data = raw - - # Postgres JSON format: list with a single object - if isinstance(data, list) and data: - root = data[0] - elif isinstance(data, dict): - root = data - else: - return None - - plan = root.get("Plan") - if not isinstance(plan, dict): - if isinstance(root, dict) and "Node Type" in root: - plan = root - else: - return None - - return self._estimate_bytes_from_plan(plan) - - def _estimate_bytes_from_plan(self, plan: dict[str, Any]) -> int | None: - """ - Estimate bytes for the *model output* from the root plan node. - - Approximation: root.Plan Rows * root.Plan Width (or DEFAULT_PG_ROW_WIDTH - if width is missing). - """ - - def _to_int(node: dict[str, Any], keys: tuple[str, ...]) -> int | None: - for key in keys: - val = node.get(key) - if val is None: - continue - try: - return int(val) - except (TypeError, ValueError): - continue - return None - - rows = _to_int(plan, ("Plan Rows", "Plan_Rows", "Rows")) - width = _to_int(plan, ("Plan Width", "Plan_Width", "Width")) - - if rows is None and width is None: - return None - - candidate: int | None - - if rows is not None and width is not None: - candidate = rows * width - elif rows is not None: - candidate = rows * self._DEFAULT_PG_ROW_WIDTH - else: - candidate = width - - if candidate is None or candidate <= 0: - return None - - return int(candidate) + return self.runtime_budget.estimate_query_bytes(sql) # --- Helpers --------------------------------------------------------- def _quote_identifier(self, ident: str) -> str: @@ -370,19 +279,7 @@ def _write_dataframe_with_stats(self, relation: str, df: pd.DataFrame, node: Nod ) from e else: self._analyze_relations([relation]) - self._record_dataframe_stats(df, int((perf_counter() - start) * 1000)) - - def _record_dataframe_stats(self, df: pd.DataFrame, duration_ms: int) -> None: - rows = len(df) - bytes_estimate = int(df.memory_usage(deep=True).sum()) if rows > 0 else 0 - bytes_val = bytes_estimate if bytes_estimate > 0 else None - self._record_query_stats( - QueryStats( - bytes_processed=bytes_val, - rows=rows if rows > 0 else None, - duration_ms=duration_ms, - ) - ) + self.runtime_query_stats.record_dataframe(df, int((perf_counter() - start) * 1000)) def load_seed( self, table: str, df: pd.DataFrame, schema: str | None = None @@ -391,8 +288,7 @@ def load_seed( qualified = self._qualify_identifier(table, schema=target_schema) drop_sql = f"DROP TABLE IF EXISTS {qualified} CASCADE" - with self.engine.begin() as conn: - conn.exec_driver_sql(drop_sql) + self._execute_sql_maintenance(drop_sql) df.to_sql( table, @@ -403,8 +299,7 @@ def load_seed( method="multi", ) - with self.engine.begin() as conn: - conn.exec_driver_sql(f"ANALYZE {qualified}") + self._execute_sql_maintenance(f"ANALYZE {qualified}") return True, qualified, False @@ -416,8 +311,8 @@ def _create_or_replace_view_from_table( q_back = self._qualified(backing_table) try: with self.engine.begin() as conn: - self._execute_sql(f"DROP VIEW IF EXISTS {q_view} CASCADE", conn=conn) - self._execute_sql( + self._execute_sql_maintenance(f"DROP VIEW IF EXISTS {q_view} CASCADE", conn=conn) + self._execute_sql_maintenance( f"CREATE OR REPLACE VIEW {q_view} AS SELECT * FROM {q_back}", conn=conn ) @@ -429,8 +324,8 @@ def _frame_name(self) -> str: def _create_or_replace_view(self, target_sql: str, select_body: str, node: Node) -> None: try: - self._execute_sql(f"DROP VIEW IF EXISTS {target_sql} CASCADE") - self._execute_sql(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}") + self._execute_sql_maintenance(f"DROP VIEW IF EXISTS {target_sql} CASCADE") + self._execute_sql_maintenance(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}") except Exception as e: preview = f"-- target={target_sql}\n{select_body}" raise ModelExecutionError(node.name, target_sql, str(e), sql_snippet=preview) from e @@ -441,7 +336,7 @@ def _create_or_replace_table(self, target_sql: str, select_body: str, node: Node Use DROP TABLE IF EXISTS + CREATE TABLE AS, and accept CTE bodies. """ try: - self._execute_sql(f"DROP TABLE IF EXISTS {target_sql} CASCADE") + self._execute_sql_maintenance(f"DROP TABLE IF EXISTS {target_sql} CASCADE") self._execute_sql(f"CREATE TABLE {target_sql} AS {select_body}") self._analyze_relations([target_sql]) except Exception as e: @@ -474,7 +369,7 @@ def exists_relation(self, relation: str) -> bool: limit 1 """ - return bool(self._execute_sql(sql, {"t": relation}).fetchone()) + return bool(self._execute_sql_maintenance(sql, {"t": relation}).fetchone()) def create_table_as(self, relation: str, select_sql: str) -> None: body = self._extract_select_like(select_sql) @@ -587,15 +482,8 @@ def utest_load_relation_from_rows(self, relation: str, rows: list[dict]) -> None if not rows: # Ensure an empty table exists (corner case). try: - with self.engine.begin() as conn: - self._execute_sql_maintenance( - f"DROP TABLE IF EXISTS {qualified} CASCADE", - conn=conn, - ) - self._execute_sql_maintenance( - f"CREATE TABLE {qualified} ()", - conn=conn, - ) + self._execute_sql_maintenance(f"DROP TABLE IF EXISTS {qualified} CASCADE") + self._execute_sql_maintenance(f"CREATE TABLE {qualified} ()") except SQLAlchemyError as e: raise ModelExecutionError( node_name=f"utest::{relation}", @@ -618,24 +506,20 @@ def utest_load_relation_from_rows(self, relation: str, rows: list[dict]) -> None insert_values_sql = ", ".join(f":{c}" for c in cols) try: - with self.engine.begin() as conn: - # Replace any existing table - self._execute_sql_maintenance( - f"DROP TABLE IF EXISTS {qualified} CASCADE", - conn=conn, - ) + # Replace any existing table + self._execute_sql_maintenance(f"DROP TABLE IF EXISTS {qualified} CASCADE") - # Create table from first row - create_sql = f"CREATE TABLE {qualified} AS SELECT {select_exprs}" - self._execute_sql_maintenance(create_sql, first, conn=conn) + # Create table from first row + create_sql = f"CREATE TABLE {qualified} AS SELECT {select_exprs}" + self._execute_sql_maintenance(create_sql, first) - # Insert remaining rows - if len(rows) > 1: - insert_sql = ( - f"INSERT INTO {qualified} ({col_list_sql}) VALUES ({insert_values_sql})" - ) - for row in rows[1:]: - self._execute_sql_maintenance(insert_sql, row, conn=conn) + # Insert remaining rows + if len(rows) > 1: + insert_sql = ( + f"INSERT INTO {qualified} ({col_list_sql}) VALUES ({insert_values_sql})" + ) + for row in rows[1:]: + self._execute_sql_maintenance(insert_sql, row) except SQLAlchemyError as e: raise ModelExecutionError( @@ -662,11 +546,10 @@ def utest_clean_target(self, relation: str) -> None: - dropping only the matching kinds. """ with self.engine.begin() as conn: - # Use the same search_path logic as the rest of the executor - self._set_search_path(conn) - # Decide which schema to inspect - cur_schema = conn.execute(text("select current_schema()")).scalar() + cur_schema = self._execute_sql_maintenance( + "select current_schema()", conn=conn + ).scalar() schema = self.schema or cur_schema # Find objects named in that schema @@ -684,17 +567,20 @@ def utest_clean_target(self, relation: str) -> None: ) s order by kind; """ - rows = conn.execute( - text(info_sql), - {"schema": schema, "rel": relation}, + rows = self._execute_sql_maintenance( + info_sql, {"schema": schema, "rel": relation}, conn=conn ).fetchall() for kind, table_schema, table_name in rows: qualified = f'"{table_schema}"."{table_name}"' if kind == "view": - conn.execute(text(f"DROP VIEW IF EXISTS {qualified} CASCADE")) + self._execute_sql_maintenance( + f"DROP VIEW IF EXISTS {qualified} CASCADE", conn=conn + ) else: # table - conn.execute(text(f"DROP TABLE IF EXISTS {qualified} CASCADE")) + self._execute_sql_maintenance( + f"DROP TABLE IF EXISTS {qualified} CASCADE", conn=conn + ) def _introspect_columns_metadata( self, @@ -743,7 +629,7 @@ def _introspect_columns_metadata( order by a.attnum """ - rows = self._execute_sql(sql, params).fetchall() + rows = self._execute_sql_maintenance(sql, params).fetchall() # Return canonical type *base* by default return [(str(name), _base_type(str(dtype))) for (name, dtype) in rows] diff --git a/src/fastflowtransform/executors/query_stats/runtime/__init__.py b/src/fastflowtransform/executors/query_stats/runtime/__init__.py index fd26221..55ab15f 100644 --- a/src/fastflowtransform/executors/query_stats/runtime/__init__.py +++ b/src/fastflowtransform/executors/query_stats/runtime/__init__.py @@ -1,4 +1,5 @@ from fastflowtransform.executors.query_stats.runtime.base import BaseQueryStatsRuntime from fastflowtransform.executors.query_stats.runtime.duckdb import DuckQueryStatsRuntime +from fastflowtransform.executors.query_stats.runtime.postgres import PostgresQueryStatsRuntime -__all__ = ["BaseQueryStatsRuntime", "DuckQueryStatsRuntime"] +__all__ = ["BaseQueryStatsRuntime", "DuckQueryStatsRuntime", "PostgresQueryStatsRuntime"] diff --git a/src/fastflowtransform/executors/query_stats/runtime/postgres.py b/src/fastflowtransform/executors/query_stats/runtime/postgres.py new file mode 100644 index 0000000..5abda97 --- /dev/null +++ b/src/fastflowtransform/executors/query_stats/runtime/postgres.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import Any + +import pandas as pd + +from fastflowtransform.executors.query_stats.core import QueryStats +from fastflowtransform.executors.query_stats.runtime.base import ( + BaseQueryStatsRuntime, + QueryStatsExecutor, +) + + +class PostgresQueryStatsRuntime(BaseQueryStatsRuntime[QueryStatsExecutor]): + """Postgres-specific stats helpers.""" + + def rowcount_from_result(self, result: Any) -> int | None: + rc = getattr(result, "rowcount", None) + if isinstance(rc, int) and rc >= 0: + return rc + return None + + def record_dataframe(self, df: pd.DataFrame, duration_ms: int) -> QueryStats: + rows = len(df) + bytes_estimate = int(df.memory_usage(deep=True).sum()) if rows > 0 else 0 + bytes_val = bytes_estimate if bytes_estimate > 0 else None + stats = QueryStats( + bytes_processed=bytes_val, + rows=rows if rows > 0 else None, + duration_ms=duration_ms, + ) + self.executor._record_query_stats(stats) + return stats diff --git a/tests/unit/executors/test_postgres_exec_unit.py b/tests/unit/executors/test_postgres_exec_unit.py index 90d0326..5c75618 100644 --- a/tests/unit/executors/test_postgres_exec_unit.py +++ b/tests/unit/executors/test_postgres_exec_unit.py @@ -2,7 +2,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any +from typing import Any, cast import pandas as pd import pytest @@ -66,6 +66,39 @@ def begin(self): return self._conn +class _FakeBudgetRuntime: + """Minimal budget runtime stub that just executes the provided fn.""" + + def __init__(self, executor: Any): + self.executor = executor + + def run_sql( + self, + sql: str, + *, + exec_fn: Any, + stats_runtime: Any, + rowcount_extractor=None, + estimate_fn=None, + **kwargs: Any, + ): + return exec_fn() + + +class _FakeQueryStatsRuntime: + """Minimal query stats stub.""" + + def __init__(self, executor: Any): + self.executor = executor + + def rowcount_from_result(self, result: Any) -> int | None: + rc = getattr(result, "rowcount", None) + return rc if isinstance(rc, int) and rc >= 0 else None + + def record_dataframe(self, df: Any, duration_ms: int): + return None + + # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @@ -347,6 +380,8 @@ def test_create_or_replace_table_wraps(fake_engine_and_conn, node_tmp): ex = PostgresExecutor.__new__(PostgresExecutor) ex.engine = engine ex.schema = "public" + ex.runtime_budget = cast(Any, _FakeBudgetRuntime(ex)) + ex.runtime_query_stats = cast(Any, _FakeQueryStatsRuntime(ex)) # Force the DB call to fail def bad_execute(stmt, params=None): @@ -519,13 +554,13 @@ def test_create_or_replace_view_from_table_happy(fake_engine_and_conn): or "CREATE OR REPLACE VIEW" in str(stmt) ] - expected_statement_len = 5 + expected_statement_len = 4 assert len(stmts) == expected_statement_len assert 'SET LOCAL search_path = "public"' in stmts[0][0] assert 'DROP VIEW IF EXISTS "public"."v_out" CASCADE' in stmts[1][0] assert ( - 'CREATE OR REPLACE VIEW "public"."v_out" AS SELECT * FROM "public"."src_tbl"' in stmts[4][0] + 'CREATE OR REPLACE VIEW "public"."v_out" AS SELECT * FROM "public"."src_tbl"' in stmts[3][0] ) From b67a38faab0b633a050e2b71b5cc3a017d89662e Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Wed, 17 Dec 2025 10:33:06 +0100 Subject: [PATCH 04/10] Refactored budget and query_stats for databricks_spark engine --- .../executors/_query_stats_adapter.py | 3 +- .../executors/budget/runtime/__init__.py | 10 +- .../budget/runtime/databricks_spark.py | 155 ++++++++++++++ .../executors/databricks_spark.py | 192 ++++-------------- .../executors/query_stats/runtime/__init__.py | 10 +- .../query_stats/runtime/databricks_spark.py | 32 +++ tests/common/fixtures.py | 8 +- 7 files changed, 246 insertions(+), 164 deletions(-) create mode 100644 src/fastflowtransform/executors/budget/runtime/databricks_spark.py create mode 100644 src/fastflowtransform/executors/query_stats/runtime/databricks_spark.py diff --git a/src/fastflowtransform/executors/_query_stats_adapter.py b/src/fastflowtransform/executors/_query_stats_adapter.py index eacc54a..8ceb8e0 100644 --- a/src/fastflowtransform/executors/_query_stats_adapter.py +++ b/src/fastflowtransform/executors/_query_stats_adapter.py @@ -126,8 +126,9 @@ def __init__(self, bytes_fn: Callable[[Any], int | None]) -> None: self.bytes_fn = bytes_fn def collect( - self, df: Any, *, duration_ms: int | None, estimated_bytes: int | None = None + self, result: Any, *, duration_ms: int | None, estimated_bytes: int | None = None ) -> QueryStats: + df = result bytes_val = estimated_bytes if bytes_val is None: try: diff --git a/src/fastflowtransform/executors/budget/runtime/__init__.py b/src/fastflowtransform/executors/budget/runtime/__init__.py index d8509e6..a03c41c 100644 --- a/src/fastflowtransform/executors/budget/runtime/__init__.py +++ b/src/fastflowtransform/executors/budget/runtime/__init__.py @@ -1,5 +1,13 @@ from fastflowtransform.executors.budget.runtime.base import BaseBudgetRuntime +from fastflowtransform.executors.budget.runtime.databricks_spark import ( + DatabricksSparkBudgetRuntime, +) from fastflowtransform.executors.budget.runtime.duckdb import DuckBudgetRuntime from fastflowtransform.executors.budget.runtime.postgres import PostgresBudgetRuntime -__all__ = ["BaseBudgetRuntime", "DuckBudgetRuntime", "PostgresBudgetRuntime"] +__all__ = [ + "BaseBudgetRuntime", + "DatabricksSparkBudgetRuntime", + "DuckBudgetRuntime", + "PostgresBudgetRuntime", +] diff --git a/src/fastflowtransform/executors/budget/runtime/databricks_spark.py b/src/fastflowtransform/executors/budget/runtime/databricks_spark.py new file mode 100644 index 0000000..1c41fb2 --- /dev/null +++ b/src/fastflowtransform/executors/budget/runtime/databricks_spark.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +from typing import Any, Protocol + +from fastflowtransform.executors._query_stats_adapter import SparkDataFrameStatsAdapter +from fastflowtransform.executors.budget.core import BudgetGuard +from fastflowtransform.executors.budget.runtime.base import BaseBudgetRuntime, BudgetExecutor + + +class DatabricksSparkBudgetExecutor(BudgetExecutor, Protocol): + spark: Any + + def _selectable_body(self, sql: str) -> str: ... + + +class DatabricksSparkBudgetRuntime(BaseBudgetRuntime[DatabricksSparkBudgetExecutor]): + """Databricks/Spark budget runtime using logical-plan stats for estimation.""" + + def __init__(self, executor: DatabricksSparkBudgetExecutor, guard: BudgetGuard | None): + super().__init__(executor, guard) + self._default_size: int | None = self.detect_default_size() + + def estimate_query_bytes(self, sql: str) -> int | None: + return self._spark_plan_bytes(sql) + + def detect_default_size(self) -> int: + """ + Detect Spark's defaultSizeInBytes sentinel. + + - Prefer spark.sql.defaultSizeInBytes if available. + - Fall back to Long.MaxValue (2^63 - 1) otherwise. + """ + try: + conf_val = self.executor.spark.conf.get("spark.sql.defaultSizeInBytes") + if conf_val is not None: + return int(conf_val) + except Exception: + # config not set / older Spark / weird environment + pass + + # Fallback: Spark uses Long.MaxValue by default + return 2**63 - 1 # 9223372036854775807 + + def spark_stats_adapter(self, sql: str) -> SparkDataFrameStatsAdapter: + """ + Build a SparkDataFrameStatsAdapter tied to this runtime's estimation logic. + """ + + def _bytes(df: Any) -> int | None: + estimate = self.dataframe_bytes(df) + if estimate is not None: + return estimate + return self.estimate_query_bytes(sql) + + return SparkDataFrameStatsAdapter(_bytes) + + # ---- Shared helpers for Spark stats ---- + def dataframe_bytes(self, df: Any) -> int | None: + try: + jdf = getattr(df, "_jdf", None) + if jdf is None: + return None + + qe = jdf.queryExecution() + jplan = qe.optimizedPlan() + + if self._jplan_uses_default_size(jplan): + return None + + stats = jplan.stats() + size_attr = getattr(stats, "sizeInBytes", None) + size_val = size_attr() if callable(size_attr) else size_attr + return self._parse_spark_stats_size(size_val) + except Exception: + return None + + def _spark_plan_bytes(self, sql: str) -> int | None: + """ + Inspect the optimized logical plan via the JVM and return sizeInBytes + as an integer, or None if not available. This does not execute the query. + """ + try: + normalized = self.executor._selectable_body(sql).rstrip(";\n\t ") + if not normalized: + normalized = sql + except Exception: + normalized = sql + + stmt = normalized.lstrip().lower() + if not stmt.startswith(("select", "with")): + # DDL/DML statements should not be executed twice. + return None + + try: + df = self.executor.spark.sql(normalized) + + jdf = getattr(df, "_jdf", None) + if jdf is None: + return None + + qe = jdf.queryExecution() + jplan = qe.optimizedPlan() + + if self._jplan_uses_default_size(jplan): + return None + + stats = jplan.stats() + size_attr = getattr(stats, "sizeInBytes", None) + size_val = size_attr() if callable(size_attr) else size_attr + + return self._parse_spark_stats_size(size_val) + except Exception: + return None + + def _jplan_uses_default_size(self, jplan: Any) -> bool: + """ + Recursively walk a JVM LogicalPlan and return True if any node's + stats.sizeInBytes equals spark.sql.defaultSizeInBytes. + """ + spark_default_size = self._default_size + if spark_default_size is None: + return False + + try: + stats = jplan.stats() + size_val = stats.sizeInBytes() + size_int = int(str(size_val)) + if size_int == spark_default_size: + return True + except Exception: + # ignore stats errors and keep walking + pass + + # children() is a Scala Seq[LogicalPlan]; iterate via .size() / .apply(i) + try: + children = jplan.children() + n = children.size() + for idx in range(n): + child = children.apply(idx) + if self._jplan_uses_default_size(child): + return True + except Exception: + # if we can't inspect children, stop here + pass + + return False + + def _parse_spark_stats_size(self, size_val: Any) -> int | None: + if size_val is None: + return None + try: + size_int = int(str(size_val)) + except Exception: + return None + return size_int if size_int > 0 else None diff --git a/src/fastflowtransform/executors/databricks_spark.py b/src/fastflowtransform/executors/databricks_spark.py index e2ba7e2..5eb800f 100644 --- a/src/fastflowtransform/executors/databricks_spark.py +++ b/src/fastflowtransform/executors/databricks_spark.py @@ -5,7 +5,7 @@ from contextlib import suppress from pathlib import Path from time import perf_counter -from typing import Any, cast +from typing import Any from urllib.parse import unquote, urlparse import pandas as pd @@ -15,11 +15,15 @@ from fastflowtransform.contracts.runtime.databricks_spark import DatabricksSparkRuntimeContracts from fastflowtransform.core import REGISTRY, Node, relation_for from fastflowtransform.errors import ModelExecutionError -from fastflowtransform.executors._budget_runner import run_sql_with_budget -from fastflowtransform.executors._query_stats_adapter import SparkDataFrameStatsAdapter from fastflowtransform.executors._test_utils import make_fetchable, rows_to_tuples from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.executors.budget.core import BudgetGuard +from fastflowtransform.executors.budget.runtime.databricks_spark import ( + DatabricksSparkBudgetRuntime, +) +from fastflowtransform.executors.query_stats.runtime.databricks_spark import ( + DatabricksSparkQueryStatsRuntime, +) from fastflowtransform.logging import echo_debug from fastflowtransform.meta import ensure_meta_table, upsert_meta from fastflowtransform.snapshots.runtime.databricks_spark import DatabricksSparkSnapshotRuntime @@ -183,10 +187,12 @@ class DatabricksSparkExecutor(BaseExecutor[SDF]): ENGINE_NAME: str = "databricks_spark" runtime_contracts: DatabricksSparkRuntimeContracts + runtime_query_stats: DatabricksSparkQueryStatsRuntime + runtime_budget: DatabricksSparkBudgetRuntime snapshot_runtime: DatabricksSparkSnapshotRuntime _BUDGET_GUARD = BudgetGuard( env_var="FF_SPK_MAX_BYTES", - estimator_attr="_estimate_query_bytes", + estimator_attr="runtime_budget_estimate_query_bytes", engine_label="Databricks/Spark", what="query", ) @@ -241,11 +247,6 @@ def __init__( # Apply Delta configuration last, after all Spark configs are set. if not wants_delta and self._user_spark is None: catalog_overridden = bool(catalog_value) - if not catalog_overridden: - # Leave Spark catalog untouched; downstream environments may supply - # their own defaults (e.g., Unity, Glue). We only force a catalog - # when the user explicitly opts into Delta. - pass # Apply Delta configuration last, after all Spark configs are set. if wants_delta and self._user_spark is None: @@ -273,9 +274,11 @@ def __init__( self.catalog = catalog self.database = database self.schema = database + self.runtime_query_stats = DatabricksSparkQueryStatsRuntime(self) + self.runtime_budget = DatabricksSparkBudgetRuntime(self, self._BUDGET_GUARD) if database: - self._execute_sql(f"CREATE DATABASE IF NOT EXISTS `{database}`") + self._execute_sql_basic(f"CREATE DATABASE IF NOT EXISTS `{database}`") with suppress(Exception): self.spark.catalog.setCurrentDatabase(database) @@ -309,122 +312,14 @@ def __init__( sql_runner=self._execute_sql, ) - self._spark_default_size = self._detect_default_size() self.runtime_contracts = DatabricksSparkRuntimeContracts(self) self.snapshot_runtime = DatabricksSparkSnapshotRuntime(self) # ---------- Cost estimation & central execution ---------- - def _detect_default_size(self) -> int: - """ - Detect Spark's defaultSizeInBytes sentinel. - - - Prefer spark.sql.defaultSizeInBytes if available. - - Fall back to Long.MaxValue (2^63 - 1) otherwise. - """ - try: - conf_val = self.spark.conf.get("spark.sql.defaultSizeInBytes") - if conf_val is not None: - return int(conf_val) - except Exception: - # config not set / older Spark / weird environment - pass - - # Fallback: Spark uses Long.MaxValue by default - return 2**63 - 1 # 9223372036854775807 - - def _parse_spark_stats_size(self, size_val: Any) -> int | None: - if size_val is None: - return None - try: - size_int = int(str(size_val)) - except Exception: - return None - return size_int if size_int > 0 else None - - def _jplan_uses_default_size(self, jplan: Any) -> bool: - """ - Recursively walk a JVM LogicalPlan and return True if any node's - stats.sizeInBytes equals spark.sql.defaultSizeInBytes. - """ - if self._spark_default_size is None: - return False - - try: - stats = jplan.stats() - size_val = stats.sizeInBytes() - size_int = int(str(size_val)) - if size_int == self._spark_default_size: - return True - except Exception: - # ignore stats errors and keep walking - pass - - # children() is a Scala Seq[LogicalPlan]; iterate via .size() / .apply(i) - try: - children = jplan.children() - n = children.size() - for idx in range(n): - child = children.apply(idx) - if self._jplan_uses_default_size(child): - return True - except Exception: - # if we can't inspect children, stop here - pass - - return False - - def _spark_plan_bytes(self, sql: str) -> int | None: - """ - Inspect the optimized logical plan via the JVM and return sizeInBytes - as an integer, or None if not available. - - This does *not* execute the query; it only goes through analysis/planning. - """ - try: - normalized = self._selectable_body(sql).rstrip(";\n\t ") - if not normalized: - normalized = sql - except Exception: - normalized = sql - - stmt = normalized.lstrip().lower() - if not stmt.startswith(("select", "with")): - # DDL/DML statements (ALTER/INSERT/etc.) should not be executed twice. - return None - - try: - df = self.spark.sql(normalized) - - jdf = cast(Any, getattr(df, "_jdf", None)) - if jdf is None: - return None - - qe = jdf.queryExecution() - jplan = qe.optimizedPlan() - - # If any node relies on defaultSizeInBytes, we don't trust the stats - if self._jplan_uses_default_size(jplan): - return None - - stats = jplan.stats() - - size_attr = getattr(stats, "sizeInBytes", None) - size_val = size_attr() if callable(size_attr) else size_attr - - return self._parse_spark_stats_size(size_val) - except Exception: - return None - - def _estimate_query_bytes(self, sql: str) -> int | None: - """ - Best-effort logical-plan size estimate using Spark's stats. - - It inspects the optimized plan's sizeInBytes via the JVM API without - executing the query. If unavailable or unsupported, returns None and - the guard is effectively disabled. - """ - return self._spark_plan_bytes(sql) + def runtime_budget_estimate_query_bytes(self, sql: str) -> int | None: + """Expose runtime_budget estimator for BudgetGuard.""" + return self.runtime_budget.estimate_query_bytes(sql) def execute_test_sql(self, stmt: Any) -> Any: """ @@ -453,6 +348,9 @@ def compute_freshness_delay_minutes(self, table: str, ts_col: str) -> tuple[floa val = row[0] if row else None return (float(val) if val is not None else None, sql) + def _execute_sql_basic(self, sql: str) -> SDF: + return self.spark.sql(sql) + def _execute_sql(self, sql: str) -> SDF: """ Central Spark SQL runner. @@ -465,13 +363,12 @@ def _execute_sql(self, sql: str) -> SDF: def _exec() -> SDF: return self.spark.sql(sql) - return run_sql_with_budget( - self, + return self.runtime_budget.run_sql( sql, - guard=self._BUDGET_GUARD, exec_fn=_exec, - estimate_fn=self._spark_plan_bytes, - post_estimate_fn=lambda _, __: self._spark_plan_bytes(sql), + stats_runtime=self.runtime_query_stats, + estimate_fn=self.runtime_budget_estimate_query_bytes, + stats_adapter=self.runtime_budget.spark_stats_adapter(sql), ) # ---------- Frame hooks (required) ---------- @@ -498,7 +395,7 @@ def _create_view_over_table(self, view_name: str, backing_table: str, node: Node """Compatibility hook: create a simple SELECT * view over an existing table.""" view_sql = self._sql_identifier(view_name) backing_sql = self._sql_identifier(backing_table) - self._execute_sql(f"CREATE OR REPLACE VIEW {view_sql} AS SELECT * FROM {backing_sql}") + self._execute_sql_basic(f"CREATE OR REPLACE VIEW {view_sql} AS SELECT * FROM {backing_sql}") def _validate_required( self, node_name: str, inputs: Any, requires: dict[str, set[str]] @@ -600,28 +497,7 @@ def _write_to_storage_path(self, relation: str, df: SDF, storage_meta: dict[str, self.spark.catalog.refreshByPath(path) def _record_spark_dataframe_stats(self, df: SDF, duration_ms: int) -> None: - adapter = SparkDataFrameStatsAdapter(self._spark_dataframe_bytes) - stats = adapter.collect(df, duration_ms=duration_ms) - self._record_query_stats(stats) - - def _spark_dataframe_bytes(self, df: SDF) -> int | None: - try: - jdf = cast(Any, getattr(df, "_jdf", None)) - if jdf is None: - return None - - qe = jdf.queryExecution() - jplan = qe.optimizedPlan() - - if self._jplan_uses_default_size(jplan): - return None - - stats = jplan.stats() - size_attr = getattr(stats, "sizeInBytes", None) - size_val = size_attr() if callable(size_attr) else size_attr - return self._parse_spark_stats_size(size_val) - except Exception: - return None + self.runtime_query_stats.record_dataframe(df, duration_ms) # ---- SQL hooks ---- def _format_relation_for_ref(self, name: str) -> str: @@ -788,12 +664,12 @@ def _save_df_as_table( self._format_handler.save_df_as_table(table_name, df) with suppress(Exception): - self._execute_sql( + self._execute_sql_basic( f"ANALYZE TABLE {self._sql_identifier(table_name)} COMPUTE STATISTICS" ) def _create_or_replace_view(self, target_sql: str, select_body: str, node: Node) -> None: - self._execute_sql(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}") + self._execute_sql_basic(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}") def _create_or_replace_table(self, target_sql: str, select_body: str, node: Node) -> None: preview = f"-- target={target_sql}\n{select_body}" @@ -809,7 +685,7 @@ def _create_or_replace_view_from_table( ) -> None: view_sql = self._sql_identifier(view_name) backing_sql = self._sql_identifier(backing_table) - self._execute_sql(f"CREATE OR REPLACE VIEW {view_sql} AS SELECT * FROM {backing_sql}") + self._execute_sql_basic(f"CREATE OR REPLACE VIEW {view_sql} AS SELECT * FROM {backing_sql}") # ---- Meta hook ---- def on_node_built(self, node: Node, relation: str, fingerprint: str) -> None: @@ -927,7 +803,7 @@ def alter_table_sync_schema( existing = {f.name for f in target_df.schema.fields} # Output schema from the SELECT body = self._first_select_body(select_sql).strip().rstrip(";\n\t ") - probe = self._execute_sql(f"SELECT * FROM ({body}) q LIMIT 0") + probe = self._execute_sql_basic(f"SELECT * FROM ({body}) q LIMIT 0") to_add = [f for f in probe.schema.fields if f.name not in existing] if not to_add: return @@ -942,7 +818,7 @@ def _spark_sql_type(dt: DataType) -> str: cols_sql = ", ".join([f"`{f.name}` {_spark_sql_type(f.dataType)}" for f in to_add]) table_sql = self._sql_identifier(relation) - self._execute_sql(f"ALTER TABLE {table_sql} ADD COLUMNS ({cols_sql})") + self._execute_sql_basic(f"ALTER TABLE {table_sql} ADD COLUMNS ({cols_sql})") # ── Snapshot runtime delegation ────────────────────────────────────── def run_snapshot_sql(self, node: Node, env: Environment) -> None: @@ -988,7 +864,9 @@ def load_seed( schema_part = self._strip_quotes(schema) if schema_part: # Ensure database exists when a separate schema is provided. - self._execute_sql(f"CREATE DATABASE IF NOT EXISTS {self._q_ident(schema_part)}") + self._execute_sql_basic( + f"CREATE DATABASE IF NOT EXISTS {self._q_ident(schema_part)}" + ) created_schema = True parts = [schema_part, parts[0]] @@ -1062,11 +940,11 @@ def utest_clean_target(self, relation: str) -> None: # Drop view first; ignore errors if it's actually a table or missing. with suppress(Exception): - self._execute_sql(f"DROP VIEW IF EXISTS {ident}") + self._execute_sql_basic(f"DROP VIEW IF EXISTS {ident}") # Then drop table; ignore errors if it's actually a view or missing. with suppress(Exception): - self._execute_sql(f"DROP TABLE IF EXISTS {ident}") + self._execute_sql_basic(f"DROP TABLE IF EXISTS {ident}") def _introspect_columns_metadata( self, diff --git a/src/fastflowtransform/executors/query_stats/runtime/__init__.py b/src/fastflowtransform/executors/query_stats/runtime/__init__.py index 55ab15f..efe9ab4 100644 --- a/src/fastflowtransform/executors/query_stats/runtime/__init__.py +++ b/src/fastflowtransform/executors/query_stats/runtime/__init__.py @@ -1,5 +1,13 @@ from fastflowtransform.executors.query_stats.runtime.base import BaseQueryStatsRuntime +from fastflowtransform.executors.query_stats.runtime.databricks_spark import ( + DatabricksSparkQueryStatsRuntime, +) from fastflowtransform.executors.query_stats.runtime.duckdb import DuckQueryStatsRuntime from fastflowtransform.executors.query_stats.runtime.postgres import PostgresQueryStatsRuntime -__all__ = ["BaseQueryStatsRuntime", "DuckQueryStatsRuntime", "PostgresQueryStatsRuntime"] +__all__ = [ + "BaseQueryStatsRuntime", + "DatabricksSparkQueryStatsRuntime", + "DuckQueryStatsRuntime", + "PostgresQueryStatsRuntime", +] diff --git a/src/fastflowtransform/executors/query_stats/runtime/databricks_spark.py b/src/fastflowtransform/executors/query_stats/runtime/databricks_spark.py new file mode 100644 index 0000000..138f12d --- /dev/null +++ b/src/fastflowtransform/executors/query_stats/runtime/databricks_spark.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import Any + +from fastflowtransform.executors._query_stats_adapter import SparkDataFrameStatsAdapter +from fastflowtransform.executors.query_stats.core import QueryStats +from fastflowtransform.executors.query_stats.runtime.base import ( + BaseQueryStatsRuntime, + QueryStatsExecutor, +) + + +class DatabricksSparkQueryStatsRuntime(BaseQueryStatsRuntime[QueryStatsExecutor]): + """Spark-specific stats helpers.""" + + def rowcount_from_result(self, result: Any) -> int | None: + # Avoid triggering extra Spark actions; rely on estimates instead. + rc = getattr(result, "count", None) + if isinstance(rc, int) and rc >= 0: + return rc + return None + + def record_dataframe(self, df: Any, duration_ms: int) -> QueryStats: + budget_runtime = getattr(self.executor, "runtime_budget", None) + adapter = ( + budget_runtime.spark_stats_adapter("") + if budget_runtime is not None + else SparkDataFrameStatsAdapter(lambda _: None) + ) + stats = adapter.collect(df, duration_ms=duration_ms, estimated_bytes=None) + self.executor._record_query_stats(stats) + return stats diff --git a/tests/common/fixtures.py b/tests/common/fixtures.py index cbb8042..bd321cd 100644 --- a/tests/common/fixtures.py +++ b/tests/common/fixtures.py @@ -138,8 +138,8 @@ def exec_minimal(monkeypatch): SP.builder.master.return_value.appName.return_value.getOrCreate.return_value = fake_spark ex = DatabricksSparkExecutor() # JVM plan inspection loops forever on MagicMocks; skip in unit tests. - monkeypatch.setattr(ex, "_spark_plan_bytes", lambda *_, **__: None) - monkeypatch.setattr(ex, "_spark_dataframe_bytes", lambda *_, **__: None) + monkeypatch.setattr(ex.runtime_budget, "_spark_plan_bytes", lambda *_, **__: None) + monkeypatch.setattr(ex.runtime_budget, "dataframe_bytes", lambda *_, **__: None) # accept mocks as frames in unit tests monkeypatch.setattr(ex, "_is_frame", lambda obj: True) return ex @@ -171,8 +171,8 @@ def _make(**kwargs) -> tuple[DatabricksSparkExecutorType, Any, MagicMock]: fake_builder.getOrCreate.return_value = fake_spark ex = DatabricksSparkExecutor(**kwargs) - ex._spark_plan_bytes = lambda *_, **__: None - ex._spark_dataframe_bytes = lambda *_, **__: None + ex.runtime_budget._spark_plan_bytes = lambda *_, **__: None + ex.runtime_budget.dataframe_bytes = lambda *_, **__: None return ex, fake_builder, fake_spark return _make From 16575a8229990f729ff0bc27292c1f424095bbdd Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Wed, 17 Dec 2025 11:32:56 +0100 Subject: [PATCH 05/10] Some design fixes for new budget and query_stats runtimes --- .../executors/budget/runtime/base.py | 9 +++++---- .../executors/budget/runtime/databricks_spark.py | 9 ++++++++- .../executors/budget/runtime/duckdb.py | 9 ++++++++- .../executors/budget/runtime/postgres.py | 9 ++++++++- .../executors/databricks_spark.py | 16 +--------------- src/fastflowtransform/executors/duckdb.py | 16 +--------------- src/fastflowtransform/executors/postgres.py | 15 +-------------- 7 files changed, 32 insertions(+), 51 deletions(-) diff --git a/src/fastflowtransform/executors/budget/runtime/base.py b/src/fastflowtransform/executors/budget/runtime/base.py index d924390..457658b 100644 --- a/src/fastflowtransform/executors/budget/runtime/base.py +++ b/src/fastflowtransform/executors/budget/runtime/base.py @@ -31,9 +31,9 @@ class BaseBudgetRuntime[E: BudgetExecutor]: executor: E guard: BudgetGuard | None - def __init__(self, executor: E, guard: BudgetGuard | None): + def __init__(self, executor: E, guard: BudgetGuard | None = None): self.executor = executor - self.guard = guard + self.guard = guard or getattr(type(self), "DEFAULT_GUARD", None) def apply_guard(self, sql: str) -> int | None: return self.executor._apply_budget_guard(self.guard, sql) @@ -52,13 +52,14 @@ def run_sql( stats_adapter: QueryStatsAdapter | None = None, ) -> Any: estimated_bytes = self.apply_guard(sql) + estimator = estimate_fn or getattr(self, "estimate_query_bytes", None) if ( estimated_bytes is None and not self.executor._is_budget_guard_active() - and estimate_fn is not None + and callable(estimator) ): with suppress(Exception): - estimated_bytes = estimate_fn(sql) + estimated_bytes = estimator(sql) if not record_stats: return exec_fn() diff --git a/src/fastflowtransform/executors/budget/runtime/databricks_spark.py b/src/fastflowtransform/executors/budget/runtime/databricks_spark.py index 1c41fb2..21bf4fa 100644 --- a/src/fastflowtransform/executors/budget/runtime/databricks_spark.py +++ b/src/fastflowtransform/executors/budget/runtime/databricks_spark.py @@ -16,7 +16,14 @@ def _selectable_body(self, sql: str) -> str: ... class DatabricksSparkBudgetRuntime(BaseBudgetRuntime[DatabricksSparkBudgetExecutor]): """Databricks/Spark budget runtime using logical-plan stats for estimation.""" - def __init__(self, executor: DatabricksSparkBudgetExecutor, guard: BudgetGuard | None): + DEFAULT_GUARD = BudgetGuard( + env_var="FF_SPK_MAX_BYTES", + estimator_attr="runtime_budget_estimate_query_bytes", + engine_label="Databricks/Spark", + what="query", + ) + + def __init__(self, executor: DatabricksSparkBudgetExecutor, guard: BudgetGuard | None = None): super().__init__(executor, guard) self._default_size: int | None = self.detect_default_size() diff --git a/src/fastflowtransform/executors/budget/runtime/duckdb.py b/src/fastflowtransform/executors/budget/runtime/duckdb.py index 933b327..ad94497 100644 --- a/src/fastflowtransform/executors/budget/runtime/duckdb.py +++ b/src/fastflowtransform/executors/budget/runtime/duckdb.py @@ -19,6 +19,13 @@ def _selectable_body(self, sql: str) -> str: ... class DuckBudgetRuntime(BaseBudgetRuntime[DuckBudgetExecutor]): """DuckDB-specific budget runtime with plan-based estimation.""" + DEFAULT_GUARD = BudgetGuard( + env_var="FF_DUCKDB_MAX_BYTES", + estimator_attr="_estimate_query_bytes", + engine_label="DuckDB", + what="query", + ) + _FIXED_TYPE_SIZES: ClassVar[dict[str, int]] = { "boolean": 1, "bool": 1, @@ -48,7 +55,7 @@ class DuckBudgetRuntime(BaseBudgetRuntime[DuckBudgetExecutor]): _VARCHAR_MAX_WIDTH = 1024 _DEFAULT_ROW_WIDTH = 128 - def __init__(self, executor: DuckBudgetExecutor, guard: BudgetGuard | None): + def __init__(self, executor: DuckBudgetExecutor, guard: BudgetGuard | None = None): super().__init__(executor, guard) self._table_row_width_cache: dict[tuple[str | None, str], int] = {} diff --git a/src/fastflowtransform/executors/budget/runtime/postgres.py b/src/fastflowtransform/executors/budget/runtime/postgres.py index 3cb3296..9e2d052 100644 --- a/src/fastflowtransform/executors/budget/runtime/postgres.py +++ b/src/fastflowtransform/executors/budget/runtime/postgres.py @@ -26,9 +26,16 @@ def _extract_select_like(self, sql_or_body: str) -> str: ... class PostgresBudgetRuntime(BaseBudgetRuntime[PostgresBudgetExecutor]): """Postgres-specific budget runtime with EXPLAIN-based estimation.""" + DEFAULT_GUARD = BudgetGuard( + env_var="FF_PG_MAX_BYTES", + estimator_attr="_estimate_query_bytes", + engine_label="Postgres", + what="query", + ) + _DEFAULT_PG_ROW_WIDTH = 128 - def __init__(self, executor: PostgresBudgetExecutor, guard: BudgetGuard | None): + def __init__(self, executor: PostgresBudgetExecutor, guard: BudgetGuard | None = None): super().__init__(executor, guard) def estimate_query_bytes(self, sql: str) -> int | None: diff --git a/src/fastflowtransform/executors/databricks_spark.py b/src/fastflowtransform/executors/databricks_spark.py index 5eb800f..e0355b4 100644 --- a/src/fastflowtransform/executors/databricks_spark.py +++ b/src/fastflowtransform/executors/databricks_spark.py @@ -17,7 +17,6 @@ from fastflowtransform.errors import ModelExecutionError from fastflowtransform.executors._test_utils import make_fetchable, rows_to_tuples from fastflowtransform.executors.base import BaseExecutor -from fastflowtransform.executors.budget.core import BudgetGuard from fastflowtransform.executors.budget.runtime.databricks_spark import ( DatabricksSparkBudgetRuntime, ) @@ -190,12 +189,6 @@ class DatabricksSparkExecutor(BaseExecutor[SDF]): runtime_query_stats: DatabricksSparkQueryStatsRuntime runtime_budget: DatabricksSparkBudgetRuntime snapshot_runtime: DatabricksSparkSnapshotRuntime - _BUDGET_GUARD = BudgetGuard( - env_var="FF_SPK_MAX_BYTES", - estimator_attr="runtime_budget_estimate_query_bytes", - engine_label="Databricks/Spark", - what="query", - ) def __init__( self, @@ -275,7 +268,7 @@ def __init__( self.database = database self.schema = database self.runtime_query_stats = DatabricksSparkQueryStatsRuntime(self) - self.runtime_budget = DatabricksSparkBudgetRuntime(self, self._BUDGET_GUARD) + self.runtime_budget = DatabricksSparkBudgetRuntime(self) if database: self._execute_sql_basic(f"CREATE DATABASE IF NOT EXISTS `{database}`") @@ -315,12 +308,6 @@ def __init__( self.runtime_contracts = DatabricksSparkRuntimeContracts(self) self.snapshot_runtime = DatabricksSparkSnapshotRuntime(self) - # ---------- Cost estimation & central execution ---------- - - def runtime_budget_estimate_query_bytes(self, sql: str) -> int | None: - """Expose runtime_budget estimator for BudgetGuard.""" - return self.runtime_budget.estimate_query_bytes(sql) - def execute_test_sql(self, stmt: Any) -> Any: """ Execute lightweight SQL for DQ tests via Spark and return fetchable rows. @@ -367,7 +354,6 @@ def _exec() -> SDF: sql, exec_fn=_exec, stats_runtime=self.runtime_query_stats, - estimate_fn=self.runtime_budget_estimate_query_bytes, stats_adapter=self.runtime_budget.spark_stats_adapter(sql), ) diff --git a/src/fastflowtransform/executors/duckdb.py b/src/fastflowtransform/executors/duckdb.py index c416a1d..308dec9 100644 --- a/src/fastflowtransform/executors/duckdb.py +++ b/src/fastflowtransform/executors/duckdb.py @@ -17,7 +17,6 @@ from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable from fastflowtransform.executors.base import BaseExecutor, _scalar -from fastflowtransform.executors.budget.core import BudgetGuard from fastflowtransform.executors.budget.runtime.duckdb import DuckBudgetRuntime from fastflowtransform.executors.common import _q_ident from fastflowtransform.executors.query_stats.runtime.duckdb import DuckQueryStatsRuntime @@ -32,13 +31,6 @@ class DuckExecutor(SqlIdentifierMixin, BaseExecutor[pd.DataFrame]): runtime_budget: DuckBudgetRuntime snapshot_runtime: DuckSnapshotRuntime - _BUDGET_GUARD = BudgetGuard( - env_var="FF_DUCKDB_MAX_BYTES", - estimator_attr="_estimate_query_bytes", - engine_label="DuckDB", - what="query", - ) - def __init__( self, db_path: str = ":memory:", schema: str | None = None, catalog: str | None = None ): @@ -56,7 +48,7 @@ def __init__( else: self.catalog = self._detect_catalog() self.runtime_query_stats = DuckQueryStatsRuntime(self) - self.runtime_budget = DuckBudgetRuntime(self, self._BUDGET_GUARD) + self.runtime_budget = DuckBudgetRuntime(self) self.runtime_contracts = DuckRuntimeContracts(self) self.snapshot_runtime = DuckSnapshotRuntime(self) @@ -134,14 +126,8 @@ def _rows(result: Any) -> int | None: exec_fn=_exec, stats_runtime=self.runtime_query_stats, rowcount_extractor=_rows, - estimate_fn=self._estimate_query_bytes, ) - # --- Cost estimation for the shared BudgetGuard ----------------- - - def _estimate_query_bytes(self, sql: str) -> int | None: - return self.runtime_budget.estimate_query_bytes(sql) - def _detect_catalog(self) -> str | None: rows = self._execute_basic("PRAGMA database_list").fetchall() if rows: diff --git a/src/fastflowtransform/executors/postgres.py b/src/fastflowtransform/executors/postgres.py index 4e4326d..4710a1b 100644 --- a/src/fastflowtransform/executors/postgres.py +++ b/src/fastflowtransform/executors/postgres.py @@ -18,7 +18,6 @@ from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable from fastflowtransform.executors.base import BaseExecutor, _scalar -from fastflowtransform.executors.budget.core import BudgetGuard from fastflowtransform.executors.budget.runtime.postgres import PostgresBudgetRuntime from fastflowtransform.executors.common import _q_ident from fastflowtransform.executors.query_stats.runtime.postgres import PostgresQueryStatsRuntime @@ -39,12 +38,6 @@ class PostgresExecutor(SqlIdentifierMixin, BaseExecutor[pd.DataFrame]): runtime_query_stats: PostgresQueryStatsRuntime runtime_budget: PostgresBudgetRuntime snapshot_runtime: PostgresSnapshotRuntime - _BUDGET_GUARD = BudgetGuard( - env_var="FF_PG_MAX_BYTES", - estimator_attr="_estimate_query_bytes", - engine_label="Postgres", - what="query", - ) def __init__(self, dsn: str, schema: str | None = None): """ @@ -72,7 +65,7 @@ def __init__(self, dsn: str, schema: str | None = None): # Enable runtime helpers and contracts. self.runtime_query_stats = PostgresQueryStatsRuntime(self) - self.runtime_budget = PostgresBudgetRuntime(self, self._BUDGET_GUARD) + self.runtime_budget = PostgresBudgetRuntime(self) self.runtime_contracts = PostgresRuntimeContracts(self) self.snapshot_runtime = PostgresSnapshotRuntime(self) @@ -194,7 +187,6 @@ def _rows(result: Any) -> int | None: exec_fn=_exec, stats_runtime=self.runtime_query_stats, rowcount_extractor=_rows, - estimate_fn=self._estimate_query_bytes, ) def _analyze_relations( @@ -217,11 +209,6 @@ def _analyze_relations( except Exception: pass - # --- Cost estimation for the shared BudgetGuard ----------------- - - def _estimate_query_bytes(self, sql: str) -> int | None: - return self.runtime_budget.estimate_query_bytes(sql) - # --- Helpers --------------------------------------------------------- def _quote_identifier(self, ident: str) -> str: return _q_ident(ident) From ed72b469e6da0bdc549b2b74bec365ec5f510812 Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Thu, 18 Dec 2025 08:48:17 +0100 Subject: [PATCH 06/10] Refactored stats and budget for snowflake_snowpark engine --- .../executors/budget/runtime/__init__.py | 4 + .../budget/runtime/snowflake_snowpark.py | 144 +++++++++++++ .../executors/query_stats/runtime/__init__.py | 4 + .../query_stats/runtime/snowflake_snowpark.py | 40 ++++ .../executors/snowflake_snowpark.py | 202 ++++-------------- tests/common/fixtures.py | 11 + 6 files changed, 245 insertions(+), 160 deletions(-) create mode 100644 src/fastflowtransform/executors/budget/runtime/snowflake_snowpark.py create mode 100644 src/fastflowtransform/executors/query_stats/runtime/snowflake_snowpark.py diff --git a/src/fastflowtransform/executors/budget/runtime/__init__.py b/src/fastflowtransform/executors/budget/runtime/__init__.py index a03c41c..558dbe4 100644 --- a/src/fastflowtransform/executors/budget/runtime/__init__.py +++ b/src/fastflowtransform/executors/budget/runtime/__init__.py @@ -4,10 +4,14 @@ ) from fastflowtransform.executors.budget.runtime.duckdb import DuckBudgetRuntime from fastflowtransform.executors.budget.runtime.postgres import PostgresBudgetRuntime +from fastflowtransform.executors.budget.runtime.snowflake_snowpark import ( + SnowflakeSnowparkBudgetRuntime, +) __all__ = [ "BaseBudgetRuntime", "DatabricksSparkBudgetRuntime", "DuckBudgetRuntime", "PostgresBudgetRuntime", + "SnowflakeSnowparkBudgetRuntime", ] diff --git a/src/fastflowtransform/executors/budget/runtime/snowflake_snowpark.py b/src/fastflowtransform/executors/budget/runtime/snowflake_snowpark.py new file mode 100644 index 0000000..42b9f5d --- /dev/null +++ b/src/fastflowtransform/executors/budget/runtime/snowflake_snowpark.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import json +from contextlib import suppress +from typing import Any, Protocol + +from fastflowtransform.executors.budget.core import BudgetGuard +from fastflowtransform.executors.budget.runtime.base import BaseBudgetRuntime, BudgetExecutor + + +class SnowflakeSnowparkBudgetExecutor(BudgetExecutor, Protocol): + session: Any + + def _selectable_body(self, sql: str) -> str: ... + + +class SnowflakeSnowparkBudgetRuntime(BaseBudgetRuntime[SnowflakeSnowparkBudgetExecutor]): + """Snowflake Snowpark budget runtime using EXPLAIN for estimation.""" + + DEFAULT_GUARD = BudgetGuard( + env_var="FF_SF_MAX_BYTES", + estimator_attr="runtime_budget_estimate_query_bytes", + engine_label="Snowflake", + what="query", + ) + + def estimate_query_bytes(self, sql: str) -> int | None: + """ + Best-effort Snowflake bytes estimation using EXPLAIN USING JSON. + Mirrors the previous executor-side logic. + """ + try: + body = self.executor._selectable_body(sql) + except Exception: + body = sql + + try: + rows = self.executor.session.sql(f"EXPLAIN USING JSON {body}").collect() + if not rows: + return None + + parts: list[str] = [] + for r in rows: + try: + parts.append(str(r[0])) + except Exception: + as_dict: dict[str, Any] = getattr(r, "asDict", lambda: {})() + if as_dict: + parts.extend(str(v) for v in as_dict.values()) + + plan_text = "\n".join(parts).strip() + if not plan_text: + return None + + try: + plan_data = json.loads(plan_text) + except Exception: + return None + + bytes_val = self._extract_bytes_from_plan(plan_data) + if bytes_val is None or bytes_val <= 0: + return None + return bytes_val + except Exception: + # Any parsing / EXPLAIN issues → no estimate, guard skipped + return None + + def dataframe_bytes(self, df: Any) -> int | None: + """ + Best-effort bytes estimate for a Snowpark DataFrame. + """ + try: + sql_text = self._snowpark_df_sql(df) + if not isinstance(sql_text, str) or not sql_text.strip(): + return None + return self.estimate_query_bytes(sql_text) + except Exception: + return None + + def _extract_bytes_from_plan(self, plan_data: Any) -> int | None: + def _to_int(value: Any) -> int | None: + if value is None: + return None + try: + return int(value) + except Exception: + return None + + if isinstance(plan_data, dict): + global_stats = plan_data.get("GlobalStats") or plan_data.get("globalStats") + if isinstance(global_stats, dict): + candidate = _to_int( + global_stats.get("bytesAssigned") or global_stats.get("bytes_assigned") + ) + if candidate: + return candidate + for val in plan_data.values(): + bytes_val = self._extract_bytes_from_plan(val) + if bytes_val: + return bytes_val + elif isinstance(plan_data, list): + for item in plan_data: + bytes_val = self._extract_bytes_from_plan(item) + if bytes_val: + return bytes_val + return None + + def _snowpark_df_sql(self, df: Any) -> str | None: + """ + Extract the main SQL statement for a Snowpark DataFrame. + + Uses the documented public APIs: + - DataFrame.queries -> {"queries": [sql1, sql2, ...], "post_actions": [...]} + - Optionally falls back to df._plan.sql() if needed. + """ + queries_dict = getattr(df, "queries", None) + + if isinstance(queries_dict, dict): + queries = queries_dict.get("queries") + if isinstance(queries, list) and queries: + candidates = [q.strip() for q in queries if isinstance(q, str) and q.strip()] + if candidates: + return max(candidates, key=len) + + plan = getattr(df, "_plan", None) + if plan is not None: + with suppress(Exception): + simplify = getattr(plan, "simplify", None) + if callable(simplify): + simplified = simplify() + to_sql = getattr(simplified, "sql", None) + if callable(to_sql): + sql = to_sql() + if isinstance(sql, str) and sql.strip(): + return sql.strip() + + with suppress(Exception): + to_sql = getattr(plan, "sql", None) + if callable(to_sql): + sql = to_sql() + if isinstance(sql, str) and sql.strip(): + return sql.strip() + + return None diff --git a/src/fastflowtransform/executors/query_stats/runtime/__init__.py b/src/fastflowtransform/executors/query_stats/runtime/__init__.py index efe9ab4..f62e12e 100644 --- a/src/fastflowtransform/executors/query_stats/runtime/__init__.py +++ b/src/fastflowtransform/executors/query_stats/runtime/__init__.py @@ -4,10 +4,14 @@ ) from fastflowtransform.executors.query_stats.runtime.duckdb import DuckQueryStatsRuntime from fastflowtransform.executors.query_stats.runtime.postgres import PostgresQueryStatsRuntime +from fastflowtransform.executors.query_stats.runtime.snowflake_snowpark import ( + SnowflakeSnowparkQueryStatsRuntime, +) __all__ = [ "BaseQueryStatsRuntime", "DatabricksSparkQueryStatsRuntime", "DuckQueryStatsRuntime", "PostgresQueryStatsRuntime", + "SnowflakeSnowparkQueryStatsRuntime", ] diff --git a/src/fastflowtransform/executors/query_stats/runtime/snowflake_snowpark.py b/src/fastflowtransform/executors/query_stats/runtime/snowflake_snowpark.py new file mode 100644 index 0000000..c117a4a --- /dev/null +++ b/src/fastflowtransform/executors/query_stats/runtime/snowflake_snowpark.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import Any + +from fastflowtransform.executors.query_stats.core import QueryStats +from fastflowtransform.executors.query_stats.runtime.base import ( + BaseQueryStatsRuntime, + QueryStatsExecutor, +) + + +class SnowflakeSnowparkQueryStatsRuntime(BaseQueryStatsRuntime[QueryStatsExecutor]): + """Snowflake Snowpark stats helpers.""" + + def rowcount_from_result(self, result: Any) -> int | None: + rc = getattr(result, "rowcount", None) + if isinstance(rc, int) and rc >= 0: + return rc + return None + + def record_dataframe(self, df: Any, duration_ms: int) -> QueryStats: + budget_runtime = getattr(self.executor, "runtime_budget", None) + + bytes_estimate: int | None = None + if budget_runtime is not None: + try: + bytes_estimate = budget_runtime.dataframe_bytes(df) + except Exception: + bytes_estimate = None + + if bytes_estimate is not None and bytes_estimate <= 0: + bytes_estimate = None + + stats = QueryStats( + bytes_processed=bytes_estimate, + rows=None, + duration_ms=duration_ms, + ) + self.executor._record_query_stats(stats) + return stats diff --git a/src/fastflowtransform/executors/snowflake_snowpark.py b/src/fastflowtransform/executors/snowflake_snowpark.py index 7cd81f7..298fb47 100644 --- a/src/fastflowtransform/executors/snowflake_snowpark.py +++ b/src/fastflowtransform/executors/snowflake_snowpark.py @@ -1,7 +1,6 @@ # src/fastflowtransform/executors/snowflake_snowpark.py from __future__ import annotations -import json from collections.abc import Iterable from contextlib import suppress from time import perf_counter @@ -11,13 +10,16 @@ from fastflowtransform.contracts.runtime.snowflake_snowpark import SnowflakeSnowparkRuntimeContracts from fastflowtransform.core import Node, relation_for -from fastflowtransform.executors._budget_runner import run_sql_with_budget from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable, rows_to_tuples from fastflowtransform.executors.base import BaseExecutor -from fastflowtransform.executors.budget.core import BudgetGuard +from fastflowtransform.executors.budget.runtime.snowflake_snowpark import ( + SnowflakeSnowparkBudgetRuntime, +) from fastflowtransform.executors.common import _q_ident -from fastflowtransform.executors.query_stats.core import QueryStats +from fastflowtransform.executors.query_stats.runtime.snowflake_snowpark import ( + SnowflakeSnowparkQueryStatsRuntime, +) from fastflowtransform.meta import ensure_meta_table, upsert_meta from fastflowtransform.snapshots.runtime.snowflake_snowpark import ( SnowflakeSnowparkSnapshotRuntime, @@ -28,14 +30,10 @@ class SnowflakeSnowparkExecutor(SqlIdentifierMixin, BaseExecutor[SNDF]): ENGINE_NAME: str = "snowflake_snowpark" runtime_contracts: SnowflakeSnowparkRuntimeContracts + runtime_query_stats: SnowflakeSnowparkQueryStatsRuntime + runtime_budget: SnowflakeSnowparkBudgetRuntime snapshot_runtime: SnowflakeSnowparkSnapshotRuntime """Snowflake executor operating on Snowpark DataFrames (no pandas).""" - _BUDGET_GUARD = BudgetGuard( - env_var="FF_SF_MAX_BYTES", - estimator_attr="_estimate_query_bytes", - engine_label="Snowflake", - what="query", - ) def __init__(self, cfg: dict): # cfg: {account, user, password, warehouse, database, schema, role?} @@ -45,6 +43,8 @@ def __init__(self, cfg: dict): self.allow_create_schema: bool = bool(cfg["allow_create_schema"]) self._ensure_schema() + self.runtime_query_stats = SnowflakeSnowparkQueryStatsRuntime(self) + self.runtime_budget = SnowflakeSnowparkBudgetRuntime(self) self.runtime_contracts = SnowflakeSnowparkRuntimeContracts(self) self.snapshot_runtime = SnowflakeSnowparkSnapshotRuntime(self) @@ -77,77 +77,18 @@ def compute_freshness_delay_minutes(self, table: str, ts_col: str) -> tuple[floa # ---------- Cost estimation & central execution ---------- - def _estimate_query_bytes(self, sql: str) -> int | None: - """ - Best-effort Snowflake bytes estimation. - - Uses `EXPLAIN USING TEXT` and tries to extract a "bytes="-style - metric from the textual plan. If parsing fails or Snowflake doesn't - expose such info, returns None and the guard is effectively disabled. - """ - try: - body = self._selectable_body(sql) - except Exception: - body = sql - - try: - rows = self.session.sql(f"EXPLAIN USING JSON {body}").collect() - if not rows: - return None - - parts: list[str] = [] - for r in rows: - try: - parts.append(str(r[0])) - except Exception: - as_dict: dict[str, Any] = getattr(r, "asDict", lambda: {})() - if as_dict: - parts.extend(str(v) for v in as_dict.values()) - - plan_text = "\n".join(parts).strip() - if not plan_text: - return None - - try: - plan_data = json.loads(plan_text) - except Exception: - return None - - bytes_val = self._extract_bytes_from_plan(plan_data) - if bytes_val is None or bytes_val <= 0: - return None - return bytes_val - except Exception: - # Any parsing / EXPLAIN issues → no estimate, guard skipped - return None + # def _estimate_query_bytes(self, sql: str) -> int | None: + # """Compatibility shim that delegates to the budget runtime estimator.""" + # return self.runtime_budget.estimate_query_bytes(sql) - def _extract_bytes_from_plan(self, plan_data: Any) -> int | None: - def _to_int(value: Any) -> int | None: - if value is None: - return None - try: - return int(value) - except Exception: - return None + # def runtime_budget_estimate_query_bytes(self, sql: str) -> int | None: + # """ + # Entry point for BudgetGuard to call into the runtime estimator. + # """ + # return self.runtime_budget.estimate_query_bytes(sql) - if isinstance(plan_data, dict): - global_stats = plan_data.get("GlobalStats") or plan_data.get("globalStats") - if isinstance(global_stats, dict): - candidate = _to_int( - global_stats.get("bytesAssigned") or global_stats.get("bytes_assigned") - ) - if candidate: - return candidate - for val in plan_data.values(): - bytes_val = self._extract_bytes_from_plan(val) - if bytes_val: - return bytes_val - elif isinstance(plan_data, list): - for item in plan_data: - bytes_val = self._extract_bytes_from_plan(item) - if bytes_val: - return bytes_val - return None + def _execute_sql_basic(self, sql: str) -> SNDF: + return self.session.sql(sql) def _execute_sql(self, sql: str) -> SNDF: """ @@ -160,12 +101,10 @@ def _execute_sql(self, sql: str) -> SNDF: def _exec() -> SNDF: return self.session.sql(sql) - return run_sql_with_budget( - self, + return self.runtime_budget.run_sql( sql, - guard=self._BUDGET_GUARD, exec_fn=_exec, - estimate_fn=self._estimate_query_bytes, + stats_runtime=self.runtime_query_stats, ) def _exec_many(self, sql: str) -> None: @@ -242,82 +181,23 @@ def _materialize_relation(self, relation: str, df: SNDF, node: Node) -> None: start = perf_counter() df.write.save_as_table(self._qualified(relation), mode="overwrite") duration_ms = int((perf_counter() - start) * 1000) - bytes_est = self._estimate_frame_bytes(df) - self._record_query_stats( - QueryStats( - bytes_processed=bytes_est, - rows=None, - duration_ms=duration_ms, - ) - ) + self.runtime_query_stats.record_dataframe(df, duration_ms) - def _estimate_frame_bytes(self, df: SNDF) -> int | None: - """ - Best-effort bytes estimate for a Snowpark DataFrame. - - Strategy: - 1) Use DataFrame.queries["queries"] (public Snowpark API) to get SQL. - 2) Optionally fall back to df._plan.sql() if queries is missing/empty. - 3) Run our existing _estimate_query_bytes(sql_text). - """ - try: - sql_text = self._snowpark_df_sql(df) - if not isinstance(sql_text, str) or not sql_text.strip(): - return None - return self._estimate_query_bytes(sql_text) - except Exception: - return None + # def _estimate_frame_bytes(self, df: SNDF) -> int | None: + # """ + # Best-effort bytes estimate for a Snowpark DataFrame. - def _snowpark_df_sql(self, df: Any) -> str | None: - """ - Extract the main SQL statement for a Snowpark DataFrame. - - Uses the documented public APIs: - - DataFrame.queries -> {"queries": [sql1, sql2, ...], "post_actions": [...]} - - Optionally falls back to df._plan.sql() if needed. - """ - # 1) Primary source: DataFrame.queries - queries_dict = getattr(df, "queries", None) - - if isinstance(queries_dict, dict): - queries = queries_dict.get("queries") - if isinstance(queries, list) and queries: - # Pick the most likely "main" query. - # Snowflake examples use queries['queries'][0], - # but we can be a bit safer and pick the longest non-empty SQL. - candidates = [q.strip() for q in queries if isinstance(q, str) and q.strip()] - if candidates: - # Heuristic: longest SQL string is usually the main SELECT/CTE. - return max(candidates, key=len) - - # 2) Fallback: internal plan (undocumented but widely used) - plan = getattr(df, "_plan", None) - if plan is not None: - # Prefer simplified plan if available - with suppress(Exception): - simplify = getattr(plan, "simplify", None) - if callable(simplify): - simplified = simplify() - to_sql = getattr(simplified, "sql", None) - if callable(to_sql): - sql = to_sql() - if isinstance(sql, str) and sql.strip(): - return sql.strip() - - # Raw plan.sql() - with suppress(Exception): - to_sql = getattr(plan, "sql", None) - if callable(to_sql): - sql = to_sql() - if isinstance(sql, str) and sql.strip(): - return sql.strip() - - return None + # Strategy: + # 1) Use DataFrame.queries["queries"] (public Snowpark API) to get SQL. + # 2) Optionally fall back to df._plan.sql() if queries is missing/empty. + # 3) Run the budget runtime estimator on the derived SQL. + # """ + # return self.runtime_budget.dataframe_bytes(df) def _create_view_over_table(self, view_name: str, backing_table: str, node: Node) -> None: qv = self._qualified(view_name) qb = self._qualified(backing_table) - self._execute_sql(f"CREATE OR REPLACE VIEW {qv} AS SELECT * FROM {qb}").collect() + self._execute_sql_basic(f"CREATE OR REPLACE VIEW {qv} AS SELECT * FROM {qb}").collect() def _validate_required( self, node_name: str, inputs: Any, requires: dict[str, set[str]] @@ -448,7 +328,7 @@ def _format_source_reference( return formatted def _create_or_replace_view(self, target_sql: str, select_body: str, node: Node) -> None: - self._execute_sql(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}").collect() + self._execute_sql_basic(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}").collect() def _create_or_replace_table(self, target_sql: str, select_body: str, node: Node) -> None: self._execute_sql(f"CREATE OR REPLACE TABLE {target_sql} AS {select_body}").collect() @@ -458,7 +338,9 @@ def _create_or_replace_view_from_table( ) -> None: view_id = self._qualified(view_name) back_id = self._qualified(backing_table) - self._execute_sql(f"CREATE OR REPLACE VIEW {view_id} AS SELECT * FROM {back_id}").collect() + self._execute_sql_basic( + f"CREATE OR REPLACE VIEW {view_id} AS SELECT * FROM {back_id}" + ).collect() def _format_test_table(self, table: str | None) -> str | None: # Bypass mixin qualification to avoid double-qualifying already dotted names. @@ -494,7 +376,7 @@ def exists_relation(self, relation: str) -> bool: limit 1 """ try: - return bool(self._execute_sql(q).collect()) + return bool(self._execute_sql_basic(q).collect()) except Exception: return False @@ -555,7 +437,7 @@ def alter_table_sync_schema( try: existing = { r[0] - for r in self._execute_sql( + for r in self._execute_sql_basic( f""" select column_name from {db_ident}.information_schema.columns @@ -578,7 +460,7 @@ def alter_table_sync_schema( # Column names are identifiers → _q is correct here cols_sql = ", ".join(f"{_q_ident(c)} STRING" for c in to_add) - self._execute_sql(f"ALTER TABLE {qrel} ADD COLUMN {cols_sql}").collect() + self._execute_sql_basic(f"ALTER TABLE {qrel} ADD COLUMN {cols_sql}").collect() # ---- Snapshot runtime delegation -------------------------------------- def run_snapshot_sql(self, node: Node, env: Any) -> None: @@ -779,7 +661,7 @@ def introspect_table_physical_schema(self, table: str) -> dict[str, str]: order by ordinal_position """ - rows = self._execute_sql(sql).collect() + rows = self._execute_sql_basic(sql).collect() out: dict[str, str] = {} for r in rows or []: @@ -816,7 +698,7 @@ def introspect_column_physical_type(self, table: str, column: str) -> str | None limit 1 """ - rows = self._execute_sql(sql).collect() + rows = self._execute_sql_basic(sql).collect() if not rows: return None r = rows[0] diff --git a/tests/common/fixtures.py b/tests/common/fixtures.py index bd321cd..b7cb8ae 100644 --- a/tests/common/fixtures.py +++ b/tests/common/fixtures.py @@ -45,12 +45,20 @@ # Snowflake try: + from fastflowtransform.executors.budget.runtime.snowflake_snowpark import ( + SnowflakeSnowparkBudgetRuntime, + ) + from fastflowtransform.executors.query_stats.runtime.snowflake_snowpark import ( + SnowflakeSnowparkQueryStatsRuntime, + ) from fastflowtransform.executors.snowflake_snowpark import SnowflakeSnowparkExecutor from fastflowtransform.snapshots.runtime.snowflake_snowpark import ( SnowflakeSnowparkSnapshotRuntime, ) except ModuleNotFoundError: # pragma: no cover SnowflakeSnowparkExecutor = None # type: ignore[assignment] + SnowflakeSnowparkBudgetRuntime = None # type: ignore[assignment] + SnowflakeSnowparkQueryStatsRuntime = None # type: ignore[assignment] SnowflakeSnowparkSnapshotRuntime = None # type: ignore[assignment] @@ -451,6 +459,9 @@ def snowflake_executor_fake() -> Any: session = FakeSnowflakeSession() ex.session = session + # Wire runtimes to mirror real executor setup. + ex.runtime_query_stats = SnowflakeSnowparkQueryStatsRuntime(ex) + ex.runtime_budget = SnowflakeSnowparkBudgetRuntime(ex) # Wire snapshot runtime to mirror real executor setup. ex.snapshot_runtime = SnowflakeSnowparkSnapshotRuntime(ex) From 4fae94155974bf2eb69a69dc223fd7571b20e85d Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Thu, 18 Dec 2025 13:10:45 +0100 Subject: [PATCH 07/10] Refactored budget and query stats for bigquery engine to use new runtime modules --- .../executors/_budget_runner.py | 70 ----------------- .../executors/bigquery/base.py | 78 +++++++------------ .../executors/bigquery/pandas.py | 11 +-- .../executors/budget/runtime/__init__.py | 2 + .../executors/budget/runtime/bigquery.py | 55 +++++++++++++ .../executors/query_stats/core.py | 2 +- .../executors/query_stats/runtime/__init__.py | 2 + .../executors/query_stats/runtime/bigquery.py | 30 +++++++ .../executors/snowflake_snowpark.py | 23 ------ tests/common/fixtures.py | 3 +- .../executors/test_bigquery_bf_exec_unit.py | 7 +- .../unit/executors/test_bigquery_exec_unit.py | 7 +- 12 files changed, 128 insertions(+), 162 deletions(-) delete mode 100644 src/fastflowtransform/executors/_budget_runner.py create mode 100644 src/fastflowtransform/executors/budget/runtime/bigquery.py create mode 100644 src/fastflowtransform/executors/query_stats/runtime/bigquery.py diff --git a/src/fastflowtransform/executors/_budget_runner.py b/src/fastflowtransform/executors/_budget_runner.py deleted file mode 100644 index 9568f8f..0000000 --- a/src/fastflowtransform/executors/_budget_runner.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable -from contextlib import suppress -from time import perf_counter -from typing import Any - -from fastflowtransform.executors._query_stats_adapter import QueryStatsAdapter, RowcountStatsAdapter -from fastflowtransform.executors.budget.core import BudgetGuard -from fastflowtransform.executors.query_stats.core import QueryStats - - -def run_sql_with_budget( - executor: Any, - sql: str, - *, - guard: BudgetGuard, - exec_fn: Callable[[], Any], - rowcount_extractor: Callable[[Any], int | None] | None = None, - extra_stats: Callable[[Any], QueryStats | None] | None = None, - estimate_fn: Callable[[str], int | None] | None = None, - post_estimate_fn: Callable[[str, Any], int | None] | None = None, - record_stats: bool = True, - stats_adapter: QueryStatsAdapter | None = None, -) -> Any: - """ - Shared helper for guarded SQL execution with timing + stats recording. - - executor object exposing _apply_budget_guard, _is_budget_guard_active, _record_query_stats - sql statement (used for guard + optional estimator) - exec_fn callable that executes the statement and returns a result/job handle - rowcount_extractor(result) -> int|None best-effort row count (non-negative only) - extra_stats(result) -> QueryStats|None allows engines to override/extend stats post-exec - estimate_fn(sql) -> int|None optional best-effort bytes estimate when guard - inactive - post_estimate_fn(sql, result) -> int|None optional post-exec fallback when bytes are still None - record_stats set False to skip immediate stats (e.g., when a job handle records on .result()) - """ - estimated_bytes = executor._apply_budget_guard(guard, sql) - if ( - estimated_bytes is None - and not executor._is_budget_guard_active() - and estimate_fn is not None - ): - with suppress(Exception): - estimated_bytes = estimate_fn(sql) - - # If stats should be deferred (BigQuery job handles), just run and return. - if not record_stats: - return exec_fn() - - started = perf_counter() - result = exec_fn() - duration_ms = int((perf_counter() - started) * 1000) - - adapter = stats_adapter - if adapter is None and (rowcount_extractor or post_estimate_fn or extra_stats): - adapter = RowcountStatsAdapter( - rowcount_extractor=rowcount_extractor, - post_estimate_fn=post_estimate_fn, - extra_stats=extra_stats, - sql=sql, - ) - if adapter is None: - stats = QueryStats(bytes_processed=estimated_bytes, rows=None, duration_ms=duration_ms) - else: - stats = adapter.collect(result, duration_ms=duration_ms, estimated_bytes=estimated_bytes) - - executor._record_query_stats(stats) - return result diff --git a/src/fastflowtransform/executors/bigquery/base.py b/src/fastflowtransform/executors/bigquery/base.py index 5eeade5..0fec466 100644 --- a/src/fastflowtransform/executors/bigquery/base.py +++ b/src/fastflowtransform/executors/bigquery/base.py @@ -5,12 +5,12 @@ from typing import Any, TypeVar from fastflowtransform.core import Node, relation_for -from fastflowtransform.executors._budget_runner import run_sql_with_budget from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable from fastflowtransform.executors.base import BaseExecutor -from fastflowtransform.executors.budget.core import BudgetGuard +from fastflowtransform.executors.budget.runtime.bigquery import BigQueryBudgetRuntime from fastflowtransform.executors.query_stats.core import _TrackedQueryJob +from fastflowtransform.executors.query_stats.runtime.bigquery import BigQueryQueryStatsRuntime from fastflowtransform.meta import ensure_meta_table, upsert_meta from fastflowtransform.snapshots.runtime.bigquery import BigQuerySnapshotRuntime from fastflowtransform.typing import BadRequest, Client, NotFound, bigquery @@ -32,12 +32,8 @@ class BigQueryBaseExecutor(SqlIdentifierMixin, BaseExecutor[TFrame]): # Subclasses override ENGINE_NAME ("bigquery", "bigquery_batch", ...) ENGINE_NAME = "bigquery_base" - _BUDGET_GUARD = BudgetGuard( - env_var="FF_BQ_MAX_BYTES", - estimator_attr="_estimate_query_bytes", - engine_label="BigQuery", - what="query", - ) + runtime_query_stats: BigQueryQueryStatsRuntime + runtime_budget: BigQueryBudgetRuntime def __init__( self, @@ -55,6 +51,8 @@ def __init__( project=self.project, location=self.location, ) + self.runtime_query_stats = BigQueryQueryStatsRuntime(self) + self.runtime_budget = BigQueryBudgetRuntime(self) self.snapshot_runtime = BigQuerySnapshotRuntime(self) # ---- Identifier helpers ---- @@ -180,6 +178,19 @@ def compute_freshness_delay_minutes(self, table: str, ts_col: str) -> tuple[floa val = delay[0] if delay else None return (float(val) if val is not None else None, sql) + def _execute_sql_basic(self, sql: str) -> _TrackedQueryJob: + job_config = bigquery.QueryJobConfig() + if self.dataset: + # Let unqualified tables resolve to project.dataset.table + job_config.default_dataset = bigquery.DatasetReference(self.project, self.dataset) + + job = self.client.query( + sql, + job_config=job_config, + location=self.location, + ) + return self.runtime_query_stats.wrap_job(job) + def _execute_sql(self, sql: str) -> _TrackedQueryJob: """ Central BigQuery query runner. @@ -189,52 +200,15 @@ def _execute_sql(self, sql: str) -> _TrackedQueryJob: """ def _exec() -> _TrackedQueryJob: - job_config = bigquery.QueryJobConfig() - if self.dataset: - # Let unqualified tables resolve to project.dataset.table - job_config.default_dataset = bigquery.DatasetReference(self.project, self.dataset) - - job = self.client.query( - sql, - job_config=job_config, - location=self.location, - ) - return _TrackedQueryJob(job, on_complete=self._record_query_job_stats) + return self._execute_sql_basic(sql) - return run_sql_with_budget( - self, + return self.runtime_budget.run_sql( sql, - guard=self._BUDGET_GUARD, exec_fn=_exec, - estimate_fn=self._estimate_query_bytes, + stats_runtime=self.runtime_query_stats, record_stats=False, ) - # --- Cost estimation for the shared BudgetGuard ----------------- - - def _estimate_query_bytes(self, sql: str) -> int | None: - """ - Estimate bytes for a BigQuery SQL statement using a dry-run. - - Returns the estimated bytes, or None if estimation is not possible. - """ - cfg = bigquery.QueryJobConfig( - dry_run=True, - use_query_cache=False, - ) - if self.dataset: - # Let unqualified tables resolve to project.dataset.table - cfg.default_dataset = bigquery.DatasetReference(self.project, self.dataset) - - job = self.client.query( - sql, - job_config=cfg, - location=self.location, - ) - # Dry-run is free; we just need the job metadata - job.result() - return int(getattr(job, "total_bytes_processed", 0) or 0) - # ---- DQ test table formatting (fft test) ---- def _format_test_table(self, table: str | None) -> str | None: """ @@ -275,7 +249,7 @@ def _apply_sql_materialization( ) from e def _create_or_replace_view(self, target_sql: str, select_body: str, node: Node) -> None: - self._execute_sql(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}").result() + self._execute_sql_basic(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}").result() def _create_or_replace_table(self, target_sql: str, select_body: str, node: Node) -> None: self._execute_sql(f"CREATE OR REPLACE TABLE {target_sql} AS {select_body}").result() @@ -289,7 +263,9 @@ def _create_or_replace_view_from_table( view_id = self._qualified_identifier(view_name) back_id = self._qualified_identifier(backing_table) self._ensure_dataset() - self._execute_sql(f"CREATE OR REPLACE VIEW {view_id} AS SELECT * FROM {back_id}").result() + self._execute_sql_basic( + f"CREATE OR REPLACE VIEW {view_id} AS SELECT * FROM {back_id}" + ).result() # ---- Meta hook ---- def on_node_built(self, node: Node, relation: str, fingerprint: str) -> None: @@ -424,7 +400,7 @@ def alter_table_sync_schema( for col in to_add: f = out_fields[col] typ = str(f.field_type) if hasattr(f, "field_type") else "STRING" - self._execute_sql(f"ALTER TABLE {target} ADD COLUMN {col} {typ}").result() + self._execute_sql_basic(f"ALTER TABLE {target} ADD COLUMN {col} {typ}").result() # ── Snapshots API (shared for pandas + BigFrames) ───────────────────── diff --git a/src/fastflowtransform/executors/bigquery/pandas.py b/src/fastflowtransform/executors/bigquery/pandas.py index c672395..45e9a35 100644 --- a/src/fastflowtransform/executors/bigquery/pandas.py +++ b/src/fastflowtransform/executors/bigquery/pandas.py @@ -10,7 +10,6 @@ from fastflowtransform.contracts.runtime.bigquery import BigQueryRuntimeContracts from fastflowtransform.core import Node from fastflowtransform.executors.bigquery.base import BigQueryBaseExecutor -from fastflowtransform.executors.query_stats.core import QueryStats from fastflowtransform.typing import BadRequest, Client, LoadJobConfig, NotFound, bigquery @@ -89,15 +88,7 @@ def _frame_name(self) -> str: return "pandas" def _record_dataframe_stats(self, df: pd.DataFrame, duration_ms: int) -> None: - rows = len(df) - bytes_val = int(df.memory_usage(deep=True).sum()) if rows > 0 else 0 - self._record_query_stats( - QueryStats( - bytes_processed=bytes_val if bytes_val > 0 else None, - rows=rows if rows > 0 else None, - duration_ms=duration_ms, - ) - ) + self.runtime_query_stats.record_dataframe(df, duration_ms) # ---- Unit-test helpers (pandas) --------------------------------------- diff --git a/src/fastflowtransform/executors/budget/runtime/__init__.py b/src/fastflowtransform/executors/budget/runtime/__init__.py index 558dbe4..cf315e4 100644 --- a/src/fastflowtransform/executors/budget/runtime/__init__.py +++ b/src/fastflowtransform/executors/budget/runtime/__init__.py @@ -1,4 +1,5 @@ from fastflowtransform.executors.budget.runtime.base import BaseBudgetRuntime +from fastflowtransform.executors.budget.runtime.bigquery import BigQueryBudgetRuntime from fastflowtransform.executors.budget.runtime.databricks_spark import ( DatabricksSparkBudgetRuntime, ) @@ -10,6 +11,7 @@ __all__ = [ "BaseBudgetRuntime", + "BigQueryBudgetRuntime", "DatabricksSparkBudgetRuntime", "DuckBudgetRuntime", "PostgresBudgetRuntime", diff --git a/src/fastflowtransform/executors/budget/runtime/bigquery.py b/src/fastflowtransform/executors/budget/runtime/bigquery.py new file mode 100644 index 0000000..ec1df12 --- /dev/null +++ b/src/fastflowtransform/executors/budget/runtime/bigquery.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import Any, Protocol + +from fastflowtransform.executors.budget.core import BudgetGuard +from fastflowtransform.executors.budget.runtime.base import BaseBudgetRuntime, BudgetExecutor +from fastflowtransform.typing import bigquery + + +class BigQueryBudgetExecutor(BudgetExecutor, Protocol): + project: str + dataset: str + location: str | None + client: Any + + +class BigQueryBudgetRuntime(BaseBudgetRuntime[BigQueryBudgetExecutor]): + """BigQuery budget runtime using dry-run estimation.""" + + DEFAULT_GUARD = BudgetGuard( + env_var="FF_BQ_MAX_BYTES", + estimator_attr="runtime_budget_estimate_query_bytes", + engine_label="BigQuery", + what="query", + ) + + def estimate_query_bytes(self, sql: str) -> int | None: + """ + Estimate bytes for a BigQuery SQL statement using a dry-run. + + Returns the estimated bytes, or None if estimation is not possible. + """ + cfg = bigquery.QueryJobConfig( + dry_run=True, + use_query_cache=False, + ) + if self.executor.dataset: + cfg.default_dataset = bigquery.DatasetReference( + self.executor.project, self.executor.dataset + ) + + try: + job = self.executor.client.query( + sql, + job_config=cfg, + location=self.executor.location, + ) + job.result() + except Exception: + return None + + try: + return int(getattr(job, "total_bytes_processed", 0) or 0) + except Exception: + return None diff --git a/src/fastflowtransform/executors/query_stats/core.py b/src/fastflowtransform/executors/query_stats/core.py index a4d58b7..d630a79 100644 --- a/src/fastflowtransform/executors/query_stats/core.py +++ b/src/fastflowtransform/executors/query_stats/core.py @@ -30,7 +30,7 @@ class _TrackedQueryJob: - Never raises from the callback; stats collection is strictly best-effort. """ - def __init__(self, inner_job: Any, *, on_complete: Callable[[Any], None]) -> None: + def __init__(self, inner_job: Any, *, on_complete: Callable[[Any], Any]) -> None: self._inner_job = inner_job self._on_complete = on_complete self._done = False diff --git a/src/fastflowtransform/executors/query_stats/runtime/__init__.py b/src/fastflowtransform/executors/query_stats/runtime/__init__.py index f62e12e..731f266 100644 --- a/src/fastflowtransform/executors/query_stats/runtime/__init__.py +++ b/src/fastflowtransform/executors/query_stats/runtime/__init__.py @@ -1,4 +1,5 @@ from fastflowtransform.executors.query_stats.runtime.base import BaseQueryStatsRuntime +from fastflowtransform.executors.query_stats.runtime.bigquery import BigQueryQueryStatsRuntime from fastflowtransform.executors.query_stats.runtime.databricks_spark import ( DatabricksSparkQueryStatsRuntime, ) @@ -10,6 +11,7 @@ __all__ = [ "BaseQueryStatsRuntime", + "BigQueryQueryStatsRuntime", "DatabricksSparkQueryStatsRuntime", "DuckQueryStatsRuntime", "PostgresQueryStatsRuntime", diff --git a/src/fastflowtransform/executors/query_stats/runtime/bigquery.py b/src/fastflowtransform/executors/query_stats/runtime/bigquery.py new file mode 100644 index 0000000..86db2a1 --- /dev/null +++ b/src/fastflowtransform/executors/query_stats/runtime/bigquery.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import Any + +import pandas as pd + +from fastflowtransform.executors.query_stats.core import QueryStats, _TrackedQueryJob +from fastflowtransform.executors.query_stats.runtime.base import ( + BaseQueryStatsRuntime, + QueryStatsExecutor, +) + + +class BigQueryQueryStatsRuntime(BaseQueryStatsRuntime[QueryStatsExecutor]): + """BigQuery-specific stats helpers.""" + + def wrap_job(self, job: Any) -> _TrackedQueryJob: + return _TrackedQueryJob(job, on_complete=self.record_job) + + def record_dataframe(self, df: pd.DataFrame, duration_ms: int) -> QueryStats: + rows = len(df) + bytes_estimate = int(df.memory_usage(deep=True).sum()) if rows > 0 else 0 + bytes_val = bytes_estimate if bytes_estimate > 0 else None + stats = QueryStats( + bytes_processed=bytes_val, + rows=rows if rows > 0 else None, + duration_ms=duration_ms, + ) + self.executor._record_query_stats(stats) + return stats diff --git a/src/fastflowtransform/executors/snowflake_snowpark.py b/src/fastflowtransform/executors/snowflake_snowpark.py index 298fb47..0c5d7dd 100644 --- a/src/fastflowtransform/executors/snowflake_snowpark.py +++ b/src/fastflowtransform/executors/snowflake_snowpark.py @@ -75,18 +75,6 @@ def compute_freshness_delay_minutes(self, table: str, ts_col: str) -> tuple[floa val = row[0] if row else None return (float(val) if val is not None else None, sql) - # ---------- Cost estimation & central execution ---------- - - # def _estimate_query_bytes(self, sql: str) -> int | None: - # """Compatibility shim that delegates to the budget runtime estimator.""" - # return self.runtime_budget.estimate_query_bytes(sql) - - # def runtime_budget_estimate_query_bytes(self, sql: str) -> int | None: - # """ - # Entry point for BudgetGuard to call into the runtime estimator. - # """ - # return self.runtime_budget.estimate_query_bytes(sql) - def _execute_sql_basic(self, sql: str) -> SNDF: return self.session.sql(sql) @@ -183,17 +171,6 @@ def _materialize_relation(self, relation: str, df: SNDF, node: Node) -> None: duration_ms = int((perf_counter() - start) * 1000) self.runtime_query_stats.record_dataframe(df, duration_ms) - # def _estimate_frame_bytes(self, df: SNDF) -> int | None: - # """ - # Best-effort bytes estimate for a Snowpark DataFrame. - - # Strategy: - # 1) Use DataFrame.queries["queries"] (public Snowpark API) to get SQL. - # 2) Optionally fall back to df._plan.sql() if queries is missing/empty. - # 3) Run the budget runtime estimator on the derived SQL. - # """ - # return self.runtime_budget.dataframe_bytes(df) - def _create_view_over_table(self, view_name: str, backing_table: str, node: Node) -> None: qv = self._qualified(view_name) qb = self._qualified(backing_table) diff --git a/tests/common/fixtures.py b/tests/common/fixtures.py index b7cb8ae..a8d30fa 100644 --- a/tests/common/fixtures.py +++ b/tests/common/fixtures.py @@ -14,6 +14,7 @@ import fastflowtransform.executors.bigquery.base as bq_base import fastflowtransform.executors.bigquery.pandas as bq_pandas +import fastflowtransform.executors.budget.runtime.bigquery as bq_budget_runtime import fastflowtransform.typing as fft_typing from fastflowtransform import utest from fastflowtransform.core import REGISTRY @@ -500,7 +501,7 @@ def bq_executor_fake(monkeypatch) -> BigQueryExecutor: # see the fake module. fake_bq = install_fake_bigquery( monkeypatch, - target_modules=[fft_typing, bq_base, bq_pandas], + target_modules=[fft_typing, bq_base, bq_pandas, bq_budget_runtime], ) # Instantiate FakeClient via the fake module so the types line up diff --git a/tests/unit/executors/test_bigquery_bf_exec_unit.py b/tests/unit/executors/test_bigquery_bf_exec_unit.py index f979631..a2267dd 100644 --- a/tests/unit/executors/test_bigquery_bf_exec_unit.py +++ b/tests/unit/executors/test_bigquery_bf_exec_unit.py @@ -19,6 +19,7 @@ import fastflowtransform.executors.bigquery.base as bq_base_mod import fastflowtransform.executors.bigquery.bigframes as bq_exec_mod +import fastflowtransform.executors.budget.runtime.bigquery as bq_budget_runtime_mod from fastflowtransform.core import Node from fastflowtransform.executors.base import BaseExecutor @@ -46,7 +47,7 @@ def read_gbq(self, table_id: str) -> Any: @pytest.fixture def bq_exec(monkeypatch): - _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_base_mod]) + _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_base_mod, bq_budget_runtime_mod]) # Test-only shim: ensure the fake bigquery module has DatasetReference, # which BigQueryBaseExecutor._execute_sql now relies on. @@ -140,7 +141,7 @@ def to_gbq(self, table_id, if_exists="replace"): @pytest.mark.unit @pytest.mark.bigquery def test_ensure_dataset_respects_flag(monkeypatch): - _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_base_mod]) + _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_base_mod, bq_budget_runtime_mod]) fake_bigframes = types.ModuleType("bigframes") fake_conf = types.ModuleType("bigframes._config") @@ -169,7 +170,7 @@ def test_ensure_dataset_respects_flag(monkeypatch): @pytest.mark.unit @pytest.mark.bigquery def test_ensure_dataset_creates_when_allowed(monkeypatch): - _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_base_mod]) + _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_base_mod, bq_budget_runtime_mod]) fake_bigframes = types.ModuleType("bigframes") fake_conf = types.ModuleType("bigframes._config") diff --git a/tests/unit/executors/test_bigquery_exec_unit.py b/tests/unit/executors/test_bigquery_exec_unit.py index 71f9586..80c06ee 100644 --- a/tests/unit/executors/test_bigquery_exec_unit.py +++ b/tests/unit/executors/test_bigquery_exec_unit.py @@ -19,13 +19,14 @@ import fastflowtransform.executors.bigquery.base as bq_base_mod import fastflowtransform.executors.bigquery.pandas as bq_exec_mod +import fastflowtransform.executors.budget.runtime.bigquery as bq_budget_runtime_mod from fastflowtransform.core import Node from fastflowtransform.executors.base import BaseExecutor @pytest.fixture def bq_exec(monkeypatch): - _ = install_fake_bigquery(monkeypatch, [bq_exec_mod]) + _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_base_mod, bq_budget_runtime_mod]) fake_client = FakeClient(project="p1", location="EU") @@ -171,7 +172,7 @@ def test_format_source_reference(bq_exec): @pytest.mark.unit @pytest.mark.bigquery def test_ensure_dataset_respects_flag(monkeypatch): - _ = install_fake_bigquery(monkeypatch, [bq_exec_mod]) + _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_base_mod, bq_budget_runtime_mod]) fake_client = FakeClient(project="p1", location="EU") ex = bq_exec_mod.BigQueryExecutor( @@ -189,7 +190,7 @@ def test_ensure_dataset_respects_flag(monkeypatch): @pytest.mark.unit @pytest.mark.bigquery def test_ensure_dataset_creates_when_allowed(monkeypatch): - _ = install_fake_bigquery(monkeypatch, [bq_exec_mod]) + _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_base_mod, bq_budget_runtime_mod]) fake_client = FakeClient(project="p1", location="EU") ex = bq_exec_mod.BigQueryExecutor( From 41b13a35bd6fb1b114045f8261886fe987971480 Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Fri, 19 Dec 2025 16:17:42 +0100 Subject: [PATCH 08/10] Updated docs from static html files to SPA + fuzzy search in docs and search palette --- src/fastflowtransform/cli/__init__.py | 2 + src/fastflowtransform/cli/docs_cmd.py | 424 ++++++++ src/fastflowtransform/docs.py | 242 ++++- .../templates/assets/spa.css | 179 ++++ src/fastflowtransform/templates/assets/spa.js | 935 ++++++++++++++++++ src/fastflowtransform/templates/index.html.j2 | 274 +---- src/fastflowtransform/templates/model.html.j2 | 235 ----- .../templates/source.html.j2 | 163 --- .../docs/test_docs_materialization_badges.py | 51 - tests/unit/docs/test_docs_unit.py | 59 +- 10 files changed, 1769 insertions(+), 795 deletions(-) create mode 100644 src/fastflowtransform/cli/docs_cmd.py create mode 100644 src/fastflowtransform/templates/assets/spa.css create mode 100644 src/fastflowtransform/templates/assets/spa.js delete mode 100644 src/fastflowtransform/templates/model.html.j2 delete mode 100644 src/fastflowtransform/templates/source.html.j2 delete mode 100644 tests/unit/docs/test_docs_materialization_badges.py diff --git a/src/fastflowtransform/cli/__init__.py b/src/fastflowtransform/cli/__init__.py index 88fff42..d641f61 100644 --- a/src/fastflowtransform/cli/__init__.py +++ b/src/fastflowtransform/cli/__init__.py @@ -19,6 +19,7 @@ from fastflowtransform.cli.dag_cmd import dag, register as _register_dag from fastflowtransform.cli.deps_cmd import register as _register_deps from fastflowtransform.cli.docgen_cmd import docgen, register as _register_docgen +from fastflowtransform.cli.docs_cmd import register as _register_docs from fastflowtransform.cli.docs_utils import ( _build_docs_manifest, _infer_sql_ref_aliases, @@ -135,6 +136,7 @@ def main( _register_source(app) _register_ci(app) _register_deps(app) +_register_docs(app) __all__ = [ diff --git a/src/fastflowtransform/cli/docs_cmd.py b/src/fastflowtransform/cli/docs_cmd.py new file mode 100644 index 0000000..24c9e07 --- /dev/null +++ b/src/fastflowtransform/cli/docs_cmd.py @@ -0,0 +1,424 @@ +# fastflowtransform/cli/docs_cmd.py +from __future__ import annotations + +import contextlib +import queue +import threading +import time +import webbrowser +from collections.abc import Callable, Iterable +from dataclasses import dataclass, field +from functools import partial +from http import HTTPStatus +from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer +from pathlib import Path +from typing import Any +from urllib.parse import unquote + +import typer + +from fastflowtransform.cli.bootstrap import _prepare_context +from fastflowtransform.cli.docs_utils import _resolve_dag_out_dir +from fastflowtransform.cli.options import ( + EngineOpt, + EnvOpt, + OutOpt, + ProjectArg, + VarsOpt, + WithSchemaOpt, +) +from fastflowtransform.core import REGISTRY +from fastflowtransform.docs import render_site +from fastflowtransform.logging import echo, echo_debug + + +# --------------------------- +# Hot-reload broadcaster (SSE) +# --------------------------- +@dataclass +class _SSEClient: + q: queue.Queue[str] = field(default_factory=queue.Queue) + + +class _ReloadHub: + def __init__(self) -> None: + self._lock = threading.Lock() + self._clients: list[_SSEClient] = [] + + def add(self) -> _SSEClient: + c = _SSEClient() + with self._lock: + self._clients.append(c) + return c + + def remove(self, c: _SSEClient) -> None: + with self._lock, contextlib.suppress(ValueError): + self._clients.remove(c) + + def broadcast_reload(self) -> None: + with self._lock: + for c in list(self._clients): + # SSE "reload" event + c.q.put("event: reload\ndata: 1\n\n") + + def broadcast_log(self, msg: str) -> None: + # Optional: can be used later for in-browser toast/logging + payload = msg.replace("\n", " ").strip() + with self._lock: + for c in list(self._clients): + c.q.put(f"event: log\ndata: {payload}\n\n") + + +# --------------------------- +# HTTP handler +# --------------------------- +_RELOAD_JS = r""" +(() => { + const es = new EventSource("/__fft_events"); + es.addEventListener("reload", () => { + // Hard reload to also refresh CSS/JS + window.location.reload(); + }); + // Optional future: show server logs in UI + es.addEventListener("log", (ev) => { + // console.log("[FFT]", ev.data); + }); + es.onerror = () => { + // auto-reconnect is handled by EventSource, but we can hint in console + // console.warn("[FFT] reload channel error; reconnecting…"); + }; +})(); +""".lstrip() + + +class _DocsHandler(SimpleHTTPRequestHandler): + """ + - Serves static files from a directory + - SPA fallback: missing paths -> index.html (except /assets/* and /__fft_*) + - SSE endpoint for hot reload + - HTML injection to load reload script (dev only) + """ + + server_version = "FFTDocs/1.0" + + def __init__( + self, + *args: Any, + directory: str | None = None, + hub: _ReloadHub | None = None, + inject_reload: bool = False, + **kwargs: Any, + ) -> None: + self._hub = hub + self._inject_reload = inject_reload + super().__init__(*args, directory=directory, **kwargs) + + def end_headers(self) -> None: + # Make dev-server behavior more predictable + self.send_header("Cache-Control", "no-store") + super().end_headers() + + def do_GET(self) -> None: + # Special endpoints first + if self.path == "/__fft_reload.js": + self._serve_text(_RELOAD_JS, content_type="application/javascript; charset=utf-8") + return + + if self.path == "/__fft_events": + self._serve_sse() + return + + # Try regular file first + local_path = self.translate_path(self.path) + p = Path(local_path) + + # SPA fallback (don't eat assets or internal endpoints) + req_path = unquote(self.path.split("?", 1)[0]) + if ( + not p.exists() + and not req_path.startswith("/assets/") + and not req_path.startswith("/__fft_") + ): + idx = Path(self.translate_path("/index.html")) + if idx.exists(): + self._serve_file(idx, inject=self._inject_reload) + return + + # Serve real file, with optional HTML injection + if p.exists() and p.is_file() and p.suffix.lower() in (".html", ".htm"): + self._serve_file(p, inject=self._inject_reload) + return + + # Default static behavior + super().do_GET() + + def _serve_text(self, text: str, *, content_type: str) -> None: + data = text.encode("utf-8") + self.send_response(HTTPStatus.OK) + self.send_header("Content-Type", content_type) + self.send_header("Content-Length", str(len(data))) + self.end_headers() + self.wfile.write(data) + + def _serve_file(self, p: Path, *, inject: bool) -> None: + data = p.read_bytes() + ctype = self.guess_type(str(p)) + + if inject and ctype.startswith("text/html"): + try: + s = data.decode("utf-8") + tag = '\n' + s = s.replace("", f"{tag}") if "" in s else s + "\n" + tag + data = s.encode("utf-8") + except Exception: + # If decoding fails, serve raw. + pass + + self.send_response(HTTPStatus.OK) + self.send_header( + "Content-Type", f"{ctype}; charset=utf-8" if ctype.startswith("text/") else ctype + ) + self.send_header("Content-Length", str(len(data))) + self.end_headers() + self.wfile.write(data) + + def _serve_sse(self) -> None: + if not self._hub: + self.send_error(HTTPStatus.NOT_FOUND, "Reload hub not configured") + return + + self.send_response(HTTPStatus.OK) + self.send_header("Content-Type", "text/event-stream; charset=utf-8") + self.send_header("Cache-Control", "no-cache") + self.send_header("Connection", "keep-alive") + self.end_headers() + + client = self._hub.add() + try: + # Initial hello + faster reconnect + self.wfile.write(b"retry: 1000\n\n") + self.wfile.flush() + + while True: + try: + msg = client.q.get(timeout=15.0) + except queue.Empty: + # keep-alive + msg = "event: ping\ndata: 1\n\n" + self.wfile.write(msg.encode("utf-8")) + self.wfile.flush() + except (BrokenPipeError, ConnectionResetError): + pass + finally: + # best-effort cleanup + with contextlib.suppress(Exception): + self._hub.remove(client) + + +# --------------------------- +# Watcher +# --------------------------- +def _glob_many(root: Path, patterns: Iterable[str]) -> list[Path]: + out: list[Path] = [] + for pat in patterns: + out.extend(root.glob(pat)) + return [p for p in out if p.exists()] + + +def _collect_watch_paths(project_dir: Path) -> list[Path]: + # Project files affecting docs content + paths: list[Path] = [] + paths += [project_dir / "project.yml"] + paths += [project_dir / "sources.yml"] + paths += _glob_many(project_dir, ["docs/**/*.md", "docs/**/*.yml", "docs/**/*.yaml"]) + paths += _glob_many(project_dir, ["models/**/*.ff.sql", "models/**/*.ff.py"]) + paths += _glob_many(project_dir, ["macros/**/*.sql", "macros/**/*.py"]) + + # Bundled templates + assets (docs.py loads templates from + # package/templates) :contentReference[oaicite:3]{index=3} + pkg_templates = Path(__file__).resolve().parents[1] / "templates" + if pkg_templates.exists(): + paths += _glob_many(pkg_templates, ["**/*.j2", "assets/**/*"]) + + # Keep only files + return sorted({p.resolve() for p in paths if p.exists() and p.is_file()}) + + +def _snapshot_mtime(paths: list[Path]) -> dict[Path, int]: + snap: dict[Path, int] = {} + for p in paths: + try: + snap[p] = p.stat().st_mtime_ns + except FileNotFoundError: + continue + return snap + + +def _watch_loop( + *, + project_dir: Path, + poll_s: float, + debounce_s: float, + on_change: Callable, +) -> None: + watched = _collect_watch_paths(project_dir) + prev = _snapshot_mtime(watched) + last_trigger = 0.0 + + echo_debug(f"Watching {len(watched)} files for docs reload") + + while True: + time.sleep(poll_s) + + # Refresh list occasionally (new files added) + if int(time.time()) % 10 == 0: + watched = _collect_watch_paths(project_dir) + + cur = _snapshot_mtime(watched) + + changed: list[Path] = [] + for p, mt in cur.items(): + if prev.get(p) != mt: + changed.append(p) + + # Removed files + for p in list(prev.keys()): + if p not in cur: + changed.append(p) + + if changed: + now = time.time() + if now - last_trigger < debounce_s: + prev = cur + continue + last_trigger = now + prev = cur + + # Pick a representative file for log + head = changed[0] + echo_debug(f"Docs change detected: {head}") + with contextlib.suppress(Exception): + on_change(head) + + +# --------------------------- +# CLI command +# --------------------------- +docs_app = typer.Typer(help="Docs tooling (dev server, live docs)") + + +def _build_docs_once( + *, + project: str, + env_name: str, + engine: Any, + vars: list[str] | None, + out: Path | None, + with_schema: bool, +) -> tuple[Path, Any]: + if out is not None: + out = out.resolve() + out.mkdir(parents=True, exist_ok=True) + + ctx = _prepare_context(project, env_name, engine, vars) + ex, *_ = ctx.make_executor() + + out_dir = _resolve_dag_out_dir(ctx.project, out) + out_dir.mkdir(parents=True, exist_ok=True) + + t0 = time.time() + render_site(out_dir, REGISTRY.nodes, executor=ex, with_schema=with_schema) + dt = time.time() - t0 + + echo(f"Docs written to {out_dir / 'index.html'} ({dt:.2f}s)") + return out_dir, ctx + + +@docs_app.command("serve", help="Serve docs with live reload (watch templates/YAML/MD/schema).") +def serve( + project: ProjectArg = ".", + env_name: EnvOpt = "dev", + engine: EngineOpt = None, + vars: VarsOpt = None, + out: OutOpt = None, + with_schema: WithSchemaOpt = True, + port: int = typer.Option(8000, "--port", help="Port to bind (default: 8000)"), + host: str = typer.Option("127.0.0.1", "--host", help="Host to bind (default: 127.0.0.1)"), + watch: bool = typer.Option( + True, + "--watch/--no-watch", + help="Watch project + templates and hot reload.", + show_default=True, + ), + open_browser: bool = typer.Option(False, "--open", help="Open docs in the default browser."), + poll: float = typer.Option(0.5, "--poll", help="Watch poll interval (seconds)."), + debounce: float = typer.Option(0.35, "--debounce", help="Debounce rebuilds (seconds)."), +) -> None: + # Initial build + out_dir, _ = _build_docs_once( + project=project, + env_name=env_name, + engine=engine, + vars=vars, + out=out, + with_schema=with_schema, + ) + + url = f"http://{host}:{port}/" + echo(f"Serving docs from {out_dir} at {url}") + + hub = _ReloadHub() + + # Rebuild callback used by watcher + rebuild_lock = threading.Lock() + + def _rebuild(changed: Path) -> None: + with rebuild_lock: + echo_debug(f"Rebuilding docs due to: {changed}") + try: + _build_docs_once( + project=project, + env_name=env_name, + engine=engine, + vars=vars, + out=out, + with_schema=with_schema, + ) + hub.broadcast_reload() + except Exception as e: + # Keep server alive + echo(f"Docs rebuild failed: {e}") + + if watch: + t = threading.Thread( + target=_watch_loop, + kwargs=dict( + project_dir=Path(project).resolve(), + poll_s=poll, + debounce_s=debounce, + on_change=_rebuild, + ), + daemon=True, + ) + t.start() + + handler = partial(_DocsHandler, directory=str(out_dir), hub=hub, inject_reload=True) + httpd = ThreadingHTTPServer((host, port), handler) + + if open_browser: + with contextlib.suppress(Exception): + webbrowser.open(url, new=2) + + try: + httpd.serve_forever() + except KeyboardInterrupt: + echo("Stopping docs server…") + finally: + httpd.server_close() + + +def register(app: typer.Typer) -> None: + app.add_typer(docs_app, name="docs") + + +__all__ = ["register"] diff --git a/src/fastflowtransform/docs.py b/src/fastflowtransform/docs.py index 15a7c42..75436c6 100644 --- a/src/fastflowtransform/docs.py +++ b/src/fastflowtransform/docs.py @@ -1,8 +1,11 @@ # fastflowtransform/docs.py from __future__ import annotations +import json import re +import shutil from dataclasses import dataclass, field +from datetime import UTC, datetime from pathlib import Path from typing import Any @@ -155,6 +158,152 @@ def _init_jinja() -> Environment: ) +_TAG_RE = re.compile(r"<[^>]+>") + + +def _html_to_text(s: str | None) -> str | None: + if not s: + return None + # fast + good-enough for docs descriptions + txt = _TAG_RE.sub("", s) + txt = re.sub(r"\s+", " ", txt).strip() + return txt or None + + +def _copy_template_assets(out_dir: Path) -> None: + """ + Copy packaged static assets from templates/assets -> /assets. + Safe no-op if no assets exist. + """ + tmpl_dir = Path(__file__).parent / "templates" + src = tmpl_dir / "assets" + if not src.exists() or not src.is_dir(): + return + dst = out_dir / "assets" + dst.mkdir(parents=True, exist_ok=True) + for p in src.rglob("*"): + if p.is_dir(): + continue + rel = p.relative_to(src) + target = dst / rel + target.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(p, target) + + +def _project_name(proj_dir: Path | None) -> str: + if not proj_dir: + return "FastFlowTransform" + cfg_path = proj_dir / "project.yml" + try: + if cfg_path.exists(): + cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) or {} + if isinstance(cfg, dict) and cfg.get("name"): + return str(cfg["name"]) + except Exception: + pass + return proj_dir.name + + +def _build_spa_manifest( + *, + proj_name: str, + env_name: str | None, + with_schema: bool, + mermaid_src: str, + models: list[ModelDoc], + sources: list[SourceDoc], + macros: list[dict[str, str]], + used_by: dict[str, list[str]], + cols_by_table: dict[str, list[ColumnInfo]], + model_source_refs: dict[str, list[tuple[str, str]]], + sources_by_key: dict[tuple[str, str], SourceDoc], +) -> dict[str, Any]: + def _col_to_dict(c: ColumnInfo) -> dict[str, Any]: + html = c.description_html + html_s = str(html) if html is not None else None + return { + "name": c.name, + "dtype": c.dtype, + "nullable": bool(c.nullable), + "description_html": c.description_html, + "description_text": _html_to_text(html_s), + "lineage": c.lineage or [], + } + + out_models: list[dict[str, Any]] = [] + for m in models: + # model -> sources used (source(), table) + src_keys = model_source_refs.get(m.name, []) or [] + src_used = [] + for k in src_keys: + doc = sources_by_key.get(k) + if not doc: + continue + src_used.append( + { + "source_name": doc.source_name, + "table_name": doc.table_name, + "relation": doc.relation, + } + ) + + cols = [] + if with_schema and m.relation in cols_by_table: + cols = [_col_to_dict(c) for c in (cols_by_table.get(m.relation) or [])] + + model_desc_html = m.description_html + model_desc_html_s = str(model_desc_html) if model_desc_html is not None else None + + out_models.append( + { + "name": m.name, + "kind": m.kind, + "path": m.path, + "relation": m.relation, + "deps": list(m.deps or []), + "used_by": list(used_by.get(m.name, []) or []), + "materialized": m.materialized, + "description_html": m.description_html, + "description_text": _html_to_text(model_desc_html_s), + "description_short": m.description_short, + "sources_used": src_used, + "columns": cols, + } + ) + + out_sources: list[dict[str, Any]] = [] + for s in sources: + src_desc_html = s.description_html + src_desc_html_s = str(src_desc_html) if src_desc_html is not None else None + + out_sources.append( + { + "source_name": s.source_name, + "table_name": s.table_name, + "relation": s.relation, + "description_html": s.description_html, + "description_text": _html_to_text(src_desc_html_s), + "loaded_at_field": s.loaded_at_field, + "warn_after_minutes": s.warn_after_minutes, + "error_after_minutes": s.error_after_minutes, + "consumers": list(s.consumers or []), + } + ) + + return { + "project": { + "name": proj_name, + "generated_at": datetime.now(UTC).isoformat(), + "env": env_name, + "with_schema": bool(with_schema), + }, + "dag": {"mermaid": mermaid_src}, + "models": out_models, + "sources": out_sources, + "macros": macros, + } + + def _get_project_dir() -> Path | None: """Best-effort resolution of the project dir from the registry.""" if not hasattr(REGISTRY, "get_project_dir"): @@ -509,8 +658,11 @@ def render_site( executor: Any | None = None, *, with_schema: bool = True, + spa: bool = True, + legacy_pages: bool = False, ) -> None: out_dir.mkdir(parents=True, exist_ok=True) + _copy_template_assets(out_dir) env = _init_jinja() source_consumers, model_source_refs = _scan_source_refs(nodes) @@ -535,30 +687,76 @@ def render_site( _apply_descriptions_to_models(models, docs_meta, cols_by_table, with_schema=with_schema) _infer_and_attach_lineage(models, executor, docs_meta, cols_by_table, with_schema=with_schema) - _render_index( - env, - out_dir, - mermaid_src=Markup(mermaid_src), - models=models, - sources=sources, - materialization_legend=mat_legend, - macros=macro_list, - ) - used_by = _reverse_deps(nodes) - _render_model_pages( - env, - out_dir, - models=models, - used_by=used_by, - cols_by_table=cols_by_table, - materialization_legend=mat_legend, - macros=macro_list, - model_sources=model_source_refs, - sources_index=sources_by_key, - ) - _render_source_pages(env, out_dir, sources) + if spa: + _copy_template_assets(out_dir) + proj_name = _project_name(proj_dir) + env_name = getattr(REGISTRY, "active_engine", None) # best-effort, not perfect + manifest = _build_spa_manifest( + proj_name=proj_name, + env_name=env_name, + with_schema=with_schema, + mermaid_src=str(mermaid_src), + models=models, + sources=sources, + macros=macro_list, + used_by=used_by, + cols_by_table=cols_by_table, + model_source_refs=model_source_refs, + sources_by_key=sources_by_key, + ) + assets_dir = out_dir / "assets" + assets_dir.mkdir(parents=True, exist_ok=True) + (assets_dir / "docs_manifest.json").write_text( + json.dumps(manifest, indent=2), encoding="utf-8" + ) + + # SPA shell (index.html.j2) + _render_index( + env, + out_dir, + project_name=proj_name, + manifest_path="assets/docs_manifest.json", + ) + + # Optional legacy pages (useful during transition) + if legacy_pages: + _render_model_pages( + env, + out_dir, + models=models, + used_by=used_by, + cols_by_table=cols_by_table, + materialization_legend=mat_legend, + macros=macro_list, + model_sources=model_source_refs, + sources_index=sources_by_key, + ) + _render_source_pages(env, out_dir, sources) + else: + # Legacy behavior + _render_index( + env, + out_dir, + mermaid_src=Markup(mermaid_src), + models=models, + sources=sources, + materialization_legend=mat_legend, + macros=macro_list, + ) + _render_model_pages( + env, + out_dir, + models=models, + used_by=used_by, + cols_by_table=cols_by_table, + materialization_legend=mat_legend, + macros=macro_list, + model_sources=model_source_refs, + sources_index=sources_by_key, + ) + _render_source_pages(env, out_dir, sources) @dataclass diff --git a/src/fastflowtransform/templates/assets/spa.css b/src/fastflowtransform/templates/assets/spa.css new file mode 100644 index 0000000..11e1e4a --- /dev/null +++ b/src/fastflowtransform/templates/assets/spa.css @@ -0,0 +1,179 @@ +:root{ + --bg:#0b0f1a; --fg:#e5e7eb; --muted:#94a3b8; --card:#111827; --border:#1f2937; + --accent:#60a5fa; --accent2:#22c55e; --warn:#f59e0b; --err:#ef4444; + --mono: ui-monospace,SFMono-Regular,Menlo,Consolas,monospace; + --sans: ui-sans-serif,system-ui,-apple-system,Segoe UI,Roboto,Arial; +} +@media (prefers-color-scheme: light){ + :root{ --bg:#ffffff; --fg:#0f172a; --muted:#6b7280; --card:#ffffff; --border:#e5e7eb; --accent:#2563eb; --accent2:#16a34a; } +} +*{ box-sizing:border-box; } +body{ margin:0; background:var(--bg); color:var(--fg); font:14px/1.5 var(--sans); } +a{ color:var(--accent); text-decoration:none; } a:hover{ text-decoration:underline; } +code, .mono{ font-family:var(--mono); font-size:12px; } +hr{ border:0; border-top:1px solid var(--border); margin:16px 0; } + +.shell{ display:grid; grid-template-columns: 340px 1fr; min-height:100vh; } +@media (max-width: 1000px){ .shell{ grid-template-columns: 1fr; } } + +.sidebar{ + border-right:1px solid var(--border); + padding:16px; + position:sticky; top:0; align-self:start; + height:100vh; overflow:auto; + background:color-mix(in srgb, var(--bg), transparent 0%); +} +@media (max-width: 1000px){ + .sidebar{ position:relative; height:auto; border-right:0; border-bottom:1px solid var(--border); } +} + +.brand{ display:flex; align-items:baseline; justify-content:space-between; gap:8px; margin-bottom:12px; } +.brand h1{ font-size:14px; margin:0; } +.badge{ font-size:12px; border:1px solid var(--border); border-radius:999px; padding:2px 8px; color:var(--muted); } +.search{ width:100%; padding:10px 12px; border-radius:12px; border:1px solid var(--border); background:transparent; color:var(--fg); outline:none; } +.section{ margin-top:14px; } +.section h2{ margin:0 0 8px 0; font-size:12px; letter-spacing:.06em; text-transform:uppercase; color:var(--muted); } + +.list{ list-style:none; padding:0; margin:0; display:flex; flex-direction:column; gap:6px; } +.item a{ + display:flex; justify-content:space-between; gap:10px; + padding:8px 10px; border-radius:12px; border:1px solid transparent; + color:inherit; +} +.item a:hover{ border-color:var(--border); background:color-mix(in srgb, var(--card), transparent 70%); } +.pill{ font-size:11px; padding:1px 8px; border-radius:999px; border:1px solid var(--border); color:var(--muted); } +.pill.sql{ color: color-mix(in srgb, var(--accent), white 0%); } +.pill.python{ color: color-mix(in srgb, var(--accent2), white 0%); } + +.main{ padding:20px; } +.card{ background:var(--card); border:1px solid var(--border); border-radius:16px; padding:16px; } +.grid{ display:grid; gap:16px; } +.grid2{ display:grid; gap:16px; grid-template-columns: 2fr 1fr; } +@media (max-width: 1100px){ .grid2{ grid-template-columns:1fr; } } + +.kv{ display:grid; grid-template-columns: 140px 1fr; gap:8px 12px; align-items:start; } +.kv .k{ color:var(--muted); } + +.btn{ + border:1px solid var(--border); background:transparent; color:var(--fg); + padding:8px 12px; border-radius:12px; cursor:pointer; +} +.btn:hover{ border-color:var(--accent); } + +.table{ width:100%; border-collapse:collapse; } +.table th,.table td{ text-align:left; padding:8px 10px; border-bottom:1px solid var(--border); vertical-align:top; } +.table th{ color:var(--muted); font-weight:600; font-size:12px; } + +.empty{ color:var(--muted); } +.desc p{ margin: 0 0 10px 0; } +.desc p:last-child{ margin-bottom:0; } +.mermaidWrap{ overflow:auto; } + +/* Global search palette */ +.overlay{ + position:fixed; inset:0; + background: rgba(0,0,0,.55); + display:flex; align-items:flex-start; justify-content:center; + padding: 10vh 16px 16px; + z-index: 9999; +} +.palette{ + width: min(900px, 100%); + background: var(--card); + border: 1px solid var(--border); + border-radius: 18px; + overflow:hidden; + box-shadow: 0 20px 60px rgba(0,0,0,.45); +} +.paletteHead{ + padding: 12px; + border-bottom: 1px solid var(--border); + display:flex; + gap:10px; + align-items:center; +} +.paletteInput{ + width:100%; + padding: 12px 12px; + border-radius: 14px; + border: 1px solid var(--border); + background: transparent; + color: var(--fg); + outline:none; +} +.paletteHint{ + color: var(--muted); + font-size: 12px; + white-space: nowrap; +} +.paletteList{ + max-height: 55vh; + overflow:auto; +} +.result{ + display:flex; + justify-content:space-between; + gap:12px; + padding: 10px 12px; + border-bottom: 1px solid color-mix(in srgb, var(--border), transparent 35%); + cursor:pointer; +} +.result:last-child{ border-bottom:0; } +.result:hover{ background: color-mix(in srgb, var(--accent), transparent 92%); } +.result.sel{ background: color-mix(in srgb, var(--accent), transparent 88%); } +.resultMain{ display:flex; flex-direction:column; gap:2px; min-width:0; } +.resultTitle{ font-weight:600; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; } +.resultSub{ color: var(--muted); font-size:12px; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; } +.kbd{ + font-family: var(--mono); + font-size: 11px; + padding: 2px 6px; + border-radius: 8px; + border: 1px solid var(--border); + color: var(--muted); +} + +/* Sidebar search hint */ +.searchWrap{ position:relative; } +.searchKbd{ + position:absolute; + right:10px; + top:50%; + transform:translateY(-50%); + pointer-events:none; + opacity:.85; +} +.searchTip{ + margin-top:8px; + color:var(--muted); + font-size:12px; +} + +/* First-run toast */ +.toast{ + position:fixed; + right:16px; + bottom:16px; + max-width:min(440px, calc(100vw - 32px)); + background:var(--card); + border:1px solid var(--border); + border-radius:16px; + padding:12px 12px; + box-shadow: 0 18px 50px rgba(0,0,0,.35); + z-index: 10000; + display:flex; + gap:12px; + align-items:flex-start; +} +.toastTitle{ font-weight:700; margin:0 0 4px 0; font-size:13px; } +.toastBody{ margin:0; color:var(--muted); font-size:12px; } +.toastActions{ margin-left:auto; display:flex; gap:8px; } +.toastBtn{ + border:1px solid var(--border); + background:transparent; + color:var(--fg); + padding:6px 10px; + border-radius:12px; + cursor:pointer; +} +.toastBtn:hover{ border-color:var(--accent); } diff --git a/src/fastflowtransform/templates/assets/spa.js b/src/fastflowtransform/templates/assets/spa.js new file mode 100644 index 0000000..6ca2df2 --- /dev/null +++ b/src/fastflowtransform/templates/assets/spa.js @@ -0,0 +1,935 @@ +const MANIFEST_URL = window.__FFT_MANIFEST_PATH__ || "assets/docs_manifest.json"; + +function el(tag, attrs = {}, ...children) { + const n = document.createElement(tag); + for (const [k, v] of Object.entries(attrs || {})) { + if (k === "class") n.className = v; + else if (k === "html") n.innerHTML = v; + else if (k.startsWith("on") && typeof v === "function") n.addEventListener(k.slice(2), v); + else n.setAttribute(k, String(v)); + } + for (const c of children) { + if (c == null) continue; + n.appendChild(typeof c === "string" ? document.createTextNode(c) : c); + } + return n; +} + +function stripHtml(html) { + if (!html) return ""; + const div = document.createElement("div"); + div.innerHTML = html; + return (div.textContent || div.innerText || "").replace(/\s+/g, " ").trim(); +} + +// “Fuzzy-ish” scorer: subsequence match + bonuses for contiguity and word boundaries. +// Returns -1 for no match, higher is better. +function fuzzyScore(query, text) { + query = (query || "").toLowerCase(); + text = (text || "").toLowerCase(); + if (!query) return 0; + + let qi = 0; + let score = 0; + let lastMatch = -10; + + for (let ti = 0; ti < text.length && qi < query.length; ti++) { + if (text[ti] === query[qi]) { + score += 10; + + // contiguous bonus + if (ti === lastMatch + 1) score += 8; + + // word boundary bonus + const prev = ti > 0 ? text[ti - 1] : " "; + if (prev === " " || prev === "_" || prev === "-" || prev === "." || prev === "/" ) score += 6; + + lastMatch = ti; + qi++; + } + } + + if (qi !== query.length) return -1; + + // shorter texts get a small bonus + score += Math.max(0, 30 - Math.min(text.length, 30)); + return score; +} + +function topN(items, n) { + items.sort((a, b) => b.score - a.score); + return items.slice(0, n); +} + +function escapeHashPart(s) { + return encodeURIComponent(String(s || "")).replaceAll("%2F", "/"); +} +function parseHash() { + const raw = (location.hash || "#/").slice(1); + const parts = raw.split("/").filter(Boolean); + if (parts.length === 0) return { route: "home" }; + if (parts[0] === "model" && parts[1]) return { route: "model", name: decodeURIComponent(parts.slice(1).join("/")) }; + if (parts[0] === "source" && parts[1] && parts[2]) { + return { route: "source", source: decodeURIComponent(parts[1]), table: decodeURIComponent(parts[2]) }; + } + if (parts[0] === "macros") return { route: "macros" }; + return { route: "home" }; +} + +async function initMermaid() { + try { + const prefersDark = window.matchMedia && window.matchMedia("(prefers-color-scheme: dark)").matches; + const mod = await import("https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.esm.min.mjs"); + const mermaid = mod.default; + mermaid.initialize({ startOnLoad: false, securityLevel: "loose", theme: prefersDark ? "dark" : "default" }); + return mermaid; + } catch (e) { + console.warn("Mermaid failed to load:", e); + return null; + } +} + +function byName(arr, keyFn) { + const m = new Map(); + for (const x of arr) m.set(keyFn(x), x); + return m; +} + +function pillForKind(kind) { + return el("span", { class: `pill ${kind}` }, kind); +} + +function renderSidebar(state, onNavigate) { + const { manifest, filter } = state; + const models = manifest.models || []; + const sources = manifest.sources || []; + + const q = (filter || "").trim().toLowerCase(); + + const filteredModels = q + ? models.filter(m => + (m.name || "").toLowerCase().includes(q) || + (m.relation || "").toLowerCase().includes(q) || + (m.description_short || "").toLowerCase().includes(q) + ) + : models; + + const filteredSources = q + ? sources.filter(s => + (`${s.source_name}.${s.table_name}`).toLowerCase().includes(q) || + (s.relation || "").toLowerCase().includes(q) + ) + : sources; + + return el( + "div", + { class: "sidebar" }, + el( + "div", + { class: "brand" }, + el("h1", {}, manifest.project?.name || "Docs"), + el("span", { class: "badge", title: `Generated: ${manifest.project?.generated_at || ""}` }, "SPA") + ), + el("div", { class: "searchWrap" }, + el("input", { + class: "search", + type: "search", + placeholder: "Filter sidebar… (press /)", + value: filter || "", + oninput: (e) => onNavigate({ type: "filter", value: e.target.value }), + }), + el("span", { class: "searchKbd kbd" }, "/") + ), + el("div", { class: "searchTip" }, "Tip: Press / (or Ctrl+K) to search everything (models, sources, columns)."), + el( + "div", + { class: "section" }, + el("h2", {}, `Models (${filteredModels.length})`), + el( + "ul", + { class: "list" }, + ...filteredModels.map(m => + el( + "li", + { class: "item" }, + el( + "a", + { + href: `#/model/${escapeHashPart(m.name)}`, + onclick: (e) => { e.preventDefault(); location.hash = `#/model/${escapeHashPart(m.name)}`; }, + title: m.description_short || m.name, + }, + el("span", {}, m.name), + pillForKind(m.kind === "python" ? "python" : "sql") + ) + ) + ) + ) + ), + el( + "div", + { class: "section" }, + el("h2", {}, `Sources (${filteredSources.length})`), + el( + "ul", + { class: "list" }, + ...filteredSources.map(s => { + const key = `${s.source_name}.${s.table_name}`; + return el( + "li", + { class: "item" }, + el( + "a", + { + href: `#/source/${escapeHashPart(s.source_name)}/${escapeHashPart(s.table_name)}`, + onclick: (e) => { e.preventDefault(); location.hash = `#/source/${escapeHashPart(s.source_name)}/${escapeHashPart(s.table_name)}`; }, + title: s.relation || key, + }, + el("span", {}, key), + el("span", { class: "pill" }, (s.consumers || []).length ? `${s.consumers.length}` : "–") + ) + ); + }) + ) + ), + el( + "div", + { class: "section" }, + el("h2", {}, "Other"), + el("ul", { class: "list" }, + el("li", { class: "item" }, + el("a", { + href: "#/macros", + onclick: (e) => { e.preventDefault(); location.hash = "#/macros"; }, + }, el("span", {}, "Macros"), el("span", { class: "pill" }, String((manifest.macros || []).length)))) + ) + ) + ); +} + +function renderHome(state) { + const { manifest, mermaid } = state; + const dagSrc = manifest.dag?.mermaid || ""; + + const dagCard = el("div", { class: "card" }, + el("div", { class: "grid" }, + el("div", { class: "grid2" }, + el("div", {}, + el("h2", {}, "DAG"), + el("p", { class: "empty" }, "Mermaid is rendered client-side.") + ), + el("div", {}, + el("button", { + class: "btn", + onclick: async () => { + try { await navigator.clipboard.writeText(dagSrc); } catch {} + } + }, "Copy Mermaid") + ) + ), + el("div", { class: "mermaidWrap" }, + el("div", { id: "mermaidTarget" }) + ) + ) + ); + + // Render mermaid after DOM is mounted + queueMicrotask(async () => { + const target = document.getElementById("mermaidTarget"); + if (!target) return; + if (!mermaid) { + target.textContent = dagSrc; + return; + } + target.innerHTML = `
${dagSrc}
`; + try { await mermaid.run({ querySelector: "#mermaidTarget .mermaid" }); } catch {} + }); + + const stats = el("div", { class: "card" }, + el("h2", {}, "Overview"), + el("div", { class: "kv" }, + el("div", { class: "k" }, "Models"), el("div", {}, String((manifest.models || []).length)), + el("div", { class: "k" }, "Sources"), el("div", {}, String((manifest.sources || []).length)), + el("div", { class: "k" }, "Macros"), el("div", {}, String((manifest.macros || []).length)), + el("div", { class: "k" }, "Schema"), el("div", {}, manifest.project?.with_schema ? "enabled" : "disabled"), + el("div", { class: "k" }, "Generated"), el("div", {}, manifest.project?.generated_at || "—") + ) + ); + + return el("div", { class: "grid2" }, dagCard, stats); +} + +function renderModel(state, name) { + const { manifest } = state; + const byModel = state.byModel; + const m = byModel.get(name); + + if (!m) { + return el("div", { class: "card" }, el("h2", {}, "Model not found"), el("p", { class: "empty" }, name)); + } + + const deps = (m.deps || []).map(d => el("a", { href: `#/model/${escapeHashPart(d)}` }, d)); + const usedBy = (m.used_by || []).map(u => el("a", { href: `#/model/${escapeHashPart(u)}` }, u)); + + const sourcesUsed = (m.sources_used || []).map(s => + el("a", { href: `#/source/${escapeHashPart(s.source_name)}/${escapeHashPart(s.table_name)}` }, `${s.source_name}.${s.table_name}`) + ); + + const head = el("div", { class: "card" }, + el("div", { class: "grid2" }, + el("div", {}, + el("h2", {}, m.name), + el("p", { class: "empty" }, m.relation ? `Relation: ${m.relation}` : "") + ), + el("div", {}, + el("button", { + class: "btn", + onclick: async () => { try { await navigator.clipboard.writeText(m.path || ""); } catch {} } + }, "Copy path") + ) + ), + el("div", { class: "kv" }, + el("div", { class: "k" }, "Kind"), el("div", {}, m.kind), + el("div", { class: "k" }, "Materialized"), el("div", {}, m.materialized || "—"), + el("div", { class: "k" }, "Path"), el("div", {}, el("code", {}, m.path || "—")), + el("div", { class: "k" }, "Deps"), el("div", {}, deps.length ? joinInline(deps) : el("span", { class: "empty" }, "—")), + el("div", { class: "k" }, "Used by"), el("div", {}, usedBy.length ? joinInline(usedBy) : el("span", { class: "empty" }, "—")), + el("div", { class: "k" }, "Sources"), el("div", {}, sourcesUsed.length ? joinInline(sourcesUsed) : el("span", { class: "empty" }, "—")), + ) + ); + + const desc = m.description_html + ? el("div", { class: "card" }, el("h2", {}, "Description"), el("div", { class: "desc", html: m.description_html })) + : null; + + const cols = (m.columns || []); + const colsCard = cols.length + ? el("div", { class: "card" }, + el("h2", {}, "Columns"), + el("table", { class: "table" }, + el("thead", {}, el("tr", {}, + el("th", {}, "Name"), + el("th", {}, "Type"), + el("th", {}, "Nullable"), + el("th", {}, "Description"), + el("th", {}, "Lineage"), + )), + el("tbody", {}, + ...cols.map(c => el("tr", {}, + el("td", {}, el("code", {}, c.name)), + el("td", {}, el("code", {}, c.dtype || "")), + el("td", {}, c.nullable ? "true" : "false"), + el("td", { html: c.description_html || '' }), + el("td", {}, renderLineage(c.lineage || [])) + )) + ) + ) + ) + : el("div", { class: "card" }, el("h2", {}, "Columns"), el("p", { class: "empty" }, manifest.project?.with_schema ? "No columns found." : "Schema collection disabled.")); + + return el("div", { class: "grid" }, head, desc, colsCard); +} + +function renderSource(state, sourceName, tableName) { + const key = `${sourceName}.${tableName}`; + const s = state.bySource.get(key); + + if (!s) { + return el("div", { class: "card" }, el("h2", {}, "Source not found"), el("p", { class: "empty" }, key)); + } + + const consumers = (s.consumers || []).map(m => el("a", { href: `#/model/${escapeHashPart(m)}` }, m)); + + const freshness = (() => { + const warn = s.warn_after_minutes != null ? `${s.warn_after_minutes}m warn` : null; + const err = s.error_after_minutes != null ? `${s.error_after_minutes}m error` : null; + const parts = [warn, err].filter(Boolean); + return parts.length ? parts.join(" • ") : "—"; + })(); + + return el("div", { class: "grid" }, + el("div", { class: "card" }, + el("h2", {}, key), + el("div", { class: "kv" }, + el("div", { class: "k" }, "Relation"), el("div", {}, el("code", {}, s.relation || "—")), + el("div", { class: "k" }, "Loaded at field"), el("div", {}, el("code", {}, s.loaded_at_field || "—")), + el("div", { class: "k" }, "Freshness"), el("div", {}, freshness), + el("div", { class: "k" }, "Consumers"), el("div", {}, consumers.length ? joinInline(consumers) : el("span", { class: "empty" }, "—")), + ) + ), + s.description_html + ? el("div", { class: "card" }, el("h2", {}, "Description"), el("div", { class: "desc", html: s.description_html })) + : null + ); +} + +function renderMacros(state) { + const ms = state.manifest.macros || []; + return el("div", { class: "card" }, + el("h2", {}, "Macros"), + ms.length + ? el("table", { class: "table" }, + el("thead", {}, el("tr", {}, + el("th", {}, "Name"), + el("th", {}, "Kind"), + el("th", {}, "Path"), + )), + el("tbody", {}, + ...ms.map(m => el("tr", {}, + el("td", {}, el("code", {}, m.name)), + el("td", {}, m.kind), + el("td", {}, el("code", {}, m.path)), + )) + ) + ) + : el("p", { class: "empty" }, "No macros discovered.") + ); +} + +function joinInline(nodes) { + const wrap = el("span", {}); + nodes.forEach((n, i) => { + if (i) wrap.appendChild(document.createTextNode(", ")); + wrap.appendChild(n); + }); + return wrap; +} + +function renderLineage(items) { + if (!items || !items.length) return el("span", { class: "empty" }, "—"); + // items are already normalized by docs.py lineage logic: + // { from_relation, from_column, transformed } + const ul = el("ul", { style: "margin:0; padding-left:16px;" }); + for (const it of items) { + const label = `${it.from_relation}.${it.from_column}` + (it.transformed ? " (xform)" : ""); + ul.appendChild(el("li", {}, el("code", {}, label))); + } + return ul; +} + +function toastOnce({ key, title, body, actionLabel, onAction }) { + try { + if (localStorage.getItem(key) === "1") return; + localStorage.setItem(key, "1"); + } catch {} + + const node = el("div", { class: "toast" }, + el("div", {}, + el("div", { class: "toastTitle" }, title), + el("div", { class: "toastBody" }, body) + ), + el("div", { class: "toastActions" }, + actionLabel ? el("button", { class: "toastBtn", onclick: () => { try { onAction?.(); } finally { node.remove(); } } }, actionLabel) : null, + el("button", { class: "toastBtn", onclick: () => node.remove() }, "Got it") + ) + ); + + document.body.appendChild(node); + setTimeout(() => { try { node.remove(); } catch {} }, 5500); +} + +async function loadManifest() { + const res = await fetch(MANIFEST_URL, { cache: "no-store" }); + if (!res.ok) throw new Error(`Failed to load manifest: ${res.status}`); + return await res.json(); +} + +async function main() { + const app = document.getElementById("app"); + app.textContent = "Loading…"; + + const [manifest, mermaid] = await Promise.all([loadManifest(), initMermaid()]); + const state = { + manifest, + mermaid, + filter: "", + byModel: byName(manifest.models || [], (m) => m.name), + bySource: byName(manifest.sources || [], (s) => `${s.source_name}.${s.table_name}`), + }; + state.sidebarMatches = { models: 0, sources: 0 }; + + const ui = { + app: document.getElementById("app"), + sidebarHost: null, + mainHost: null, + paletteOverlay: null, + paletteInput: null, + paletteList: null, + }; + state.ui = ui; + + // Mount shell once + const shell = el("div", { class: "shell" }, + (ui.sidebarHost = el("div")), + (ui.mainHost = el("div", { class: "main" })) + ); + ui.app.replaceChildren(shell); + + const projKey = (manifest.project?.name || "fft").toLowerCase().replace(/\s+/g, "_"); + toastOnce({ + key: `fft_docs_search_toast_seen:${projKey}`, + title: "Quick search", + body: "Press / (or Ctrl+K) to search models, sources, and columns.", + actionLabel: "Open search", + onAction: () => openPalette(""), + }); + + // Build a flat searchable index: models, sources, columns + const searchIndex = []; + + for (const m of (manifest.models || [])) { + const descTxt = (m.description_text != null && m.description_text !== "") + ? m.description_text + : stripHtml(m.description_html); + + const baseHay = [ + `model ${m.name}`, + m.relation || "", + descTxt || "", + m.path || "", + m.kind || "", + m.materialized || "", + ].join(" | "); + + searchIndex.push({ + kind: "model", + title: m.name, + subtitle: m.relation || (m.path || ""), + route: `#/model/${escapeHashPart(m.name)}`, + haystack: baseHay, + }); + + // Columns as their own results (so you can jump directly) + for (const c of (m.columns || [])) { + const cDesc = (c.description_text != null && c.description_text !== "") + ? c.description_text + : stripHtml(c.description_html); + + const colHay = [ + `column ${m.name}.${c.name}`, + c.name, + c.dtype || "", + cDesc || "", + m.name, + m.relation || "", + ].join(" | "); + + searchIndex.push({ + kind: "column", + title: `${m.name}.${c.name}`, + subtitle: `${m.relation || ""}${c.dtype ? " • " + c.dtype : ""}`, + route: `#/model/${escapeHashPart(m.name)}`, // navigates to model; we can later auto-scroll to column + haystack: colHay, + }); + } + } + + for (const s of (manifest.sources || [])) { + const key = `${s.source_name}.${s.table_name}`; + const descTxt = (s.description_text != null && s.description_text !== "") + ? s.description_text + : stripHtml(s.description_html); + + const hay = [ + `source ${key}`, + s.relation || "", + descTxt || "", + s.loaded_at_field || "", + (s.consumers || []).join(" "), + ].join(" | "); + + searchIndex.push({ + kind: "source", + title: key, + subtitle: s.relation || "", + route: `#/source/${escapeHashPart(s.source_name)}/${escapeHashPart(s.table_name)}`, + haystack: hay, + }); + } + + state.search = { + open: false, + query: "", + selected: 0, + results: [], + }; + + function runSearch(q) { + const query = (q || "").trim(); + if (!query) { + // show a helpful default: top models + sources (no scoring) + const defaults = []; + for (const it of searchIndex) { + if (it.kind === "model" || it.kind === "source") defaults.push({ ...it, score: 0 }); + if (defaults.length >= 30) break; + } + state.search.results = defaults; + state.search.selected = 0; + return; + } + + const scored = []; + for (const it of searchIndex) { + const score = fuzzyScore(query, it.haystack); + if (score >= 0) scored.push({ ...it, score }); + } + state.search.results = topN(scored, 80); + state.search.selected = 0; + } + + function renderPaletteResults() { + const results = state.search.results || []; + const sel = Math.max(0, Math.min(state.search.selected || 0, results.length - 1)); + + state.ui.paletteList.replaceChildren( + ...(results.length + ? results.map((r, idx) => + el("div", { + class: `result ${idx === sel ? "sel" : ""}`, + onclick: () => { + closePalette(); + location.hash = r.route; + }, + }, + el("div", { class: "resultMain" }, + el("div", { class: "resultTitle" }, r.title), + el("div", { class: "resultSub" }, `${r.kind.toUpperCase()} • ${r.subtitle || ""}`) + ), + el("div", { class: "kbd" }, "↵") + ) + ) + : [el("div", { class: "result" }, + el("div", { class: "resultMain" }, + el("div", { class: "resultTitle" }, "No results"), + el("div", { class: "resultSub" }, "Try a different query.") + ) + )] + ) + ); + } + + function buildPalette() { + if (state.ui.paletteOverlay) return; + + state.ui.paletteList = el("div", { class: "paletteList" }); + state.ui.paletteInput = el("input", { + id: "globalSearch", + class: "paletteInput", + type: "search", + placeholder: "Search models, sources, columns…", + value: state.search.query || "", + oninput: (e) => { + state.search.query = e.target.value || ""; + runSearch(state.search.query); + renderPaletteResults(); // ✅ no app rerender + }, + onkeydown: (e) => { + // Key handling while focused in the input + if (e.key === "Escape") { + e.preventDefault(); + if (state.search.query) { + state.search.query = ""; + state.ui.paletteInput.value = ""; + runSearch(""); + renderPaletteResults(); + } else { + closePalette(); + } + return; + } + if (e.key === "ArrowDown") { + e.preventDefault(); + const n = (state.search.results || []).length; + if (n) state.search.selected = (state.search.selected + 1) % n; + renderPaletteResults(); + return; + } + if (e.key === "ArrowUp") { + e.preventDefault(); + const n = (state.search.results || []).length; + if (n) state.search.selected = (state.search.selected - 1 + n) % n; + renderPaletteResults(); + return; + } + if (e.key === "Enter") { + e.preventDefault(); + const results = state.search.results || []; + const idx = Math.max(0, Math.min(state.search.selected || 0, results.length - 1)); + const hit = results[idx]; + if (hit) { + closePalette(); + location.hash = hit.route; + } + } + } + }); + + const overlay = el("div", { + class: "overlay", + onclick: (e) => { + if (e.target.classList.contains("overlay")) closePalette(); + } + }, + el("div", { class: "palette" }, + el("div", { class: "paletteHead" }, + state.ui.paletteInput, + el("div", { class: "paletteHint" }, + el("span", { class: "kbd" }, "Esc"), " close ", + el("span", { class: "kbd" }, "↑↓"), " select ", + el("span", { class: "kbd" }, "Enter"), " go" + ) + ), + state.ui.paletteList + ) + ); + + overlay.style.display = "none"; + state.ui.paletteOverlay = overlay; + document.body.appendChild(overlay); + } + + function openPalette(prefill = "") { + buildPalette(); + + state.search.open = true; + state.search.query = prefill; + state.search.selected = 0; + + state.ui.paletteOverlay.style.display = "flex"; + state.ui.paletteInput.value = state.search.query; + + runSearch(state.search.query); + renderPaletteResults(); + + // focus once, no re-render + queueMicrotask(() => { + state.ui.paletteInput.focus(); + state.ui.paletteInput.select(); + }); + } + + function closePalette() { + if (!state.ui.paletteOverlay) return; + state.search.open = false; + state.ui.paletteOverlay.style.display = "none"; + } + + // Sidebar UI handles (persistent DOM nodes) + ui.sidebar = { + root: null, + input: null, + modelsTitle: null, + sourcesTitle: null, + modelsList: null, + sourcesList: null, + macrosCount: null, + }; + + function buildSidebar() { + if (ui.sidebar.root) return; + + ui.sidebar.input = el("input", { + class: "search", + type: "search", + placeholder: "Filter sidebar… (press /)", + value: state.filter || "", + oninput: (e) => { + state.filter = e.target.value || ""; + updateSidebarLists(); + }, + onkeydown: (e) => { + if (e.key !== "Enter") return; + + const q = (state.filter || "").trim(); + const total = (state.sidebarMatches.models || 0) + (state.sidebarMatches.sources || 0); + + // Empty input => Enter opens global palette + if (!q) { + e.preventDefault(); + openPalette(""); + return; + } + + // No sidebar matches => Enter escalates to global palette (prefilled) + if (total === 0) { + e.preventDefault(); + openPalette(q); + return; + } + + // Otherwise: normal behavior (do nothing special) + }, + }); + + ui.sidebar.modelsTitle = el("h2", {}, "Models"); + ui.sidebar.sourcesTitle = el("h2", {}, "Sources"); + ui.sidebar.macrosCount = el("span", { class: "pill" }, "0"); + + ui.sidebar.modelsList = el("ul", { class: "list" }); + ui.sidebar.sourcesList = el("ul", { class: "list" }); + + ui.sidebar.root = el( + "div", + { class: "sidebar" }, + el( + "div", + { class: "brand" }, + el("h1", {}, state.manifest.project?.name || "Docs"), + el("span", { class: "badge" }, "SPA") + ), + el( + "div", + { class: "searchWrap" }, + ui.sidebar.input, + el("span", { class: "searchKbd kbd" }, "/") + ), + el( + "div", + { class: "searchTip" }, + "Tip: Press / (or Ctrl+K) to search everything (models, sources, columns)." + ), + + el("div", { class: "section" }, ui.sidebar.modelsTitle, ui.sidebar.modelsList), + el("div", { class: "section" }, ui.sidebar.sourcesTitle, ui.sidebar.sourcesList), + + el("div", { class: "section" }, + el("h2", {}, "Other"), + el("ul", { class: "list" }, + el("li", { class: "item" }, + el("a", { + href: "#/macros", + onclick: (e) => { e.preventDefault(); location.hash = "#/macros"; }, + }, el("span", {}, "Macros"), ui.sidebar.macrosCount) + ) + ) + ) + ); + + ui.sidebarHost.replaceChildren(ui.sidebar.root); + } + + function updateSidebarLists() { + const q = (state.filter || "").trim().toLowerCase(); + const models = state.manifest.models || []; + const sources = state.manifest.sources || []; + + const filteredModels = q + ? models.filter(m => + (m.name || "").toLowerCase().includes(q) || + (m.relation || "").toLowerCase().includes(q) || + (m.description_short || "").toLowerCase().includes(q) + ) + : models; + + const filteredSources = q + ? sources.filter(s => + (`${s.source_name}.${s.table_name}`).toLowerCase().includes(q) || + (s.relation || "").toLowerCase().includes(q) + ) + : sources; + + state.sidebarMatches.models = filteredModels.length; + state.sidebarMatches.sources = filteredSources.length; + + ui.sidebar.modelsTitle.textContent = `Models (${filteredModels.length})`; + ui.sidebar.sourcesTitle.textContent = `Sources (${filteredSources.length})`; + ui.sidebar.macrosCount.textContent = String((state.manifest.macros || []).length); + + ui.sidebar.modelsList.replaceChildren( + ...filteredModels.map(m => + el("li", { class: "item" }, + el("a", { + href: `#/model/${escapeHashPart(m.name)}`, + onclick: (e) => { e.preventDefault(); location.hash = `#/model/${escapeHashPart(m.name)}`; }, + title: m.description_short || m.name, + }, + el("span", {}, m.name), + pillForKind(m.kind === "python" ? "python" : "sql") + ) + ) + ) + ); + + ui.sidebar.sourcesList.replaceChildren( + ...filteredSources.map(s => { + const key = `${s.source_name}.${s.table_name}`; + return el("li", { class: "item" }, + el("a", { + href: `#/source/${escapeHashPart(s.source_name)}/${escapeHashPart(s.table_name)}`, + onclick: (e) => { e.preventDefault(); location.hash = `#/source/${escapeHashPart(s.source_name)}/${escapeHashPart(s.table_name)}`; }, + title: s.relation || key, + }, + el("span", {}, key), + el("span", { class: "pill" }, (s.consumers || []).length ? `${s.consumers.length}` : "–") + ) + ); + }) + ); + } + + function updateMain() { + const route = parseHash(); + let view; + if (route.route === "model") view = renderModel(state, route.name); + else if (route.route === "source") view = renderSource(state, route.source, route.table); + else if (route.route === "macros") view = renderMacros(state); + else view = renderHome(state); + + state.ui.mainHost.replaceChildren(view); + + // If home view contains mermaid, render it now (same as before) + if (route.route === "home") { + queueMicrotask(async () => { + const target = document.getElementById("mermaidTarget"); + if (!target) return; + const dagSrc = state.manifest.dag?.mermaid || ""; + if (!state.mermaid) { + target.textContent = dagSrc; + return; + } + target.innerHTML = `
${dagSrc}
`; + try { await state.mermaid.run({ querySelector: "#mermaidTarget .mermaid" }); } catch {} + }); + } + } + + window.addEventListener("keydown", (e) => { + const tag = e.target?.tagName?.toLowerCase(); + const typing = tag === "input" || tag === "textarea" || e.target?.isContentEditable; + + // Ctrl+K (or Cmd+K on mac) opens palette + const ctrlK = (e.key.toLowerCase() === "k") && (e.ctrlKey || e.metaKey); + + if (!typing && (e.key === "/" || ctrlK)) { + e.preventDefault(); + openPalette(""); + } + }); + + runSearch(""); + buildPalette(); + + window.addEventListener("hashchange", () => { + closePalette(); // optional: close palette on navigation + updateMain(); + }); + + buildSidebar(); + updateSidebarLists(); + + updateMain(); + +} + +main().catch((e) => { + const app = document.getElementById("app"); + app.replaceChildren( + el("div", { class: "main" }, + el("div", { class: "card" }, + el("h2", {}, "Docs failed to load"), + el("p", { class: "empty" }, String(e?.message || e)), + el("p", { class: "empty" }, `Manifest URL: ${MANIFEST_URL}`) + ) + ) + ); +}); diff --git a/src/fastflowtransform/templates/index.html.j2 b/src/fastflowtransform/templates/index.html.j2 index 512142b..30216e8 100644 --- a/src/fastflowtransform/templates/index.html.j2 +++ b/src/fastflowtransform/templates/index.html.j2 @@ -1,277 +1,17 @@ + - FastFlowTransform - DAG & Mini Docs - - - - - -
-
-

FastFlowTransform - DAG & Mini Docs

-
Mermaid renders automatically (light/dark)
-
-
- - -
-
- -
-
-

DAG

-
- SQL - Python - - Materialization: - {% for key, item in materialization_legend.items() %} - {{ item.label }} - {% endfor %} -
-
{{ mermaid_src | safe }}
-
- - - -
-

Macros

- {% if macros and macros|length %} - - - - - - - - - - {% for m in macros %} - - - - - - {% endfor %} - -
NameTypePath
{{ m.name }} - - {{ m.kind }} - - {{ m.path }}
- {% else %} -

No macros found.

- {% endif %} -
- -
-

Sources

- {% if sources and sources|length %} - - - - - - - - - - - {% for s in sources %} - - - - - - - {% endfor %} - -
NameRelationFreshnessConsumers
- {{ s.source_name }}.{{ s.table_name }} - {% if s.description_html %} - {{ s.description_html|striptags|trim|truncate(160, True, '…') }} - {% endif %} - {{ s.relation }} - {% if s.loaded_at_field %} -
field: {{ s.loaded_at_field }}
- {% endif %} -
- {% if s.warn_after_minutes is not none %} - warn {{ s.warn_after_minutes }}m - {% else %} - no warn - {% endif %} - {% if s.error_after_minutes is not none %} - error {{ s.error_after_minutes }}m - {% else %} - no error - {% endif %} -
-
- {% if s.consumers and s.consumers|length %} - {{ s.consumers|length }} model{{ 's' if s.consumers|length > 1 else '' }} - {% else %} - - {% endif %} -
- {% else %} -

No sources declared.

- {% endif %} -
-
- - +
+ diff --git a/src/fastflowtransform/templates/model.html.j2 b/src/fastflowtransform/templates/model.html.j2 deleted file mode 100644 index 61e230e..0000000 --- a/src/fastflowtransform/templates/model.html.j2 +++ /dev/null @@ -1,235 +0,0 @@ - - - - - - {{ m.name }} – FastFlowTransform - - - -

← Back to overview

- -
-
-

- {{ m.name }} - {{ m.materialized }} -

-
Model Detail • FastFlowTransform
-
- {{ m.kind }} -
- -
- {% if m.description_html %} -
-

Description

-
{{ m.description_html | safe }}
-
- {% endif %} -
-

Metadata

-
-
Materialized
-
{{ m.materialized }}
- -
Relation
-
{{ m.relation or m.name }}
- -
Path
-
- {{ m.path }} - -
- -
Dependencies
-
- {% if m.deps and m.deps|length > 0 %} -
    - {% for d in m.deps %} -
  • {{ d }}
  • - {% endfor %} -
- {% else %} - - {% endif %} -
- -
Sources
-
- {% if sources_used and sources_used|length %} - - {% else %} - No source() refs - {% endif %} -
- - {% if used_by is defined and used_by %} -
Referenced by
-
-
    - {% for u in used_by %} -
  • {{ u }}
  • - {% endfor %} -
-
- {% endif %} -
-
- - {% if m.description_html %} -
-

Description

- -
{{ m.description_html | safe }}
-
- {% endif %} - - {% if cols %} -
-

Columns

- - - - - - - - - - - - - - - - - - - {% for col in cols %} - - - - - - - - {% endfor %} - -
NameTypeNullableDescriptionLineage
{{ col.name }}{{ col.dtype }} - {% if col.nullable %} - yes - {% else %} - no - {% endif %} - - {% if col.description_html %} - {{ col.description_html | safe }} - {% else %} - — - {% endif %} - - {% if col.lineage and col.lineage|length %} - {% for src in col.lineage %} - {{ src.from_relation }}.{{ src.from_column }} - {% if src.transformed %} - transformed - {% else %} - direct - {% endif %} - {% if not loop.last %}, {% endif %} - {% endfor %} - {% else %} - unknown - {% endif %} -
-
- {% endif %} - - -
- - - - diff --git a/src/fastflowtransform/templates/source.html.j2 b/src/fastflowtransform/templates/source.html.j2 deleted file mode 100644 index 1313c9b..0000000 --- a/src/fastflowtransform/templates/source.html.j2 +++ /dev/null @@ -1,163 +0,0 @@ - - - - - - {{ source.source_name }}.{{ source.table_name }} – Source – FastFlowTransform - - - -

← Back to overview

- -
-
-

- {{ source.source_name }}.{{ source.table_name }} -

-
Source Detail • FastFlowTransform
-
- source -
- -
- {% if source.description_html %} -
-

Description

-
{{ source.description_html | safe }}
-
- {% endif %} - -
-

Metadata

-
-
Source
-
{{ source.source_name }}
- -
Table
-
{{ source.table_name }}
- -
Relation
-
{{ source.relation }}
- -
Loaded at field
-
- {% if source.loaded_at_field %} - {{ source.loaded_at_field }} - {% else %} - - {% endif %} -
- -
Warn after
-
- {% if source.warn_after_minutes is not none %} - {{ source.warn_after_minutes }} min - {% else %} - not set - {% endif %} -
- -
Error after
-
- {% if source.error_after_minutes is not none %} - {{ source.error_after_minutes }} min - {% else %} - not set - {% endif %} -
- -
Consumers
-
- {% if source.consumers and source.consumers|length %} -
    - {% for m in source.consumers %} -
  • {{ m }}
  • - {% endfor %} -
- {% else %} - No models referencing this source (via source('{{ source.source_name }}','{{ source.table_name }}')) - {% endif %} -
-
-
- - -
- - diff --git a/tests/unit/docs/test_docs_materialization_badges.py b/tests/unit/docs/test_docs_materialization_badges.py deleted file mode 100644 index 7566daf..0000000 --- a/tests/unit/docs/test_docs_materialization_badges.py +++ /dev/null @@ -1,51 +0,0 @@ -# tests/common/test_docs_materialization_badges.py -from pathlib import Path - -from fastflowtransform.core import Node -from fastflowtransform.docs import render_site - - -def _mk_node(tmp: Path, name: str, kind: str, materialized: str): - p = tmp / f"{name}.sql" - p.write_text("-- stub\n", encoding="utf-8") - n = Node(name=name, kind=kind, path=p, deps=[]) - # attach meta dynamically (docs.py reads getattr(n, "meta", {})) - n.meta = {"materialized": materialized} - return n - - -def test_docs_show_badges_in_index_and_detail(tmp_path: Path): - # arrange: 3 nodes with different materializations - a = _mk_node(tmp_path, "a_table", "sql", "table") - b = _mk_node(tmp_path, "b_view", "sql", "view") - c = _mk_node(tmp_path, "c_ephemeral", "sql", "ephemeral") - nodes = {a.name: a, b.name: b, c.name: c} - - out = tmp_path / "site" - render_site(out, nodes, executor=None) - - index = (out / "index.html").read_text(encoding="utf-8") - # legend badges present - assert "badge-table" in index - assert "badge-view" in index - assert "badge-ephemeral" in index - assert "Materialization" in index - - # per-model pages have correct badge class - assert "badge-table" in (out / "a_table.html").read_text(encoding="utf-8") - assert "badge-view" in (out / "b_view.html").read_text(encoding="utf-8") - assert "badge-ephemeral" in (out / "c_ephemeral.html").read_text(encoding="utf-8") - - -def test_docs_default_materialization_is_table(tmp_path: Path): - # node without meta → should default to 'table' in docs - p = tmp_path / "m.sql" - p.write_text("-- stub\n", encoding="utf-8") - n = Node(name="m", kind="sql", path=p, deps=[]) - nodes = {"m": n} - - out = tmp_path / "site" - render_site(out, nodes, executor=None) - - index = (out / "index.html").read_text(encoding="utf-8") - assert "badge-table" in index diff --git a/tests/unit/docs/test_docs_unit.py b/tests/unit/docs/test_docs_unit.py index 5424266..2c41f42 100644 --- a/tests/unit/docs/test_docs_unit.py +++ b/tests/unit/docs/test_docs_unit.py @@ -4,7 +4,7 @@ import textwrap from pathlib import Path from types import SimpleNamespace -from typing import Any, cast +from typing import Any import pytest from jinja2 import TemplateNotFound @@ -136,9 +136,7 @@ def test_scan_source_refs(tmp_path: Path): sql_path = models_dir / "model_a.sql" sql_path.write_text("select * from {{ source('crm', 'customers') }}", encoding="utf-8") - nodes = { - "model_a": SimpleNamespace(name="model_a", kind="sql", path=sql_path, deps=[], meta={}) - } + nodes = {"model_a": Node(name="model_a", kind="sql", path=sql_path, deps=[], meta={})} by_source, by_model = docs_mod._scan_source_refs(nodes) @@ -243,59 +241,6 @@ def test_apply_descriptions_to_models_applies_short_and_column_desc(): assert cols_by_table["db.sc.m1"][1].description_html == "

Col 2

" -# --------------------------------------------------------------------------- -# render_site (with patched jinja + registry) -# --------------------------------------------------------------------------- - - -@pytest.mark.unit -def test_render_site_writes_index_and_model_pages(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): - fake_nodes_raw = { - "model_a": SimpleNamespace( - name="model_a", - kind="sql", - path=tmp_path / "models" / "model_a.sql", - deps=["model_b"], - meta={"materialized": "view"}, - ), - "model_b": SimpleNamespace( - name="model_b", - kind="python", - path=tmp_path / "models" / "model_b.py", - deps=[], - meta={}, - ), - } - - monkeypatch.setattr( - docs_mod, - "REGISTRY", - SimpleNamespace( - nodes=fake_nodes_raw, - macros={}, - get_project_dir=lambda: tmp_path, - ), - raising=True, - ) - - monkeypatch.setattr(docs_mod, "_init_jinja", lambda: _FakeEnv(), raising=True) - fake_nodes = cast(dict[str, Node], fake_nodes_raw) - - docs_mod.render_site(tmp_path, fake_nodes, executor=None, with_schema=False) - - index_file = tmp_path / "index.html" - assert index_file.exists() - assert "INDEX" in index_file.read_text(encoding="utf-8") - - model_a_file = tmp_path / "model_a.html" - model_b_file = tmp_path / "model_b.html" - assert model_a_file.exists() - assert model_b_file.exists() - - assert "MODEL model_a" in model_a_file.read_text(encoding="utf-8") - assert "MODEL model_b" in model_b_file.read_text(encoding="utf-8") - - # --------------------------------------------------------------------------- # _collect_columns engine stubs # --------------------------------------------------------------------------- From e5291f3815841810dad9d55fce54a8a240fbf5fa Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Fri, 19 Dec 2025 16:38:14 +0100 Subject: [PATCH 09/10] Added collapsible sidebar sections to docs --- src/fastflowtransform/templates/assets/spa.js | 292 +++++++++--------- 1 file changed, 139 insertions(+), 153 deletions(-) diff --git a/src/fastflowtransform/templates/assets/spa.js b/src/fastflowtransform/templates/assets/spa.js index 6ca2df2..9c280be 100644 --- a/src/fastflowtransform/templates/assets/spa.js +++ b/src/fastflowtransform/templates/assets/spa.js @@ -15,6 +15,21 @@ function el(tag, attrs = {}, ...children) { return n; } +function safeGet(key) { + try { return localStorage.getItem(key); } catch { return null; } +} +function safeSet(key, value) { + try { localStorage.setItem(key, value); } catch {} +} +function safeGetJSON(key, fallback) { + const raw = safeGet(key); + if (!raw) return fallback; + try { return JSON.parse(raw); } catch { return fallback; } +} +function safeSetJSON(key, obj) { + safeSet(key, JSON.stringify(obj)); +} + function stripHtml(html) { if (!html) return ""; const div = document.createElement("div"); @@ -99,114 +114,6 @@ function pillForKind(kind) { return el("span", { class: `pill ${kind}` }, kind); } -function renderSidebar(state, onNavigate) { - const { manifest, filter } = state; - const models = manifest.models || []; - const sources = manifest.sources || []; - - const q = (filter || "").trim().toLowerCase(); - - const filteredModels = q - ? models.filter(m => - (m.name || "").toLowerCase().includes(q) || - (m.relation || "").toLowerCase().includes(q) || - (m.description_short || "").toLowerCase().includes(q) - ) - : models; - - const filteredSources = q - ? sources.filter(s => - (`${s.source_name}.${s.table_name}`).toLowerCase().includes(q) || - (s.relation || "").toLowerCase().includes(q) - ) - : sources; - - return el( - "div", - { class: "sidebar" }, - el( - "div", - { class: "brand" }, - el("h1", {}, manifest.project?.name || "Docs"), - el("span", { class: "badge", title: `Generated: ${manifest.project?.generated_at || ""}` }, "SPA") - ), - el("div", { class: "searchWrap" }, - el("input", { - class: "search", - type: "search", - placeholder: "Filter sidebar… (press /)", - value: filter || "", - oninput: (e) => onNavigate({ type: "filter", value: e.target.value }), - }), - el("span", { class: "searchKbd kbd" }, "/") - ), - el("div", { class: "searchTip" }, "Tip: Press / (or Ctrl+K) to search everything (models, sources, columns)."), - el( - "div", - { class: "section" }, - el("h2", {}, `Models (${filteredModels.length})`), - el( - "ul", - { class: "list" }, - ...filteredModels.map(m => - el( - "li", - { class: "item" }, - el( - "a", - { - href: `#/model/${escapeHashPart(m.name)}`, - onclick: (e) => { e.preventDefault(); location.hash = `#/model/${escapeHashPart(m.name)}`; }, - title: m.description_short || m.name, - }, - el("span", {}, m.name), - pillForKind(m.kind === "python" ? "python" : "sql") - ) - ) - ) - ) - ), - el( - "div", - { class: "section" }, - el("h2", {}, `Sources (${filteredSources.length})`), - el( - "ul", - { class: "list" }, - ...filteredSources.map(s => { - const key = `${s.source_name}.${s.table_name}`; - return el( - "li", - { class: "item" }, - el( - "a", - { - href: `#/source/${escapeHashPart(s.source_name)}/${escapeHashPart(s.table_name)}`, - onclick: (e) => { e.preventDefault(); location.hash = `#/source/${escapeHashPart(s.source_name)}/${escapeHashPart(s.table_name)}`; }, - title: s.relation || key, - }, - el("span", {}, key), - el("span", { class: "pill" }, (s.consumers || []).length ? `${s.consumers.length}` : "–") - ) - ); - }) - ) - ), - el( - "div", - { class: "section" }, - el("h2", {}, "Other"), - el("ul", { class: "list" }, - el("li", { class: "item" }, - el("a", { - href: "#/macros", - onclick: (e) => { e.preventDefault(); location.hash = "#/macros"; }, - }, el("span", {}, "Macros"), el("span", { class: "pill" }, String((manifest.macros || []).length)))) - ) - ) - ); -} - function renderHome(state) { const { manifest, mermaid } = state; const dagSrc = manifest.dag?.mermaid || ""; @@ -465,7 +372,32 @@ async function main() { ); ui.app.replaceChildren(shell); - const projKey = (manifest.project?.name || "fft").toLowerCase().replace(/\s+/g, "_"); + const projKey = (manifest.project?.name || "fft") + .toLowerCase() + .replace(/\s+/g, "_") + .replace(/[^a-z0-9_]+/g, ""); + + const STORE = { + filter: `fft_docs:${projKey}:sidebar_filter`, + collapsed: `fft_docs:${projKey}:sidebar_collapsed`, + lastHash: `fft_docs:${projKey}:last_hash`, + paletteQuery: `fft_docs:${projKey}:palette_query`, + }; + + // Persisted UI state + state.filter = safeGet(STORE.filter) ?? ""; + state.sidebarCollapsed = safeGetJSON(STORE.collapsed, { + models: false, + sources: false, + macros: false, + }); + + // Restore last route only if user is on the default route + const last = safeGet(STORE.lastHash); + if ((!location.hash || location.hash === "#/" || location.hash === "#") && last) { + location.hash = last; + } + toastOnce({ key: `fft_docs_search_toast_seen:${projKey}`, title: "Quick search", @@ -620,6 +552,7 @@ async function main() { value: state.search.query || "", oninput: (e) => { state.search.query = e.target.value || ""; + safeSet(STORE.paletteQuery, state.search.query); runSearch(state.search.query); renderPaletteResults(); // ✅ no app rerender }, @@ -691,8 +624,11 @@ async function main() { function openPalette(prefill = "") { buildPalette(); + const remembered = safeGet(STORE.paletteQuery) ?? ""; + const initial = prefill != null && prefill !== "" ? prefill : remembered; + state.search.open = true; - state.search.query = prefill; + state.search.query = initial; state.search.selected = 0; state.ui.paletteOverlay.style.display = "flex"; @@ -724,6 +660,28 @@ async function main() { sourcesList: null, macrosCount: null, }; + ui.sidebar.macrosList = null; + ui.sidebar.modelsSection = null; + ui.sidebar.sourcesSection = null; + ui.sidebar.macrosSection = null; + + function sectionHeader(titleNode, key, labelWhenOpen) { + const btn = el("button", { + class: "btn", + style: "width:100%; display:flex; justify-content:space-between; align-items:center; padding:8px 10px;", + onclick: () => { + state.sidebarCollapsed[key] = !state.sidebarCollapsed[key]; + safeSetJSON(STORE.collapsed, state.sidebarCollapsed); + applySidebarCollapse(); // show/hide without rebuilding + } + }, + el("span", {}, labelWhenOpen), + el("span", { class: "kbd" }, state.sidebarCollapsed[key] ? "+" : "–") + ); + // store reference for label updates + titleNode.replaceChildren(btn); + return btn; + } function buildSidebar() { if (ui.sidebar.root) return; @@ -735,6 +693,7 @@ async function main() { value: state.filter || "", oninput: (e) => { state.filter = e.target.value || ""; + safeSet(STORE.filter, state.filter); updateSidebarLists(); }, onkeydown: (e) => { @@ -761,51 +720,53 @@ async function main() { }, }); - ui.sidebar.modelsTitle = el("h2", {}, "Models"); - ui.sidebar.sourcesTitle = el("h2", {}, "Sources"); - ui.sidebar.macrosCount = el("span", { class: "pill" }, "0"); + ui.sidebar.modelsTitle = el("div"); + ui.sidebar.sourcesTitle = el("div"); + ui.sidebar.macrosTitle = el("div"); - ui.sidebar.modelsList = el("ul", { class: "list" }); - ui.sidebar.sourcesList = el("ul", { class: "list" }); + ui.sidebar.modelsList = el("ul", { class: "list" }); + ui.sidebar.sourcesList = el("ul", { class: "list" }); + ui.sidebar.macrosList = el("ul", { class: "list" }); - ui.sidebar.root = el( - "div", - { class: "sidebar" }, - el( - "div", - { class: "brand" }, - el("h1", {}, state.manifest.project?.name || "Docs"), - el("span", { class: "badge" }, "SPA") - ), - el( - "div", - { class: "searchWrap" }, - ui.sidebar.input, - el("span", { class: "searchKbd kbd" }, "/") - ), - el( + ui.sidebar.modelsSection = el("div", { class: "section" }, ui.sidebar.modelsTitle, ui.sidebar.modelsList); + ui.sidebar.sourcesSection = el("div", { class: "section" }, ui.sidebar.sourcesTitle, ui.sidebar.sourcesList); + ui.sidebar.macrosSection = el("div", { class: "section" }, ui.sidebar.macrosTitle, ui.sidebar.macrosList); + + ui.sidebar.root = el( "div", - { class: "searchTip" }, - "Tip: Press / (or Ctrl+K) to search everything (models, sources, columns)." - ), + { class: "sidebar" }, + el( + "div", + { class: "brand" }, + el("h1", {}, state.manifest.project?.name || "Docs"), + el("span", { class: "badge", title: `Generated: ${state.manifest.project?.generated_at || ""}` }, "SPA") + ), + el( + "div", + { class: "searchWrap" }, + ui.sidebar.input, + el("span", { class: "searchKbd kbd" }, "/") + ), + el("div", { class: "searchTip" }, "Tip: Press / (or Ctrl+K) to search everything (models, sources, columns)."), + ui.sidebar.modelsSection, + ui.sidebar.sourcesSection, + ui.sidebar.macrosSection, + ); - el("div", { class: "section" }, ui.sidebar.modelsTitle, ui.sidebar.modelsList), - el("div", { class: "section" }, ui.sidebar.sourcesTitle, ui.sidebar.sourcesList), - - el("div", { class: "section" }, - el("h2", {}, "Other"), - el("ul", { class: "list" }, - el("li", { class: "item" }, - el("a", { - href: "#/macros", - onclick: (e) => { e.preventDefault(); location.hash = "#/macros"; }, - }, el("span", {}, "Macros"), ui.sidebar.macrosCount) - ) - ) - ) - ); + ui.sidebarHost.replaceChildren(ui.sidebar.root); - ui.sidebarHost.replaceChildren(ui.sidebar.root); + // Turn titles into toggle headers + sectionHeader(ui.sidebar.modelsTitle, "models", "Models"); + sectionHeader(ui.sidebar.sourcesTitle, "sources", "Sources"); + sectionHeader(ui.sidebar.macrosTitle, "macros", "Macros"); + + } + + function applySidebarCollapse() { + const c = state.sidebarCollapsed || {}; + ui.sidebar.modelsList.style.display = c.models ? "none" : ""; + ui.sidebar.sourcesList.style.display = c.sources ? "none" : ""; + ui.sidebar.macrosList.style.display = c.macros ? "none" : ""; } function updateSidebarLists() { @@ -833,7 +794,6 @@ async function main() { ui.sidebar.modelsTitle.textContent = `Models (${filteredModels.length})`; ui.sidebar.sourcesTitle.textContent = `Sources (${filteredSources.length})`; - ui.sidebar.macrosCount.textContent = String((state.manifest.macros || []).length); ui.sidebar.modelsList.replaceChildren( ...filteredModels.map(m => @@ -865,6 +825,30 @@ async function main() { ); }) ); + + const macros = state.manifest.macros || []; + + sectionHeader(ui.sidebar.modelsTitle, "models", `Models (${filteredModels.length})`); + sectionHeader(ui.sidebar.sourcesTitle, "sources", `Sources (${filteredSources.length})`); + sectionHeader(ui.sidebar.macrosTitle, "macros", `Macros (${macros.length})`); + + ui.sidebar.macrosList.replaceChildren( + ...macros.map(m => + el("li", { class: "item" }, + el("a", { + href: "#/macros", + onclick: (e) => { e.preventDefault(); location.hash = "#/macros"; }, + title: m.path || m.name, + }, + el("span", {}, m.name), + el("span", { class: "pill" }, m.kind) + ) + ) + ) + ); + + applySidebarCollapse(); + } function updateMain() { @@ -907,16 +891,18 @@ async function main() { }); runSearch(""); - buildPalette(); window.addEventListener("hashchange", () => { + safeSet(STORE.lastHash, location.hash || "#/"); closePalette(); // optional: close palette on navigation updateMain(); }); + safeSet(STORE.lastHash, location.hash || "#/"); + buildSidebar(); updateSidebarLists(); - + buildPalette(); // palette exists but hidden updateMain(); } From 8fdad700ab648258d35e6723960bf7f3c24944c6 Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Sat, 20 Dec 2025 20:49:17 +0100 Subject: [PATCH 10/10] Added model page tabs to docs --- src/fastflowtransform/docs.py | 129 +------ src/fastflowtransform/executors/base.py | 17 + .../executors/bigquery/base.py | 32 +- .../executors/databricks_spark.py | 56 ++- src/fastflowtransform/executors/duckdb.py | 38 +- src/fastflowtransform/executors/postgres.py | 26 +- .../executors/snowflake_snowpark.py | 36 +- .../templates/assets/spa.css | 25 ++ src/fastflowtransform/templates/assets/spa.js | 329 +++++++++++++++--- tests/unit/docs/test_docs_unit.py | 187 +--------- 10 files changed, 518 insertions(+), 357 deletions(-) diff --git a/src/fastflowtransform/docs.py b/src/fastflowtransform/docs.py index 75436c6..d1e6147 100644 --- a/src/fastflowtransform/docs.py +++ b/src/fastflowtransform/docs.py @@ -12,10 +12,10 @@ import yaml from jinja2 import Environment, FileSystemLoader, TemplateNotFound, select_autoescape from markupsafe import Markup -from sqlalchemy import text from fastflowtransform.core import REGISTRY, Node, relation_for from fastflowtransform.dag import mermaid as dag_mermaid +from fastflowtransform.executors.base import ColumnInfo from fastflowtransform.lineage import ( infer_py_lineage, infer_sql_lineage, @@ -57,22 +57,18 @@ def _safe_filename(name: str) -> str: def _collect_columns(executor: Any) -> dict[str, list[ColumnInfo]]: """ - Best-effort schema discovery for supported engines. + Best-effort schema discovery delegated to the executor. Returns an empty mapping if unsupported or on errors. """ + fn = getattr(executor, "collect_docs_columns", None) + if not callable(fn): + return {} try: - if hasattr(executor, "spark"): - return _columns_spark(executor.spark) - if hasattr(executor, "con"): # DuckDB - return _columns_duckdb(executor.con) - if hasattr(executor, "engine"): # Postgres - return _columns_postgres(executor.engine) - if hasattr(executor, "session"): - return _columns_snowflake(executor.session) + res = fn() + return res if isinstance(res, dict) else {} except Exception: # Fail-open: no schema info, UI will simply hide the columns card. return {} - return {} def _read_project_yaml_docs(project_dir: Path) -> dict[str, Any]: @@ -759,117 +755,6 @@ def render_site( _render_source_pages(env, out_dir, sources) -@dataclass -class ColumnInfo: - name: str - dtype: str - nullable: bool - description_html: str | None = None - lineage: list[dict[str, Any]] | None = None - - -def _columns_duckdb(con: Any) -> dict[str, list[ColumnInfo]]: - rows = con.execute(""" - select table_name, column_name, data_type, is_nullable - from information_schema.columns - where table_schema in ('main','temp') - order by table_name, ordinal_position - """).fetchall() - out: dict[str, list[ColumnInfo]] = {} - for t, c, dt, null in rows: - out.setdefault(t, []).append(ColumnInfo(c, str(dt), null in (True, "YES", "Yes"))) - return out - - -def _columns_postgres(engine: Any) -> dict[str, list[ColumnInfo]]: - with engine.begin() as conn: - rows = conn.execute( - text(""" - select table_name, column_name, data_type, is_nullable - from information_schema.columns - where table_schema = current_schema() - order by table_name, ordinal_position - """) - ).fetchall() - out: dict[str, list[ColumnInfo]] = {} - for t, c, dt, null in rows: - out.setdefault(t, []).append(ColumnInfo(c, str(dt), null == "YES")) - return out - - -def _columns_snowflake(session: Any) -> dict[str, list[ColumnInfo]]: - rows = session.sql(""" - select table_name, column_name, data_type, is_nullable - from information_schema.columns - where table_schema = current_schema() - order by table_name, ordinal_position - """).collect() - out: dict[str, list[ColumnInfo]] = {} - for r in rows: - t = r["TABLE_NAME"] - c = r["COLUMN_NAME"] - dt = r["DATA_TYPE"] - null = r["IS_NULLABLE"] - out.setdefault(t, []).append(ColumnInfo(c, str(dt), null == "YES")) - return out - - -def _columns_spark(spark: Any) -> dict[str, list[ColumnInfo]]: - """ - Collect column metadata from a SparkSession (Databricks / Spark SQL). - Uses catalog.listTables/listColumns, available on vanilla Spark 3+. - """ - try: - tables = list(spark.catalog.listTables()) - except Exception: - return {} - - out: dict[str, list[ColumnInfo]] = {} - seen: set[tuple[str | None, str]] = set() - - def _list_columns(table_name: str, database: str | None) -> list[Any]: - ident = table_name if not database else f"{database}.{table_name}" - try: - return list(spark.catalog.listColumns(ident)) - except TypeError: - return list(spark.catalog.listColumns(table_name, database)) - - for tbl in tables: - database = getattr(tbl, "database", None) - raw_name = getattr(tbl, "name", None) - if not raw_name: - continue - table_name = str(raw_name) - key = (database, table_name) - if key in seen: - continue - seen.add(key) - try: - cols = _list_columns(table_name, database) - except Exception: - continue - if not cols: - continue - - keys: set[str] = {table_name} - catalog = getattr(tbl, "catalog", None) - if database: - keys.add(f"{database}.{table_name}") - if database and catalog: - keys.add(f"{catalog}.{database}.{table_name}") - for c in cols: - nullable = bool(getattr(c, "nullable", False)) - dtype = str(getattr(c, "dataType", "")) - col_name = getattr(c, "name", None) - if not col_name: - continue - info = ColumnInfo(str(col_name), dtype, nullable) - for k in keys: - out.setdefault(k, []).append(info) - - return out - - def read_docs_metadata(project_dir: Path) -> dict[str, Any]: """ Merge YAML + Markdown descriptions with priority: Markdown > YAML. diff --git a/src/fastflowtransform/executors/base.py b/src/fastflowtransform/executors/base.py index ee55e66..3a147c6 100644 --- a/src/fastflowtransform/executors/base.py +++ b/src/fastflowtransform/executors/base.py @@ -7,6 +7,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Mapping from contextlib import suppress +from dataclasses import dataclass from pathlib import Path from typing import Any, TypeVar, cast @@ -93,6 +94,15 @@ def _scalar(executor: BaseExecutor, sql: Any) -> Any: TFrame = TypeVar("TFrame") +@dataclass +class ColumnInfo: + name: str + dtype: str + nullable: bool + description_html: str | None = None + lineage: list[dict[str, Any]] | None = None + + class _ThisProxy: """ Jinja compatible proxy for {{ this }}: @@ -1194,6 +1204,13 @@ def normalize_physical_type(self, t: str | None) -> str: """ return (t or "").strip().lower() + def collect_docs_columns(self) -> dict[str, list[ColumnInfo]]: + """ + Return column metadata for docs rendering keyed by physical relation name. + Engines can override; default is empty mapping. + """ + return {} + # ── Seed loading hook ─────────────────────────────────────────────── def load_seed( self, table: str, df: Any, schema: str | None = None diff --git a/src/fastflowtransform/executors/bigquery/base.py b/src/fastflowtransform/executors/bigquery/base.py index 0fec466..2cbb9aa 100644 --- a/src/fastflowtransform/executors/bigquery/base.py +++ b/src/fastflowtransform/executors/bigquery/base.py @@ -7,7 +7,7 @@ from fastflowtransform.core import Node, relation_for from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable -from fastflowtransform.executors.base import BaseExecutor +from fastflowtransform.executors.base import BaseExecutor, ColumnInfo from fastflowtransform.executors.budget.runtime.bigquery import BigQueryBudgetRuntime from fastflowtransform.executors.query_stats.core import _TrackedQueryJob from fastflowtransform.executors.query_stats.runtime.bigquery import BigQueryQueryStatsRuntime @@ -494,6 +494,36 @@ def introspect_table_physical_schema(self, table: str) -> dict[str, str]: # keys are lowercased to match the DuckRuntimeContracts verify logic return {name: dtype for (name, dtype) in rows} + def collect_docs_columns(self) -> dict[str, list[ColumnInfo]]: + """ + Column metadata for docs (project+dataset scoped). + """ + sql = f""" + select table_name, column_name, data_type, is_nullable + from `{self.project}.{self.dataset}.INFORMATION_SCHEMA.COLUMNS` + order by table_name, ordinal_position + """ + try: + job = self.client.query( + sql, + job_config=bigquery.QueryJobConfig( + default_dataset=bigquery.DatasetReference(self.project, self.dataset) + ), + location=self.location, + ) + rows = list(job.result()) + except Exception: + return {} + + out: dict[str, list[ColumnInfo]] = {} + for row in rows: + table = str(row["table_name"]) + col = str(row["column_name"]) + dtype = str(row["data_type"]) + nullable = str(row["is_nullable"]).upper() == "YES" + out.setdefault(table, []).append(ColumnInfo(col, dtype, nullable)) + return out + def load_seed(self, table: str, df: Any, schema: str | None = None) -> tuple[bool, str, bool]: dataset_id = schema or self.dataset diff --git a/src/fastflowtransform/executors/databricks_spark.py b/src/fastflowtransform/executors/databricks_spark.py index e0355b4..c755deb 100644 --- a/src/fastflowtransform/executors/databricks_spark.py +++ b/src/fastflowtransform/executors/databricks_spark.py @@ -16,7 +16,7 @@ from fastflowtransform.core import REGISTRY, Node, relation_for from fastflowtransform.errors import ModelExecutionError from fastflowtransform.executors._test_utils import make_fetchable, rows_to_tuples -from fastflowtransform.executors.base import BaseExecutor +from fastflowtransform.executors.base import BaseExecutor, ColumnInfo from fastflowtransform.executors.budget.runtime.databricks_spark import ( DatabricksSparkBudgetRuntime, ) @@ -932,6 +932,60 @@ def utest_clean_target(self, relation: str) -> None: with suppress(Exception): self._execute_sql_basic(f"DROP TABLE IF EXISTS {ident}") + def collect_docs_columns(self) -> dict[str, list[ColumnInfo]]: + """ + Collect column metadata via Spark catalog for docs rendering. + """ + try: + tables = list(self.spark.catalog.listTables()) + except Exception: + return {} + + out: dict[str, list[ColumnInfo]] = {} + seen: set[tuple[str | None, str]] = set() + + def _list_columns(table_name: str, database: str | None) -> list[Any]: + ident = table_name if not database else f"{database}.{table_name}" + try: + return list(self.spark.catalog.listColumns(ident)) + except TypeError: + return list(self.spark.catalog.listColumns(table_name, database)) + + for tbl in tables: + database = getattr(tbl, "database", None) + raw_name = getattr(tbl, "name", None) + if not raw_name: + continue + table_name = str(raw_name) + key = (database, table_name) + if key in seen: + continue + seen.add(key) + try: + cols = _list_columns(table_name, database) + except Exception: + continue + if not cols: + continue + + keys: set[str] = {table_name} + catalog = getattr(tbl, "catalog", None) + if database: + keys.add(f"{database}.{table_name}") + if database and catalog: + keys.add(f"{catalog}.{database}.{table_name}") + for c in cols: + nullable = bool(getattr(c, "nullable", False)) + dtype = str(getattr(c, "dataType", "")) + col_name = getattr(c, "name", None) + if not col_name: + continue + info = ColumnInfo(str(col_name), dtype, nullable) + for k in keys: + out.setdefault(k, []).append(info) + + return out + def _introspect_columns_metadata( self, table: str, diff --git a/src/fastflowtransform/executors/duckdb.py b/src/fastflowtransform/executors/duckdb.py index 308dec9..9b35b1c 100644 --- a/src/fastflowtransform/executors/duckdb.py +++ b/src/fastflowtransform/executors/duckdb.py @@ -16,7 +16,7 @@ from fastflowtransform.core import Node from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable -from fastflowtransform.executors.base import BaseExecutor, _scalar +from fastflowtransform.executors.base import BaseExecutor, ColumnInfo, _scalar from fastflowtransform.executors.budget.runtime.duckdb import DuckBudgetRuntime from fastflowtransform.executors.common import _q_ident from fastflowtransform.executors.query_stats.runtime.duckdb import DuckQueryStatsRuntime @@ -414,6 +414,42 @@ def utest_clean_target(self, relation: str) -> None: with suppress(Exception): self._execute_basic(f"drop table if exists {target}") + def collect_docs_columns(self) -> dict[str, list[ColumnInfo]]: + """ + Best-effort column metadata for docs (schema-aware, supports catalog). + """ + where: list[str] = [] + params: list[str] = [] + + if self.catalog: + where.append("lower(table_catalog) = lower(?)") + params.append(self.catalog) + if self.schema: + where.append("lower(table_schema) = lower(?)") + params.append(self.schema) + else: + where.append("table_schema in ('main','temp')") + + where_sql = " AND ".join(where) if where else "1=1" + sql = f""" + select table_name, column_name, data_type, is_nullable + from information_schema.columns + where {where_sql} + order by table_schema, table_name, ordinal_position + """ + + try: + rows = self._execute_basic(sql, params or None).fetchall() + except Exception: + return {} + + out: dict[str, list[ColumnInfo]] = {} + for table, col, dtype, nullable in rows: + out.setdefault(table, []).append( + ColumnInfo(col, str(dtype), str(nullable) in (True, "YES", "Yes")) + ) + return out + def _introspect_columns_metadata( self, table: str, diff --git a/src/fastflowtransform/executors/postgres.py b/src/fastflowtransform/executors/postgres.py index 4710a1b..787fcd9 100644 --- a/src/fastflowtransform/executors/postgres.py +++ b/src/fastflowtransform/executors/postgres.py @@ -17,7 +17,7 @@ from fastflowtransform.errors import ModelExecutionError, ProfileConfigError from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable -from fastflowtransform.executors.base import BaseExecutor, _scalar +from fastflowtransform.executors.base import BaseExecutor, ColumnInfo, _scalar from fastflowtransform.executors.budget.runtime.postgres import PostgresBudgetRuntime from fastflowtransform.executors.common import _q_ident from fastflowtransform.executors.query_stats.runtime.postgres import PostgresQueryStatsRuntime @@ -569,6 +569,30 @@ def utest_clean_target(self, relation: str) -> None: f"DROP TABLE IF EXISTS {qualified} CASCADE", conn=conn ) + def collect_docs_columns(self) -> dict[str, list[ColumnInfo]]: + """ + Column metadata for docs, scoped to the effective schema. + """ + sql = """ + select table_name, column_name, data_type, is_nullable + from information_schema.columns + where table_schema = current_schema() + order by table_name, ordinal_position + """ + try: + with self.engine.begin() as conn: + self._set_search_path(conn) + rows = conn.execute(text(sql)).fetchall() + except Exception: + return {} + + out: dict[str, list[ColumnInfo]] = {} + for table, col, dtype, nullable in rows: + out.setdefault(table, []).append( + ColumnInfo(col, str(dtype), str(nullable).upper() == "YES") + ) + return out + def _introspect_columns_metadata( self, table: str, diff --git a/src/fastflowtransform/executors/snowflake_snowpark.py b/src/fastflowtransform/executors/snowflake_snowpark.py index 0c5d7dd..282ff9c 100644 --- a/src/fastflowtransform/executors/snowflake_snowpark.py +++ b/src/fastflowtransform/executors/snowflake_snowpark.py @@ -12,7 +12,7 @@ from fastflowtransform.core import Node, relation_for from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable, rows_to_tuples -from fastflowtransform.executors.base import BaseExecutor +from fastflowtransform.executors.base import BaseExecutor, ColumnInfo from fastflowtransform.executors.budget.runtime.snowflake_snowpark import ( SnowflakeSnowparkBudgetRuntime, ) @@ -547,6 +547,40 @@ def utest_clean_target(self, relation: str) -> None: with suppress(Exception): self.session.sql(f"DROP TABLE IF EXISTS {qualified}").collect() + def collect_docs_columns(self) -> dict[str, list[ColumnInfo]]: + """ + Best-effort column metadata for docs (scoped to configured DB/schema). + """ + schema_pred = ( + f"lower(table_schema) = '{self.schema.lower()}'" + if self.schema + else "table_schema = current_schema()" + ) + catalog_pred = ( + f" AND lower(table_catalog) = '{self.database.lower()}'" if self.database else "" + ) + sql = f""" + select table_name, column_name, data_type, is_nullable + from information_schema.columns + where {schema_pred}{catalog_pred} + order by table_schema, table_name, ordinal_position + """ + try: + rows = self.session.sql(sql).collect() + except Exception: + return {} + + out: dict[str, list[ColumnInfo]] = {} + for r in rows: + table = r["TABLE_NAME"] + col = r["COLUMN_NAME"] + dtype = r["DATA_TYPE"] + nullable = r["IS_NULLABLE"] + out.setdefault(table, []).append( + ColumnInfo(col, str(dtype), str(nullable).upper() == "YES") + ) + return out + def _normalize_table_parts_for_introspection(self, table: str) -> tuple[str, str, str]: """ Return (database, schema, table_name) for a possibly qualified identifier. diff --git a/src/fastflowtransform/templates/assets/spa.css b/src/fastflowtransform/templates/assets/spa.css index 11e1e4a..6d0f7a4 100644 --- a/src/fastflowtransform/templates/assets/spa.css +++ b/src/fastflowtransform/templates/assets/spa.css @@ -177,3 +177,28 @@ hr{ border:0; border-top:1px solid var(--border); margin:16px 0; } cursor:pointer; } .toastBtn:hover{ border-color:var(--accent); } + +.tabs{ display:flex; gap:8px; flex-wrap:wrap; } +.tab{ + border:1px solid var(--border); + background:transparent; + color:var(--fg); + padding:6px 10px; + border-radius:999px; + cursor:pointer; + font-size:12px; +} +.tab:hover{ border-color:var(--accent); } +.tab.active{ + border-color: color-mix(in srgb, var(--accent), var(--border) 20%); + background: color-mix(in srgb, var(--accent), transparent 88%); +} +.tabPanel{ margin-top:12px; } + +tr.colHit td{ + background: color-mix(in srgb, var(--accent), transparent 88%); +} +tr.colHit{ + outline: 1px solid color-mix(in srgb, var(--accent), transparent 45%); + outline-offset: -1px; +} diff --git a/src/fastflowtransform/templates/assets/spa.js b/src/fastflowtransform/templates/assets/spa.js index 9c280be..dd5c081 100644 --- a/src/fastflowtransform/templates/assets/spa.js +++ b/src/fastflowtransform/templates/assets/spa.js @@ -79,15 +79,55 @@ function topN(items, n) { function escapeHashPart(s) { return encodeURIComponent(String(s || "")).replaceAll("%2F", "/"); } -function parseHash() { - const raw = (location.hash || "#/").slice(1); - const parts = raw.split("/").filter(Boolean); + +function parseHashWithQuery() { + const full = (location.hash || "#/").slice(1); // remove leading '#' + const [pathPart, queryPart] = full.split("?", 2); + const parts = pathPart.split("/").filter(Boolean); + + const query = new URLSearchParams(queryPart || ""); + return { parts, query }; +} + +function setTabInHash(tab) { + const full = (location.hash || "#/").slice(1); + const [pathPart, queryPart] = full.split("?", 2); + const q = new URLSearchParams(queryPart || ""); + if (tab) q.set("tab", tab); + else q.delete("tab"); + const next = q.toString() ? `${pathPart}?${q.toString()}` : `${pathPart}`; + location.hash = `#${next.startsWith("/") ? "" : "/"}${next}`; +} + +function setModelQuery({ tab, col }) { + const full = (location.hash || "#/").slice(1); + const [pathPart, queryPart] = full.split("?", 2); + const q = new URLSearchParams(queryPart || ""); + + if (tab) q.set("tab", tab); else q.delete("tab"); + if (col) q.set("col", col); else q.delete("col"); + + const next = q.toString() ? `${pathPart}?${q.toString()}` : `${pathPart}`; + location.hash = `#${next.startsWith("/") ? "" : "/"}${next}`; +} + +function parseRoute() { + const { parts, query } = parseHashWithQuery(); if (parts.length === 0) return { route: "home" }; - if (parts[0] === "model" && parts[1]) return { route: "model", name: decodeURIComponent(parts.slice(1).join("/")) }; + + if (parts[0] === "model" && parts[1]) { + return { + route: "model", + name: decodeURIComponent(parts.slice(1).join("/")), + tab: query.get("tab") || "", + col: query.get("col") || "", + }; + } if (parts[0] === "source" && parts[1] && parts[2]) { return { route: "source", source: decodeURIComponent(parts[1]), table: decodeURIComponent(parts[2]) }; } if (parts[0] === "macros") return { route: "macros" }; + return { route: "home" }; } @@ -166,23 +206,20 @@ function renderHome(state) { return el("div", { class: "grid2" }, dagCard, stats); } -function renderModel(state, name) { - const { manifest } = state; - const byModel = state.byModel; - const m = byModel.get(name); - +function renderModel(state, name, tabFromRoute, colFromRoute) { + const m = state.byModel.get(name); if (!m) { return el("div", { class: "card" }, el("h2", {}, "Model not found"), el("p", { class: "empty" }, name)); } - const deps = (m.deps || []).map(d => el("a", { href: `#/model/${escapeHashPart(d)}` }, d)); - const usedBy = (m.used_by || []).map(u => el("a", { href: `#/model/${escapeHashPart(u)}` }, u)); + const active = (tabFromRoute || state.modelTabDefault || "overview").toLowerCase(); + const hasCol = !!(colFromRoute && String(colFromRoute).trim()); - const sourcesUsed = (m.sources_used || []).map(s => - el("a", { href: `#/source/${escapeHashPart(s.source_name)}/${escapeHashPart(s.table_name)}` }, `${s.source_name}.${s.table_name}`) - ); + let tab = ["overview","columns","lineage","code","meta"].includes(active) ? active : "overview"; + // Only force columns if col is present AND the URL didn't explicitly set a tab + if (hasCol && !tabFromRoute) tab = "columns"; - const head = el("div", { class: "card" }, + const header = el("div", { class: "card" }, el("div", { class: "grid2" }, el("div", {}, el("h2", {}, m.name), @@ -195,46 +232,155 @@ function renderModel(state, name) { }, "Copy path") ) ), - el("div", { class: "kv" }, - el("div", { class: "k" }, "Kind"), el("div", {}, m.kind), - el("div", { class: "k" }, "Materialized"), el("div", {}, m.materialized || "—"), - el("div", { class: "k" }, "Path"), el("div", {}, el("code", {}, m.path || "—")), - el("div", { class: "k" }, "Deps"), el("div", {}, deps.length ? joinInline(deps) : el("span", { class: "empty" }, "—")), - el("div", { class: "k" }, "Used by"), el("div", {}, usedBy.length ? joinInline(usedBy) : el("span", { class: "empty" }, "—")), - el("div", { class: "k" }, "Sources"), el("div", {}, sourcesUsed.length ? joinInline(sourcesUsed) : el("span", { class: "empty" }, "—")), - ) + renderTabs(tab, (next) => { + // Persist default for convenience + state.modelTabDefault = next; + safeSet(state.STORE.modelTab, next); + + setModelQuery({ + tab: next, + col: (next === "columns") ? (colFromRoute || "") : "" // clear col when leaving Columns + }); + + }) ); - const desc = m.description_html - ? el("div", { class: "card" }, el("h2", {}, "Description"), el("div", { class: "desc", html: m.description_html })) - : null; + const panel = el("div", { class: "tabPanel" }, renderModelPanel(state, m, tab, colFromRoute)); - const cols = (m.columns || []); - const colsCard = cols.length - ? el("div", { class: "card" }, - el("h2", {}, "Columns"), - el("table", { class: "table" }, - el("thead", {}, el("tr", {}, - el("th", {}, "Name"), - el("th", {}, "Type"), - el("th", {}, "Nullable"), - el("th", {}, "Description"), - el("th", {}, "Lineage"), - )), - el("tbody", {}, - ...cols.map(c => el("tr", {}, - el("td", {}, el("code", {}, c.name)), - el("td", {}, el("code", {}, c.dtype || "")), - el("td", {}, c.nullable ? "true" : "false"), - el("td", { html: c.description_html || '' }), - el("td", {}, renderLineage(c.lineage || [])) - )) + return el("div", { class: "grid" }, header, panel); +} + +function renderModelPanel(state, m, tab, colFromRoute) { + if (tab === "overview") { + const deps = (m.deps || []).map(d => el("a", { href: `#/model/${escapeHashPart(d)}` }, d)); + const usedBy = (m.used_by || []).map(u => el("a", { href: `#/model/${escapeHashPart(u)}` }, u)); + const sourcesUsed = (m.sources_used || []).map(s => + el("a", { href: `#/source/${escapeHashPart(s.source_name)}/${escapeHashPart(s.table_name)}` }, `${s.source_name}.${s.table_name}`) + ); + + return el("div", { class: "grid" }, + el("div", { class: "card" }, + el("h3", {}, "Summary"), + el("div", { class: "kv" }, + el("div", { class: "k" }, "Kind"), el("div", {}, m.kind), + el("div", { class: "k" }, "Materialized"), el("div", {}, m.materialized || "—"), + el("div", { class: "k" }, "Path"), el("div", {}, el("code", {}, m.path || "—")), + el("div", { class: "k" }, "Deps"), el("div", {}, deps.length ? joinInline(deps) : el("span", { class: "empty" }, "—")), + el("div", { class: "k" }, "Used by"), el("div", {}, usedBy.length ? joinInline(usedBy) : el("span", { class: "empty" }, "—")), + el("div", { class: "k" }, "Sources"), el("div", {}, sourcesUsed.length ? joinInline(sourcesUsed) : el("span", { class: "empty" }, "—")), + ) + ), + m.description_html + ? el("div", { class: "card" }, el("h3", {}, "Description"), el("div", { class: "desc", html: m.description_html })) + : el("div", { class: "card" }, el("h3", {}, "Description"), el("p", { class: "empty" }, "No description.")) + ); + } + + if (tab === "columns") { + const cols = m.columns || []; + + const card = cols.length + ? el("div", { class: "card" }, + el("h3", {}, `Columns (${cols.length})`), + el("table", { class: "table" }, + el("thead", {}, el("tr", {}, + el("th", {}, "Name"), + el("th", {}, "Type"), + el("th", {}, "Nullable"), + el("th", {}, "Description"), + )), + el("tbody", {}, + ...cols.map(c => el( + "tr", + { id: `col-${cssSafeId(m.name)}-${cssSafeId(c.name)}` }, + el("td", {}, el("code", {}, c.name)), + el("td", {}, el("code", {}, c.dtype || "")), + el("td", {}, c.nullable ? "true" : "false"), + el("td", { html: c.description_html || '' }), + )) + ) ) ) - ) - : el("div", { class: "card" }, el("h2", {}, "Columns"), el("p", { class: "empty" }, manifest.project?.with_schema ? "No columns found." : "Schema collection disabled.")); + : el("div", { class: "card" }, + el("h3", {}, "Columns"), + el("p", { class: "empty" }, state.manifest.project?.with_schema ? "No columns found." : "Schema collection disabled.") + ); + + // Scroll + highlight if col query param is present + const colName = (colFromRoute || "").trim(); + if (cols.length && colName) { + queueMicrotask(() => { + const rowId = `col-${cssSafeId(m.name)}-${cssSafeId(colName)}`; + const row = document.getElementById(rowId); + if (!row) return; - return el("div", { class: "grid" }, head, desc, colsCard); + // clear previous hit + document.querySelectorAll("tr.colHit").forEach(n => n.classList.remove("colHit")); + + row.classList.add("colHit"); + row.scrollIntoView({ block: "center", behavior: "smooth" }); + + // remove highlight after a moment (optional) + setTimeout(() => row.classList.remove("colHit"), 2200); + }); + } + + return card; + } + + if (tab === "lineage") { + const cols = m.columns || []; + const rows = cols + .filter(c => (c.lineage || []).length) + .map(c => + el("tr", {}, + el("td", {}, el("code", {}, c.name)), + el("td", {}, renderLineage(c.lineage || [])) + ) + ); + + return el("div", { class: "card" }, + el("h3", {}, "Column lineage"), + rows.length + ? el("table", { class: "table" }, + el("thead", {}, el("tr", {}, el("th", {}, "Column"), el("th", {}, "Lineage"))), + el("tbody", {}, ...rows) + ) + : el("p", { class: "empty" }, "No lineage available for this model’s columns.") + ); + } + + if (tab === "code") { + // Placeholder until we add compiled SQL / python source to manifest + return el("div", { class: "card" }, + el("h3", {}, "Code"), + el("p", { class: "empty" }, "Code view not yet available. Next step: include rendered SQL / Python source in the manifest.") + ); + } + + if (tab === "meta") { + // Show a structured dump of whatever we have + const meta = { + name: m.name, + kind: m.kind, + relation: m.relation, + materialized: m.materialized, + path: m.path, + deps: m.deps || [], + used_by: m.used_by || [], + sources_used: m.sources_used || [], + }; + return el("div", { class: "card" }, + el("h3", {}, "Meta"), + el("pre", { class: "mono", style: "white-space:pre-wrap; margin:0;" }, JSON.stringify(meta, null, 2)) + ); + } + + return el("div", { class: "card" }, el("p", { class: "empty" }, "Unknown tab.")); +} + +function cssSafeId(s) { + return String(s || "").replace(/[^a-zA-Z0-9_-]+/g, "_"); } function renderSource(state, sourceName, tableName) { @@ -335,6 +481,43 @@ function toastOnce({ key, title, body, actionLabel, onAction }) { setTimeout(() => { try { node.remove(); } catch {} }, 5500); } +function renderTabs(active, onPick) { + const tabs = [ + ["overview", "Overview"], + ["columns", "Columns"], + ["lineage", "Lineage"], + ["code", "Code"], + ["meta", "Meta"], + ]; + + return el("div", { class: "tabs" }, + ...tabs.map(([id, label]) => + el("button", { + class: `tab ${active === id ? "active" : ""}`, + onclick: () => onPick(id), + }, label) + ) + ); +} + +function makeSnippet(text, query, maxLen = 90) { + const t = (text || "").replace(/\s+/g, " ").trim(); + if (!t) return ""; + + const q = (query || "").trim().toLowerCase(); + if (!q) return t.length > maxLen ? t.slice(0, maxLen - 1) + "…" : t; + + const idx = t.toLowerCase().indexOf(q); + if (idx < 0) return t.length > maxLen ? t.slice(0, maxLen - 1) + "…" : t; + + const start = Math.max(0, idx - Math.floor(maxLen * 0.35)); + const end = Math.min(t.length, start + maxLen); + + const prefix = start > 0 ? "…" : ""; + const suffix = end < t.length ? "…" : ""; + return prefix + t.slice(start, end) + suffix; +} + async function loadManifest() { const res = await fetch(MANIFEST_URL, { cache: "no-store" }); if (!res.ok) throw new Error(`Failed to load manifest: ${res.status}`); @@ -383,6 +566,9 @@ async function main() { lastHash: `fft_docs:${projKey}:last_hash`, paletteQuery: `fft_docs:${projKey}:palette_query`, }; + STORE.modelTab = `fft_docs:${projKey}:model_tab_default`; + state.modelTabDefault = safeGet(STORE.modelTab) || "overview"; + state.STORE = STORE; // Persisted UI state state.filter = safeGet(STORE.filter) ?? ""; @@ -448,9 +634,14 @@ async function main() { searchIndex.push({ kind: "column", + model: m.name, + column: c.name, + relation: m.relation || "", + dtype: c.dtype || "", + descText: cDesc || "", title: `${m.name}.${c.name}`, subtitle: `${m.relation || ""}${c.dtype ? " • " + c.dtype : ""}`, - route: `#/model/${escapeHashPart(m.name)}`, // navigates to model; we can later auto-scroll to column + route: `#/model/${escapeHashPart(m.name)}?tab=columns&col=${escapeHashPart(c.name)}`, haystack: colHay, }); } @@ -513,6 +704,33 @@ async function main() { const results = state.search.results || []; const sel = Math.max(0, Math.min(state.search.selected || 0, results.length - 1)); + const q = (state.search.query || "").trim(); + const sub = (() => { + if (r.kind === "column") { + const parts = [ + "COLUMN", + r.model || "", + r.relation ? `• ${r.relation}` : "", + r.dtype ? `• ${r.dtype}` : "", + ].filter(Boolean).join(" "); + const snip = makeSnippet(r.descText || "", q, 90); + return snip ? `${parts} • ${snip}` : parts; + } + if (r.kind === "model") { + const snip = makeSnippet((r.descText || ""), q, 90); + return snip ? `MODEL • ${r.subtitle || ""} • ${snip}` : `MODEL • ${r.subtitle || ""}`; + } + if (r.kind === "source") { + const snip = makeSnippet((r.descText || ""), q, 90); + return snip ? `SOURCE • ${r.subtitle || ""} • ${snip}` : `SOURCE • ${r.subtitle || ""}`; + } + return `${(r.kind || "").toUpperCase()} • ${r.subtitle || ""}`; + })(); + + const right = r.kind === "column" && r.dtype + ? el("span", { class: "pill" }, r.dtype) + : el("div", { class: "kbd" }, "↵"); + state.ui.paletteList.replaceChildren( ...(results.length ? results.map((r, idx) => @@ -525,9 +743,9 @@ async function main() { }, el("div", { class: "resultMain" }, el("div", { class: "resultTitle" }, r.title), - el("div", { class: "resultSub" }, `${r.kind.toUpperCase()} • ${r.subtitle || ""}`) + el("div", { class: "resultSub" }, sub) ), - el("div", { class: "kbd" }, "↵") + right ) ) : [el("div", { class: "result" }, @@ -554,7 +772,7 @@ async function main() { state.search.query = e.target.value || ""; safeSet(STORE.paletteQuery, state.search.query); runSearch(state.search.query); - renderPaletteResults(); // ✅ no app rerender + renderPaletteResults(); }, onkeydown: (e) => { // Key handling while focused in the input @@ -658,7 +876,6 @@ async function main() { sourcesTitle: null, modelsList: null, sourcesList: null, - macrosCount: null, }; ui.sidebar.macrosList = null; ui.sidebar.modelsSection = null; @@ -852,9 +1069,9 @@ async function main() { } function updateMain() { - const route = parseHash(); + const route = parseRoute(); let view; - if (route.route === "model") view = renderModel(state, route.name); + if (route.route === "model") view = renderModel(state, route.name, route.tab, route.col); else if (route.route === "source") view = renderSource(state, route.source, route.table); else if (route.route === "macros") view = renderMacros(state); else view = renderHome(state); diff --git a/tests/unit/docs/test_docs_unit.py b/tests/unit/docs/test_docs_unit.py index 2c41f42..ff97c4c 100644 --- a/tests/unit/docs/test_docs_unit.py +++ b/tests/unit/docs/test_docs_unit.py @@ -241,187 +241,26 @@ def test_apply_descriptions_to_models_applies_short_and_column_desc(): assert cols_by_table["db.sc.m1"][1].description_html == "

Col 2

" -# --------------------------------------------------------------------------- -# _collect_columns engine stubs -# --------------------------------------------------------------------------- - - -@pytest.mark.unit -def test_collect_columns_prefers_spark(): - class FakeCol: - def __init__(self, name: str): - self.name = name - self.dataType = "INT" - self.nullable = True - - class FakeTable: - def __init__(self, name: str): - self.name = name - self.database = None - self.catalog = None - - class FakeSparkCatalog: - def listTables(self): - return [FakeTable("T1")] - - def listColumns(self, ident, database=None): - return [FakeCol("C1"), FakeCol("C2")] - - class FakeSpark: - catalog = FakeSparkCatalog() - - cols = docs_mod._collect_columns(SimpleNamespace(spark=FakeSpark())) - assert "T1" in cols - assert [c.name for c in cols["T1"]] == ["C1", "C2"] - - -@pytest.mark.unit -def test_collect_columns_with_unknown_executor_returns_empty(): - cols = docs_mod._collect_columns(object()) - assert cols == {} - - -# ---------------------- _columns_duckdb ---------------------- - - -@pytest.mark.unit -def test_columns_duckdb_collects_tables_and_cols(): - class FakeCursor: - def __init__(self, rows): - self._rows = rows - - def fetchall(self): - return self._rows - - class FakeConn: - def __init__(self, rows): - self._rows = rows - - def execute(self, _sql: str): - return FakeCursor(self._rows) - - rows = [ - # table_name, column_name, data_type, is_nullable - ("my_table", "id", "INTEGER", "NO"), - ("my_table", "name", "TEXT", "YES"), - ("other", "x", "BOOLEAN", "YES"), - ] - fake_con = FakeConn(rows) - - cols = docs_mod._columns_duckdb(fake_con) - - assert set(cols.keys()) == {"my_table", "other"} - mt = cols["my_table"] - assert [c.name for c in mt] == ["id", "name"] - assert mt[0].dtype == "INTEGER" - assert mt[0].nullable is False - assert mt[1].nullable is True - - -# ---------------------- _columns_postgres ---------------------- - - @pytest.mark.unit -def test_columns_postgres_collects_from_engine(): - class FakeResult: - def __init__(self, rows): - self._rows = rows - - def fetchall(self): - return self._rows - - class FakeConn: - def __init__(self, rows): - self._rows = rows - - def execute(self, _stmt): - return FakeResult(self._rows) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - class FakeEngine: - def __init__(self, rows): - self._rows = rows - - def begin(self): - return FakeConn(self._rows) - - rows = [ - # table_name, column_name, data_type, is_nullable - ("public_tbl", "id", "integer", "YES"), - ("public_tbl", "email", "text", "NO"), - ] - fake_engine = FakeEngine(rows) - - cols = docs_mod._columns_postgres(fake_engine) - - assert "public_tbl" in cols - tcols = cols["public_tbl"] - assert [c.name for c in tcols] == ["id", "email"] - assert tcols[0].dtype == "integer" - # in deiner Implementierung: nullable == "YES" - assert tcols[0].nullable is True - assert tcols[1].nullable is False +def test_collect_columns_uses_executor_hook_when_available(): + expected = {"tbl": [docs_mod.ColumnInfo("c1", "INT", True)]} + class FakeExecutor: + def collect_docs_columns(self): + return expected -# ---------------------- _columns_snowflake ---------------------- + cols = docs_mod._collect_columns(FakeExecutor()) + assert cols is expected @pytest.mark.unit -def test_columns_snowflake_collects_from_session(): - class FakeDF: - def __init__(self, rows): - self._rows = rows - - def collect(self): - return self._rows - - class FakeSession: - def __init__(self, rows): - self._rows = rows - - def sql(self, _sql: str): - return FakeDF(self._rows) - - rows = [ - { - "TABLE_NAME": "T1", - "COLUMN_NAME": "ID", - "DATA_TYPE": "NUMBER", - "IS_NULLABLE": "NO", - }, - { - "TABLE_NAME": "T1", - "COLUMN_NAME": "NAME", - "DATA_TYPE": "TEXT", - "IS_NULLABLE": "YES", - }, - { - "TABLE_NAME": "T2", - "COLUMN_NAME": "TS", - "DATA_TYPE": "TIMESTAMP_NTZ", - "IS_NULLABLE": "YES", - }, - ] - fake_session = FakeSession(rows) - - cols = docs_mod._columns_snowflake(fake_session) - - assert set(cols.keys()) == {"T1", "T2"} - t1 = cols["T1"] - assert [c.name for c in t1] == ["ID", "NAME"] - assert t1[0].dtype == "NUMBER" - assert t1[0].nullable is False - assert t1[1].nullable is True +def test_collect_columns_swallows_errors_and_unknown(): + class BoomExecutor: + def collect_docs_columns(self): + raise RuntimeError("boom") - t2 = cols["T2"] - assert t2[0].name == "TS" - assert t2[0].dtype == "TIMESTAMP_NTZ" - assert t2[0].nullable is True + assert docs_mod._collect_columns(BoomExecutor()) == {} + assert docs_mod._collect_columns(object()) == {} @pytest.mark.unit