From dcb388dfa92b891f59058cf5252fe83e6adab31e Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 10 May 2026 18:40:37 -0600 Subject: [PATCH 01/10] Track per-call wall-clock time and emit token-rate in cost summary MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds elapsed_seconds wall-clock timing on every LLM API call, aggregates it through the agent loop, and surfaces it in the QUERY TOTAL log line and provenance JSON as `elapsed=X.XXs out_tps=Y.Y`. Provider-agnostic — works for Ollama (whose OpenAI-compatible endpoint omits its native timing fields), Anthropic, OpenAI, and GitHub Models alike. The reported rate is `output_tokens / elapsed_wallclock`, which includes prefill + network and is therefore a lower bound on the model's true decode rate; callers that want a clean decode rate should probe the model directly. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/datasight/agent.py | 7 ++++++ src/datasight/cli.py | 9 +++++--- src/datasight/cost.py | 47 ++++++++++++++++++++++++++----------- src/datasight/llm.py | 22 +++++++++++++++++- tests/test_cli_tools.py | 51 ++++++++++++++++++++++++++++++++++++----- 5 files changed, 113 insertions(+), 23 deletions(-) diff --git a/src/datasight/agent.py b/src/datasight/agent.py index f19290ae..76bf5702 100644 --- a/src/datasight/agent.py +++ b/src/datasight/agent.py @@ -816,6 +816,7 @@ class AgentResult: total_output_tokens: int = 0 total_cache_creation_input_tokens: int = 0 total_cache_read_input_tokens: int = 0 + total_elapsed_seconds: float = 0.0 api_calls: int = 0 retries_performed: int = 0 truncated: bool = False @@ -903,6 +904,7 @@ async def run_agent_loop( # noqa: C901 total_output_tokens = 0 total_cache_creation_input_tokens = 0 total_cache_read_input_tokens = 0 + total_elapsed_seconds = 0.0 api_calls = 0 retries_performed = 0 @@ -919,6 +921,7 @@ async def run_agent_loop( # noqa: C901 total_output_tokens += response.usage.output_tokens total_cache_creation_input_tokens += response.usage.cache_creation_input_tokens total_cache_read_input_tokens += response.usage.cache_read_input_tokens + total_elapsed_seconds += response.usage.elapsed_seconds retries_performed += response.call_stats.retries_performed if max_cost_usd is not None: @@ -947,6 +950,7 @@ async def run_agent_loop( # noqa: C901 total_output_tokens=total_output_tokens, total_cache_creation_input_tokens=total_cache_creation_input_tokens, total_cache_read_input_tokens=total_cache_read_input_tokens, + total_elapsed_seconds=total_elapsed_seconds, api_calls=api_calls, retries_performed=retries_performed, ) @@ -968,6 +972,7 @@ async def run_agent_loop( # noqa: C901 total_output_tokens=total_output_tokens, total_cache_creation_input_tokens=total_cache_creation_input_tokens, total_cache_read_input_tokens=total_cache_read_input_tokens, + total_elapsed_seconds=total_elapsed_seconds, api_calls=api_calls, retries_performed=retries_performed, truncated=True, @@ -1020,6 +1025,7 @@ async def run_agent_loop( # noqa: C901 total_output_tokens=total_output_tokens, total_cache_creation_input_tokens=total_cache_creation_input_tokens, total_cache_read_input_tokens=total_cache_read_input_tokens, + total_elapsed_seconds=total_elapsed_seconds, api_calls=api_calls, retries_performed=retries_performed, ) @@ -1031,6 +1037,7 @@ async def run_agent_loop( # noqa: C901 total_output_tokens=total_output_tokens, total_cache_creation_input_tokens=total_cache_creation_input_tokens, total_cache_read_input_tokens=total_cache_read_input_tokens, + total_elapsed_seconds=total_elapsed_seconds, api_calls=api_calls, retries_performed=retries_performed, ) diff --git a/src/datasight/cli.py b/src/datasight/cli.py index 09356d61..277d1f16 100644 --- a/src/datasight/cli.py +++ b/src/datasight/cli.py @@ -20,6 +20,7 @@ resolve_settings as resolve_settings, ) from datasight.config import create_sql_runner_from_settings +from datasight.cost import build_cost_data, log_query_cost from datasight.data_profile import ( build_measure_overview, build_prompt_recipes, @@ -117,7 +118,6 @@ async def run_ask_pipeline( load_schema_description, load_time_series_config, ) - from datasight.cost import build_cost_data, log_query_cost from datasight.data_profile import format_time_series_prompt_context from datasight.prompts import build_system_prompt from datasight.query_log import QueryLogger @@ -257,6 +257,7 @@ async def run_ask_pipeline( result.total_output_tokens, cache_creation_input_tokens=result.total_cache_creation_input_tokens, cache_read_input_tokens=result.total_cache_read_input_tokens, + elapsed_seconds=result.total_elapsed_seconds, provider=settings.llm.provider, ) cost_data = build_cost_data( @@ -266,6 +267,7 @@ async def run_ask_pipeline( result.total_output_tokens, cache_creation_input_tokens=result.total_cache_creation_input_tokens, cache_read_input_tokens=result.total_cache_read_input_tokens, + elapsed_seconds=result.total_elapsed_seconds, provider=settings.llm.provider, ) query_logger.log_cost( @@ -1157,8 +1159,6 @@ def build_cli_provenance( project_dir: str, provider: str | None = None, ) -> dict[str, Any]: - from datasight.cost import build_cost_data - cost_data = build_cost_data( model, result.api_calls, @@ -1166,6 +1166,7 @@ def build_cli_provenance( result.total_output_tokens, cache_creation_input_tokens=result.total_cache_creation_input_tokens, cache_read_input_tokens=result.total_cache_read_input_tokens, + elapsed_seconds=result.total_elapsed_seconds, provider=provider, ) tools = [] @@ -1206,6 +1207,8 @@ def build_cli_provenance( "api_calls": result.api_calls, "input_tokens": result.total_input_tokens, "output_tokens": result.total_output_tokens, + "elapsed_seconds": cost_data.get("elapsed_seconds"), + "output_tokens_per_sec": cost_data.get("output_tokens_per_sec"), "estimated_cost": cost_data.get("estimated_cost"), }, "warnings": warnings, diff --git a/src/datasight/cost.py b/src/datasight/cost.py index 11511451..7ed895e2 100644 --- a/src/datasight/cost.py +++ b/src/datasight/cost.py @@ -38,6 +38,7 @@ def build_cost_data( *, cache_creation_input_tokens: int = 0, cache_read_input_tokens: int = 0, + elapsed_seconds: float | None = None, provider: str | None = None, ) -> dict[str, Any]: """Build a cost/token summary dict for a single turn. @@ -47,6 +48,12 @@ def build_cost_data( ``None`` even if the model name happens to be in ``MODEL_PRICING`` — GitHub Models reuses OpenAI model names but is billed by quota, not per-token, so pricing them against OpenAI rates is misleading. + + When ``elapsed_seconds`` is supplied and positive, ``elapsed_seconds``, + ``output_tokens_per_sec`` and ``total_tokens_per_sec`` are added to the + returned dict. ``output_tokens_per_sec`` is a wall-clock aggregate — + it includes prefill + decode + network, so it's a lower bound on the + model's true decode rate. """ data: dict[str, Any] = { "api_calls": api_calls, @@ -56,20 +63,25 @@ def build_cost_data( "cache_read_input_tokens": cache_read_input_tokens, "estimated_cost": None, } - if provider is not None and provider not in _PROVIDERS_WITH_PRICING: - return data - pricing = MODEL_PRICING.get(model) - if pricing: - input_cost = input_tokens * pricing[0] / 1_000_000 - output_cost = output_tokens * pricing[1] / 1_000_000 - # Anthropic prompt-cache writes are billed at 1.25x input price and - # cache reads at 0.1x input price for the ephemeral cache used here. - cache_creation_cost = cache_creation_input_tokens * pricing[0] * 1.25 / 1_000_000 - cache_read_cost = cache_read_input_tokens * pricing[0] * 0.1 / 1_000_000 - data["estimated_cost"] = round( - input_cost + output_cost + cache_creation_cost + cache_read_cost, - 6, + if elapsed_seconds is not None and elapsed_seconds > 0: + data["elapsed_seconds"] = round(elapsed_seconds, 4) + data["output_tokens_per_sec"] = round(output_tokens / elapsed_seconds, 2) + data["total_tokens_per_sec"] = round( + (input_tokens + output_tokens) / elapsed_seconds, 2 ) + if provider is None or provider in _PROVIDERS_WITH_PRICING: + pricing = MODEL_PRICING.get(model) + if pricing: + input_cost = input_tokens * pricing[0] / 1_000_000 + output_cost = output_tokens * pricing[1] / 1_000_000 + # Anthropic prompt-cache writes are billed at 1.25x input price and + # cache reads at 0.1x input price for the ephemeral cache used here. + cache_creation_cost = cache_creation_input_tokens * pricing[0] * 1.25 / 1_000_000 + cache_read_cost = cache_read_input_tokens * pricing[0] * 0.1 / 1_000_000 + data["estimated_cost"] = round( + input_cost + output_cost + cache_creation_cost + cache_read_cost, + 6, + ) return data @@ -81,6 +93,7 @@ def log_query_cost( *, cache_creation_input_tokens: int = 0, cache_read_input_tokens: int = 0, + elapsed_seconds: float | None = None, provider: str | None = None, ) -> None: """Emit a one-line loguru summary of token usage and estimated cost.""" @@ -91,13 +104,21 @@ def log_query_cost( output_tokens, cache_creation_input_tokens=cache_creation_input_tokens, cache_read_input_tokens=cache_read_input_tokens, + elapsed_seconds=elapsed_seconds, provider=provider, ) cost = data["estimated_cost"] cost_str = f" est_cost=${cost:.4f}" if cost is not None else "" + rate_str = "" + if "elapsed_seconds" in data: + rate_str = ( + f" elapsed={data['elapsed_seconds']:.2f}s" + f" out_tps={data['output_tokens_per_sec']:.1f}" + ) logger.info( f"[tokens] QUERY TOTAL: api_calls={api_calls} " f"input={input_tokens} output={output_tokens} " f"cache_create={cache_creation_input_tokens} cache_read={cache_read_input_tokens}" + f"{rate_str}" f"{cost_str}" ) diff --git a/src/datasight/llm.py b/src/datasight/llm.py index ecf7584b..bd5040d6 100644 --- a/src/datasight/llm.py +++ b/src/datasight/llm.py @@ -18,6 +18,7 @@ import asyncio import json +import time from dataclasses import dataclass, field from typing import Any, Protocol, cast @@ -118,12 +119,20 @@ class ToolUseBlock: @dataclass class Usage: - """Token usage statistics from an LLM response.""" + """Token usage statistics from an LLM response. + + ``elapsed_seconds`` is wall-clock time for the API call, measured + client-side. It is the only available timing signal for providers whose + response payloads omit native timing fields (Ollama's OpenAI-compatible + endpoint, for example). Includes prefill + decode + network; callers + that want a clean decode rate should probe the model directly. + """ input_tokens: int = 0 output_tokens: int = 0 cache_creation_input_tokens: int = 0 cache_read_input_tokens: int = 0 + elapsed_seconds: float = 0.0 @dataclass @@ -299,9 +308,14 @@ async def create_message( # noqa: C901 response = None retries = 0 + elapsed_seconds = 0.0 for attempt in range(DEFAULT_MAX_ATTEMPTS): is_last = attempt == DEFAULT_MAX_ATTEMPTS - 1 try: + # Wall-clock the actual API call; this is the only timing + # signal available across all providers since their response + # payloads don't expose per-call timing uniformly. + _t0 = time.monotonic() response = await self._client.messages.create( model=model, max_tokens=max_tokens, @@ -310,6 +324,7 @@ async def create_message( # noqa: C901 tools=anthropic_tools, messages=anthropic_messages, ) + elapsed_seconds = time.monotonic() - _t0 break except anthropic.AuthenticationError as e: # Auth failures are a configuration problem, not a transient @@ -392,6 +407,7 @@ async def create_message( # noqa: C901 cache_creation_input_tokens=getattr(response.usage, "cache_creation_input_tokens", 0) or 0, cache_read_input_tokens=getattr(response.usage, "cache_read_input_tokens", 0) or 0, + elapsed_seconds=elapsed_seconds, ) return LLMResponse( content=content, @@ -599,9 +615,11 @@ async def create_message( # noqa: C901 response = None retries = 0 + elapsed_seconds = 0.0 for attempt in range(DEFAULT_MAX_ATTEMPTS): is_last = attempt == DEFAULT_MAX_ATTEMPTS - 1 try: + _t0 = time.monotonic() if openai_tools: response = await self._client.chat.completions.create( model=model, @@ -617,6 +635,7 @@ async def create_message( # noqa: C901 max_tokens=max_tokens, temperature=0, ) + elapsed_seconds = time.monotonic() - _t0 break except oa.AuthenticationError as e: # Auth failures are a configuration problem — don't retry. @@ -736,6 +755,7 @@ async def create_message( # noqa: C901 usage = Usage( input_tokens=response.usage.prompt_tokens if response.usage else 0, output_tokens=response.usage.completion_tokens if response.usage else 0, + elapsed_seconds=elapsed_seconds, ) return LLMResponse( content=content, diff --git a/tests/test_cli_tools.py b/tests/test_cli_tools.py index 1ac250af..500620fe 100644 --- a/tests/test_cli_tools.py +++ b/tests/test_cli_tools.py @@ -10,6 +10,7 @@ from click.testing import CliRunner from datasight.cli import cli +from datasight.cost import build_cost_data from datasight.llm import LLMResponse, TextBlock, ToolUseBlock, Usage @@ -1486,6 +1487,7 @@ def _make_sql_result(text="answer", queries=None): total_output_tokens=0, total_cache_creation_input_tokens=0, total_cache_read_input_tokens=0, + total_elapsed_seconds=0.0, api_calls=0, ) @@ -1939,8 +1941,6 @@ def test_ask_sql_script_rejects_with_file(project_dir, tmp_path): def test_build_cost_data_known_model_returns_estimated_cost(): - from datasight.cost import build_cost_data - data = build_cost_data( "claude-sonnet-4-6", api_calls=2, @@ -1955,8 +1955,6 @@ def test_build_cost_data_known_model_returns_estimated_cost(): def test_build_cost_data_counts_anthropic_cache_tokens(): - from datasight.cost import build_cost_data - data = build_cost_data( "claude-sonnet-4-6", api_calls=1, @@ -1972,8 +1970,6 @@ def test_build_cost_data_counts_anthropic_cache_tokens(): def test_build_cost_data_unknown_model_returns_none_cost(): - from datasight.cost import build_cost_data - data = build_cost_data( "made-up-model", api_calls=1, @@ -1986,6 +1982,49 @@ def test_build_cost_data_unknown_model_returns_none_cost(): assert data["estimated_cost"] is None +def test_build_cost_data_with_elapsed_adds_token_rates(): + data = build_cost_data( + "qwen2.5", + api_calls=2, + input_tokens=1000, + output_tokens=500, + elapsed_seconds=10.0, + provider="ollama", + ) + assert data["elapsed_seconds"] == 10.0 + assert data["output_tokens_per_sec"] == 50.0 + assert data["total_tokens_per_sec"] == 150.0 + # Ollama is not in _PROVIDERS_WITH_PRICING, so no cost estimate. + assert data["estimated_cost"] is None + + +def test_build_cost_data_without_elapsed_omits_rate_fields(): + data = build_cost_data( + "qwen2.5", + api_calls=1, + input_tokens=100, + output_tokens=50, + provider="ollama", + ) + assert "elapsed_seconds" not in data + assert "output_tokens_per_sec" not in data + + +def test_build_cost_data_zero_elapsed_omits_rate_fields(): + # Zero elapsed (e.g. a synthetic agent test that never made an API call) + # would divide by zero — guard against it by skipping the rate fields. + data = build_cost_data( + "qwen2.5", + api_calls=0, + input_tokens=0, + output_tokens=0, + elapsed_seconds=0.0, + provider="ollama", + ) + assert "elapsed_seconds" not in data + assert "output_tokens_per_sec" not in data + + def test_run_ask_pipeline_logs_cost_entry(monkeypatch, project_dir): """``datasight ask`` must persist a turn-level cost summary to the query log.""" import asyncio From 0e6751c635fdc46ba653af4d8a304b69172793fe Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 10 May 2026 18:40:51 -0600 Subject: [PATCH 02/10] Detect and repair grounding-file drift after schema changes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds two complementary pieces that catch the failure mode where queries.yaml / schema_description.md / time_series.yaml fall out of sync with the live database (typically after a `datasight tidy review` reshape). When that happens the LLM agent is silently being trained on wrong column names — hallucinating, refusing, or returning tables full of zeros despite a clean execution. Plan 1 — `datasight verify` static drift check - New `datasight.grounding` library: AST-walks queries.yaml with sqlglot, scans backticked identifiers in schema_description.md, and validates time_series.yaml entries against the live schema. - Suppresses false positives from prose enumeration values by auto-loading distinct values of low-cardinality VARCHAR columns. - Wired into `datasight verify` as a cheap pre-flight check. New `--static-only` flag skips the LLM phase for fast drift-only checks; `--skip-grounding-check` disables it. Plan 2 — `datasight tidy review` post-apply repair hook - New `datasight.grounding_repair` library: snapshots schema before the tidy transform, calls the configured LLM with both old and new schemas plus the drift report, validates every rewritten SQL against the live DB (retries up to 2x on failures), and returns proposed file contents without writing. - Hook in `tidy review` runs after `apply_proposal` completes: shows drift, prompts to repair, displays unified diff, prompts to apply, and writes atomically via tempfile + os.replace. - Scope limited to queries.yaml, schema_description.md, and time_series.yaml — schema.yaml and measures.yaml stay owned by tidy_review's existing update helpers. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/reference/cli.md | 8 + src/datasight/cli_commands/tidy.py | 152 ++++++++ src/datasight/cli_commands/verify.py | 72 +++- src/datasight/grounding.py | 511 +++++++++++++++++++++++++++ src/datasight/grounding_repair.py | 429 ++++++++++++++++++++++ tests/test_grounding.py | 272 ++++++++++++++ tests/test_grounding_repair.py | 331 +++++++++++++++++ 7 files changed, 1774 insertions(+), 1 deletion(-) create mode 100644 src/datasight/grounding.py create mode 100644 src/datasight/grounding_repair.py create mode 100644 tests/test_grounding.py create mode 100644 tests/test_grounding_repair.py diff --git a/docs/reference/cli.md b/docs/reference/cli.md index d6095115..f5b11815 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -437,10 +437,16 @@ Runs each question from queries.yaml through the full LLM pipeline, executes the generated SQL, and compares results against expected values. Use this to validate correctness across different models and providers. +Before the LLM phase, runs a static schema-drift check that flags +references to columns or tables that no longer exist in the live +database. ``--static-only`` skips the LLM phase entirely; +``--skip-grounding-check`` skips the static check. + Examples: ``` datasight verify +datasight verify --static-only datasight verify --queries verification.yaml datasight verify --model gpt-4o ``` @@ -470,6 +476,8 @@ datasight verify [OPTIONS] | `--project-dir` | Project directory containing .env and queries.yaml. Default: `.`. | | `--model` | Model name (overrides .env). | | `--queries` | Path to queries YAML file (default: queries.yaml in project dir). | +| `--static-only` | Run only the cheap schema-drift check (no LLM, no query execution). Reports unresolved column/table references in queries.yaml, schema_description.md, and time_series.yaml against the live DB. | +| `--skip-grounding-check` | Skip the static drift check that normally runs before the LLM phase. | ### `datasight ask` diff --git a/src/datasight/cli_commands/tidy.py b/src/datasight/cli_commands/tidy.py index 00396f06..8cd419c4 100644 --- a/src/datasight/cli_commands/tidy.py +++ b/src/datasight/cli_commands/tidy.py @@ -13,8 +13,20 @@ from datasight import cli from datasight.cli_helpers import format_epilog +from datasight.config import create_sql_runner_from_settings from datasight.data_profile import find_table_info from datasight.explore import create_files_session_for_settings +from datasight.grounding import ( + build_enum_values_sync, + build_schema_truth_sync, + check_grounding_drift, + format_drift_report, +) +from datasight.grounding_repair import ( + format_repair_summary, + repair_grounding, + write_repair_atomic, +) from datasight.schema import introspect_schema from datasight.tidy import _detect_period_groups, analyze_tidy_patterns from datasight.tidy_llm import propose_reshapes @@ -614,8 +626,148 @@ async def _load_schema(): click.echo("No proposals approved.") return + # Snapshot the pre-apply schema so the grounding repair flow (post-apply, + # below) can show the LLM both old and new schemas. Skipped on dry runs + # since nothing changes; skipped on snapshot errors so a quirky DuckDB + # state doesn't block the apply itself. + old_schema: dict[str, set[str]] | None = None + if not dry_run: + try: + ro = duckdb.connect(resolved_db_path, read_only=True) + try: + old_schema = build_schema_truth_sync(ro) + finally: + ro.close() + except duckdb.Error: + old_schema = None + _apply_review_proposals(approved, disposition, as_mode, dry_run, resolved_db_path, project_dir) + if not dry_run and old_schema is not None: + _offer_grounding_repair(resolved_db_path, project_dir, settings, old_schema) + + +def _offer_grounding_repair( # noqa: C901 + db_path: str, + project_dir: str, + settings: Any, + old_schema: dict[str, set[str]], +) -> None: + """Post-apply hook: detect grounding drift, offer LLM-driven repair. + + No-op when the schema is unchanged, when drift is clean against the new + schema, when the user declines the prompt, or when LLM settings are + incomplete. + """ + try: + ro = duckdb.connect(db_path, read_only=True) + try: + new_schema = build_schema_truth_sync(ro) + enum_values = build_enum_values_sync(ro, new_schema) + finally: + ro.close() + except duckdb.Error as exc: + click.echo(f"warn: grounding check skipped (DB unreadable): {exc}", err=True) + return + + if old_schema == new_schema: + return + + drift = check_grounding_drift( + Path(project_dir), new_schema, enum_values=enum_values + ) + if drift.is_clean: + return + + click.echo("") + click.echo(format_drift_report(drift)) + click.echo("") + + try: + cli.validate_settings_for_llm(settings) + except (click.UsageError, click.ClickException, SystemExit) as exc: + click.echo( + f"Grounding files have drifted but no LLM is configured to repair them: {exc}", + err=True, + ) + return + + if not click.confirm( + "Repair grounding files with the configured LLM?", default=False + ): + return + + resolved_model = settings.llm.model + try: + result = asyncio.run( + _run_grounding_repair( + project_dir, old_schema, new_schema, drift, settings, resolved_model + ) + ) + except Exception as exc: # noqa: BLE001 — broad on purpose; surface to user + click.echo(f"Repair failed: {exc}", err=True) + return + + if not result.any_changes: + click.echo("LLM proposed no changes.") + return + + click.echo("") + click.echo(format_repair_summary(result)) + for f in result.files: + if not f.changed: + continue + click.echo("") + click.echo(f.unified_diff(), nl=False) + + if not result.overall_ok: + click.echo("") + click.echo( + "Some proposed files failed validation after retries. Skipping apply; " + "edit the files manually using the diff above as a starting point.", + err=True, + ) + return + + click.echo("") + if not click.confirm("Apply this diff?", default=False): + return + + written = write_repair_atomic(result, Path(project_dir)) + for p in written: + click.echo(f"Wrote {p}") + + +async def _run_grounding_repair( + project_dir: str, + old_schema: dict[str, set[str]], + new_schema: dict[str, set[str]], + drift, + settings: Any, + resolved_model: str, +): + """Wire up the LLM client + SQL runner the repair library needs.""" + llm_client = cli.create_llm_client( + provider=settings.llm.provider, + api_key=settings.llm.api_key, + base_url=settings.llm.base_url, + timeout=settings.llm.timeout, + model=resolved_model, + ) + try: + sql_runner = create_sql_runner_from_settings(settings.database, project_dir) + return await repair_grounding( + Path(project_dir), + old_schema, + new_schema, + drift, + llm_client=llm_client, + model=resolved_model, + run_sql=sql_runner.run_sql, + ) + finally: + await llm_client.aclose() + def _propose_via_llm( project_dir: str, settings: Any, source_table: str | None, sample_rows: int diff --git a/src/datasight/cli_commands/verify.py b/src/datasight/cli_commands/verify.py index 3b00cfc2..a7ebfdc8 100644 --- a/src/datasight/cli_commands/verify.py +++ b/src/datasight/cli_commands/verify.py @@ -5,9 +5,16 @@ import sys from pathlib import Path +import duckdb import rich_click as click from datasight.config import create_sql_runner_from_settings +from datasight.grounding import ( + build_enum_values_sync, + build_schema_truth_sync, + check_grounding_drift, + format_drift_report, +) from datasight import cli from datasight.cli_helpers import format_epilog @@ -19,6 +26,7 @@ Examples: datasight verify + datasight verify --static-only datasight verify --queries verification.yaml datasight verify --model gpt-4o @@ -50,16 +58,78 @@ default=None, help="Path to queries YAML file (default: queries.yaml in project dir).", ) -def verify(project_dir, model, queries_path): # noqa: C901 +@click.option( + "--static-only", + is_flag=True, + default=False, + help=( + "Run only the cheap schema-drift check (no LLM, no query execution). " + "Reports unresolved column/table references in queries.yaml, " + "schema_description.md, and time_series.yaml against the live DB." + ), +) +@click.option( + "--skip-grounding-check", + is_flag=True, + default=False, + help="Skip the static drift check that normally runs before the LLM phase.", +) +def verify(project_dir, model, queries_path, static_only, skip_grounding_check): # noqa: C901 """Verify LLM-generated SQL against expected results. Runs each question from queries.yaml through the full LLM pipeline, executes the generated SQL, and compares results against expected values. Use this to validate correctness across different models and providers. + + Before the LLM phase, runs a static schema-drift check that flags + references to columns or tables that no longer exist in the live + database. ``--static-only`` skips the LLM phase entirely; + ``--skip-grounding-check`` skips the static check. """ project_dir = str(Path(project_dir).resolve()) + # Static drift check first. Cheap, no LLM, no async — runs against a + # direct DuckDB connection. For non-DuckDB backends we skip the + # static check (information_schema.columns availability varies) and + # rely on the LLM-phase query execution to surface drift. + if not skip_grounding_check: + settings_preflight, _ = cli.resolve_settings(project_dir, model) + if settings_preflight.database.mode == "duckdb": + resolved_db_path = cli.resolve_db_path(settings_preflight, project_dir) + if resolved_db_path and os.path.exists(resolved_db_path): + conn = duckdb.connect(resolved_db_path, read_only=True) + try: + truth = build_schema_truth_sync(conn) + enums = build_enum_values_sync(conn, truth) + finally: + conn.close() + report = check_grounding_drift( + Path(project_dir), truth, enum_values=enums + ) + if not report.is_clean: + click.echo(format_drift_report(report), err=True) + click.echo("", err=True) + click.echo( + "Drift found. Run `datasight tidy review` to repair, " + "or pass `--skip-grounding-check` to proceed anyway.", + err=True, + ) + if static_only: + sys.exit(1) + # Continue into the LLM phase; the user can opt out via + # Ctrl-C if they didn't want to burn tokens. + elif static_only: + click.echo("grounding clean: no drift detected.") + sys.exit(0) + elif static_only: + click.echo( + "--static-only requires DuckDB; database.mode is " + f"{settings_preflight.database.mode!r}.", + err=True, + ) + sys.exit(2) + # Load queries from datasight.config import load_example_queries diff --git a/src/datasight/grounding.py b/src/datasight/grounding.py new file mode 100644 index 00000000..df15c7be --- /dev/null +++ b/src/datasight/grounding.py @@ -0,0 +1,511 @@ +"""Cheap schema-drift check for grounding files. + +The LLM agent loop is steered by three files that live in the project +directory: + +- ``queries.yaml`` — few-shot SQL examples +- ``schema_description.md`` — prose schema description +- ``time_series.yaml`` — temporal-structure declarations + +Any of these can fall out of sync with the live database after a schema +change (for example a ``datasight tidy review`` that reshapes a wide +table into long form). When that happens the LLM is silently being +trained on wrong column names: the agent either hallucinates plausible +columns, refuses with a citation of the stale grounding, or returns +``SELECT`` results full of zeros. + +This module performs a no-LLM, AST-driven check that every column and +table reference in the grounding files resolves against the current +schema. It is intentionally cheap — the goal is to surface drift loudly +at known checkpoints (``datasight verify``, the post-apply step of +``datasight tidy review``) rather than at every agent invocation. + +The companion module :mod:`datasight.grounding_repair` consumes a +:class:`DriftReport` from here and uses an LLM to rewrite the affected +files; this module never invokes an LLM. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from difflib import get_close_matches +from pathlib import Path +from typing import Any, Awaitable, Callable + +import sqlglot +import yaml +from sqlglot import exp + + +# Backtick-quoted lowercase identifier in markdown. Matches ``foo``, +# ``foo_bar``, ``foo.bar``. Anything that isn't a snake_case identifier +# inside backticks (prose, multiword content, SQL fragments) is ignored. +_MD_BACKTICK_IDENT = re.compile(r"`([a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)?)`") + + +# SQL keywords / built-ins that may legitimately appear in backticks +# inside ``schema_description.md`` — used to suppress false positives +# from the markdown scan. Not exhaustive: only the words that show up in +# prose. Anything not here AND not in the current schema gets flagged. +_SQL_KEYWORDS: frozenset[str] = frozenset({ + "all", "and", "as", "asc", "avg", "between", "boolean", "by", + "case", "cast", "ceil", "coalesce", "corr", "count", "current_date", + "date", "date_trunc", "datetime", "day", "desc", "distinct", + "double", "else", "end", "extract", "false", "floor", "from", + "group", "having", "in", "inner", "integer", "is", "join", "left", + "limit", "max", "min", "month", "not", "now", "null", "offset", + "on", "or", "order", "outer", "over", "regr_intercept", "regr_r2", + "regr_slope", "right", "round", "row_number", "select", "sum", + "then", "timestamp", "to_date", "true", "union", "varchar", + "when", "where", "with", "year", +}) + + +_DEFAULT_GROUNDING_FILES = ("queries.yaml", "schema_description.md", "time_series.yaml") + + +@dataclass +class DriftItem: + """One finding: a claim in a grounding file that doesn't resolve. + + Attributes + ---------- + file : str + Source file path (string for serializability). + line : int | None + 1-based line number when the source format makes it easy to find + (markdown). ``None`` for YAML SQL bodies where the line of the + offending token is harder to localize without a full AST walk. + kind : str + One of ``"table"``, ``"column"``, ``"ts_table"``, ``"ts_column"``, + ``"parse_error"``. + claim : str + The identifier as it appeared in the source. + detail : str + Human-readable explanation suitable for terminal output. + suggestion : str | None + Nearest match in the current schema by edit distance, if one is + within the similarity cutoff. + """ + + file: str + line: int | None + kind: str + claim: str + detail: str + suggestion: str | None = None + + +@dataclass +class DriftReport: + """Result of a grounding-drift check. + + A ``DriftReport`` with no items means every grounding-file claim + resolved against the live schema. The grouping helpers exist so + formatters can render findings per-file without re-iterating. + """ + + items: list[DriftItem] = field(default_factory=list) + + @property + def is_clean(self) -> bool: + """True if no drift was detected.""" + return not self.items + + def by_file(self) -> dict[str, list[DriftItem]]: + """Group items by source file, preserving insertion order.""" + out: dict[str, list[DriftItem]] = {} + for item in self.items: + out.setdefault(item.file, []).append(item) + return out + + +def check_grounding_drift( + project_dir: Path, + schema_truth: dict[str, set[str]], + *, + enum_values: set[str] | None = None, + queries_path: Path | None = None, + schema_description_path: Path | None = None, + time_series_path: Path | None = None, +) -> DriftReport: + """Check grounding files for references the current schema can't resolve. + + Parameters + ---------- + project_dir : Path + Project directory; default file paths are resolved relative to it. + schema_truth : dict[str, set[str]] + ``{table_name: set(column_names)}`` for the current live schema. + Build this with :func:`build_schema_truth_sync` or + :func:`build_schema_truth_async`. + enum_values : set[str] | None + Optional allowlist of legal values (distinct VARCHAR values from + the live schema). When the markdown scan finds a backticked + identifier that matches one of these, it is treated as an enum + value rather than a missing-column candidate. Build this with + :func:`build_enum_values_sync` to suppress false positives from + prose like "Values: \\`east_north_central\\`, ...". + queries_path, schema_description_path, time_series_path : Path | None + Override default file locations. Missing files are silently + skipped — a project that hasn't added one of the grounding files + yet should not produce drift items. + + Returns + ------- + DriftReport + ``is_clean`` is True when no drift was detected. + """ + report = DriftReport() + qpath = queries_path or project_dir / "queries.yaml" + sdpath = schema_description_path or project_dir / "schema_description.md" + tspath = time_series_path or project_dir / "time_series.yaml" + + if qpath.exists(): + _check_queries(qpath, schema_truth, report) + if sdpath.exists(): + _check_schema_description(sdpath, schema_truth, enum_values or set(), report) + if tspath.exists(): + _check_time_series(tspath, schema_truth, report) + + return report + + +def _check_queries( + path: Path, schema_truth: dict[str, set[str]], report: DriftReport +) -> None: + """Walk every ``sql:`` block, flag unresolved table/column references.""" + text = path.read_text(encoding="utf-8") + try: + docs = yaml.safe_load(text) or [] + except yaml.YAMLError as exc: + report.items.append(DriftItem( + file=str(path), line=None, kind="parse_error", + claim="", detail=f"yaml parse error: {exc}", + )) + return + if not isinstance(docs, list): + return + + all_tables = set(schema_truth.keys()) + all_columns: set[str] = set() + for cols in schema_truth.values(): + all_columns |= cols + + for entry in docs: + if not isinstance(entry, dict): + continue + sql = entry.get("sql", "") + if not sql: + continue + try: + parsed = sqlglot.parse(sql, read="duckdb") + except sqlglot.errors.ParseError: + # Out of scope here — ``sql_validation.py`` covers parse errors. + continue + + # Collect CTE names so we don't flag them as missing tables. + cte_names: set[str] = set() + for stmt in parsed: + if stmt is None: + continue + for cte in stmt.find_all(exp.CTE): + cte_names.add(cte.alias_or_name) + + # Also collect output aliases (``AS alias``) so we don't flag + # them when they're referenced later in the same query (e.g. in + # an ``ORDER BY`` clause on a computed column). + output_aliases: set[str] = set() + for stmt in parsed: + if stmt is None: + continue + for alias in stmt.find_all(exp.Alias): + a = alias.alias_or_name + if a: + output_aliases.add(a) + + for stmt in parsed: + if stmt is None: + continue + for tref in stmt.find_all(exp.Table): + name = tref.name + if not name or name in cte_names or name in all_tables: + continue + report.items.append(DriftItem( + file=str(path), line=None, kind="table", + claim=name, + detail=f"table '{name}' not in current schema", + suggestion=_nearest(name, all_tables), + )) + for cref in stmt.find_all(exp.Column): + name = cref.name + if not name or name in all_columns or name in output_aliases: + continue + # Qualified-but-unknown table prefixes are caught by the + # table check above; ignore the column part to avoid + # double-reporting. + if cref.table and cref.table not in all_tables: + continue + report.items.append(DriftItem( + file=str(path), line=None, kind="column", + claim=name, + detail=f"column '{name}' not in any table", + suggestion=_nearest(name, all_columns), + )) + + +def _check_schema_description( + path: Path, + schema_truth: dict[str, set[str]], + enum_values: set[str], + report: DriftReport, +) -> None: + """Flag backticked identifiers in markdown that don't resolve.""" + all_tables = set(schema_truth.keys()) + all_columns: set[str] = set() + for cols in schema_truth.values(): + all_columns |= cols + known = all_tables | all_columns | _SQL_KEYWORDS | enum_values + + text = path.read_text(encoding="utf-8") + seen_on_line: set[tuple[int, str]] = set() + for lineno, line in enumerate(text.splitlines(), start=1): + for m in _MD_BACKTICK_IDENT.finditer(line): + ident = m.group(1).lower() + + # ``table.column`` — check the column against that table's + # column list. Unknown tables are silently ignored here to + # avoid noise from prose that mentions tables from other DBs. + parts = ident.split(".") + if len(parts) == 2: + table, col = parts + if table not in all_tables: + continue + if col not in schema_truth.get(table, set()): + key = (lineno, ident) + if key in seen_on_line: + continue + seen_on_line.add(key) + suggestion = _nearest(col, schema_truth.get(table, set())) + report.items.append(DriftItem( + file=str(path), line=lineno, kind="column", + claim=ident, + detail=f"`{ident}` not a column of '{table}'", + suggestion=f"{table}.{suggestion}" if suggestion else None, + )) + continue + + if ident in known: + continue + # Identifier heuristics: snake_case, not all digits, length >= 3. + # The minimum length suppresses common English words ("on", + # "is", "as") that happen to slip through the keyword set. + if not re.match(r"^[a-z][a-z0-9_]*$", ident) or len(ident) < 3: + continue + key = (lineno, ident) + if key in seen_on_line: + continue + seen_on_line.add(key) + report.items.append(DriftItem( + file=str(path), line=lineno, kind="column", + claim=ident, + detail=f"`{ident}` not in current schema (column or table)", + suggestion=_nearest(ident, all_columns | all_tables), + )) + + +def _check_time_series( + path: Path, schema_truth: dict[str, set[str]], report: DriftReport +) -> None: + """Verify each entry's ``table`` / ``timestamp_column`` / ``group_columns``.""" + text = path.read_text(encoding="utf-8") + try: + docs = yaml.safe_load(text) or [] + except yaml.YAMLError as exc: + report.items.append(DriftItem( + file=str(path), line=None, kind="parse_error", + claim="", detail=f"yaml parse error: {exc}", + )) + return + if not isinstance(docs, list): + return + + for entry in docs: + if not isinstance(entry, dict): + continue + table = entry.get("table") + if not table: + continue + if table not in schema_truth: + report.items.append(DriftItem( + file=str(path), line=None, kind="ts_table", + claim=str(table), + detail=f"time_series table '{table}' not in current schema", + suggestion=_nearest(str(table), set(schema_truth.keys())), + )) + continue + ts_col = entry.get("timestamp_column") + if ts_col and ts_col not in schema_truth[table]: + report.items.append(DriftItem( + file=str(path), line=None, kind="ts_column", + claim=str(ts_col), + detail=f"time_series timestamp_column '{ts_col}' not a column of '{table}'", + suggestion=_nearest(str(ts_col), schema_truth[table]), + )) + for col in entry.get("group_columns") or []: + if col not in schema_truth[table]: + report.items.append(DriftItem( + file=str(path), line=None, kind="ts_column", + claim=str(col), + detail=f"time_series group_column '{col}' not a column of '{table}'", + suggestion=_nearest(str(col), schema_truth[table]), + )) + + +def _nearest(claim: str, candidates: set[str]) -> str | None: + """Closest match by edit distance, or None when nothing's similar enough.""" + if not candidates: + return None + matches = get_close_matches(claim, list(candidates), n=1, cutoff=0.6) + return matches[0] if matches else None + + +def build_enum_values_sync( + conn: Any, schema_truth: dict[str, set[str]], *, max_per_column: int = 200 +) -> set[str]: + """Collect distinct values from low-cardinality VARCHAR columns. + + Used to suppress false positives in the markdown drift scan: a + prose listing like "Values: \\`east_north_central\\`, ..." would + otherwise be flagged as references to missing columns. Columns with + more than ``max_per_column`` distinct values are skipped — those + aren't enums and aren't worth scanning for. + + Parameters + ---------- + conn : duckdb.DuckDBPyConnection + Open DuckDB connection. + schema_truth : dict[str, set[str]] + Output of :func:`build_schema_truth_sync`. The function reads + column types from ``information_schema.columns`` and queries + only those typed VARCHAR/STRING. + max_per_column : int + Skip columns that exceed this distinct-value count. Defaults + to 200, which fits typical enum-shaped columns (regions, + subsectors, status codes) and skips free-text columns. + + Returns + ------- + set[str] + Distinct string values across all qualifying columns. Values + are lowercased to match the case used by the markdown scan. + """ + out: set[str] = set() + rows = conn.execute( + "SELECT table_name, column_name, data_type " + "FROM information_schema.columns " + "WHERE table_schema = current_schema() " + "AND table_catalog = current_database()" + ).fetchall() + for table, col, dtype in rows: + if table not in schema_truth or col not in schema_truth[table]: + continue + if "char" not in str(dtype).lower() and "string" not in str(dtype).lower(): + continue + try: + count = conn.execute( + f"SELECT COUNT(DISTINCT {col}) FROM {table}" + ).fetchone() + except Exception: # noqa: BLE001 — never let one bad column abort the whole scan + continue + if count is None or count[0] > max_per_column: + continue + try: + values = conn.execute( + f"SELECT DISTINCT {col} FROM {table} WHERE {col} IS NOT NULL" + ).fetchall() + except Exception: # noqa: BLE001 + continue + for (v,) in values: + if isinstance(v, str): + out.add(v.lower()) + return out + + +def build_schema_truth_sync(conn: Any) -> dict[str, set[str]]: + """Build the ``{table: set(columns)}`` truth set from a sync DuckDB conn. + + Filters to user-visible objects in the current schema/database so a + ``schema_description.md`` reference can't be silently validated + against a same-named column in an attached DB. + + Parameters + ---------- + conn : duckdb.DuckDBPyConnection + Open DuckDB connection. + + Returns + ------- + dict[str, set[str]] + ``{table_name: {column_name, ...}}``. + """ + rows = conn.execute( + "SELECT table_name, column_name " + "FROM information_schema.columns " + "WHERE table_schema = current_schema() " + "AND table_catalog = current_database()" + ).fetchall() + out: dict[str, set[str]] = {} + for table, col in rows: + out.setdefault(table, set()).add(col) + return out + + +async def build_schema_truth_async( + run_sql: Callable[[str], Awaitable[Any]], +) -> dict[str, set[str]]: + """Build the truth set from the async ``run_sql`` callable used by datasight. + + Uses the same ``information_schema.columns`` query as the sync + variant; works against DuckDB and PostgreSQL alike. For SQLite, + callers should construct the dict from the existing + :func:`datasight.schema.introspect_schema` result instead — SQLite + has no ``information_schema``. + """ + df = await run_sql( + "SELECT table_name, column_name " + "FROM information_schema.columns " + "WHERE table_schema NOT IN ('information_schema', 'pg_catalog')" + ) + out: dict[str, set[str]] = {} + for _, row in df.iterrows(): + out.setdefault(row["table_name"], set()).add(row["column_name"]) + return out + + +def format_drift_report( + report: DriftReport, *, max_items_per_file: int = 20 +) -> str: + """Render a DriftReport as a multi-line string for terminal output. + + Truncates per-file listings beyond ``max_items_per_file`` with a + summary line so a wholesale schema change doesn't dump hundreds of + items into the terminal. + """ + if report.is_clean: + return "grounding clean: no drift detected." + parts: list[str] = [ + f"grounding drift: {len(report.items)} reference(s) don't resolve", + "", + ] + for file, items in report.by_file().items(): + parts.append(f" {file}:") + for item in items[:max_items_per_file]: + loc = f":{item.line}" if item.line else "" + sug = f" (did you mean: {item.suggestion}?)" if item.suggestion else "" + parts.append(f" {loc:<5} {item.kind:<10} {item.claim!r}{sug}") + if len(items) > max_items_per_file: + parts.append(f" ... and {len(items) - max_items_per_file} more") + parts.append("") + return "\n".join(parts).rstrip() diff --git a/src/datasight/grounding_repair.py b/src/datasight/grounding_repair.py new file mode 100644 index 00000000..7a0f2d90 --- /dev/null +++ b/src/datasight/grounding_repair.py @@ -0,0 +1,429 @@ +"""LLM-driven repair of grounding files after a schema-changing transform. + +This module is invoked from ``datasight tidy review`` (or the web UI +equivalent) after :func:`datasight.tidy_review.apply_proposal` reshapes +a table. The drift detector in :mod:`datasight.grounding` has already +told us *that* the grounding files are stale; here we ask the LLM to +rewrite them. + +Design +------ +The repair flow is deliberately conservative: + +1. The caller snapshots the schema **before** the tidy transform, so + the LLM can see both old and new structure. Without the before + snapshot the prompt degenerates to "rewrite from scratch" and loses + any human customizations. +2. The LLM is asked to return a single JSON object keyed by filename. + No unified-diff parsing; the full proposed file contents are + returned and we compute diffs locally. +3. Every SQL example in the proposed ``queries.yaml`` is executed + against the live database before the result is shown to the user. + Failures trigger up to ``max_retries`` LLM retries with the error + context attached. If a query still fails after retries, the + :class:`RepairFile` records the validation errors and the orchestrator + decides whether to fall through to manual edit mode. +4. Nothing is written to disk inside this module. The orchestrator in + the CLI flow shows the diff, prompts for confirmation, and only then + calls :func:`write_repair_atomic`. + +The atomic-write helper writes each accepted file via a sibling +``.new`` tempfile + ``os.replace``, so an interrupted repair can't +half-corrupt the grounding. +""" + +from __future__ import annotations + +import difflib +import json +import os +import re +import tempfile +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Awaitable, Callable + +import yaml +from loguru import logger + +from datasight.grounding import DriftReport +from datasight.llm import LLMClient, TextBlock + + +# Files the repair flow may touch. Other grounding-adjacent files +# (schema.yaml, measures.yaml) are owned by tidy_review's own update +# helpers and stay out of scope here. +REPAIR_FILE_NAMES: tuple[str, ...] = ( + "queries.yaml", + "schema_description.md", + "time_series.yaml", +) + + +# Match an opening ```json or ``` fence, then anything (non-greedy) up +# to the closing ```. Used to peel a JSON object out of an LLM response +# that wraps it in a markdown code block. +_FENCED_JSON = re.compile(r"```(?:json)?\s*\n(.*?)\n```", re.DOTALL) + + +@dataclass +class RepairFile: + """One file's before/after content and validation outcome.""" + + name: str + path: Path + old_text: str + new_text: str + validation_errors: list[str] = field(default_factory=list) + + @property + def changed(self) -> bool: + """True if the LLM proposed any change to this file.""" + return self.new_text != self.old_text + + @property + def ok(self) -> bool: + """True if the file is unchanged or its changes validated cleanly.""" + return not self.validation_errors + + def unified_diff(self) -> str: + """Unified diff between ``old_text`` and ``new_text``.""" + return "".join( + difflib.unified_diff( + self.old_text.splitlines(keepends=True), + self.new_text.splitlines(keepends=True), + fromfile=f"a/{self.name}", + tofile=f"b/{self.name}", + ) + ) + + +@dataclass +class RepairResult: + """Outcome of one repair attempt, before any files are written.""" + + files: list[RepairFile] + llm_retries: int = 0 + + @property + def overall_ok(self) -> bool: + """True if every changed file's validation came back clean.""" + return all(f.ok for f in self.files) + + @property + def any_changes(self) -> bool: + """True if the LLM proposed changes to at least one file.""" + return any(f.changed for f in self.files) + + +async def repair_grounding( + project_dir: Path, + old_schema: dict[str, set[str]], + new_schema: dict[str, set[str]], + drift: DriftReport, + *, + llm_client: LLMClient, + model: str, + run_sql: Callable[[str], Awaitable[Any]], + max_tokens: int = 16384, + max_retries: int = 2, +) -> RepairResult: + """Ask the LLM to rewrite grounding files so they match ``new_schema``. + + Does not write to disk. The caller is responsible for showing the + diff, prompting for confirmation, and calling + :func:`write_repair_atomic` to apply. + + Parameters + ---------- + project_dir : Path + Directory containing the grounding files. + old_schema, new_schema : dict[str, set[str]] + Schemas before and after the transform that triggered repair. + drift : DriftReport + Output of :func:`datasight.grounding.check_grounding_drift` + against the new schema. + llm_client : LLMClient + Same client the agent loop uses; this repair is a one-shot + text completion (no tool use). + model : str + Model name passed through to ``llm_client.create_message``. + run_sql : async callable + Async SQL runner used to validate each proposed query. + max_tokens : int + Output budget for the LLM call. Defaults to 16 384, which is + enough to rewrite a schema_description.md plus queries.yaml + for most projects. + max_retries : int + How many times to re-prompt the LLM with validation-error + context if proposed queries don't execute. Defaults to 2 (so + up to 3 LLM calls in total). + """ + files = _load_repair_files(project_dir) + prompt = _build_repair_prompt(old_schema, new_schema, drift, files) + system = _SYSTEM_PROMPT + + proposed: dict[str, str] | None = None + last_error: str | None = None + retries = 0 + for attempt in range(max_retries + 1): + user_prompt = prompt if last_error is None else ( + f"{prompt}\n\nYour previous response had validation errors. " + f"Fix them and return the full corrected JSON object:\n\n{last_error}" + ) + response = await llm_client.create_message( + model=model, + system=system, + messages=[{"role": "user", "content": user_prompt}], + tools=[], + max_tokens=max_tokens, + ) + text = "".join(b.text for b in response.content if isinstance(b, TextBlock)) + try: + proposed = _parse_repair_json(text) + except ValueError as exc: + last_error = f"Could not parse JSON from your response: {exc}" + retries = attempt + 1 + logger.warning(f"repair attempt {attempt + 1}: {last_error}") + continue + + # Apply proposed contents to RepairFile objects and validate. + for f in files: + if f.name in proposed: + f.new_text = proposed[f.name] + f.validation_errors = [] + await _validate_repair(files, run_sql=run_sql) + + if all(f.ok for f in files): + return RepairResult(files=files, llm_retries=attempt) + + # Build a summarized error report for the next retry. + error_lines: list[str] = [] + for f in files: + for err in f.validation_errors: + error_lines.append(f"- {f.name}: {err}") + last_error = "\n".join(error_lines) + retries = attempt + 1 + logger.warning( + f"repair attempt {attempt + 1}: {len(error_lines)} validation error(s)" + ) + + # Out of retries — return the last attempt with its errors so the + # caller can fall back to manual edit mode. + return RepairResult(files=files, llm_retries=retries) + + +def write_repair_atomic(result: RepairResult, project_dir: Path) -> list[Path]: + """Write each changed, validated file via tempfile + ``os.replace``. + + Only files where ``RepairFile.changed`` and ``RepairFile.ok`` are + True are written. The function fsyncs each temp file before the + rename so an interrupted run can't leave partial content visible. + + Parameters + ---------- + result : RepairResult + Output of :func:`repair_grounding`. + project_dir : Path + Project directory. Files are written under their stored + ``RepairFile.path`` which is rooted here. + + Returns + ------- + list[Path] + Paths that were written. Empty list when nothing changed. + """ + written: list[Path] = [] + for f in result.files: + if not f.changed or not f.ok: + continue + parent = f.path.parent + parent.mkdir(parents=True, exist_ok=True) + # NamedTemporaryFile with delete=False so we can keep the path + # after closing; os.replace then atomically swaps it in. + with tempfile.NamedTemporaryFile( + mode="w", encoding="utf-8", dir=parent, delete=False, prefix=".grounding-", + ) as tmp: + tmp.write(f.new_text) + tmp.flush() + os.fsync(tmp.fileno()) + tmp_path = Path(tmp.name) + try: + os.replace(tmp_path, f.path) + except OSError: + tmp_path.unlink(missing_ok=True) + raise + written.append(f.path) + return written + + +def format_repair_summary(result: RepairResult) -> str: + """One-paragraph summary suitable for terminal output before the diff.""" + changed = [f for f in result.files if f.changed] + if not changed: + return "Repair: no files changed." + parts = [f"Repair: LLM proposed changes to {len(changed)} file(s)"] + if result.llm_retries: + parts.append(f"after {result.llm_retries} retry/retries") + parts.append(":") + for f in changed: + status = "ok" if f.ok else f"FAILED ({len(f.validation_errors)} error(s))" + parts.append(f" {f.name} [{status}]") + return "\n".join(parts) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +_SYSTEM_PROMPT = ( + "You are rewriting grounding files for a SQL-aware LLM agent so that every " + "column and table reference resolves against the current database schema. " + "Preserve human prose, comments, structure, and any example that still " + "applies. Do not invent new sections or example questions. Every SQL " + "snippet in the rewritten queries.yaml must execute successfully against " + "the NEW schema. Reply with a single JSON object whose keys are the " + "filenames you are rewriting and whose values are the full new file " + "contents as strings. Do not include any prose outside the JSON." +) + + +def _build_repair_prompt( + old_schema: dict[str, set[str]], + new_schema: dict[str, set[str]], + drift: DriftReport, + files: list[RepairFile], +) -> str: + """Compose the user-side prompt sent to the repair LLM.""" + sections: list[str] = [] + sections.append("OLD SCHEMA (before the transform):") + sections.append(_schema_text(old_schema)) + sections.append("") + sections.append("NEW SCHEMA (current state):") + sections.append(_schema_text(new_schema)) + sections.append("") + sections.append("DRIFT DETECTED:") + if drift.is_clean: + sections.append("(no drift; only schema-shape changes)") + else: + for item in drift.items: + loc = f":{item.line}" if item.line else "" + sug = f" (suggested: {item.suggestion})" if item.suggestion else "" + sections.append(f" - {item.file}{loc} {item.kind} {item.claim!r}{sug}") + sections.append("") + sections.append("FILES TO REWRITE:") + sections.append("") + for f in files: + sections.append(f"--- {f.name} ---") + sections.append(f.old_text) + sections.append("") + sections.append( + "Return a single JSON object. Keys: filenames. Values: full new " + "contents. Files that need no changes may be omitted." + ) + return "\n".join(sections) + + +def _schema_text(schema: dict[str, set[str]]) -> str: + """Compact human-readable rendering of a {table: columns} dict.""" + if not schema: + return "(empty)" + lines: list[str] = [] + for table in sorted(schema.keys()): + cols = sorted(schema[table]) + lines.append(f" {table}({', '.join(cols)})") + return "\n".join(lines) + + +def _load_repair_files(project_dir: Path) -> list[RepairFile]: + """Read each in-scope grounding file. Missing files are skipped.""" + out: list[RepairFile] = [] + for name in REPAIR_FILE_NAMES: + path = project_dir / name + if not path.exists(): + continue + text = path.read_text(encoding="utf-8") + out.append(RepairFile(name=name, path=path, old_text=text, new_text=text)) + return out + + +def _parse_repair_json(text: str) -> dict[str, str]: + """Extract a JSON object from the LLM's text response. + + Accepts the JSON either bare or wrapped in a ```json fenced block. + Raises ``ValueError`` with a short reason on parse failure. + """ + candidate = text.strip() + fence_match = _FENCED_JSON.search(candidate) + if fence_match: + candidate = fence_match.group(1).strip() + # The bare-JSON case: find the first ``{`` and parse forward. + if not candidate.startswith("{"): + start = candidate.find("{") + if start == -1: + msg = "no JSON object found in LLM response" + raise ValueError(msg) + candidate = candidate[start:] + try: + parsed = json.loads(candidate) + except json.JSONDecodeError as exc: + msg = f"invalid JSON: {exc}" + raise ValueError(msg) from exc + if not isinstance(parsed, dict): + msg = f"expected JSON object, got {type(parsed).__name__}" + raise ValueError(msg) + out: dict[str, str] = {} + for key, value in parsed.items(): + if not isinstance(value, str): + msg = f"value for {key!r} is not a string" + raise ValueError(msg) + if key in REPAIR_FILE_NAMES: + out[key] = value + # Unknown keys are silently ignored — the LLM may include + # commentary fields that we don't want to write to disk. + return out + + +async def _validate_repair( + files: list[RepairFile], + *, + run_sql: Callable[[str], Awaitable[Any]], +) -> None: + """Run each SQL example from the proposed queries.yaml against the DB. + + Records any execution errors on the corresponding ``RepairFile``. + Also catches yaml parse failures for queries.yaml and time_series.yaml. + """ + for f in files: + if not f.changed: + continue + if f.name == "queries.yaml": + try: + docs = yaml.safe_load(f.new_text) or [] + except yaml.YAMLError as exc: + f.validation_errors.append(f"yaml parse error: {exc}") + continue + if not isinstance(docs, list): + f.validation_errors.append("expected a top-level YAML list") + continue + for i, entry in enumerate(docs, start=1): + if not isinstance(entry, dict): + continue + sql = entry.get("sql") + if not sql: + continue + try: + await run_sql(sql) + except Exception as exc: # noqa: BLE001 + question = entry.get("question", "(no question)") + f.validation_errors.append( + f"query {i} ({question!r}) failed: {exc}" + ) + elif f.name == "time_series.yaml": + try: + yaml.safe_load(f.new_text) + except yaml.YAMLError as exc: + f.validation_errors.append(f"yaml parse error: {exc}") + # schema_description.md has no executable content; trust the + # markdown parse to be valid by inspection. diff --git a/tests/test_grounding.py b/tests/test_grounding.py new file mode 100644 index 00000000..65a2e2fb --- /dev/null +++ b/tests/test_grounding.py @@ -0,0 +1,272 @@ +"""Tests for the grounding-drift detector.""" + +from __future__ import annotations + +import textwrap +from pathlib import Path + +import duckdb +import pytest + +from datasight.grounding import ( + DriftItem, + DriftReport, + build_enum_values_sync, + build_schema_truth_sync, + check_grounding_drift, + format_drift_report, +) + + +def _make_db(tmp_path: Path, rows: list[tuple]) -> str: + """Build a tiny long-format DuckDB and return its path.""" + db_path = tmp_path / "test.duckdb" + conn = duckdb.connect(str(db_path)) + conn.execute( + "CREATE TABLE load_data " + "(geography VARCHAR, fuel_type VARCHAR, end_use VARCHAR, " + "time_year BIGINT, energy_mwh DOUBLE)" + ) + for row in rows: + conn.execute( + "INSERT INTO load_data VALUES (?, ?, ?, ?, ?)", row + ) + conn.close() + return str(db_path) + + +def test_build_schema_truth_sync_returns_table_to_columns(): + conn = duckdb.connect(":memory:") + conn.execute("CREATE TABLE t1 (a INT, b VARCHAR)") + conn.execute("CREATE TABLE t2 (c DOUBLE)") + truth = build_schema_truth_sync(conn) + assert truth == {"t1": {"a", "b"}, "t2": {"c"}} + + +def test_build_enum_values_sync_collects_distinct_strings(tmp_path): + db_path = _make_db(tmp_path, [ + ("pacific", "elec", "heating", 2020, 1.0), + ("pacific", "ng", "cooling", 2020, 2.0), + ("south_atlantic", "elec", "heating", 2020, 3.0), + ]) + conn = duckdb.connect(db_path, read_only=True) + truth = build_schema_truth_sync(conn) + values = build_enum_values_sync(conn, truth) + assert "pacific" in values + assert "south_atlantic" in values + assert "elec" in values + assert "ng" in values + assert "heating" in values + assert "cooling" in values + + +def test_build_enum_values_sync_skips_high_cardinality(tmp_path): + db_path = tmp_path / "big.duckdb" + conn = duckdb.connect(str(db_path)) + conn.execute("CREATE TABLE t (label VARCHAR)") + for i in range(50): + conn.execute("INSERT INTO t VALUES (?)", (f"label_{i}",)) + conn.close() + conn = duckdb.connect(str(db_path), read_only=True) + truth = build_schema_truth_sync(conn) + values = build_enum_values_sync(conn, truth, max_per_column=20) + assert values == set() + + +def test_check_clean_grounding_reports_no_drift(tmp_path): + truth = {"load_data": {"geography", "fuel_type", "energy_mwh"}} + (tmp_path / "queries.yaml").write_text(textwrap.dedent(""" + - question: "Total energy by region" + sql: | + SELECT geography, SUM(energy_mwh) AS total + FROM load_data + WHERE fuel_type = 'elec' + GROUP BY geography; + """).strip()) + report = check_grounding_drift(tmp_path, truth, enum_values={"elec"}) + assert report.is_clean + + +def test_queries_yaml_missing_column_is_flagged(tmp_path): + truth = {"load_data": {"geography", "fuel_type", "energy_mwh"}} + (tmp_path / "queries.yaml").write_text(textwrap.dedent(""" + - question: "Stale" + sql: SELECT elec_heating FROM load_data; + """).strip()) + report = check_grounding_drift(tmp_path, truth) + assert not report.is_clean + assert any(item.claim == "elec_heating" for item in report.items) + + +def test_queries_yaml_missing_table_is_flagged(tmp_path): + truth = {"load_data": {"geography"}} + (tmp_path / "queries.yaml").write_text(textwrap.dedent(""" + - question: "Wrong table" + sql: SELECT geography FROM missing_table; + """).strip()) + report = check_grounding_drift(tmp_path, truth) + assert any( + item.kind == "table" and item.claim == "missing_table" + for item in report.items + ) + + +def test_queries_yaml_cte_name_is_not_flagged_as_missing_table(tmp_path): + truth = {"load_data": {"geography", "energy_mwh"}} + (tmp_path / "queries.yaml").write_text(textwrap.dedent(""" + - question: "CTE chain" + sql: | + WITH yearly AS ( + SELECT geography, SUM(energy_mwh) AS total FROM load_data GROUP BY geography + ) + SELECT * FROM yearly; + """).strip()) + report = check_grounding_drift(tmp_path, truth) + assert report.is_clean, [item.detail for item in report.items] + + +def test_queries_yaml_output_alias_is_not_flagged(tmp_path): + truth = {"load_data": {"geography", "energy_mwh"}} + (tmp_path / "queries.yaml").write_text(textwrap.dedent(""" + - question: "Aliased output" + sql: | + SELECT geography, SUM(energy_mwh) AS total_energy + FROM load_data + GROUP BY geography + ORDER BY total_energy DESC; + """).strip()) + report = check_grounding_drift(tmp_path, truth) + assert report.is_clean, [item.detail for item in report.items] + + +def test_schema_description_md_flags_missing_column_reference(tmp_path): + truth = {"load_data": {"geography", "fuel_type"}} + (tmp_path / "schema_description.md").write_text( + "# Schema\n\nThe `elec_heating` column tracks electricity heating.\n" + ) + report = check_grounding_drift(tmp_path, truth) + assert any( + item.claim == "elec_heating" and item.line == 3 + for item in report.items + ) + + +def test_schema_description_md_qualified_table_column_resolves(tmp_path): + truth = {"load_data": {"geography", "fuel_type"}} + (tmp_path / "schema_description.md").write_text( + "# Schema\n\nUse `load_data.geography` for region filtering.\n" + ) + report = check_grounding_drift(tmp_path, truth) + assert report.is_clean + + +def test_schema_description_md_qualified_unknown_column_is_flagged(tmp_path): + truth = {"load_data": {"geography"}} + (tmp_path / "schema_description.md").write_text( + "# Schema\n\nUse `load_data.elec_heating` for heating.\n" + ) + report = check_grounding_drift(tmp_path, truth) + assert any(item.claim == "load_data.elec_heating" for item in report.items) + + +def test_schema_description_md_enum_values_are_allowlisted(tmp_path): + truth = {"load_data": {"geography"}} + (tmp_path / "schema_description.md").write_text( + "# Schema\n\nValues: `pacific`, `mountain`, `new_england`.\n" + ) + report = check_grounding_drift( + tmp_path, truth, enum_values={"pacific", "mountain", "new_england"} + ) + assert report.is_clean + + +def test_time_series_yaml_missing_table_is_flagged(tmp_path): + truth = {"load_data": {"geography"}} + (tmp_path / "time_series.yaml").write_text(textwrap.dedent(""" + - table: missing_table + timestamp_column: ts + frequency: PT1H + """).strip()) + report = check_grounding_drift(tmp_path, truth) + assert any( + item.kind == "ts_table" and item.claim == "missing_table" + for item in report.items + ) + + +def test_time_series_yaml_missing_timestamp_column_is_flagged(tmp_path): + truth = {"load_data": {"geography"}} + (tmp_path / "time_series.yaml").write_text(textwrap.dedent(""" + - table: load_data + timestamp_column: ts + frequency: PT1H + """).strip()) + report = check_grounding_drift(tmp_path, truth) + assert any( + item.kind == "ts_column" and item.claim == "ts" + for item in report.items + ) + + +def test_time_series_yaml_missing_group_column_is_flagged(tmp_path): + truth = {"load_data": {"geography", "ts"}} + (tmp_path / "time_series.yaml").write_text(textwrap.dedent(""" + - table: load_data + timestamp_column: ts + group_columns: [geography, missing_dim] + frequency: PT1H + """).strip()) + report = check_grounding_drift(tmp_path, truth) + assert any( + item.kind == "ts_column" and item.claim == "missing_dim" + for item in report.items + ) + + +def test_missing_grounding_files_are_silently_skipped(tmp_path): + truth = {"load_data": {"geography"}} + report = check_grounding_drift(tmp_path, truth) + assert report.is_clean + + +def test_format_drift_report_shows_per_file_breakdown(): + report = DriftReport(items=[ + DriftItem( + file="a/queries.yaml", line=None, kind="column", + claim="foo", detail="foo not found", suggestion="bar", + ), + DriftItem( + file="a/schema_description.md", line=10, kind="column", + claim="baz", detail="baz not found", + ), + ]) + text = format_drift_report(report) + assert "queries.yaml" in text + assert "schema_description.md" in text + assert "foo" in text + assert "did you mean: bar" in text + assert ":10" in text + + +def test_format_drift_report_clean_returns_clean_message(): + text = format_drift_report(DriftReport()) + assert "no drift" in text.lower() + + +def test_drift_report_by_file_groups_items(): + report = DriftReport(items=[ + DriftItem(file="x", line=None, kind="column", claim="a", detail=""), + DriftItem(file="y", line=None, kind="column", claim="b", detail=""), + DriftItem(file="x", line=None, kind="column", claim="c", detail=""), + ]) + grouped = report.by_file() + assert list(grouped.keys()) == ["x", "y"] + assert len(grouped["x"]) == 2 + assert len(grouped["y"]) == 1 + + +def test_queries_yaml_with_invalid_yaml_reports_parse_error(tmp_path): + truth = {"load_data": {"x"}} + (tmp_path / "queries.yaml").write_text("- this: is\n not: valid: yaml:") + report = check_grounding_drift(tmp_path, truth) + assert any(item.kind == "parse_error" for item in report.items) diff --git a/tests/test_grounding_repair.py b/tests/test_grounding_repair.py new file mode 100644 index 00000000..d9679bfd --- /dev/null +++ b/tests/test_grounding_repair.py @@ -0,0 +1,331 @@ +"""Tests for the LLM-driven grounding-repair flow.""" + +from __future__ import annotations + +import asyncio +import json +import textwrap +from pathlib import Path +from typing import Any + +import duckdb +import pandas as pd +import pytest + +from datasight.grounding import DriftItem, DriftReport +from datasight.grounding_repair import ( + REPAIR_FILE_NAMES, + RepairFile, + RepairResult, + _parse_repair_json, + format_repair_summary, + repair_grounding, + write_repair_atomic, +) +from datasight.llm import CallStats, LLMResponse, TextBlock, Usage + + +class _FakeLLMClient: + """Returns canned text responses in sequence. + + Each ``create_message`` call pops the next response. ``aclose`` is a + no-op so the contract matches :class:`datasight.llm.LLMClient`. + """ + + def __init__(self, responses: list[str]): + self._responses = list(responses) + self.calls: list[dict[str, Any]] = [] + + async def create_message(self, *, model, system, messages, tools, max_tokens): + self.calls.append({"messages": messages, "tools": tools}) + if not self._responses: + msg = "fake client ran out of canned responses" + raise RuntimeError(msg) + text = self._responses.pop(0) + return LLMResponse( + content=[TextBlock(text=text)], + stop_reason="end_turn", + usage=Usage(input_tokens=10, output_tokens=20), + call_stats=CallStats(), + ) + + async def aclose(self) -> None: + return None + + +def _make_run_sql(db_path: str): + """Async wrapper matching datasight's ``run_sql`` signature.""" + + async def run_sql(sql: str) -> pd.DataFrame: + conn = duckdb.connect(db_path, read_only=True) + try: + return conn.execute(sql).fetchdf() + finally: + conn.close() + + return run_sql + + +def _long_format_db(tmp_path: Path) -> str: + db_path = tmp_path / "test.duckdb" + conn = duckdb.connect(str(db_path)) + conn.execute( + "CREATE TABLE load_data " + "(geography VARCHAR, fuel_type VARCHAR, end_use VARCHAR, energy_mwh DOUBLE)" + ) + conn.execute( + "INSERT INTO load_data VALUES " + "('pacific', 'elec', 'heating', 10.0), " + "('pacific', 'ng', 'heating', 20.0)" + ) + conn.close() + return str(db_path) + + +def test_parse_repair_json_bare_object(): + result = _parse_repair_json('{"queries.yaml": "- q: hi"}') + assert result == {"queries.yaml": "- q: hi"} + + +def test_parse_repair_json_fenced_block(): + text = '```json\n{"queries.yaml": "a"}\n```' + assert _parse_repair_json(text) == {"queries.yaml": "a"} + + +def test_parse_repair_json_with_prose_prefix(): + text = 'Here you go:\n{"queries.yaml": "a"}' + assert _parse_repair_json(text) == {"queries.yaml": "a"} + + +def test_parse_repair_json_rejects_non_object(): + # An object-shaped string would parse, but a top-level number with a + # stray brace must still be rejected as not-an-object. + with pytest.raises(ValueError): + _parse_repair_json("just text 42") + + +def test_parse_repair_json_rejects_malformed(): + with pytest.raises(ValueError, match="no JSON object"): + _parse_repair_json("nothing useful here") + + +def test_parse_repair_json_drops_unknown_keys(): + text = json.dumps({ + "queries.yaml": "x", + "comment": "ignore me", + "schema_description.md": "y", + }) + result = _parse_repair_json(text) + assert set(result.keys()) == {"queries.yaml", "schema_description.md"} + + +def test_parse_repair_json_rejects_non_string_value(): + with pytest.raises(ValueError, match="not a string"): + _parse_repair_json('{"queries.yaml": 42}') + + +def test_repair_file_unified_diff_reflects_changes(tmp_path): + f = RepairFile( + name="queries.yaml", + path=tmp_path / "queries.yaml", + old_text="- old\n", + new_text="- new\n", + ) + diff = f.unified_diff() + assert "a/queries.yaml" in diff + assert "-- old" in diff or "-old" in diff + assert "+- new" in diff or "+new" in diff + + +def test_repair_file_unchanged_when_text_equal(tmp_path): + f = RepairFile( + name="queries.yaml", + path=tmp_path / "queries.yaml", + old_text="x", + new_text="x", + ) + assert not f.changed + + +def test_write_repair_atomic_writes_only_validated_changes(tmp_path): + (tmp_path / "queries.yaml").write_text("- old\n") + (tmp_path / "schema_description.md").write_text("# old\n") + f1 = RepairFile( + name="queries.yaml", + path=tmp_path / "queries.yaml", + old_text="- old\n", + new_text="- new\n", + ) + f2 = RepairFile( + name="schema_description.md", + path=tmp_path / "schema_description.md", + old_text="# old\n", + new_text="# new\n", + validation_errors=["bad sql"], + ) + result = RepairResult(files=[f1, f2]) + written = write_repair_atomic(result, tmp_path) + assert written == [tmp_path / "queries.yaml"] + assert (tmp_path / "queries.yaml").read_text() == "- new\n" + # File with validation errors is left untouched. + assert (tmp_path / "schema_description.md").read_text() == "# old\n" + + +def test_write_repair_atomic_no_changes_writes_nothing(tmp_path): + f = RepairFile( + name="queries.yaml", + path=tmp_path / "queries.yaml", + old_text="same", + new_text="same", + ) + written = write_repair_atomic(RepairResult(files=[f]), tmp_path) + assert written == [] + assert not (tmp_path / "queries.yaml").exists() + + +def test_format_repair_summary_lists_changed_files(): + f = RepairFile( + name="queries.yaml", + path=Path("queries.yaml"), + old_text="a", + new_text="b", + ) + result = RepairResult(files=[f]) + text = format_repair_summary(result) + assert "queries.yaml" in text + assert "[ok]" in text + + +def test_format_repair_summary_with_no_changes(): + text = format_repair_summary(RepairResult(files=[])) + assert "no files changed" in text.lower() + + +def test_repair_grounding_happy_path(tmp_path): + """LLM proposal validates cleanly on first try; result is well-formed.""" + db_path = _long_format_db(tmp_path) + (tmp_path / "queries.yaml").write_text(textwrap.dedent(""" + - question: "Old top regions" + sql: SELECT * FROM load_data WHERE elec_heating > 0; + """).strip()) + + new_queries = textwrap.dedent(""" + - question: "Top regions" + sql: SELECT geography, SUM(energy_mwh) AS total FROM load_data WHERE fuel_type = 'elec' AND end_use = 'heating' GROUP BY geography; + """).strip() + llm_response = json.dumps({"queries.yaml": new_queries}) + client = _FakeLLMClient([llm_response]) + run_sql = _make_run_sql(db_path) + + drift = DriftReport(items=[DriftItem( + file=str(tmp_path / "queries.yaml"), line=None, kind="column", + claim="elec_heating", detail="missing", + )]) + old_schema = {"load_data": {"elec_heating", "geography"}} + new_schema = {"load_data": {"geography", "fuel_type", "end_use", "energy_mwh"}} + + result = asyncio.run(repair_grounding( + tmp_path, old_schema, new_schema, drift, + llm_client=client, model="test", run_sql=run_sql, + )) + assert result.overall_ok + assert result.any_changes + assert result.llm_retries == 0 + q_file = next(f for f in result.files if f.name == "queries.yaml") + assert q_file.changed + assert not q_file.validation_errors + + +def test_repair_grounding_retries_on_invalid_sql(tmp_path): + """First proposal fails to execute; second one validates.""" + db_path = _long_format_db(tmp_path) + (tmp_path / "queries.yaml").write_text(textwrap.dedent(""" + - question: "Stale" + sql: SELECT foo FROM load_data; + """).strip()) + + broken = json.dumps({"queries.yaml": "- question: q\n sql: SELECT still_broken FROM load_data;"}) + good = json.dumps({"queries.yaml": "- question: q\n sql: SELECT geography FROM load_data;"}) + client = _FakeLLMClient([broken, good]) + run_sql = _make_run_sql(db_path) + + drift = DriftReport(items=[DriftItem( + file=str(tmp_path / "queries.yaml"), line=None, kind="column", + claim="foo", detail="missing", + )]) + old_schema = {"load_data": {"foo"}} + new_schema = {"load_data": {"geography", "fuel_type", "end_use", "energy_mwh"}} + + result = asyncio.run(repair_grounding( + tmp_path, old_schema, new_schema, drift, + llm_client=client, model="test", run_sql=run_sql, + max_retries=2, + )) + assert result.overall_ok + assert result.llm_retries == 1 + # The client should have been called twice and the second user prompt + # should include the validation error context. + assert len(client.calls) == 2 + second_user_prompt = client.calls[1]["messages"][0]["content"] + assert "still_broken" in second_user_prompt or "validation error" in second_user_prompt.lower() + + +def test_repair_grounding_gives_up_after_max_retries(tmp_path): + """When every proposal is broken, the result surfaces validation errors.""" + db_path = _long_format_db(tmp_path) + (tmp_path / "queries.yaml").write_text(textwrap.dedent(""" + - question: "Stale" + sql: SELECT foo FROM load_data; + """).strip()) + + bad = json.dumps({"queries.yaml": "- question: q\n sql: SELECT nope_a FROM load_data;"}) + worse = json.dumps({"queries.yaml": "- question: q\n sql: SELECT nope_b FROM load_data;"}) + worst = json.dumps({"queries.yaml": "- question: q\n sql: SELECT nope_c FROM load_data;"}) + client = _FakeLLMClient([bad, worse, worst]) + run_sql = _make_run_sql(db_path) + + drift = DriftReport(items=[]) + old_schema: dict[str, set[str]] = {} + new_schema = {"load_data": {"geography"}} + + result = asyncio.run(repair_grounding( + tmp_path, old_schema, new_schema, drift, + llm_client=client, model="test", run_sql=run_sql, + max_retries=2, + )) + assert not result.overall_ok + q_file = next(f for f in result.files if f.name == "queries.yaml") + assert q_file.validation_errors + + +def test_repair_files_are_loaded_from_disk(tmp_path): + """Only existing files in the scope set are presented to the LLM.""" + (tmp_path / "queries.yaml").write_text("- q: y") + (tmp_path / "schema_description.md").write_text("hello") + # time_series.yaml deliberately missing. + + db_path = _long_format_db(tmp_path) + run_sql = _make_run_sql(db_path) + new_queries_value = "- question: q\n sql: SELECT geography FROM load_data;" + response = json.dumps({"queries.yaml": new_queries_value}) + client = _FakeLLMClient([response]) + + drift = DriftReport(items=[]) + new_schema = {"load_data": {"geography"}} + + result = asyncio.run(repair_grounding( + tmp_path, {}, new_schema, drift, + llm_client=client, model="test", run_sql=run_sql, + )) + file_names = {f.name for f in result.files} + assert file_names == {"queries.yaml", "schema_description.md"} + assert "time_series.yaml" not in file_names + + +def test_repair_file_names_constant_matches_user_scope(): + """The hardcoded scope in REPAIR_FILE_NAMES matches what was agreed.""" + assert REPAIR_FILE_NAMES == ( + "queries.yaml", + "schema_description.md", + "time_series.yaml", + ) From 48d29c0a42552cbe7cf1472c57bca8d044203f33 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 10 May 2026 19:29:05 -0600 Subject: [PATCH 03/10] Document measured per-model memory footprints for Apple Silicon MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an "Apple Silicon: MLX-native models" subsection to the choosing-an-llm concepts page with measured numbers from the resource monitor: - qwen2.5:7b (q4_K_M GGUF): ~2 GB resident — only option that fits 16 GB - gemma4:e2b-mlx-bf16: ~11 GB — NOT low-memory despite the "e2b" naming, because the default 256K KV-cache buffer dominates - qwen3.6:35b-a3b-coding-mxfp8: ~38 GB — best answer quality, sparse MoE works well on Apple Silicon's unified memory Tier table makes the recommendation explicit: 16 GB → qwen2.5:7b, 32 GB → qwen2.5:7b or gemma4 (tradeoff: terse vs. richer), 48 GB+ → qwen3.6:35b-A3B. Cross-references added from install.md and configuration.md so users landing on either page get pointed at the measured guidance. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/reference/configuration.md | 2 +- docs/use/concepts/choosing-an-llm.md | 39 +++++++++++++++++++++++++++- docs/use/how-to/install.md | 11 +++++--- 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/docs/reference/configuration.md b/docs/reference/configuration.md index 022a0798..737bbcf9 100644 --- a/docs/reference/configuration.md +++ b/docs/reference/configuration.md @@ -74,7 +74,7 @@ For help picking a provider, see [Choosing an LLM](../use/concepts/choosing-an-l | Variable | Default | Description | |----------|---------|-------------| -| `OLLAMA_MODEL` | `qwen2.5:7b` | Ollama model name (must support tool calling). `qwen2.5:7b` works well for CLI queries; for the web UI with visualizations, try `qwen2.5:14b`. | +| `OLLAMA_MODEL` | `qwen2.5:7b` | Ollama model name (must support tool calling). `qwen2.5:7b` is the safest cross-platform default (~2 GB resident, fits on 16 GB Macs). For Apple Silicon with 48 GB+ unified memory, `qwen3.6:35b-a3b-coding-mxfp8` gives richer answers at comparable decode speed. See [Choosing an AI provider](../use/concepts/choosing-an-llm.md#apple-silicon-mlx-native-models). | | `OLLAMA_BASE_URL` | `http://localhost:11434/v1` | Ollama API endpoint | ### Database settings diff --git a/docs/use/concepts/choosing-an-llm.md b/docs/use/concepts/choosing-an-llm.md index c70bc40d..0999e6e3 100644 --- a/docs/use/concepts/choosing-an-llm.md +++ b/docs/use/concepts/choosing-an-llm.md @@ -97,7 +97,7 @@ So a Llama 3.1 8B model fits in ~5 GB VRAM at 4-bit, a 70B model needs |---|---| | Apple Silicon with 16 GB unified memory | 7–8B models at 4-bit | | Apple Silicon with 32 GB | 13B at 4-bit, or 8B at 8-bit | -| Apple Silicon with 64 GB+ | 34–70B at 4-bit | +| Apple Silicon with 64 GB+ | 34–70B at 4-bit, or sparse-MoE models like Qwen3.6 35B-A3B | | NVIDIA laptop GPU, 8 GB VRAM | 7–8B at 4-bit | | NVIDIA laptop GPU, 16 GB VRAM | 13B at 4-bit | @@ -107,6 +107,43 @@ visualizations, step up to `qwen2.5:14b` — the 7B model struggles with the more complex multi-step agent interactions required for chart generation. Smaller models often struggle with realistic schemas. +### Apple Silicon: MLX-native models + +If you're on Apple Silicon, models tagged `-mlx-*` use Apple's MLX +runtime and Metal compute. They typically decode 10–30% faster than the +equivalent GGUF model, but the *resident memory* can be much larger than +the weight size alone suggests because MLX allocates a large KV-cache +buffer for the model's default context window (often 256K tokens). +Measure before recommending to users — the model card's parameter count +is not a reliable predictor of laptop fit. + +Measured on a single benchmark dataset (5 questions, agent loop with +tool calls, Ollama server keep-alive at default 5 min) on a Mac with +unified memory: + +| Model | Decode (tok/s) | Resident memory (incl. KV cache) | Answer style | +|---|---|---|---| +| `qwen2.5:7b` (q4_K_M, GGUF) | ~85 | **~2 GB** | Middle: substantive but can hit `max_tokens` | +| `gemma4:e2b-mlx-bf16` | ~95 | ~11 GB | Tersest: dumps data tables, minimal analysis | +| `qwen3.6:35b-a3b-coding-mxfp8` | ~90 | ~38 GB | Richest: includes slopes, R², regional context | + +The headline surprise: **`gemma4:e2b-mlx-bf16` is not a low-memory +option**, despite the "e2b" (effective 2B) naming. Its weights are +small but the default 256K-token context allocation dominates resident +memory. Use it on 32 GB+ Macs only. + +Apple Silicon recommendations by RAM tier: + +| Unified memory | Recommended model | Why | +|---|---|---| +| 16 GB | `qwen2.5:7b` (GGUF) | Only option that fits with headroom for the OS, browser, and IDE. | +| 32 GB | `qwen2.5:7b` or `gemma4:e2b-mlx-bf16` | Either fits. Gemma is faster but its answers are tersest; pick based on whether you want interpretation or just raw data. | +| 48 GB+ | `qwen3.6:35b-a3b-coding-mxfp8` | Sparse MoE (3B active params) — best answer quality, comparable speed, properly leverages Apple Silicon's unified memory + Metal. | + +If you have an Apple Silicon machine but aren't sure which tag to use, +start with `qwen2.5:7b` (the cross-platform recommendation above). It +works on every backend and has the smallest memory footprint by far. + ### On an HPC GPU node If your HPC has GPU nodes, they typically unlock much larger models. diff --git a/docs/use/how-to/install.md b/docs/use/how-to/install.md index 5a5f868b..c982afdb 100644 --- a/docs/use/how-to/install.md +++ b/docs/use/how-to/install.md @@ -112,9 +112,14 @@ Alternatively, paste the key directly into a project's `.env` file instead. OLLAMA_MODEL=qwen2.5:7b ``` - `qwen2.5:7b` works well for CLI queries (`datasight ask`). For the web UI with - chart generation, `qwen2.5:14b` handles the more complex interactions better. - See [Choosing an AI provider](../concepts/choosing-an-llm.md) for hardware sizing guidance. + `qwen2.5:7b` works well for CLI queries (`datasight ask`) and is the safest + cross-platform default — it uses ~2 GB resident on Apple Silicon, so it fits + even on 16 GB Macs. For the web UI with chart generation, `qwen2.5:14b` handles + the more complex interactions better. On Apple Silicon with 48 GB+ of unified + memory, `qwen3.6:35b-a3b-coding-mxfp8` produces noticeably richer answers with + comparable decode speed (sparse MoE). See + [Choosing an AI provider](../concepts/choosing-an-llm.md#apple-silicon-mlx-native-models) + for measured per-model memory footprints and Apple-Silicon-specific guidance. See the [Configuration reference](../../reference/configuration.md) for every supported variable. From a349afbf9ea35a6b052df9a17951be9a2cc2328e Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 10 May 2026 19:33:37 -0600 Subject: [PATCH 04/10] Address CodeQL findings: drop unused _DEFAULT_GROUNDING_FILES and pytest import - grounding.py: the _DEFAULT_GROUNDING_FILES tuple was a vestige of an earlier design pass where callers iterated over it; the final API takes each path as a keyword argument so the constant is dead code. - tests/test_grounding.py: pytest was imported but the tests use only assert statements (no raises/fixtures from pytest itself). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/datasight/grounding.py | 3 --- tests/test_grounding.py | 1 - 2 files changed, 4 deletions(-) diff --git a/src/datasight/grounding.py b/src/datasight/grounding.py index df15c7be..f9f6ef94 100644 --- a/src/datasight/grounding.py +++ b/src/datasight/grounding.py @@ -62,9 +62,6 @@ }) -_DEFAULT_GROUNDING_FILES = ("queries.yaml", "schema_description.md", "time_series.yaml") - - @dataclass class DriftItem: """One finding: a claim in a grounding file that doesn't resolve. diff --git a/tests/test_grounding.py b/tests/test_grounding.py index 65a2e2fb..db844ac8 100644 --- a/tests/test_grounding.py +++ b/tests/test_grounding.py @@ -6,7 +6,6 @@ from pathlib import Path import duckdb -import pytest from datasight.grounding import ( DriftItem, From 79bb66b67649b14eab022a9c81d06a44c769d6bc Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 10 May 2026 20:13:28 -0600 Subject: [PATCH 05/10] Hook grounding repair into the web tidy-apply endpoint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The CLI ``datasight tidy review`` already calls grounding repair after ``apply_proposal`` completes; the web equivalent at ``/api/tidy/apply`` did not. This wires the same flow into the web endpoint: 1. Snapshot the pre-apply schema from ``state.schema_info`` before the transform runs (so the LLM can see both old and new shapes). 2. After the apply + existing schema.yaml/measures.yaml updates and the schema re-introspection, build the post-apply truth set and run the drift check. 3. When drift exists and an LLM client is configured, call ``repair_grounding``, validate each rewritten SQL against the live DB, and atomically write any files that pass validation. Failures are surfaced in the response but never fail the apply itself — the database mutation has already committed. 4. Surface the outcome as ``grounding_repair`` in the JSON response so the UI can tell the user what changed (or why nothing did). The repair runs *before* the schema_text rebuild because ``_load_user_description`` reads schema_description.md from disk; rewriting first ensures the next LLM call sees the repaired grounding. Tests cover the three branches the response shape distinguishes: applied successfully, skipped (no LLM client configured), and surfaced validation errors. ``repair_grounding`` is monkeypatched in tests since the FastAPI startup hook re-initializes ``state.llm_client`` from the environment and would clobber an in-place stub. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/datasight/web/app.py | 152 ++++++++++++++++++++++++++++++++++++ tests/test_web_tidy.py | 161 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 313 insertions(+) diff --git a/src/datasight/web/app.py b/src/datasight/web/app.py index fd956942..81807568 100644 --- a/src/datasight/web/app.py +++ b/src/datasight/web/app.py @@ -97,6 +97,14 @@ load_global_env, restore_original_env, ) +from datasight.grounding import ( + build_enum_values_sync, + check_grounding_drift, +) +from datasight.grounding_repair import ( + repair_grounding, + write_repair_atomic, +) from datasight.sql_validation import build_measure_rule_map, build_schema_map, validate_sql from datasight.tidy import _detect_period_groups, analyze_tidy_patterns from datasight.tidy_llm import propose_reshapes @@ -3339,6 +3347,15 @@ async def tidy_apply(request: Request, state: AppState = Depends(get_state)): # db_path = runner._database_path + # Snapshot the pre-apply schema so the post-apply grounding-repair step + # can show the LLM both old and new shapes. Without an explicit + # snapshot the repair prompt degenerates to "regenerate from scratch" + # and loses any human customizations in the grounding files. + old_schema: dict[str, set[str]] = { + t["name"]: {c["name"] for c in t["columns"]} + for t in state.schema_info + } + async with state.state_lock: # Cached DataFrames can keep DuckDB buffers alive and block the # read-write reopen of the same file; drop them before closing. @@ -3431,6 +3448,37 @@ async def tidy_apply(request: Request, state: AppState = Depends(get_state)): # } for t in tables ] + + # Grounding repair: when the schema actually changed, rewrite + # queries.yaml / schema_description.md / time_series.yaml so the + # next LLM call sees stale-reference-free grounding. Must happen + # *before* the schema_text rebuild below, because + # ``_load_user_description`` reads schema_description.md from + # disk — we want it to load the repaired content, not the stale + # one. Failures here are logged and surfaced in the response but + # don't fail the apply (the database mutation already committed). + grounding_summary: dict[str, Any] | None = None + new_schema: dict[str, set[str]] = { + t["name"]: {c["name"] for c in t["columns"]} + for t in state.schema_info + } + if ( + state.project_dir + and state.llm_client is not None + and state.model + and new_schema != old_schema + and not state.is_ephemeral + ): + grounding_summary = await _repair_grounding_after_tidy( + project_dir=state.project_dir, + db_path=db_path, + old_schema=old_schema, + new_schema=new_schema, + llm_client=state.llm_client, + model=state.model, + run_sql=state.sql_runner.run_sql, + ) + state.schema_text = format_schema_context( tables, user_description=None if state.is_ephemeral else _load_user_description(state), @@ -3442,6 +3490,110 @@ async def tidy_apply(request: Request, state: AppState = Depends(get_state)): # "success": True, "result": result.to_dict(), "schema_info": state.schema_info, + "grounding_repair": grounding_summary, + } + + +async def _repair_grounding_after_tidy( # noqa: C901 + *, + project_dir: str, + db_path: str, + old_schema: dict[str, set[str]], + new_schema: dict[str, set[str]], + llm_client: LLMClient, + model: str, + run_sql, +) -> dict[str, Any]: + """Run the grounding-drift check and, if needed, the LLM repair. + + Returns a structured summary the API caller can render in the UI: + + - ``drift_items``: how many unresolved references the static check found + - ``files_written``: paths that were rewritten and accepted + - ``applied``: True iff at least one file was successfully written + - ``validation_errors``: per-file SQL execution errors when the LLM's + proposal still didn't validate after retries (empty on success) + - ``error``: top-level exception message when the flow itself crashed + + The endpoint already committed the database mutation by the time this + runs, so failures here are surfaced as info but don't fail the apply. + """ + import duckdb as _duckdb + + # Load enum values from a fresh read-only conn — pure data fetch. + try: + ro = _duckdb.connect(db_path, read_only=True) + try: + enum_values = build_enum_values_sync(ro, new_schema) + finally: + ro.close() + except Exception: # noqa: BLE001 + enum_values = set() + + drift = check_grounding_drift( + Path(project_dir), new_schema, enum_values=enum_values + ) + if drift.is_clean: + return { + "drift_items": 0, + "files_written": [], + "applied": False, + "skipped": "no_drift", + } + + try: + result = await repair_grounding( + Path(project_dir), + old_schema, + new_schema, + drift, + llm_client=llm_client, + model=model, + run_sql=run_sql, + ) + except Exception as exc: # noqa: BLE001 + logger.exception("Grounding repair LLM call failed") + return { + "drift_items": len(drift.items), + "files_written": [], + "applied": False, + "error": str(exc), + } + + if not result.any_changes: + return { + "drift_items": len(drift.items), + "files_written": [], + "applied": False, + "skipped": "no_llm_changes", + } + if not result.overall_ok: + validation_errors = [ + f"{f.name}: {err}" + for f in result.files + for err in f.validation_errors + ] + return { + "drift_items": len(drift.items), + "files_written": [], + "applied": False, + "validation_errors": validation_errors, + } + + try: + written = write_repair_atomic(result, Path(project_dir)) + except OSError as exc: + logger.exception("Grounding repair atomic write failed") + return { + "drift_items": len(drift.items), + "files_written": [], + "applied": False, + "error": str(exc), + } + return { + "drift_items": len(drift.items), + "files_written": [p.name for p in written], + "applied": True, } diff --git a/tests/test_web_tidy.py b/tests/test_web_tidy.py index c5fb5fb5..25a9cbe9 100644 --- a/tests/test_web_tidy.py +++ b/tests/test_web_tidy.py @@ -17,6 +17,7 @@ from fastapi.testclient import TestClient import datasight.web.app as web_app +from datasight.grounding_repair import RepairFile, RepairResult from datasight.runner import DuckDBRunner from datasight.tidy import _detect_period_groups from datasight.tidy_llm import ProposeResult @@ -457,6 +458,166 @@ def test_apply_creates_table_with_replace_disposition(loaded_state): assert any(c["name"] == "year" for c in sales_cols) +def test_apply_repairs_stale_grounding_when_llm_configured(loaded_state, monkeypatch): + """When grounding files reference columns the reshape removes, the + apply endpoint should detect drift, call the LLM, and rewrite the + affected files atomically — surfacing what changed in the response. + + Monkeypatches :func:`repair_grounding` so the test exercises the + wiring (drift detection → repair call → atomic write → response + shaping) without depending on a live LLM. The endpoint's TestClient + startup re-initializes ``state.llm_client`` from the env, so a + state-level stub gets clobbered before the request runs. + """ + project_dir = Path(loaded_state.project_dir) + queries_path = project_dir / "queries.yaml" + queries_path.write_text( + "- question: Stale wide reference\n" + " sql: SELECT region, sales_2020 FROM sales;\n", + encoding="utf-8", + ) + repaired_yaml = ( + "- question: Long-form sales\n" + " sql: SELECT region, period, sales FROM sales_long;\n" + ) + + async def fake_repair(project_dir, old_schema, new_schema, drift, **kwargs): + # Verify the wiring passes the right schemas: drop disposition → + # ``sales`` is gone from new_schema, ``sales_long`` is present. + assert "sales_2020" in old_schema["sales"] + assert "sales" not in new_schema + assert "sales_long" in new_schema + return RepairResult( + files=[ + RepairFile( + name="queries.yaml", + path=queries_path, + old_text=queries_path.read_text(encoding="utf-8"), + new_text=repaired_yaml, + ), + ], + ) + + monkeypatch.setattr(web_app, "repair_grounding", fake_repair) + + suggestion = _detect_period_groups(loaded_state.schema_info[0])[0] + proposal = _suggestion_to_proposal_dict(suggestion) + + with TestClient(web_app.app) as client: + response = client.post( + "/api/tidy/apply", + json={ + "proposal": proposal, + "mode": "table", + "disposition": {"mode": "drop"}, + }, + ) + + body = response.json() + assert body["success"], body + summary = body["grounding_repair"] + assert summary is not None, body + assert summary["applied"] is True, summary + assert summary["files_written"] == ["queries.yaml"] + assert summary["drift_items"] >= 1 + + rewritten = queries_path.read_text(encoding="utf-8") + assert "sales_2020" not in rewritten + assert "sales_long" in rewritten + + +def test_apply_skips_grounding_repair_when_no_llm_configured(loaded_state, monkeypatch): + """No llm_client → repair flow short-circuits before the drift check. + Response carries ``grounding_repair: None`` and ``repair_grounding`` + is never called even when drift exists on disk.""" + project_dir = Path(loaded_state.project_dir) + queries_path = project_dir / "queries.yaml" + original_text = "- question: Stale\n sql: SELECT sales_2020 FROM sales;\n" + queries_path.write_text(original_text, encoding="utf-8") + + repair_called: list[Any] = [] + + async def fake_repair(*args, **kwargs): + repair_called.append((args, kwargs)) + raise AssertionError("repair_grounding should not be called without llm_client") + + monkeypatch.setattr(web_app, "repair_grounding", fake_repair) + + suggestion = _detect_period_groups(loaded_state.schema_info[0])[0] + proposal = _suggestion_to_proposal_dict(suggestion) + + with TestClient(web_app.app) as client: + # Drop the LLM client after startup has re-initialized it from env. + loaded_state.llm_client = None + response = client.post( + "/api/tidy/apply", + json={ + "proposal": proposal, + "mode": "table", + "disposition": {"mode": "drop"}, + }, + ) + + body = response.json() + assert body["success"], body + assert body["grounding_repair"] is None, body + assert repair_called == [], "repair_grounding was called unexpectedly" + # The stale queries.yaml stays exactly as written. + assert queries_path.read_text(encoding="utf-8") == original_text + + +def test_apply_grounding_repair_surfaces_validation_errors(loaded_state, monkeypatch): + """When the LLM proposal validates dirty, the file is left untouched + and the summary surfaces the validation errors so the UI can show + them to the user.""" + project_dir = Path(loaded_state.project_dir) + queries_path = project_dir / "queries.yaml" + original_text = "- question: Stale\n sql: SELECT sales_2020 FROM sales;\n" + queries_path.write_text(original_text, encoding="utf-8") + + async def fake_repair(project_dir, old_schema, new_schema, drift, **kwargs): + return RepairResult( + files=[ + RepairFile( + name="queries.yaml", + path=queries_path, + old_text=original_text, + new_text=( + "- question: still broken\n" + " sql: SELECT nonexistent_col FROM sales_long;\n" + ), + validation_errors=["query 1 ('still broken') failed: nonexistent_col"], + ), + ], + ) + + monkeypatch.setattr(web_app, "repair_grounding", fake_repair) + + suggestion = _detect_period_groups(loaded_state.schema_info[0])[0] + proposal = _suggestion_to_proposal_dict(suggestion) + + with TestClient(web_app.app) as client: + response = client.post( + "/api/tidy/apply", + json={ + "proposal": proposal, + "mode": "table", + "disposition": {"mode": "drop"}, + }, + ) + + body = response.json() + assert body["success"], body + summary = body["grounding_repair"] + assert summary is not None, body + assert summary["applied"] is False + assert summary["files_written"] == [] + assert summary["validation_errors"], summary + assert any("nonexistent_col" in err for err in summary["validation_errors"]) + # File contents untouched — write_repair_atomic skipped the bad file. + assert queries_path.read_text(encoding="utf-8") == original_text + + def test_apply_creates_table_with_bare_drop_disposition(loaded_state): """``drop`` disposition (post-rename) is the bare drop: source goes away but the long form keeps its target name. Downstream code that From 229c55b0a5d8c2e489c10f0fe6d14cd2de2ffb83 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 11 May 2026 07:26:55 -0600 Subject: [PATCH 06/10] Make grounding repair an independent operation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The post-apply grounding repair was previously locked behind a synchronous call inside `tidy review` apply (CLI) and /api/tidy/apply (web). When the LLM call timed out — common with large prompts on local models — the user had no way to retry without re-running the database transform. Web requests also blocked on the slow LLM call, which made the apply UX feel unrelated to the actual DB work it had already committed. This commit splits repair from apply on both surfaces and adds the missing retry path: - Persist the pre-apply schema to .datasight/grounding_snapshot.json on every CLI tidy review and web apply. Both surfaces and the new standalone CLI command read it. - New `datasight grounding {check, repair}` group. `repair` reads the snapshot, runs the LLM rewrite, shows a diff, and prompts to apply. `--model` overrides the configured model for retry, `--from-csv` derives the pre-tidy schema from a source CSV when no snapshot exists, `--dry-run` skips the write. - Web /api/tidy/apply now returns immediately after the DB mutation and includes a static drift summary. The slow LLM repair moved to /api/tidy/grounding/repair, called from a banner the drawer shows after apply. The endpoint falls back to the on-disk snapshot when the in-memory copy is gone (post-restart). - New /api/grounding/status + header pill so users see drift any time it exists, not only after they just applied something. - `tidy review` accepts --model so the same flag works on every LLM-using command. Docs: - Per-workload model recommendation in choosing-an-llm.md: tool-use commands (tidy review) favor general qwen3.6, long-form generation commands (grounding repair) favor the coding-mxfp8 variant. Drops the previous single-observation hedge now that we have evidence from both LLM-using paths. - New "Repair grounding" section in curate-with-tidy-review.md. - cli.md regenerated with the new grounding group. - Top-level CLI help layout: orphans (config, session, tidy, grounding) moved into existing or new groups so there is no generic "Commands" bucket. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/reference/cli.md | 95 +++++ docs/use/concepts/choosing-an-llm.md | 32 +- docs/use/how-to/curate-with-tidy-review.md | 95 +++++ frontend/src/App.svelte | 8 + frontend/src/lib/api/tidy.ts | 60 +++ frontend/src/lib/components/Header.svelte | 71 ++++ frontend/src/lib/components/TidyDrawer.svelte | 145 +++++++ frontend/src/lib/stores/grounding.svelte.ts | 165 ++++++++ frontend/src/lib/stores/tidy.svelte.ts | 104 +++++ src/datasight/cli.py | 10 +- src/datasight/cli_commands/grounding.py | 362 ++++++++++++++++++ src/datasight/cli_commands/tidy.py | 55 ++- src/datasight/grounding_repair.py | 90 +++++ src/datasight/web/app.py | 221 +++++++++-- tests/test_web_tidy.py | 112 ++++-- 15 files changed, 1560 insertions(+), 65 deletions(-) create mode 100644 frontend/src/lib/stores/grounding.svelte.ts create mode 100644 src/datasight/cli_commands/grounding.py diff --git a/docs/reference/cli.md b/docs/reference/cli.md index f5b11815..56becfc5 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -66,6 +66,7 @@ datasight [OPTIONS] COMMAND [ARGS]... - `measures`: Surface likely measures and default aggregations. - `quality`: Audit data quality - nulls, suspicious ranges, and date coverage. - `tidy`: Detect untidy column shapes and reshape into long form. +- `grounding`: Detect and repair drift between grounding files and the live schema. - `integrity`: Audit cross-table referential integrity - keys, orphans, and join risks. - `distribution`: Profile value distributions - percentiles, outliers, and measure flags. - `validate`: Run declarative validation rules against the database. @@ -780,6 +781,100 @@ datasight tidy review [OPTIONS] | `--replace-source` | Drop the source after a successful reshape and rename the long-form table to take the source's old name. Downstream code that referenced the source keeps working without edits. Requires '--as table' — a view's body references its source by name. | | `--drop-source` | Drop the source after a successful reshape; the long form keeps its target name. Pick this when the new shape is the canonical one going forward and you don't need to preserve the source's name. Requires '--as table'. NOTE: previously this flag carried the semantics now moved to '--replace-source'; scripts depending on the old behavior should switch to '--replace-source'. | | `--sample` | Send N sample rows per candidate to the configured LLM provider (default 0). Sample values get sent over the network — opt in only when the LLM seeing the values is acceptable. | +| `--model` | LLM model name to use for the propose-reshapes call and the post-apply grounding-repair call (overrides .env). Useful when different models suit each workload — see docs/use/concepts/choosing-an-llm.md. | + +### `datasight grounding` + +Detect and repair drift between grounding files and the live schema. + +Grounding files (``queries.yaml``, ``schema_description.md``, +``time_series.yaml``) describe the database to the LLM. When the +schema changes (typically after ``datasight tidy review``), these +files fall out of sync and the agent silently hallucinates against +columns that no longer exist. + +- ``check`` reports drift without changing anything. +- ``repair`` asks the configured LLM to rewrite the stale files + against the current schema, validates each proposed query, and + writes atomically after you confirm the diff. + +Examples: + +``` +datasight grounding check +datasight grounding repair +datasight grounding repair --model qwen3.6 +datasight grounding repair --from-csv load_data.csv +datasight grounding repair --dry-run +``` + +```bash +datasight grounding [OPTIONS] COMMAND [ARGS]... +``` + +**Subcommands** + +- `check`: Report stale references in grounding files against the live schema. +- `repair`: Run the LLM grounding repair against an existing drift. + +#### `datasight grounding check` + +Report stale references in grounding files against the live schema. + +Static — no LLM, no query execution. Exits 0 when grounding is +clean, 1 when drift is detected. Use ``datasight grounding +repair`` to fix what this command finds. + +Examples: + +``` +datasight grounding check +datasight grounding check --project-dir /path/to/project +``` + +```bash +datasight grounding check [OPTIONS] +``` + +**Parameters** + +| Name | Details | +| --- | --- | +| `--project-dir` | Project directory containing .env and grounding files. Default: `.`. | + +#### `datasight grounding repair` + +Run the LLM grounding repair against an existing drift. + +Reads the pre-tidy schema snapshot persisted by the most recent +apply (``.datasight/grounding_snapshot.json``). When no snapshot +is on file, ``--from-csv`` lets you supply the wide-form schema +by pointing at the source CSV(s). + +Shows the unified diff and prompts for confirmation before writing. +Use ``--dry-run`` to skip the write entirely. + +Examples: + +``` +datasight grounding repair +datasight grounding repair --model qwen3.6 +datasight grounding repair --from-csv load_data.csv +datasight grounding repair --dry-run +``` + +```bash +datasight grounding repair [OPTIONS] +``` + +**Parameters** + +| Name | Details | +| --- | --- | +| `--project-dir` | Project directory containing .env and grounding files. Default: `.`. | +| `--model` | LLM model name to use for the repair (overrides .env). Useful for retrying with a different model after a timeout. | +| `--from-csv` | Derive the pre-tidy schema from CSV headers when no snapshot is available. Pass once per source file (e.g. the wide-format input the apply consumed). Each CSV becomes a single table named after the file stem. Combinable with the snapshot — snapshot tables win on conflict. | +| `--dry-run` | Show drift + LLM proposal + diff, but don't write any files. | ### `datasight integrity` diff --git a/docs/use/concepts/choosing-an-llm.md b/docs/use/concepts/choosing-an-llm.md index 0999e6e3..602ebd69 100644 --- a/docs/use/concepts/choosing-an-llm.md +++ b/docs/use/concepts/choosing-an-llm.md @@ -132,13 +132,43 @@ option**, despite the "e2b" (effective 2B) naming. Its weights are small but the default 256K-token context allocation dominates resident memory. Use it on 32 GB+ Macs only. +The benchmark above measures `datasight ask` only. The other LLM-using +commands have very different shapes, and observed behavior on the +two qwen3.6 variants splits cleanly along those shapes: + +| Workload | Calls a tool? | Output budget | Best of the two | +|---|---|---|---| +| `datasight ask` | yes (multi-turn agent) | small per turn | either; coding MoE for richer prose | +| `datasight tidy review` (LLM advisor) | yes (single `propose_reshapes` call) | 4 K | **general `qwen3.6`** | +| `datasight grounding repair` | no (long-form file rewrite) | 16 K | **`qwen3.6:35b-a3b-coding-mxfp8`** | + +The split is consistent with what code-specialized fine-tunes are known +to trade: better long-form structured generation (winning grounding +repair, where the prompt and the output are both large) at the cost of +weaker tool-call adherence (losing `tidy review`, where the model has +to emit a structured tool call instead of free text). Observed in +practice: the coding variant silently emitted zero proposals on +`tidy review`'s `propose_reshapes`, while the general variant timed out +on grounding repair against the same database. + +**Practical setup**: pull both models. Use `qwen3.6` as your default +`OLLAMA_MODEL`, and override per-call where the coding variant wins: + +```bash +datasight grounding repair --model qwen3.6:35b-a3b-coding-mxfp8 +datasight tidy review --model qwen3.6 # explicit default; useful in scripts +``` + +Both `tidy review` and `grounding repair` accept `--model`, as do +`ask`, `verify`, and `run`. + Apple Silicon recommendations by RAM tier: | Unified memory | Recommended model | Why | |---|---|---| | 16 GB | `qwen2.5:7b` (GGUF) | Only option that fits with headroom for the OS, browser, and IDE. | | 32 GB | `qwen2.5:7b` or `gemma4:e2b-mlx-bf16` | Either fits. Gemma is faster but its answers are tersest; pick based on whether you want interpretation or just raw data. | -| 48 GB+ | `qwen3.6:35b-a3b-coding-mxfp8` | Sparse MoE (3B active params) — best answer quality, comparable speed, properly leverages Apple Silicon's unified memory + Metal. | +| 48 GB+ | Both `qwen3.6` and `qwen3.6:35b-a3b-coding-mxfp8` (switch per command) | Sparse MoE (3B active params) — properly leverages Apple Silicon's unified memory + Metal. The two variants are complementary, not interchangeable: general for tool-use commands (`ask`, `tidy review`), coding for long-form generation (`grounding repair`, `ask` when you want richer prose). See workload table above. | If you have an Apple Silicon machine but aren't sure which tag to use, start with `qwen2.5:7b` (the cross-platform recommendation above). It diff --git a/docs/use/how-to/curate-with-tidy-review.md b/docs/use/how-to/curate-with-tidy-review.md index 5e0cb98e..94e6451d 100644 --- a/docs/use/how-to/curate-with-tidy-review.md +++ b/docs/use/how-to/curate-with-tidy-review.md @@ -419,6 +419,101 @@ whose values would duplicate or drop rows — the transaction rolls back in step 2, before the source is touched. The error message names the table and the count it expected. +## Repair grounding after a reshape + +A successful Tidy review changes the database schema — long-form +column and table names replace the wide-form ones. That breaks any +reference to the old column names in your grounding files +(`queries.yaml`, `schema_description.md`, `time_series.yaml`), which +the LLM agent reads on every turn. Stale grounding silently teaches +the agent to hallucinate against columns that no longer exist. + +### How drift is detected + +Every Tidy review apply (CLI or web) does two things in addition to +the reshape itself: + +1. Writes a snapshot of the pre-apply schema to + `.datasight/grounding_snapshot.json`. This is the "before" picture + the LLM needs to rewrite the grounding files in context. +2. Runs a fast static check against the new schema. The CLI + surfaces drift in an interactive prompt (see the `--apply-all` + sections above); the web UI shows an orange "Grounding may be + stale" banner with a **Repair grounding** button. + +### Run the repair + +Web UI: click **Repair grounding** in the banner. The agent rewrites +the affected files, validates every SQL example against the live DB, +and applies the changes if validation passes. + +CLI: + +```bash +datasight grounding repair +``` + +This reads the snapshot, runs the LLM repair, prints a unified diff +of the proposed rewrites, and asks `Apply this diff? [y/N]` before +writing. + +### Retry with a different model + +Local LLMs sometimes time out on the repair (the prompt includes +your full grounding files plus both schemas, which can be large). +Retry with a different model — no need to re-run the reshape: + +```bash +datasight grounding repair --model qwen3.6:35b-a3b-coding-mxfp8 +``` + +`--model` overrides the configured `OLLAMA_MODEL` (or +`ANTHROPIC_MODEL`, etc.) for this one call. + +The two LLM-using steps in the Tidy review flow have very different +shapes and reward different model variants — `tidy review`'s +proposal step is a tool call (favors general-purpose models), +`grounding repair` is a long-form file rewrite (favors +coding-specialized models). Both `tidy review` and `grounding +repair` accept `--model` so you can pick per call. See +[Choosing an LLM](../concepts/choosing-an-llm.md) for the per-workload +recommendation table. + +### When the snapshot is missing + +If you applied a reshape before snapshotting was wired in, or you +deleted `.datasight/grounding_snapshot.json`, the repair has nothing +to compare against. Pass `--from-csv` pointing at the original +wide-form source so the CLI can derive the pre-tidy schema from the +header row: + +```bash +datasight grounding repair --from-csv generation_fuel_wide.csv +``` + +You can pass `--from-csv` multiple times for multi-file inputs; +each CSV becomes a single table named after the file stem. + +### Preview without writing + +Add `--dry-run` to skip the confirmation and the write. The diff +prints as usual, then the command exits — useful for inspecting the +LLM's proposal in CI or before committing the result: + +```bash +datasight grounding repair --dry-run +``` + +### Check drift without repairing + +```bash +datasight grounding check +``` + +Static, no LLM. Exits 0 when grounding is clean, 1 when drift +exists. Same logic as `datasight verify --static-only`, exposed +under a more discoverable name. + ## Recipes ### Curate one wide table from the web UI diff --git a/frontend/src/App.svelte b/frontend/src/App.svelte index d5725d5e..dd3b3fb6 100644 --- a/frontend/src/App.svelte +++ b/frontend/src/App.svelte @@ -26,6 +26,7 @@ import { sqlEditorStore } from "$lib/stores/sql_editor.svelte"; import { paletteStore } from "$lib/stores/palette.svelte"; import { tidyStore } from "$lib/stores/tidy.svelte"; + import { groundingStore } from "$lib/stores/grounding.svelte"; import { exitExploreSession, getProjectStatus } from "$lib/api/projects"; import { loadSettings, loadLlmConfig } from "$lib/api/settings"; import { loadSchema, loadQueries, loadRecipes } from "$lib/api/schema"; @@ -114,6 +115,12 @@ loadMeasureCatalog(), ]); + // Refresh the always-on grounding pill. Fires after the schema + // load above so the server's drift check sees the up-to-date + // schema (the static check doesn't care about timing, but this + // keeps the lifecycle obvious). Cheap call — no LLM. + groundingStore.check(); + // Run pending starter if one was selected on landing page if (fromLanding) { await maybeRunPendingStarter(); @@ -189,6 +196,7 @@ dashboardStore.clear(); sqlEditorStore.clearAll(); sessionStore.reset(); + groundingStore.reset(); dashboardStore.currentView = "chat"; exportMode = false; exportExcludeIndices = new Set(); diff --git a/frontend/src/lib/api/tidy.ts b/frontend/src/lib/api/tidy.ts index 3e660534..cec58a48 100644 --- a/frontend/src/lib/api/tidy.ts +++ b/frontend/src/lib/api/tidy.ts @@ -78,10 +78,40 @@ export interface TidyApplyResult { dry_run: boolean; } +/** Static drift summary returned by /api/tidy/apply. The slow LLM + * repair runs separately via /api/tidy/grounding/repair so the apply + * response can return immediately after the database mutation. */ +export interface TidyGroundingDrift { + needs_repair: boolean; + drift_items: number; +} + export interface TidyApplyResponse { success: boolean; result?: TidyApplyResult; schema_info?: TableInfo[]; + /** Present when apply changed the schema and a drift check ran. */ + grounding_drift?: TidyGroundingDrift | null; + error?: string; +} + +/** Result of the LLM grounding repair. Mirrors the summary built on + * the server in ``_run_grounding_repair``. */ +export interface TidyGroundingRepairSummary { + drift_items: number; + files_written: string[]; + applied: boolean; + /** Set when the repair was a no-op (e.g. "no_drift", "no_llm_changes"). */ + skipped?: string; + /** Set when the LLM proposal failed SQL validation after retries. */ + validation_errors?: string[]; + /** Set when the LLM call itself crashed (timeout, etc). */ + error?: string; +} + +export interface TidyGroundingRepairResponse { + success: boolean; + grounding_repair?: TidyGroundingRepairSummary; error?: string; } @@ -228,3 +258,33 @@ export async function applyTidy(args: { }): Promise { return postJson("/api/tidy/apply", args); } + +/** Trigger the LLM grounding repair after a tidy apply. Slow (the LLM + * call can take minutes on local models); the caller should show a + * spinner. The server reads the pre-tidy schema snapshot it stashed + * during /api/tidy/apply, so this only works while that snapshot is + * still live (cleared on project change or after a successful repair). */ +export async function repairGrounding(): Promise { + return postJson( + "/api/tidy/grounding/repair", + {}, + ); +} + +/** Always-on grounding-drift status, polled by the header pill on + * project load (and after applies/repairs). When ``available`` is + * false the UI hides the pill — covers ephemeral sessions, no-project + * states, and non-DuckDB backends. ``has_snapshot`` tells the pill + * whether the repair button should be enabled or whether the user + * needs the CLI's ``--from-csv`` fallback. */ +export interface GroundingStatusResponse { + available: boolean; + reason?: string; + needs_repair?: boolean; + drift_items?: number; + has_snapshot?: boolean; +} + +export async function fetchGroundingStatus(): Promise { + return fetchJson("/api/grounding/status"); +} diff --git a/frontend/src/lib/components/Header.svelte b/frontend/src/lib/components/Header.svelte index 67464c2d..8188cb94 100644 --- a/frontend/src/lib/components/Header.svelte +++ b/frontend/src/lib/components/Header.svelte @@ -6,6 +6,8 @@ import { formatCost } from "$lib/utils/format"; import { summarizeDataset } from "$lib/api/summarize"; import { chatStore } from "$lib/stores/chat.svelte"; + import { groundingStore } from "$lib/stores/grounding.svelte"; + import { toastStore } from "$lib/stores/toast.svelte"; interface Props { theme: string; @@ -50,6 +52,47 @@ let showCost = $derived( settingsStore.showCost && queriesStore.sessionTotalCost > 0, ); + + let groundingPillTitle = $derived( + !groundingStore.hasSnapshot + ? "Grounding files reference columns that don't exist. " + + "No pre-tidy snapshot available — use the CLI: " + + "datasight grounding repair --from-csv .csv" + : groundingStore.repairStatus === "running" + ? "Repairing grounding files..." + : `${groundingStore.driftItems} stale reference${ + groundingStore.driftItems === 1 ? "" : "s" + } in queries.yaml / schema_description.md / time_series.yaml. ` + + "Click to repair with the configured LLM.", + ); + + async function handleRepairFromPill() { + if (!groundingStore.hasSnapshot) { + toastStore.show( + "No snapshot available — repair via CLI: datasight grounding repair --from-csv .csv", + "error", + ); + return; + } + await groundingStore.runRepair(); + if (groundingStore.repairStatus === "success") { + const summary = groundingStore.repairSummary; + const written = summary?.files_written ?? []; + toastStore.show( + written.length > 0 + ? `Rewrote ${written.join(", ")}` + : "Grounding files were already up to date", + "success", + ); + groundingStore.dismissRepairResult(); + } else if (groundingStore.repairStatus === "error") { + toastStore.show( + groundingStore.repairError ?? "Grounding repair failed", + "error", + ); + groundingStore.dismissRepairResult(); + } + }
{projectName} {/if} + + + {#if groundingStore.pillVisible} + + {/if} {/if} diff --git a/frontend/src/lib/components/TidyDrawer.svelte b/frontend/src/lib/components/TidyDrawer.svelte index 645e4760..13a0e226 100644 --- a/frontend/src/lib/components/TidyDrawer.svelte +++ b/frontend/src/lib/components/TidyDrawer.svelte @@ -2,6 +2,7 @@ import { tidyStore } from "$lib/stores/tidy.svelte"; import { toastStore } from "$lib/stores/toast.svelte"; import { schemaStore, type TableInfo } from "$lib/stores/schema.svelte"; + import { groundingStore } from "$lib/stores/grounding.svelte"; import TidyProposalCard from "./TidyProposalCard.svelte"; let renameValid = $derived( @@ -58,6 +59,9 @@ `Applied ${applied} proposal${applied === 1 ? "" : "s"}`, "success", ); + // Apply changed the schema; refresh the always-on header pill so + // it reflects the new drift state (typically: drift just appeared). + groundingStore.check(); } if (failed > 0) { toastStore.show( @@ -67,6 +71,27 @@ } } + async function handleRepairGrounding() { + await tidyStore.runGroundingRepair(); + if (tidyStore.repairStatus === "success") { + const written = tidyStore.repairSummary?.files_written ?? []; + const msg = written.length > 0 + ? `Rewrote ${written.join(", ")}` + : "Grounding files were already up to date"; + toastStore.show(msg, "success"); + } else if (tidyStore.repairStatus === "error") { + toastStore.show( + tidyStore.repairError ?? "Grounding repair failed", + "error", + ); + } + // Repair (success or partial-failure) changed something on disk; + // re-poll so the header pill matches what the user just did. + groundingStore.check(); + } + + let isRepairing = $derived(tidyStore.repairStatus === "running"); + function handleSampleRowsBlur(value: string) { const n = parseInt(value, 10); tidyStore.sampleRows = Number.isFinite(n) ? n : 0; @@ -109,6 +134,77 @@
+ + {#if tidyStore.groundingDrift?.needs_repair} +
+
+
+

+ {#if isRepairing} + + Repairing grounding files… + {:else} + Grounding may be stale + {/if} +

+

+ {#if isRepairing} + The agent is rewriting queries.yaml, + schema_description.md, and time_series.yaml against + the new schema. Local models can take a few minutes. + {:else} + The reshape changed the schema and the static check + found {tidyStore.groundingDrift.drift_items} stale + reference{tidyStore.groundingDrift.drift_items === 1 + ? "" + : "s"} in your grounding files. Run the LLM + repair to rewrite them, or skip if you'd rather edit + by hand. + {/if} +

+
+ {#if !isRepairing} +
+ + +
+ {/if} +
+
+ {/if} + + + {#if tidyStore.repairStatus === "success" && tidyStore.repairSummary} + + {/if} + {#if tidyStore.repairStatus === "error"} + + {/if} +
@@ -524,6 +620,55 @@ border-color: color-mix(in srgb, #ef4444 36%, var(--border)); color: color-mix(in srgb, #ef4444 80%, var(--text)); } + .banner-ok { + background: color-mix(in srgb, #22c55e 8%, var(--surface)); + border-color: color-mix(in srgb, #22c55e 36%, var(--border)); + color: color-mix(in srgb, #16a34a 80%, var(--text)); + } + + /* Grounding-repair prompt: same visual language as .agent-panel so + it reads as a peer prompt rather than an inline notice. Orange tint + to flag that action is needed; flips to a more urgent border while + the LLM call is in flight. */ + .repair-panel { + display: grid; + gap: 10px; + padding: 12px 14px; + border: 1px solid color-mix(in srgb, var(--orange) 32%, var(--border)); + border-radius: 10px; + background: color-mix(in srgb, var(--orange) 6%, var(--surface)); + transition: background 0.2s, border-color 0.2s; + } + .repair-panel-running { + border-color: color-mix(in srgb, var(--orange) 60%, var(--border)); + background: color-mix(in srgb, var(--orange) 10%, var(--surface)); + } + .repair-head { + display: flex; + align-items: flex-start; + justify-content: space-between; + gap: 12px; + } + .repair-head h4 { + margin: 0 0 4px; + font-size: 0.86rem; + color: var(--text); + display: inline-flex; + align-items: center; + gap: 8px; + } + .repair-head p { + margin: 0; + max-width: 52ch; + font-size: 0.74rem; + line-height: 1.5; + color: var(--text-secondary); + } + .repair-actions { + display: inline-flex; + gap: 8px; + flex-shrink: 0; + } .warnings { margin: 0; diff --git a/frontend/src/lib/stores/grounding.svelte.ts b/frontend/src/lib/stores/grounding.svelte.ts new file mode 100644 index 00000000..a6f77e15 --- /dev/null +++ b/frontend/src/lib/stores/grounding.svelte.ts @@ -0,0 +1,165 @@ +/** Always-on grounding-drift state for the header pill. + * + * Distinct from ``tidyStore``: tidyStore tracks the in-flight tidy + * review drawer (its grounding banner is a one-shot post-apply + * surface). This store tracks the project-wide truth — "is the + * grounding currently stale?" — so the header pill can advertise + * drift any time the user is in the project, not just after they + * applied something. + * + * Both surfaces call the same backend repair endpoint, but they own + * their own UI state: the post-apply banner clears when the user + * dismisses or repairs from the drawer; the header pill clears when + * the next status check returns clean. + */ + +import { + fetchGroundingStatus, + repairGrounding, + type TidyGroundingRepairSummary, +} from "$lib/api/tidy"; + +export type GroundingStatus = + | "idle" + | "checking" + | "clean" + | "drift" + | "unavailable" + | "error"; + +export type GroundingRepairStatus = "idle" | "running" | "success" | "error"; + +function createGroundingStore() { + let status = $state("idle"); + let driftItems = $state(0); + let hasSnapshot = $state(false); + // ``unavailableReason`` is the server's hint about why the check + // doesn't apply (no_project / non_duckdb / etc). Surfaced to the + // pill tooltip so a confused user can see why no chip is rendered. + let unavailableReason = $state(null); + let lastError = $state(null); + + let repairStatus = $state("idle"); + let repairSummary = $state(null); + let repairError = $state(null); + + return { + get status() { + return status; + }, + get driftItems() { + return driftItems; + }, + get hasSnapshot() { + return hasSnapshot; + }, + get unavailableReason() { + return unavailableReason; + }, + get lastError() { + return lastError; + }, + get repairStatus() { + return repairStatus; + }, + get repairSummary() { + return repairSummary; + }, + get repairError() { + return repairError; + }, + /** True when the pill should be visible. Covers the only case the + * user can act on — known drift in a real project. */ + get pillVisible() { + return status === "drift"; + }, + + /** Re-poll /api/grounding/status. Cheap (no LLM), safe to call on + * project load and after any apply/repair. Sets status="checking" + * during the call so the UI can debounce duplicate triggers. */ + async check(): Promise { + if (status === "checking") return; + status = "checking"; + lastError = null; + try { + const resp = await fetchGroundingStatus(); + if (!resp.available) { + status = "unavailable"; + unavailableReason = resp.reason ?? null; + driftItems = 0; + hasSnapshot = false; + return; + } + unavailableReason = null; + driftItems = resp.drift_items ?? 0; + hasSnapshot = resp.has_snapshot ?? false; + status = resp.needs_repair ? "drift" : "clean"; + } catch (err) { + status = "error"; + lastError = (err as Error).message ?? "Failed to check grounding status"; + } + }, + + /** Clear all state — call when switching projects so a stale + * drift count from project A doesn't leak into project B. */ + reset(): void { + status = "idle"; + driftItems = 0; + hasSnapshot = false; + unavailableReason = null; + lastError = null; + repairStatus = "idle"; + repairSummary = null; + repairError = null; + }, + + /** Trigger the LLM grounding repair from the header pill. Mirrors + * tidyStore.runGroundingRepair but updates this store's repair + * state. On success, re-checks status so the pill goes from + * "drift" → "clean" if the rewrite resolved everything. */ + async runRepair(): Promise { + if (repairStatus === "running") return; + repairStatus = "running"; + repairSummary = null; + repairError = null; + try { + const resp = await repairGrounding(); + if (!resp.success) { + repairStatus = "error"; + repairError = resp.error ?? "Grounding repair failed"; + return; + } + const summary = resp.grounding_repair ?? null; + repairSummary = summary; + if (summary?.applied) { + repairStatus = "success"; + } else if (summary?.error) { + repairStatus = "error"; + repairError = summary.error; + } else if (summary?.validation_errors?.length) { + repairStatus = "error"; + repairError = summary.validation_errors.join("; "); + } else { + repairStatus = "success"; + } + // Either way, re-check the canonical status — drift may be + // gone (success) or partially gone (some files written). The + // pill should reflect what the file system actually shows now. + await this.check(); + } catch (err) { + repairStatus = "error"; + repairError = (err as Error).message ?? "Grounding repair failed"; + } + }, + + /** Dismiss a one-shot success/error toast without touching the + * underlying status. The pill stays visible if drift remains. */ + dismissRepairResult(): void { + repairStatus = "idle"; + repairSummary = null; + repairError = null; + }, + }; +} + +export const groundingStore = createGroundingStore(); diff --git a/frontend/src/lib/stores/tidy.svelte.ts b/frontend/src/lib/stores/tidy.svelte.ts index 88aa8755..26a6bf2a 100644 --- a/frontend/src/lib/stores/tidy.svelte.ts +++ b/frontend/src/lib/stores/tidy.svelte.ts @@ -15,9 +15,12 @@ import { detectTidy, previewTidy, proposeTidy, + repairGrounding, type TidyApplyResult, type TidyDisposition, type TidyDispositionMode, + type TidyGroundingDrift, + type TidyGroundingRepairSummary, type TidyMaterializeMode, type TidyProposal, } from "$lib/api/tidy"; @@ -109,6 +112,12 @@ function applyEdits(p: TidyProposal, edits: ProposalEdits): TidyProposal { }; } +/** Lifecycle of the post-apply grounding-repair LLM call. The repair is + * user-triggered (the drawer shows a banner after apply succeeds), and + * runs as a separate slow request — keeping its status here so the + * banner can show a spinner / result without blocking the drawer. */ +export type RepairStatus = "idle" | "running" | "success" | "error"; + function createTidyStore() { let open = $state(false); let table = $state(null); @@ -121,6 +130,12 @@ function createTidyStore() { let dispositionRenameTo = $state(""); let sampleRows = $state(0); let abortController = $state(null); + // Grounding drift detected on the last apply (null when no apply has + // run yet, or when no drift was found). Drives the repair banner. + let groundingDrift = $state(null); + let repairStatus = $state("idle"); + let repairSummary = $state(null); + let repairError = $state(null); // The store keeps its own onApplied hook so callers (App.svelte) can // reload the schema sidebar without introducing a circular import @@ -139,6 +154,10 @@ function createTidyStore() { dispositionMode = "keep"; dispositionRenameTo = ""; sampleRows = 0; + groundingDrift = null; + repairStatus = "idle"; + repairSummary = null; + repairError = null; } return { @@ -387,6 +406,7 @@ function createTidyStore() { let applied = 0; let failed = 0; let lastSchema: unknown = null; + let lastDrift: TidyGroundingDrift | null = null; const disposition: TidyDisposition = { mode: dispositionMode, @@ -394,6 +414,13 @@ function createTidyStore() { dispositionMode === "rename" ? dispositionRenameTo.trim() : undefined, }; + // Starting a fresh apply batch: clear any prior repair status so + // the banner reflects this run's drift. + groundingDrift = null; + repairStatus = "idle"; + repairSummary = null; + repairError = null; + for (const id of queue) { const target = proposals.find((p) => p.id === id); if (!target) continue; @@ -414,6 +441,13 @@ function createTidyStore() { if (resp.success) { applied += 1; lastSchema = resp.schema_info ?? lastSchema; + // Each apply re-runs the drift check against the latest + // schema, so the last response's drift is the only one that + // matters — earlier ones describe intermediate states the + // user can't see. + if (resp.grounding_drift) { + lastDrift = resp.grounding_drift; + } proposals = proposals.map((p) => p.id === id ? { @@ -452,8 +486,78 @@ function createTidyStore() { appliedHook(lastSchema); } + // Only surface the banner when there's actually drift to repair — + // a clean check (or a server that didn't run one) shouldn't nag. + if (lastDrift?.needs_repair) { + groundingDrift = lastDrift; + } + return { applied, failed }; }, + + get groundingDrift() { + return groundingDrift; + }, + get repairStatus() { + return repairStatus; + }, + get repairSummary() { + return repairSummary; + }, + get repairError() { + return repairError; + }, + + /** Trigger the slow LLM grounding repair for the most recent apply. + * The banner calls this; status flips to "running" so the UI can + * show a spinner. On success the drift is cleared and the summary + * is exposed via :prop:`repairSummary` for a one-shot confirmation. */ + async runGroundingRepair(): Promise { + if (repairStatus === "running") return; + repairStatus = "running"; + repairSummary = null; + repairError = null; + try { + const resp = await repairGrounding(); + if (!resp.success) { + repairStatus = "error"; + repairError = resp.error ?? "Grounding repair failed"; + return; + } + const summary = resp.grounding_repair ?? null; + repairSummary = summary; + if (summary?.applied) { + repairStatus = "success"; + // Drift is resolved (or partially) — drop the banner. Summary + // stays visible so the user can see what was rewritten. + groundingDrift = null; + } else if (summary?.error) { + repairStatus = "error"; + repairError = summary.error; + } else if (summary?.validation_errors?.length) { + repairStatus = "error"; + repairError = summary.validation_errors.join("; "); + } else { + // No-op (no_drift / no_llm_changes): treat as success and + // clear the banner so we don't keep prompting. + repairStatus = "success"; + groundingDrift = null; + } + } catch (err) { + repairStatus = "error"; + repairError = (err as Error).message ?? "Grounding repair failed"; + } + }, + + /** Dismiss the grounding-repair banner without running the LLM call. + * Used by the Skip button. Doesn't touch the server snapshot — it + * just suppresses the prompt for this drawer session. */ + dismissGroundingPrompt(): void { + groundingDrift = null; + repairStatus = "idle"; + repairSummary = null; + repairError = null; + }, }; } diff --git a/src/datasight/cli.py b/src/datasight/cli.py index 277d1f16..525b549b 100644 --- a/src/datasight/cli.py +++ b/src/datasight/cli.py @@ -1316,7 +1316,7 @@ def write_batch_result_files( # noqa: C901 "datasight": [ { "name": "Quick start", - "commands": ["inspect", "run"], + "commands": ["inspect", "run", "config"], }, { "name": "Project setup", @@ -1326,6 +1326,10 @@ def write_batch_result_files( # noqa: C901 "name": "AI-powered", "commands": ["ask", "verify"], }, + { + "name": "Schema curation", + "commands": ["tidy", "grounding"], + }, { "name": "Data analysis (no LLM)", "commands": [ @@ -1344,7 +1348,7 @@ def write_batch_result_files( # noqa: C901 }, { "name": "Session history", - "commands": ["log", "export", "report"], + "commands": ["log", "export", "report", "session"], }, { "name": "Demo datasets", @@ -1433,6 +1437,7 @@ def _register_commands() -> None: from datasight.cli_commands.doctor import doctor from datasight.cli_commands.export import export from datasight.cli_commands.generate import generate + from datasight.cli_commands.grounding import grounding from datasight.cli_commands.init import init from datasight.cli_commands.inspect import inspect from datasight.cli_commands.integrity import integrity @@ -1462,6 +1467,7 @@ def _register_commands() -> None: cli.add_command(measures) cli.add_command(quality) cli.add_command(tidy) + cli.add_command(grounding) cli.add_command(integrity) cli.add_command(distribution) cli.add_command(validate) diff --git a/src/datasight/cli_commands/grounding.py b/src/datasight/cli_commands/grounding.py new file mode 100644 index 00000000..f59467a7 --- /dev/null +++ b/src/datasight/cli_commands/grounding.py @@ -0,0 +1,362 @@ +"""``datasight grounding`` — manage grounding-file drift independently. + +Two subcommands: + +- ``check``: run the static drift detector against the live database + and print a report. Same logic as ``datasight verify --static-only``, + exposed under a more discoverable name. +- ``repair``: run the LLM repair against an existing drift, using the + schema snapshot persisted by the most recent ``tidy review`` / + web-UI apply (or a CSV fallback). Supports ``--model`` to retry with + a different model after a timeout, and ``--dry-run`` to preview the + diff without writing. + +The repair flow is deliberately decoupled from ``tidy review`` apply +so a slow/failed LLM call can be retried any time, with a different +model, without re-running the database transform. +""" + +from __future__ import annotations + +import asyncio +import csv +import os +import sys +from pathlib import Path +from typing import Any + +import duckdb +import rich_click as click + +from datasight import cli +from datasight.cli_helpers import format_epilog +from datasight.config import create_sql_runner_from_settings +from datasight.grounding import ( + build_enum_values_sync, + build_schema_truth_sync, + check_grounding_drift, + format_drift_report, +) +from datasight.grounding_repair import ( + format_repair_summary, + read_snapshot, + repair_grounding, + snapshot_path, + write_repair_atomic, +) + + +@click.group( + epilog=format_epilog( + """ + Examples: + + datasight grounding check + datasight grounding repair + datasight grounding repair --model qwen3.6 + datasight grounding repair --from-csv load_data.csv + datasight grounding repair --dry-run + """ + ) +) +def grounding(): + """Detect and repair drift between grounding files and the live schema. + + Grounding files (``queries.yaml``, ``schema_description.md``, + ``time_series.yaml``) describe the database to the LLM. When the + schema changes (typically after ``datasight tidy review``), these + files fall out of sync and the agent silently hallucinates against + columns that no longer exist. + + \b + - ``check`` reports drift without changing anything. + - ``repair`` asks the configured LLM to rewrite the stale files + against the current schema, validates each proposed query, and + writes atomically after you confirm the diff. + """ + + +@click.command( + name="check", + epilog=format_epilog( + """ + Examples: + + datasight grounding check + datasight grounding check --project-dir /path/to/project + """ + ), +) +@click.option( + "--project-dir", + type=click.Path(exists=True), + default=".", + help="Project directory containing .env and grounding files.", +) +def grounding_check(project_dir: str) -> None: + """Report stale references in grounding files against the live schema. + + Static — no LLM, no query execution. Exits 0 when grounding is + clean, 1 when drift is detected. Use ``datasight grounding + repair`` to fix what this command finds. + """ + project_dir = str(Path(project_dir).resolve()) + settings, _ = cli.resolve_settings(project_dir) + if settings.database.mode != "duckdb": + click.echo( + "grounding check requires DuckDB; database.mode is " + f"{settings.database.mode!r}.", + err=True, + ) + sys.exit(2) + resolved_db_path = cli.resolve_db_path(settings, project_dir) + if not resolved_db_path or not os.path.exists(resolved_db_path): + click.echo(f"Error: Database file not found: {resolved_db_path}", err=True) + sys.exit(1) + + conn = duckdb.connect(resolved_db_path, read_only=True) + try: + truth = build_schema_truth_sync(conn) + enums = build_enum_values_sync(conn, truth) + finally: + conn.close() + + report = check_grounding_drift(Path(project_dir), truth, enum_values=enums) + if report.is_clean: + click.echo("grounding clean: no drift detected.") + sys.exit(0) + click.echo(format_drift_report(report), err=True) + click.echo("", err=True) + click.echo( + "Run `datasight grounding repair` to rewrite the affected files.", + err=True, + ) + sys.exit(1) + + +@click.command( + name="repair", + epilog=format_epilog( + """ + Examples: + + datasight grounding repair + datasight grounding repair --model qwen3.6 + datasight grounding repair --from-csv load_data.csv + datasight grounding repair --dry-run + """ + ), +) +@click.option( + "--project-dir", + type=click.Path(exists=True), + default=".", + help="Project directory containing .env and grounding files.", +) +@click.option( + "--model", + default=None, + help=( + "LLM model name to use for the repair (overrides .env). " + "Useful for retrying with a different model after a timeout." + ), +) +@click.option( + "--from-csv", + "from_csv", + type=click.Path(exists=True, dir_okay=False), + multiple=True, + help=( + "Derive the pre-tidy schema from CSV headers when no snapshot " + "is available. Pass once per source file (e.g. the wide-format " + "input the apply consumed). Each CSV becomes a single table " + "named after the file stem. Combinable with the snapshot — " + "snapshot tables win on conflict." + ), +) +@click.option( + "--dry-run", + is_flag=True, + default=False, + help="Show drift + LLM proposal + diff, but don't write any files.", +) +def grounding_repair( # noqa: C901 + project_dir: str, + model: str | None, + from_csv: tuple[str, ...], + dry_run: bool, +) -> None: + """Run the LLM grounding repair against an existing drift. + + Reads the pre-tidy schema snapshot persisted by the most recent + apply (``.datasight/grounding_snapshot.json``). When no snapshot + is on file, ``--from-csv`` lets you supply the wide-form schema + by pointing at the source CSV(s). + + Shows the unified diff and prompts for confirmation before writing. + Use ``--dry-run`` to skip the write entirely. + """ + project_dir = str(Path(project_dir).resolve()) + settings, resolved_model = cli.resolve_settings(project_dir, model) + if settings.database.mode != "duckdb": + click.echo( + "grounding repair requires DuckDB; database.mode is " + f"{settings.database.mode!r}.", + err=True, + ) + sys.exit(2) + resolved_db_path = cli.resolve_db_path(settings, project_dir) + if not resolved_db_path or not os.path.exists(resolved_db_path): + click.echo(f"Error: Database file not found: {resolved_db_path}", err=True) + sys.exit(1) + + # Build old_schema from the snapshot, then merge in any --from-csv + # entries for tables the snapshot doesn't cover. We deliberately + # let snapshot tables win — the snapshot was captured by the actual + # apply, the CSV is just a structural fallback. + old_schema: dict[str, set[str]] = {} + snapshot = read_snapshot(project_dir) + if snapshot: + old_schema.update(snapshot) + click.echo(f"Loaded snapshot: {snapshot_path(project_dir)}") + for csv_path in from_csv: + table_name = Path(csv_path).stem + cols = _read_csv_header(Path(csv_path)) + if table_name in old_schema: + click.echo( + f" --from-csv {csv_path}: snapshot already covers " + f"{table_name!r}; using snapshot version." + ) + continue + old_schema[table_name] = set(cols) + click.echo(f" --from-csv {csv_path}: derived {table_name} ({len(cols)} columns)") + + if not old_schema: + click.echo( + "No pre-tidy schema available: snapshot file missing and no " + "--from-csv passed. Either run `datasight tidy review` first " + "(which writes the snapshot), or pass --from-csv pointing at " + "the wide-form source.", + err=True, + ) + sys.exit(1) + + conn = duckdb.connect(resolved_db_path, read_only=True) + try: + new_schema = build_schema_truth_sync(conn) + enums = build_enum_values_sync(conn, new_schema) + finally: + conn.close() + + drift = check_grounding_drift(Path(project_dir), new_schema, enum_values=enums) + if drift.is_clean: + click.echo("grounding clean: no drift detected. Nothing to repair.") + return + + click.echo("") + click.echo(format_drift_report(drift)) + click.echo("") + + try: + cli.validate_settings_for_llm(settings) + except (click.UsageError, click.ClickException, SystemExit) as exc: + click.echo(f"No LLM configured to run the repair: {exc}", err=True) + sys.exit(1) + + click.echo(f"Running repair with model: {resolved_model}") + try: + result = asyncio.run( + _run_repair( + project_dir, old_schema, new_schema, drift, settings, resolved_model + ) + ) + except Exception as exc: # noqa: BLE001 — surface to user + click.echo(f"Repair failed: {exc}", err=True) + sys.exit(1) + + if not result.any_changes: + click.echo("LLM proposed no changes.") + return + + click.echo("") + click.echo(format_repair_summary(result)) + for f in result.files: + if not f.changed: + continue + click.echo("") + click.echo(f.unified_diff(), nl=False) + + if not result.overall_ok: + click.echo("") + click.echo( + "Some proposed files failed validation after retries. Skipping write; " + "edit the files manually using the diff above as a starting point.", + err=True, + ) + sys.exit(1) + + if dry_run: + click.echo("") + click.echo("--dry-run: no files written.") + return + + click.echo("") + if not click.confirm("Apply this diff?", default=False): + click.echo("Aborted; no files written.") + return + + written = write_repair_atomic(result, Path(project_dir)) + for p in written: + click.echo(f"Wrote {p}") + + +async def _run_repair( + project_dir: str, + old_schema: dict[str, set[str]], + new_schema: dict[str, set[str]], + drift: Any, + settings: Any, + resolved_model: str, +): + """Wire up the LLM client + SQL runner the repair library needs.""" + llm_client = cli.create_llm_client( + provider=settings.llm.provider, + api_key=settings.llm.api_key, + base_url=settings.llm.base_url, + timeout=settings.llm.timeout, + model=resolved_model, + ) + try: + sql_runner = create_sql_runner_from_settings(settings.database, project_dir) + return await repair_grounding( + Path(project_dir), + old_schema, + new_schema, + drift, + llm_client=llm_client, + model=resolved_model, + run_sql=sql_runner.run_sql, + ) + finally: + await llm_client.aclose() + + +def _read_csv_header(path: Path) -> list[str]: + """Read the first line of a CSV as the header row. + + Uses :mod:`csv` rather than ``str.split`` so quoted fields with + embedded commas don't get miscounted. Strips whitespace because + real-world CSV headers are inconsistently formatted. + """ + with path.open("r", encoding="utf-8", newline="") as f: + reader = csv.reader(f) + try: + row = next(reader) + except StopIteration: + msg = f"CSV is empty: {path}" + raise click.ClickException(msg) from None + return [c.strip() for c in row if c.strip()] + + +grounding.add_command(grounding_check) +grounding.add_command(grounding_repair) diff --git a/src/datasight/cli_commands/tidy.py b/src/datasight/cli_commands/tidy.py index 8cd419c4..a0b5db98 100644 --- a/src/datasight/cli_commands/tidy.py +++ b/src/datasight/cli_commands/tidy.py @@ -26,6 +26,7 @@ format_repair_summary, repair_grounding, write_repair_atomic, + write_snapshot, ) from datasight.schema import introspect_schema from datasight.tidy import _detect_period_groups, analyze_tidy_patterns @@ -482,6 +483,16 @@ def tidy_table(project_dir, source_table, dry_run): "when the LLM seeing the values is acceptable." ), ) +@click.option( + "--model", + default=None, + help=( + "LLM model name to use for the propose-reshapes call and the " + "post-apply grounding-repair call (overrides .env). Useful when " + "different models suit each workload — see " + "docs/use/concepts/choosing-an-llm.md." + ), +) def tidy_review( # noqa: C901 project_dir, source_table, @@ -495,6 +506,7 @@ def tidy_review( # noqa: C901 replace_source, drop_source, sample_rows, + model, ): """LLM-augmented advisor that proposes reshapes for the developer to review. @@ -551,6 +563,11 @@ def tidy_review( # noqa: C901 if settings.database.mode != "duckdb": msg = "tidy review requires DuckDB; the apply path opens a writable DuckDB connection." raise click.UsageError(msg) + # ``--model`` overrides the configured default for both the + # propose-reshapes call and the post-apply grounding-repair call. + # Both flows accept the override; the rest of the command (snapshot, + # validation, DDL) doesn't touch the LLM. + resolved_model = model if model else settings.llm.model # Load suggestions. Three sources, in priority order: # 1. --from PLAN : load that plan and skip the LLM entirely. @@ -568,7 +585,9 @@ def tidy_review( # noqa: C901 ) suggestions = [s for sugs in suggestions_by_table.values() for s in sugs] else: - suggestions = _propose_via_llm(project_dir, settings, source_table, sample_rows) + suggestions = _propose_via_llm( + project_dir, settings, source_table, sample_rows, resolved_model + ) if source_table: suggestions = [s for s in suggestions if s.table == source_table] @@ -627,9 +646,12 @@ async def _load_schema(): return # Snapshot the pre-apply schema so the grounding repair flow (post-apply, - # below) can show the LLM both old and new schemas. Skipped on dry runs - # since nothing changes; skipped on snapshot errors so a quirky DuckDB - # state doesn't block the apply itself. + # below) can show the LLM both old and new schemas. Persist it to + # ``.datasight/grounding_snapshot.json`` so a later + # ``datasight grounding repair`` invocation can retry against the + # same baseline (e.g. with a different model after a timeout). + # Skipped on dry runs since nothing changes; skipped on snapshot + # errors so a quirky DuckDB state doesn't block the apply itself. old_schema: dict[str, set[str]] | None = None if not dry_run: try: @@ -640,11 +662,22 @@ async def _load_schema(): ro.close() except duckdb.Error: old_schema = None + if old_schema is not None: + try: + write_snapshot(project_dir, old_schema) + except OSError as exc: + click.echo( + f"warn: grounding snapshot write failed ({exc}); " + f"`datasight grounding repair` won't be able to retry against this baseline.", + err=True, + ) _apply_review_proposals(approved, disposition, as_mode, dry_run, resolved_db_path, project_dir) if not dry_run and old_schema is not None: - _offer_grounding_repair(resolved_db_path, project_dir, settings, old_schema) + _offer_grounding_repair( + resolved_db_path, project_dir, settings, old_schema, resolved_model + ) def _offer_grounding_repair( # noqa: C901 @@ -652,6 +685,7 @@ def _offer_grounding_repair( # noqa: C901 project_dir: str, settings: Any, old_schema: dict[str, set[str]], + resolved_model: str, ) -> None: """Post-apply hook: detect grounding drift, offer LLM-driven repair. @@ -697,7 +731,6 @@ def _offer_grounding_repair( # noqa: C901 ): return - resolved_model = settings.llm.model try: result = asyncio.run( _run_grounding_repair( @@ -770,7 +803,11 @@ async def _run_grounding_repair( def _propose_via_llm( - project_dir: str, settings: Any, source_table: str | None, sample_rows: int + project_dir: str, + settings: Any, + source_table: str | None, + sample_rows: int, + resolved_model: str, ) -> list: """Call the configured LLM provider for tidy-reshape proposals. @@ -809,7 +846,7 @@ async def _gather(): api_key=settings.llm.api_key, base_url=settings.llm.base_url, timeout=settings.llm.timeout, - model=settings.llm.model, + model=resolved_model, ) async def _call(): @@ -818,7 +855,7 @@ async def _call(): try: return await propose_reshapes( llm_client, - model=settings.llm.model, + model=resolved_model, schema_info=schema_info, deterministic_hits=deterministic_hits, samples=samples or None, diff --git a/src/datasight/grounding_repair.py b/src/datasight/grounding_repair.py index 7a0f2d90..94abbf34 100644 --- a/src/datasight/grounding_repair.py +++ b/src/datasight/grounding_repair.py @@ -40,6 +40,7 @@ import re import tempfile from dataclasses import dataclass, field +from datetime import datetime, timezone from pathlib import Path from typing import Any, Awaitable, Callable @@ -49,6 +50,95 @@ from datasight.grounding import DriftReport from datasight.llm import LLMClient, TextBlock +# Pre-tidy schema snapshot: written by tidy apply, consumed by grounding +# repair. Lives under ``.datasight/`` (same dir as conversations, +# query_log.jsonl, etc.) so it survives server restarts and is +# accessible to both the CLI and the web endpoint. +SNAPSHOT_RELATIVE_PATH = Path(".datasight") / "grounding_snapshot.json" +SNAPSHOT_SCHEMA_VERSION = 1 + + +def snapshot_path(project_dir: Path | str) -> Path: + """Return the absolute path to the project's grounding snapshot file.""" + return Path(project_dir) / SNAPSHOT_RELATIVE_PATH + + +def write_snapshot(project_dir: Path | str, schema: dict[str, set[str]]) -> Path: + """Persist the pre-tidy schema snapshot atomically. + + Overwrites any prior snapshot — there's only one "most recent + apply" per project. Sets are serialized as sorted lists for stable + file content (so a snapshot diff in git review is meaningful). + + Parameters + ---------- + project_dir : Path | str + Project root containing ``.datasight/``. + schema : dict[str, set[str]] + Tables → column-name set, as captured immediately before the + tidy apply mutates the database. + + Returns + ------- + Path + The path that was written. + """ + payload = { + "schema_version": SNAPSHOT_SCHEMA_VERSION, + "applied_at": datetime.now(timezone.utc).isoformat(), + "schema": {table: sorted(cols) for table, cols in schema.items()}, + } + target = snapshot_path(project_dir) + target.parent.mkdir(parents=True, exist_ok=True) + with tempfile.NamedTemporaryFile( + mode="w", encoding="utf-8", dir=target.parent, delete=False, + prefix=".grounding-snapshot-", + ) as tmp: + json.dump(payload, tmp, indent=2, sort_keys=True) + tmp.flush() + os.fsync(tmp.fileno()) + tmp_path = Path(tmp.name) + try: + os.replace(tmp_path, target) + except OSError: + tmp_path.unlink(missing_ok=True) + raise + return target + + +def read_snapshot(project_dir: Path | str) -> dict[str, set[str]] | None: + """Load the most recent pre-tidy schema snapshot, or None if absent. + + Returns None for missing files, malformed JSON, or unrecognized + schema versions — callers treat absence as "no prior apply on + record" and surface that to the user. We deliberately don't raise + on corruption so a bad snapshot can't permanently brick the + repair flow; the user can re-apply or pass an explicit fallback. + """ + path = snapshot_path(project_dir) + if not path.exists(): + return None + try: + payload = json.loads(path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError) as exc: + logger.warning(f"grounding snapshot at {path} is unreadable: {exc}") + return None + if not isinstance(payload, dict) or payload.get("schema_version") != SNAPSHOT_SCHEMA_VERSION: + logger.warning( + f"grounding snapshot at {path} has unexpected schema_version " + f"{payload.get('schema_version') if isinstance(payload, dict) else 'n/a'!r}; ignoring" + ) + return None + raw_schema = payload.get("schema") + if not isinstance(raw_schema, dict): + return None + out: dict[str, set[str]] = {} + for table, cols in raw_schema.items(): + if not isinstance(table, str) or not isinstance(cols, list): + continue + out[table] = {c for c in cols if isinstance(c, str)} + return out + # Files the repair flow may touch. Other grounding-adjacent files # (schema.yaml, measures.yaml) are owned by tidy_review's own update diff --git a/src/datasight/web/app.py b/src/datasight/web/app.py index 81807568..6e600d65 100644 --- a/src/datasight/web/app.py +++ b/src/datasight/web/app.py @@ -102,8 +102,10 @@ check_grounding_drift, ) from datasight.grounding_repair import ( + read_snapshot, repair_grounding, write_repair_atomic, + write_snapshot, ) from datasight.sql_validation import build_measure_rule_map, build_schema_map, validate_sql from datasight.tidy import _detect_period_groups, analyze_tidy_patterns @@ -604,6 +606,10 @@ def __init__(self) -> None: self.is_ephemeral: bool = False self.ephemeral_tables_info: list[dict[str, Any]] = [] self.time_series_configs: list[dict[str, Any]] = [] + # Snapshot of the schema captured immediately before the most recent + # tidy_apply. Consumed by /api/tidy/grounding/repair so the LLM + # repair prompt can show both shapes; cleared on project change. + self.pre_tidy_schema: dict[str, set[str]] | None = None def clear_project(self) -> None: """Clear project-specific state.""" @@ -638,6 +644,7 @@ def clear_project(self) -> None: self.ephemeral_tables_info = [] self._ephemeral_messages = {} self._session_locks.clear() + self.pre_tidy_schema = None def rebuild_system_prompt(self) -> None: """Rebuild the system prompt after settings change.""" @@ -3347,14 +3354,27 @@ async def tidy_apply(request: Request, state: AppState = Depends(get_state)): # db_path = runner._database_path - # Snapshot the pre-apply schema so the post-apply grounding-repair step + # Snapshot the pre-apply schema so a follow-up grounding-repair call # can show the LLM both old and new shapes. Without an explicit # snapshot the repair prompt degenerates to "regenerate from scratch" - # and loses any human customizations in the grounding files. + # and loses any human customizations in the grounding files. Stashed + # on AppState because grounding repair runs as a separate request + # (the LLM call is too slow to block the apply response). old_schema: dict[str, set[str]] = { t["name"]: {c["name"] for c in t["columns"]} for t in state.schema_info } + state.pre_tidy_schema = old_schema + # Persist alongside the in-memory snapshot so the user can retry + # grounding repair after a server restart, or run it from the CLI + # against this same baseline. Best-effort: a write failure here + # shouldn't block the apply itself (the in-memory copy still works + # for the immediate banner flow). + if state.project_dir and not state.is_ephemeral: + try: + write_snapshot(state.project_dir, old_schema) + except OSError as exc: + logger.warning(f"grounding snapshot write failed: {exc}") async with state.state_lock: # Cached DataFrames can keep DuckDB buffers alive and block the @@ -3449,34 +3469,26 @@ async def tidy_apply(request: Request, state: AppState = Depends(get_state)): # for t in tables ] - # Grounding repair: when the schema actually changed, rewrite - # queries.yaml / schema_description.md / time_series.yaml so the - # next LLM call sees stale-reference-free grounding. Must happen - # *before* the schema_text rebuild below, because - # ``_load_user_description`` reads schema_description.md from - # disk — we want it to load the repaired content, not the stale - # one. Failures here are logged and surfaced in the response but - # don't fail the apply (the database mutation already committed). - grounding_summary: dict[str, Any] | None = None + # Grounding drift check: fast, static. When the schema actually + # changed, scan queries.yaml / schema_description.md / + # time_series.yaml for stale references so the UI can prompt the + # user to run grounding repair as a separate, slow step. The LLM + # rewrite happens in /api/tidy/grounding/repair — keeping it out + # of the apply path so the response returns immediately. + grounding_drift: dict[str, Any] | None = None new_schema: dict[str, set[str]] = { t["name"]: {c["name"] for c in t["columns"]} for t in state.schema_info } if ( state.project_dir - and state.llm_client is not None - and state.model and new_schema != old_schema and not state.is_ephemeral ): - grounding_summary = await _repair_grounding_after_tidy( + grounding_drift = _summarize_grounding_drift( project_dir=state.project_dir, db_path=db_path, - old_schema=old_schema, new_schema=new_schema, - llm_client=state.llm_client, - model=state.model, - run_sql=state.sql_runner.run_sql, ) state.schema_text = format_schema_context( @@ -3490,11 +3502,44 @@ async def tidy_apply(request: Request, state: AppState = Depends(get_state)): # "success": True, "result": result.to_dict(), "schema_info": state.schema_info, - "grounding_repair": grounding_summary, + "grounding_drift": grounding_drift, + } + + +def _summarize_grounding_drift( + *, + project_dir: str, + db_path: str, + new_schema: dict[str, set[str]], +) -> dict[str, Any]: + """Run the static drift check and return a summary for the UI banner. + + Fast: just reads the project's grounding YAML/markdown files and + cross-references them against ``new_schema``. Used by ``tidy_apply`` + so the response can tell the client whether to prompt the user for + grounding repair. + """ + import duckdb as _duckdb + + try: + ro = _duckdb.connect(db_path, read_only=True) + try: + enum_values = build_enum_values_sync(ro, new_schema) + finally: + ro.close() + except Exception: # noqa: BLE001 + enum_values = set() + + drift = check_grounding_drift( + Path(project_dir), new_schema, enum_values=enum_values + ) + return { + "needs_repair": not drift.is_clean, + "drift_items": len(drift.items), } -async def _repair_grounding_after_tidy( # noqa: C901 +async def _run_grounding_repair( # noqa: C901 *, project_dir: str, db_path: str, @@ -3504,7 +3549,7 @@ async def _repair_grounding_after_tidy( # noqa: C901 model: str, run_sql, ) -> dict[str, Any]: - """Run the grounding-drift check and, if needed, the LLM repair. + """Run the LLM grounding repair and atomically write changed files. Returns a structured summary the API caller can render in the UI: @@ -3514,13 +3559,9 @@ async def _repair_grounding_after_tidy( # noqa: C901 - ``validation_errors``: per-file SQL execution errors when the LLM's proposal still didn't validate after retries (empty on success) - ``error``: top-level exception message when the flow itself crashed - - The endpoint already committed the database mutation by the time this - runs, so failures here are surfaced as info but don't fail the apply. """ import duckdb as _duckdb - # Load enum values from a fresh read-only conn — pure data fetch. try: ro = _duckdb.connect(db_path, read_only=True) try: @@ -3597,6 +3638,136 @@ async def _repair_grounding_after_tidy( # noqa: C901 } +@app.post("/api/tidy/grounding/repair") +async def tidy_grounding_repair(state: AppState = Depends(get_state)): + """Run the slow LLM grounding repair triggered by the user post-apply. + + Split out from ``tidy_apply`` so the apply response can return + immediately. The client calls this endpoint after seeing + ``grounding_drift.needs_repair`` in the apply response. Reads the + pre-tidy schema snapshot stashed on AppState during the apply. + """ + if not state.project_loaded or not state.project_dir: + return {"success": False, "error": "No project loaded"} + if state.is_ephemeral: + return {"success": False, "error": "Grounding repair not supported in explore sessions"} + if state.llm_client is None or not state.model: + return {"success": False, "error": "LLM client not configured"} + if state.sql_runner is None: + return {"success": False, "error": "No SQL runner available"} + # Prefer the in-memory snapshot from the current server's apply, but + # fall back to the on-disk snapshot so the user can retry repair + # after a restart (or after the LLM call timed out and nothing was + # written). The CLI's `datasight grounding repair` reads the same + # file — they're interchangeable baselines. + old_schema = state.pre_tidy_schema + if old_schema is None: + old_schema = read_snapshot(state.project_dir) + if old_schema is None: + return { + "success": False, + "error": ( + "No pre-tidy schema snapshot available. Run a tidy apply " + "first, or use `datasight grounding repair --from-csv ...` " + "to supply a baseline from a source CSV." + ), + } + + from datasight.runner import DuckDBRunner + + runner = state.sql_runner + if isinstance(runner, CachingSqlRunner): + runner = runner._inner + if not isinstance(runner, DuckDBRunner): + return {"success": False, "error": "Grounding repair requires a project DuckDB runner"} + db_path = runner._database_path + + new_schema: dict[str, set[str]] = { + t["name"]: {c["name"] for c in t["columns"]} + for t in state.schema_info + } + + async with state.state_lock: + summary = await _run_grounding_repair( + project_dir=state.project_dir, + db_path=db_path, + old_schema=old_schema, + new_schema=new_schema, + llm_client=state.llm_client, + model=state.model, + run_sql=state.sql_runner.run_sql, + ) + + # Reload the user-facing prose so the system prompt picks up the + # repaired schema_description.md. schema_info itself didn't + # change in this endpoint — only the grounding files on disk did. + if summary.get("applied"): + tables = await introspect_schema( + state.sql_runner.run_sql, runner=state.sql_runner + ) + state.schema_text = format_schema_context( + tables, + user_description=_load_user_description(state), + ) + state.schema_map = build_schema_map(state.schema_info) + state.rebuild_system_prompt() + # Snapshot is consumed — clear it so a stale one can't shadow + # a future apply. + state.pre_tidy_schema = None + + return {"success": True, "grounding_repair": summary} + + +@app.get("/api/grounding/status") +async def grounding_status(state: AppState = Depends(get_state)): + """Static grounding-drift summary for the always-on header pill. + + The header pill polls this on project load so the user sees + "grounding stale" any time the files are out of sync — not only + after they just ran an apply. Static check, no LLM, fast. + + Returns ``{available: false}`` for sessions where the check + doesn't apply (no project loaded, ephemeral, non-DuckDB) so the + UI can hide the pill rather than showing a confusing error. + """ + if not state.project_loaded or not state.project_dir or state.is_ephemeral: + return {"available": False, "reason": "no_project"} + if state.sql_runner is None: + return {"available": False, "reason": "no_runner"} + + from datasight.runner import DuckDBRunner + + runner = state.sql_runner + if isinstance(runner, CachingSqlRunner): + runner = runner._inner + if not isinstance(runner, DuckDBRunner): + return {"available": False, "reason": "non_duckdb"} + db_path = runner._database_path + + new_schema: dict[str, set[str]] = { + t["name"]: {c["name"] for c in t["columns"]} + for t in state.schema_info + } + summary = _summarize_grounding_drift( + project_dir=state.project_dir, + db_path=db_path, + new_schema=new_schema, + ) + # Tell the UI whether a snapshot exists so the pill can choose + # whether to enable the repair button — without one, the user has + # to use the CLI's ``--from-csv`` fallback. + has_snapshot = ( + state.pre_tidy_schema is not None + or read_snapshot(state.project_dir) is not None + ) + return { + "available": True, + "needs_repair": summary["needs_repair"], + "drift_items": summary["drift_items"], + "has_snapshot": has_snapshot, + } + + @app.get("/api/measures/editor") async def get_measure_overrides_editor(state: AppState = Depends(get_state)): """Return editable measure override YAML for the active project.""" diff --git a/tests/test_web_tidy.py b/tests/test_web_tidy.py index 25a9cbe9..5d48e9dd 100644 --- a/tests/test_web_tidy.py +++ b/tests/test_web_tidy.py @@ -458,16 +458,12 @@ def test_apply_creates_table_with_replace_disposition(loaded_state): assert any(c["name"] == "year" for c in sales_cols) -def test_apply_repairs_stale_grounding_when_llm_configured(loaded_state, monkeypatch): +def test_apply_returns_grounding_drift_summary(loaded_state, monkeypatch): """When grounding files reference columns the reshape removes, the - apply endpoint should detect drift, call the LLM, and rewrite the - affected files atomically — surfacing what changed in the response. - - Monkeypatches :func:`repair_grounding` so the test exercises the - wiring (drift detection → repair call → atomic write → response - shaping) without depending on a live LLM. The endpoint's TestClient - startup re-initializes ``state.llm_client`` from the env, so a - state-level stub gets clobbered before the request runs. + apply endpoint runs the static drift check inline and returns a + summary so the UI can prompt the user — but the slow LLM rewrite is + NOT called from /api/tidy/apply (it moved to a separate endpoint + so the apply response can return immediately). """ project_dir = Path(loaded_state.project_dir) queries_path = project_dir / "queries.yaml" @@ -476,6 +472,53 @@ def test_apply_repairs_stale_grounding_when_llm_configured(loaded_state, monkeyp " sql: SELECT region, sales_2020 FROM sales;\n", encoding="utf-8", ) + + repair_called: list[Any] = [] + + async def fake_repair(*args, **kwargs): + repair_called.append((args, kwargs)) + raise AssertionError("repair_grounding must not run during apply") + + monkeypatch.setattr(web_app, "repair_grounding", fake_repair) + + suggestion = _detect_period_groups(loaded_state.schema_info[0])[0] + proposal = _suggestion_to_proposal_dict(suggestion) + + with TestClient(web_app.app) as client: + response = client.post( + "/api/tidy/apply", + json={ + "proposal": proposal, + "mode": "table", + "disposition": {"mode": "drop"}, + }, + ) + + body = response.json() + assert body["success"], body + drift = body["grounding_drift"] + assert drift is not None, body + assert drift["needs_repair"] is True + assert drift["drift_items"] >= 1 + assert repair_called == [], "repair_grounding should not be called from apply" + # Apply leaves the file untouched — the user has to opt in to repair. + assert "sales_2020" in queries_path.read_text(encoding="utf-8") + # The pre-tidy schema snapshot is stashed for the repair endpoint. + assert loaded_state.pre_tidy_schema is not None + assert "sales_2020" in loaded_state.pre_tidy_schema["sales"] + + +def test_grounding_repair_endpoint_rewrites_stale_files(loaded_state, monkeypatch): + """Calling /api/tidy/grounding/repair after a tidy apply runs the + LLM rewrite and atomically writes the corrected files, using the + pre-tidy schema snapshot stashed on AppState during the apply.""" + project_dir = Path(loaded_state.project_dir) + queries_path = project_dir / "queries.yaml" + queries_path.write_text( + "- question: Stale wide reference\n" + " sql: SELECT region, sales_2020 FROM sales;\n", + encoding="utf-8", + ) repaired_yaml = ( "- question: Long-form sales\n" " sql: SELECT region, period, sales FROM sales_long;\n" @@ -504,7 +547,7 @@ async def fake_repair(project_dir, old_schema, new_schema, drift, **kwargs): proposal = _suggestion_to_proposal_dict(suggestion) with TestClient(web_app.app) as client: - response = client.post( + apply_resp = client.post( "/api/tidy/apply", json={ "proposal": proposal, @@ -512,11 +555,13 @@ async def fake_repair(project_dir, old_schema, new_schema, drift, **kwargs): "disposition": {"mode": "drop"}, }, ) + assert apply_resp.json()["success"], apply_resp.json() - body = response.json() + repair_resp = client.post("/api/tidy/grounding/repair") + + body = repair_resp.json() assert body["success"], body summary = body["grounding_repair"] - assert summary is not None, body assert summary["applied"] is True, summary assert summary["files_written"] == ["queries.yaml"] assert summary["drift_items"] >= 1 @@ -524,12 +569,16 @@ async def fake_repair(project_dir, old_schema, new_schema, drift, **kwargs): rewritten = queries_path.read_text(encoding="utf-8") assert "sales_2020" not in rewritten assert "sales_long" in rewritten + # Snapshot is consumed after a successful write so a second click + # can't reuse a stale pre-tidy schema. + assert loaded_state.pre_tidy_schema is None -def test_apply_skips_grounding_repair_when_no_llm_configured(loaded_state, monkeypatch): - """No llm_client → repair flow short-circuits before the drift check. - Response carries ``grounding_repair: None`` and ``repair_grounding`` - is never called even when drift exists on disk.""" +def test_grounding_repair_endpoint_requires_llm_client(loaded_state, monkeypatch): + """The static drift check runs even without an LLM, so apply still + surfaces drift. But the repair endpoint needs an LLM client — it + returns an error so the UI can surface it instead of silently + failing.""" project_dir = Path(loaded_state.project_dir) queries_path = project_dir / "queries.yaml" original_text = "- question: Stale\n sql: SELECT sales_2020 FROM sales;\n" @@ -539,7 +588,7 @@ def test_apply_skips_grounding_repair_when_no_llm_configured(loaded_state, monke async def fake_repair(*args, **kwargs): repair_called.append((args, kwargs)) - raise AssertionError("repair_grounding should not be called without llm_client") + raise AssertionError("repair_grounding must not run without llm_client") monkeypatch.setattr(web_app, "repair_grounding", fake_repair) @@ -547,9 +596,8 @@ async def fake_repair(*args, **kwargs): proposal = _suggestion_to_proposal_dict(suggestion) with TestClient(web_app.app) as client: - # Drop the LLM client after startup has re-initialized it from env. loaded_state.llm_client = None - response = client.post( + apply_resp = client.post( "/api/tidy/apply", json={ "proposal": proposal, @@ -557,16 +605,20 @@ async def fake_repair(*args, **kwargs): "disposition": {"mode": "drop"}, }, ) + # Apply still reports drift — the static check doesn't need an LLM. + assert apply_resp.json()["grounding_drift"]["needs_repair"] is True - body = response.json() - assert body["success"], body - assert body["grounding_repair"] is None, body - assert repair_called == [], "repair_grounding was called unexpectedly" - # The stale queries.yaml stays exactly as written. + loaded_state.llm_client = None + repair_resp = client.post("/api/tidy/grounding/repair") + + body = repair_resp.json() + assert body["success"] is False, body + assert "LLM client" in body.get("error", ""), body + assert repair_called == [] assert queries_path.read_text(encoding="utf-8") == original_text -def test_apply_grounding_repair_surfaces_validation_errors(loaded_state, monkeypatch): +def test_grounding_repair_endpoint_surfaces_validation_errors(loaded_state, monkeypatch): """When the LLM proposal validates dirty, the file is left untouched and the summary surfaces the validation errors so the UI can show them to the user.""" @@ -597,7 +649,7 @@ async def fake_repair(project_dir, old_schema, new_schema, drift, **kwargs): proposal = _suggestion_to_proposal_dict(suggestion) with TestClient(web_app.app) as client: - response = client.post( + apply_resp = client.post( "/api/tidy/apply", json={ "proposal": proposal, @@ -605,17 +657,21 @@ async def fake_repair(project_dir, old_schema, new_schema, drift, **kwargs): "disposition": {"mode": "drop"}, }, ) + assert apply_resp.json()["success"], apply_resp.json() - body = response.json() + repair_resp = client.post("/api/tidy/grounding/repair") + + body = repair_resp.json() assert body["success"], body summary = body["grounding_repair"] - assert summary is not None, body assert summary["applied"] is False assert summary["files_written"] == [] assert summary["validation_errors"], summary assert any("nonexistent_col" in err for err in summary["validation_errors"]) # File contents untouched — write_repair_atomic skipped the bad file. assert queries_path.read_text(encoding="utf-8") == original_text + # Snapshot retained on failure so the user can retry without re-applying. + assert loaded_state.pre_tidy_schema is not None def test_apply_creates_table_with_bare_drop_disposition(loaded_state): From 132905ca8394e1240c25f65eb0ab69ec67e4e59e Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 11 May 2026 07:33:13 -0600 Subject: [PATCH 07/10] Fix ruff lint errors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit EM101 in the two new "must not run" assertions added by the test split — assign the message to a variable first, matching the project-wide convention. C901 noqa added to three functions intentionally over the McCabe threshold; pre-existing but uncovered when the previous CI run skipped them on path filters. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/datasight/grounding.py | 4 ++-- src/datasight/grounding_repair.py | 2 +- tests/test_web_tidy.py | 6 ++++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/datasight/grounding.py b/src/datasight/grounding.py index f9f6ef94..8912161a 100644 --- a/src/datasight/grounding.py +++ b/src/datasight/grounding.py @@ -169,7 +169,7 @@ def check_grounding_drift( return report -def _check_queries( +def _check_queries( # noqa: C901 path: Path, schema_truth: dict[str, set[str]], report: DriftReport ) -> None: """Walk every ``sql:`` block, flag unresolved table/column references.""" @@ -252,7 +252,7 @@ def _check_queries( )) -def _check_schema_description( +def _check_schema_description( # noqa: C901 path: Path, schema_truth: dict[str, set[str]], enum_values: set[str], diff --git a/src/datasight/grounding_repair.py b/src/datasight/grounding_repair.py index 94abbf34..4fef02da 100644 --- a/src/datasight/grounding_repair.py +++ b/src/datasight/grounding_repair.py @@ -475,7 +475,7 @@ def _parse_repair_json(text: str) -> dict[str, str]: return out -async def _validate_repair( +async def _validate_repair( # noqa: C901 files: list[RepairFile], *, run_sql: Callable[[str], Awaitable[Any]], diff --git a/tests/test_web_tidy.py b/tests/test_web_tidy.py index 5d48e9dd..033e7557 100644 --- a/tests/test_web_tidy.py +++ b/tests/test_web_tidy.py @@ -477,7 +477,8 @@ def test_apply_returns_grounding_drift_summary(loaded_state, monkeypatch): async def fake_repair(*args, **kwargs): repair_called.append((args, kwargs)) - raise AssertionError("repair_grounding must not run during apply") + msg = "repair_grounding must not run during apply" + raise AssertionError(msg) monkeypatch.setattr(web_app, "repair_grounding", fake_repair) @@ -588,7 +589,8 @@ def test_grounding_repair_endpoint_requires_llm_client(loaded_state, monkeypatch async def fake_repair(*args, **kwargs): repair_called.append((args, kwargs)) - raise AssertionError("repair_grounding must not run without llm_client") + msg = "repair_grounding must not run without llm_client" + raise AssertionError(msg) monkeypatch.setattr(web_app, "repair_grounding", fake_repair) From ebc7a46d9a9619f032b4d49d529822e6262036e7 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 11 May 2026 08:15:25 -0600 Subject: [PATCH 08/10] Address PR #72 review feedback and CI failures - Run ruff format across grounding-touched files (10 files reformatted) - Stub llm_client inside TestClient context for the two grounding-repair endpoint tests so FastAPI startup's init_llm_client doesn't clear it in CI (no ANTHROPIC_API_KEY) - Sanitize stack-trace exposure: return generic messages from the grounding-repair endpoint instead of str(exc); keep server logs intact (CodeQL CWE-209) - Validate qualified column references against the referenced table's columns rather than the union, so renamed columns in one table aren't shadowed by their presence in another - Flag markdown table.column references where the table itself is missing instead of silently skipping them (the failure this check exists to catch) - Type-check group_columns before iterating in time_series drift check, matching load_time_series_config's contract - Quote identifiers via _quote_identifier and bound the enum-values scan with LIMIT max_per_column + 1 so a free-text column on a large table can't turn the pre-flight into a full scan - Make llm_retries semantics consistent: report the number of retries performed (attempt index on success, max_retries on give-up) - Exit with an error for --static-only when the DuckDB file is missing, instead of falling through to the LLM phase - Release state_lock around the slow LLM call in tidy_grounding_repair; re-acquire only for the small post-write critical section, with a project-switch guard so a concurrent project change can't be clobbered Co-Authored-By: Claude Opus 4.7 (1M context) --- src/datasight/cli_commands/grounding.py | 10 +- src/datasight/cli_commands/tidy.py | 8 +- src/datasight/cli_commands/verify.py | 13 +- src/datasight/cost.py | 7 +- src/datasight/grounding.py | 318 +++++++++++++++++------- src/datasight/grounding_repair.py | 40 +-- src/datasight/web/app.py | 94 +++---- tests/test_grounding.py | 131 +++++----- tests/test_grounding_repair.py | 128 +++++++--- tests/test_web_tidy.py | 19 +- 10 files changed, 498 insertions(+), 270 deletions(-) diff --git a/src/datasight/cli_commands/grounding.py b/src/datasight/cli_commands/grounding.py index f59467a7..a70ff886 100644 --- a/src/datasight/cli_commands/grounding.py +++ b/src/datasight/cli_commands/grounding.py @@ -104,8 +104,7 @@ def grounding_check(project_dir: str) -> None: settings, _ = cli.resolve_settings(project_dir) if settings.database.mode != "duckdb": click.echo( - "grounding check requires DuckDB; database.mode is " - f"{settings.database.mode!r}.", + f"grounding check requires DuckDB; database.mode is {settings.database.mode!r}.", err=True, ) sys.exit(2) @@ -200,8 +199,7 @@ def grounding_repair( # noqa: C901 settings, resolved_model = cli.resolve_settings(project_dir, model) if settings.database.mode != "duckdb": click.echo( - "grounding repair requires DuckDB; database.mode is " - f"{settings.database.mode!r}.", + f"grounding repair requires DuckDB; database.mode is {settings.database.mode!r}.", err=True, ) sys.exit(2) @@ -266,9 +264,7 @@ def grounding_repair( # noqa: C901 click.echo(f"Running repair with model: {resolved_model}") try: result = asyncio.run( - _run_repair( - project_dir, old_schema, new_schema, drift, settings, resolved_model - ) + _run_repair(project_dir, old_schema, new_schema, drift, settings, resolved_model) ) except Exception as exc: # noqa: BLE001 — surface to user click.echo(f"Repair failed: {exc}", err=True) diff --git a/src/datasight/cli_commands/tidy.py b/src/datasight/cli_commands/tidy.py index a0b5db98..159f5114 100644 --- a/src/datasight/cli_commands/tidy.py +++ b/src/datasight/cli_commands/tidy.py @@ -707,9 +707,7 @@ def _offer_grounding_repair( # noqa: C901 if old_schema == new_schema: return - drift = check_grounding_drift( - Path(project_dir), new_schema, enum_values=enum_values - ) + drift = check_grounding_drift(Path(project_dir), new_schema, enum_values=enum_values) if drift.is_clean: return @@ -726,9 +724,7 @@ def _offer_grounding_repair( # noqa: C901 ) return - if not click.confirm( - "Repair grounding files with the configured LLM?", default=False - ): + if not click.confirm("Repair grounding files with the configured LLM?", default=False): return try: diff --git a/src/datasight/cli_commands/verify.py b/src/datasight/cli_commands/verify.py index a7ebfdc8..cec0cea2 100644 --- a/src/datasight/cli_commands/verify.py +++ b/src/datasight/cli_commands/verify.py @@ -104,9 +104,7 @@ def verify(project_dir, model, queries_path, static_only, skip_grounding_check): enums = build_enum_values_sync(conn, truth) finally: conn.close() - report = check_grounding_drift( - Path(project_dir), truth, enum_values=enums - ) + report = check_grounding_drift(Path(project_dir), truth, enum_values=enums) if not report.is_clean: click.echo(format_drift_report(report), err=True) click.echo("", err=True) @@ -122,6 +120,15 @@ def verify(project_dir, model, queries_path, static_only, skip_grounding_check): elif static_only: click.echo("grounding clean: no drift detected.") sys.exit(0) + elif static_only: + # --static-only can't run without a live DB to introspect. + # Fail loudly instead of falling through to the LLM phase, + # which would contradict the flag's semantics. + click.echo( + f"Error: Database file not found: {resolved_db_path}", + err=True, + ) + sys.exit(1) elif static_only: click.echo( "--static-only requires DuckDB; database.mode is " diff --git a/src/datasight/cost.py b/src/datasight/cost.py index 7ed895e2..0e72259c 100644 --- a/src/datasight/cost.py +++ b/src/datasight/cost.py @@ -66,9 +66,7 @@ def build_cost_data( if elapsed_seconds is not None and elapsed_seconds > 0: data["elapsed_seconds"] = round(elapsed_seconds, 4) data["output_tokens_per_sec"] = round(output_tokens / elapsed_seconds, 2) - data["total_tokens_per_sec"] = round( - (input_tokens + output_tokens) / elapsed_seconds, 2 - ) + data["total_tokens_per_sec"] = round((input_tokens + output_tokens) / elapsed_seconds, 2) if provider is None or provider in _PROVIDERS_WITH_PRICING: pricing = MODEL_PRICING.get(model) if pricing: @@ -112,8 +110,7 @@ def log_query_cost( rate_str = "" if "elapsed_seconds" in data: rate_str = ( - f" elapsed={data['elapsed_seconds']:.2f}s" - f" out_tps={data['output_tokens_per_sec']:.1f}" + f" elapsed={data['elapsed_seconds']:.2f}s out_tps={data['output_tokens_per_sec']:.1f}" ) logger.info( f"[tokens] QUERY TOTAL: api_calls={api_calls} " diff --git a/src/datasight/grounding.py b/src/datasight/grounding.py index 8912161a..049dd1f6 100644 --- a/src/datasight/grounding.py +++ b/src/datasight/grounding.py @@ -37,6 +37,8 @@ import yaml from sqlglot import exp +from datasight.schema import _quote_identifier + # Backtick-quoted lowercase identifier in markdown. Matches ``foo``, # ``foo_bar``, ``foo.bar``. Anything that isn't a snake_case identifier @@ -48,18 +50,77 @@ # inside ``schema_description.md`` — used to suppress false positives # from the markdown scan. Not exhaustive: only the words that show up in # prose. Anything not here AND not in the current schema gets flagged. -_SQL_KEYWORDS: frozenset[str] = frozenset({ - "all", "and", "as", "asc", "avg", "between", "boolean", "by", - "case", "cast", "ceil", "coalesce", "corr", "count", "current_date", - "date", "date_trunc", "datetime", "day", "desc", "distinct", - "double", "else", "end", "extract", "false", "floor", "from", - "group", "having", "in", "inner", "integer", "is", "join", "left", - "limit", "max", "min", "month", "not", "now", "null", "offset", - "on", "or", "order", "outer", "over", "regr_intercept", "regr_r2", - "regr_slope", "right", "round", "row_number", "select", "sum", - "then", "timestamp", "to_date", "true", "union", "varchar", - "when", "where", "with", "year", -}) +_SQL_KEYWORDS: frozenset[str] = frozenset( + { + "all", + "and", + "as", + "asc", + "avg", + "between", + "boolean", + "by", + "case", + "cast", + "ceil", + "coalesce", + "corr", + "count", + "current_date", + "date", + "date_trunc", + "datetime", + "day", + "desc", + "distinct", + "double", + "else", + "end", + "extract", + "false", + "floor", + "from", + "group", + "having", + "in", + "inner", + "integer", + "is", + "join", + "left", + "limit", + "max", + "min", + "month", + "not", + "now", + "null", + "offset", + "on", + "or", + "order", + "outer", + "over", + "regr_intercept", + "regr_r2", + "regr_slope", + "right", + "round", + "row_number", + "select", + "sum", + "then", + "timestamp", + "to_date", + "true", + "union", + "varchar", + "when", + "where", + "with", + "year", + } +) @dataclass @@ -177,10 +238,15 @@ def _check_queries( # noqa: C901 try: docs = yaml.safe_load(text) or [] except yaml.YAMLError as exc: - report.items.append(DriftItem( - file=str(path), line=None, kind="parse_error", - claim="", detail=f"yaml parse error: {exc}", - )) + report.items.append( + DriftItem( + file=str(path), + line=None, + kind="parse_error", + claim="", + detail=f"yaml parse error: {exc}", + ) + ) return if not isinstance(docs, list): return @@ -229,27 +295,57 @@ def _check_queries( # noqa: C901 name = tref.name if not name or name in cte_names or name in all_tables: continue - report.items.append(DriftItem( - file=str(path), line=None, kind="table", - claim=name, - detail=f"table '{name}' not in current schema", - suggestion=_nearest(name, all_tables), - )) + report.items.append( + DriftItem( + file=str(path), + line=None, + kind="table", + claim=name, + detail=f"table '{name}' not in current schema", + suggestion=_nearest(name, all_tables), + ) + ) for cref in stmt.find_all(exp.Column): name = cref.name - if not name or name in all_columns or name in output_aliases: + if not name or name in output_aliases: + continue + qualifier = cref.table + if qualifier: + # Qualified-but-unknown table prefixes are caught by + # the table check above; ignore the column part to + # avoid double-reporting. + if qualifier not in all_tables: + continue + # Validate against the referenced table's column set + # rather than the union. A column moved/renamed in + # one table but still present elsewhere would + # otherwise be a false negative. + table_cols = schema_truth.get(qualifier, set()) + if name in table_cols: + continue + report.items.append( + DriftItem( + file=str(path), + line=None, + kind="column", + claim=f"{qualifier}.{name}", + detail=f"column '{name}' not in table '{qualifier}'", + suggestion=_nearest(name, table_cols), + ) + ) continue - # Qualified-but-unknown table prefixes are caught by the - # table check above; ignore the column part to avoid - # double-reporting. - if cref.table and cref.table not in all_tables: + if name in all_columns: continue - report.items.append(DriftItem( - file=str(path), line=None, kind="column", - claim=name, - detail=f"column '{name}' not in any table", - suggestion=_nearest(name, all_columns), - )) + report.items.append( + DriftItem( + file=str(path), + line=None, + kind="column", + claim=name, + detail=f"column '{name}' not in any table", + suggestion=_nearest(name, all_columns), + ) + ) def _check_schema_description( # noqa: C901 @@ -271,13 +367,28 @@ def _check_schema_description( # noqa: C901 for m in _MD_BACKTICK_IDENT.finditer(line): ident = m.group(1).lower() - # ``table.column`` — check the column against that table's - # column list. Unknown tables are silently ignored here to - # avoid noise from prose that mentions tables from other DBs. + # ``table.column`` — check both the table and the column + # against the live schema. A reference to a renamed/dropped + # table is exactly the failure this check is meant to catch, + # so flag it instead of silently skipping. parts = ident.split(".") if len(parts) == 2: table, col = parts if table not in all_tables: + key = (lineno, ident) + if key in seen_on_line: + continue + seen_on_line.add(key) + report.items.append( + DriftItem( + file=str(path), + line=lineno, + kind="table", + claim=ident, + detail=f"`{ident}` references unknown table '{table}'", + suggestion=_nearest(table, all_tables), + ) + ) continue if col not in schema_truth.get(table, set()): key = (lineno, ident) @@ -285,12 +396,16 @@ def _check_schema_description( # noqa: C901 continue seen_on_line.add(key) suggestion = _nearest(col, schema_truth.get(table, set())) - report.items.append(DriftItem( - file=str(path), line=lineno, kind="column", - claim=ident, - detail=f"`{ident}` not a column of '{table}'", - suggestion=f"{table}.{suggestion}" if suggestion else None, - )) + report.items.append( + DriftItem( + file=str(path), + line=lineno, + kind="column", + claim=ident, + detail=f"`{ident}` not a column of '{table}'", + suggestion=f"{table}.{suggestion}" if suggestion else None, + ) + ) continue if ident in known: @@ -304,15 +419,19 @@ def _check_schema_description( # noqa: C901 if key in seen_on_line: continue seen_on_line.add(key) - report.items.append(DriftItem( - file=str(path), line=lineno, kind="column", - claim=ident, - detail=f"`{ident}` not in current schema (column or table)", - suggestion=_nearest(ident, all_columns | all_tables), - )) - - -def _check_time_series( + report.items.append( + DriftItem( + file=str(path), + line=lineno, + kind="column", + claim=ident, + detail=f"`{ident}` not in current schema (column or table)", + suggestion=_nearest(ident, all_columns | all_tables), + ) + ) + + +def _check_time_series( # noqa: C901 path: Path, schema_truth: dict[str, set[str]], report: DriftReport ) -> None: """Verify each entry's ``table`` / ``timestamp_column`` / ``group_columns``.""" @@ -320,10 +439,15 @@ def _check_time_series( try: docs = yaml.safe_load(text) or [] except yaml.YAMLError as exc: - report.items.append(DriftItem( - file=str(path), line=None, kind="parse_error", - claim="", detail=f"yaml parse error: {exc}", - )) + report.items.append( + DriftItem( + file=str(path), + line=None, + kind="parse_error", + claim="", + detail=f"yaml parse error: {exc}", + ) + ) return if not isinstance(docs, list): return @@ -335,29 +459,48 @@ def _check_time_series( if not table: continue if table not in schema_truth: - report.items.append(DriftItem( - file=str(path), line=None, kind="ts_table", - claim=str(table), - detail=f"time_series table '{table}' not in current schema", - suggestion=_nearest(str(table), set(schema_truth.keys())), - )) + report.items.append( + DriftItem( + file=str(path), + line=None, + kind="ts_table", + claim=str(table), + detail=f"time_series table '{table}' not in current schema", + suggestion=_nearest(str(table), set(schema_truth.keys())), + ) + ) continue ts_col = entry.get("timestamp_column") if ts_col and ts_col not in schema_truth[table]: - report.items.append(DriftItem( - file=str(path), line=None, kind="ts_column", - claim=str(ts_col), - detail=f"time_series timestamp_column '{ts_col}' not a column of '{table}'", - suggestion=_nearest(str(ts_col), schema_truth[table]), - )) - for col in entry.get("group_columns") or []: + report.items.append( + DriftItem( + file=str(path), + line=None, + kind="ts_column", + claim=str(ts_col), + detail=f"time_series timestamp_column '{ts_col}' not a column of '{table}'", + suggestion=_nearest(str(ts_col), schema_truth[table]), + ) + ) + # Match ``config.load_time_series_config``'s contract: only + # iterate when ``group_columns`` is actually a list. A scalar + # (e.g. ``group_columns: region``) would otherwise iterate + # characters and produce nonsense drift items. + group_cols = entry.get("group_columns") + if not isinstance(group_cols, list): + continue + for col in group_cols: if col not in schema_truth[table]: - report.items.append(DriftItem( - file=str(path), line=None, kind="ts_column", - claim=str(col), - detail=f"time_series group_column '{col}' not a column of '{table}'", - suggestion=_nearest(str(col), schema_truth[table]), - )) + report.items.append( + DriftItem( + file=str(path), + line=None, + kind="ts_column", + claim=str(col), + detail=f"time_series group_column '{col}' not a column of '{table}'", + suggestion=_nearest(str(col), schema_truth[table]), + ) + ) def _nearest(claim: str, candidates: set[str]) -> str | None: @@ -411,18 +554,25 @@ def build_enum_values_sync( if "char" not in str(dtype).lower() and "string" not in str(dtype).lower(): continue try: - count = conn.execute( - f"SELECT COUNT(DISTINCT {col}) FROM {table}" - ).fetchone() - except Exception: # noqa: BLE001 — never let one bad column abort the whole scan - continue - if count is None or count[0] > max_per_column: + qcol = _quote_identifier(col) + qtable = _quote_identifier(table) + except ValueError: + # Identifier with embedded control chars; not safe to embed. continue + # Bound the scan so a free-text column on a huge table can't turn + # the "cheap pre-flight" grounding check into a full-table scan. + # ``LIMIT max_per_column + 1`` lets us decide whether the column + # qualifies (<= max_per_column distinct values) without counting + # all the rest. try: values = conn.execute( - f"SELECT DISTINCT {col} FROM {table} WHERE {col} IS NOT NULL" + f"SELECT DISTINCT {qcol} FROM {qtable} " + f"WHERE {qcol} IS NOT NULL LIMIT {max_per_column + 1}" ).fetchall() - except Exception: # noqa: BLE001 + except Exception: # noqa: BLE001 — never let one bad column abort the whole scan + continue + if len(values) > max_per_column: + # Likely free-text or high-cardinality; skip without adding. continue for (v,) in values: if isinstance(v, str): @@ -481,9 +631,7 @@ async def build_schema_truth_async( return out -def format_drift_report( - report: DriftReport, *, max_items_per_file: int = 20 -) -> str: +def format_drift_report(report: DriftReport, *, max_items_per_file: int = 20) -> str: """Render a DriftReport as a multi-line string for terminal output. Truncates per-file listings beyond ``max_items_per_file`` with a diff --git a/src/datasight/grounding_repair.py b/src/datasight/grounding_repair.py index 4fef02da..1e025387 100644 --- a/src/datasight/grounding_repair.py +++ b/src/datasight/grounding_repair.py @@ -91,7 +91,10 @@ def write_snapshot(project_dir: Path | str, schema: dict[str, set[str]]) -> Path target = snapshot_path(project_dir) target.parent.mkdir(parents=True, exist_ok=True) with tempfile.NamedTemporaryFile( - mode="w", encoding="utf-8", dir=target.parent, delete=False, + mode="w", + encoding="utf-8", + dir=target.parent, + delete=False, prefix=".grounding-snapshot-", ) as tmp: json.dump(payload, tmp, indent=2, sort_keys=True) @@ -255,11 +258,14 @@ async def repair_grounding( proposed: dict[str, str] | None = None last_error: str | None = None - retries = 0 for attempt in range(max_retries + 1): - user_prompt = prompt if last_error is None else ( - f"{prompt}\n\nYour previous response had validation errors. " - f"Fix them and return the full corrected JSON object:\n\n{last_error}" + user_prompt = ( + prompt + if last_error is None + else ( + f"{prompt}\n\nYour previous response had validation errors. " + f"Fix them and return the full corrected JSON object:\n\n{last_error}" + ) ) response = await llm_client.create_message( model=model, @@ -273,7 +279,6 @@ async def repair_grounding( proposed = _parse_repair_json(text) except ValueError as exc: last_error = f"Could not parse JSON from your response: {exc}" - retries = attempt + 1 logger.warning(f"repair attempt {attempt + 1}: {last_error}") continue @@ -285,6 +290,9 @@ async def repair_grounding( await _validate_repair(files, run_sql=run_sql) if all(f.ok for f in files): + # ``attempt`` is the 0-based index of the call that succeeded, + # which equals the number of retries performed (0 = success + # on the first call, 1 = one retry, ...). return RepairResult(files=files, llm_retries=attempt) # Build a summarized error report for the next retry. @@ -293,14 +301,12 @@ async def repair_grounding( for err in f.validation_errors: error_lines.append(f"- {f.name}: {err}") last_error = "\n".join(error_lines) - retries = attempt + 1 - logger.warning( - f"repair attempt {attempt + 1}: {len(error_lines)} validation error(s)" - ) + logger.warning(f"repair attempt {attempt + 1}: {len(error_lines)} validation error(s)") # Out of retries — return the last attempt with its errors so the - # caller can fall back to manual edit mode. - return RepairResult(files=files, llm_retries=retries) + # caller can fall back to manual edit mode. ``max_retries`` is the + # number of retries performed beyond the initial call. + return RepairResult(files=files, llm_retries=max_retries) def write_repair_atomic(result: RepairResult, project_dir: Path) -> list[Path]: @@ -332,7 +338,11 @@ def write_repair_atomic(result: RepairResult, project_dir: Path) -> list[Path]: # NamedTemporaryFile with delete=False so we can keep the path # after closing; os.replace then atomically swaps it in. with tempfile.NamedTemporaryFile( - mode="w", encoding="utf-8", dir=parent, delete=False, prefix=".grounding-", + mode="w", + encoding="utf-8", + dir=parent, + delete=False, + prefix=".grounding-", ) as tmp: tmp.write(f.new_text) tmp.flush() @@ -507,9 +517,7 @@ async def _validate_repair( # noqa: C901 await run_sql(sql) except Exception as exc: # noqa: BLE001 question = entry.get("question", "(no question)") - f.validation_errors.append( - f"query {i} ({question!r}) failed: {exc}" - ) + f.validation_errors.append(f"query {i} ({question!r}) failed: {exc}") elif f.name == "time_series.yaml": try: yaml.safe_load(f.new_text) diff --git a/src/datasight/web/app.py b/src/datasight/web/app.py index 6e600d65..cef5da0b 100644 --- a/src/datasight/web/app.py +++ b/src/datasight/web/app.py @@ -3361,8 +3361,7 @@ async def tidy_apply(request: Request, state: AppState = Depends(get_state)): # # on AppState because grounding repair runs as a separate request # (the LLM call is too slow to block the apply response). old_schema: dict[str, set[str]] = { - t["name"]: {c["name"] for c in t["columns"]} - for t in state.schema_info + t["name"]: {c["name"] for c in t["columns"]} for t in state.schema_info } state.pre_tidy_schema = old_schema # Persist alongside the in-memory snapshot so the user can retry @@ -3477,14 +3476,9 @@ async def tidy_apply(request: Request, state: AppState = Depends(get_state)): # # of the apply path so the response returns immediately. grounding_drift: dict[str, Any] | None = None new_schema: dict[str, set[str]] = { - t["name"]: {c["name"] for c in t["columns"]} - for t in state.schema_info + t["name"]: {c["name"] for c in t["columns"]} for t in state.schema_info } - if ( - state.project_dir - and new_schema != old_schema - and not state.is_ephemeral - ): + if state.project_dir and new_schema != old_schema and not state.is_ephemeral: grounding_drift = _summarize_grounding_drift( project_dir=state.project_dir, db_path=db_path, @@ -3530,9 +3524,7 @@ def _summarize_grounding_drift( except Exception: # noqa: BLE001 enum_values = set() - drift = check_grounding_drift( - Path(project_dir), new_schema, enum_values=enum_values - ) + drift = check_grounding_drift(Path(project_dir), new_schema, enum_values=enum_values) return { "needs_repair": not drift.is_clean, "drift_items": len(drift.items), @@ -3571,9 +3563,7 @@ async def _run_grounding_repair( # noqa: C901 except Exception: # noqa: BLE001 enum_values = set() - drift = check_grounding_drift( - Path(project_dir), new_schema, enum_values=enum_values - ) + drift = check_grounding_drift(Path(project_dir), new_schema, enum_values=enum_values) if drift.is_clean: return { "drift_items": 0, @@ -3592,13 +3582,13 @@ async def _run_grounding_repair( # noqa: C901 model=model, run_sql=run_sql, ) - except Exception as exc: # noqa: BLE001 + except Exception: # noqa: BLE001 logger.exception("Grounding repair LLM call failed") return { "drift_items": len(drift.items), "files_written": [], "applied": False, - "error": str(exc), + "error": "Grounding repair failed; see server logs for details.", } if not result.any_changes: @@ -3610,9 +3600,7 @@ async def _run_grounding_repair( # noqa: C901 } if not result.overall_ok: validation_errors = [ - f"{f.name}: {err}" - for f in result.files - for err in f.validation_errors + f"{f.name}: {err}" for f in result.files for err in f.validation_errors ] return { "drift_items": len(drift.items), @@ -3623,13 +3611,13 @@ async def _run_grounding_repair( # noqa: C901 try: written = write_repair_atomic(result, Path(project_dir)) - except OSError as exc: + except OSError: logger.exception("Grounding repair atomic write failed") return { "drift_items": len(drift.items), "files_written": [], "applied": False, - "error": str(exc), + "error": "Failed to write repaired grounding files; see server logs for details.", } return { "drift_items": len(drift.items), @@ -3683,28 +3671,46 @@ async def tidy_grounding_repair(state: AppState = Depends(get_state)): db_path = runner._database_path new_schema: dict[str, set[str]] = { - t["name"]: {c["name"] for c in t["columns"]} - for t in state.schema_info + t["name"]: {c["name"] for c in t["columns"]} for t in state.schema_info } - async with state.state_lock: - summary = await _run_grounding_repair( - project_dir=state.project_dir, - db_path=db_path, - old_schema=old_schema, - new_schema=new_schema, - llm_client=state.llm_client, - model=state.model, - run_sql=state.sql_runner.run_sql, - ) + # Snapshot the inputs we need before releasing the lock so the LLM + # call doesn't see a half-mutated state if a concurrent project + # change races us. ``project_dir`` is the guard — if it differs when + # we re-acquire the lock, the user switched projects mid-call and + # we must not touch the new project's state. + project_dir_snapshot = state.project_dir + llm_client = state.llm_client + model = state.model + run_sql = state.sql_runner.run_sql + + # Run the slow LLM + validation + write outside ``state_lock`` so + # other state-level operations (project load/unload, settings + # updates, tidy apply) aren't blocked for the minutes a local model + # can take. The repair itself only touches grounding files on disk; + # it doesn't mutate AppState. + summary = await _run_grounding_repair( + project_dir=project_dir_snapshot, + db_path=db_path, + old_schema=old_schema, + new_schema=new_schema, + llm_client=llm_client, + model=model, + run_sql=run_sql, + ) - # Reload the user-facing prose so the system prompt picks up the - # repaired schema_description.md. schema_info itself didn't - # change in this endpoint — only the grounding files on disk did. - if summary.get("applied"): - tables = await introspect_schema( - state.sql_runner.run_sql, runner=state.sql_runner - ) + # Re-acquire the lock only for the small critical section that + # mutates AppState. Guard against a concurrent project switch: if + # the loaded project changed while the LLM ran, the schema we + # rebuilt no longer matches what's loaded — drop the mutation. + async with state.state_lock: + if ( + summary.get("applied") + and state.project_loaded + and state.project_dir == project_dir_snapshot + and state.sql_runner is not None + ): + tables = await introspect_schema(state.sql_runner.run_sql, runner=state.sql_runner) state.schema_text = format_schema_context( tables, user_description=_load_user_description(state), @@ -3745,8 +3751,7 @@ async def grounding_status(state: AppState = Depends(get_state)): db_path = runner._database_path new_schema: dict[str, set[str]] = { - t["name"]: {c["name"] for c in t["columns"]} - for t in state.schema_info + t["name"]: {c["name"] for c in t["columns"]} for t in state.schema_info } summary = _summarize_grounding_drift( project_dir=state.project_dir, @@ -3757,8 +3762,7 @@ async def grounding_status(state: AppState = Depends(get_state)): # whether to enable the repair button — without one, the user has # to use the CLI's ``--from-csv`` fallback. has_snapshot = ( - state.pre_tidy_schema is not None - or read_snapshot(state.project_dir) is not None + state.pre_tidy_schema is not None or read_snapshot(state.project_dir) is not None ) return { "available": True, diff --git a/tests/test_grounding.py b/tests/test_grounding.py index db844ac8..c7dfb089 100644 --- a/tests/test_grounding.py +++ b/tests/test_grounding.py @@ -27,9 +27,7 @@ def _make_db(tmp_path: Path, rows: list[tuple]) -> str: "time_year BIGINT, energy_mwh DOUBLE)" ) for row in rows: - conn.execute( - "INSERT INTO load_data VALUES (?, ?, ?, ?, ?)", row - ) + conn.execute("INSERT INTO load_data VALUES (?, ?, ?, ?, ?)", row) conn.close() return str(db_path) @@ -43,11 +41,14 @@ def test_build_schema_truth_sync_returns_table_to_columns(): def test_build_enum_values_sync_collects_distinct_strings(tmp_path): - db_path = _make_db(tmp_path, [ - ("pacific", "elec", "heating", 2020, 1.0), - ("pacific", "ng", "cooling", 2020, 2.0), - ("south_atlantic", "elec", "heating", 2020, 3.0), - ]) + db_path = _make_db( + tmp_path, + [ + ("pacific", "elec", "heating", 2020, 1.0), + ("pacific", "ng", "cooling", 2020, 2.0), + ("south_atlantic", "elec", "heating", 2020, 3.0), + ], + ) conn = duckdb.connect(db_path, read_only=True) truth = build_schema_truth_sync(conn) values = build_enum_values_sync(conn, truth) @@ -74,24 +75,28 @@ def test_build_enum_values_sync_skips_high_cardinality(tmp_path): def test_check_clean_grounding_reports_no_drift(tmp_path): truth = {"load_data": {"geography", "fuel_type", "energy_mwh"}} - (tmp_path / "queries.yaml").write_text(textwrap.dedent(""" + (tmp_path / "queries.yaml").write_text( + textwrap.dedent(""" - question: "Total energy by region" sql: | SELECT geography, SUM(energy_mwh) AS total FROM load_data WHERE fuel_type = 'elec' GROUP BY geography; - """).strip()) + """).strip() + ) report = check_grounding_drift(tmp_path, truth, enum_values={"elec"}) assert report.is_clean def test_queries_yaml_missing_column_is_flagged(tmp_path): truth = {"load_data": {"geography", "fuel_type", "energy_mwh"}} - (tmp_path / "queries.yaml").write_text(textwrap.dedent(""" + (tmp_path / "queries.yaml").write_text( + textwrap.dedent(""" - question: "Stale" sql: SELECT elec_heating FROM load_data; - """).strip()) + """).strip() + ) report = check_grounding_drift(tmp_path, truth) assert not report.is_clean assert any(item.claim == "elec_heating" for item in report.items) @@ -99,41 +104,44 @@ def test_queries_yaml_missing_column_is_flagged(tmp_path): def test_queries_yaml_missing_table_is_flagged(tmp_path): truth = {"load_data": {"geography"}} - (tmp_path / "queries.yaml").write_text(textwrap.dedent(""" + (tmp_path / "queries.yaml").write_text( + textwrap.dedent(""" - question: "Wrong table" sql: SELECT geography FROM missing_table; - """).strip()) - report = check_grounding_drift(tmp_path, truth) - assert any( - item.kind == "table" and item.claim == "missing_table" - for item in report.items + """).strip() ) + report = check_grounding_drift(tmp_path, truth) + assert any(item.kind == "table" and item.claim == "missing_table" for item in report.items) def test_queries_yaml_cte_name_is_not_flagged_as_missing_table(tmp_path): truth = {"load_data": {"geography", "energy_mwh"}} - (tmp_path / "queries.yaml").write_text(textwrap.dedent(""" + (tmp_path / "queries.yaml").write_text( + textwrap.dedent(""" - question: "CTE chain" sql: | WITH yearly AS ( SELECT geography, SUM(energy_mwh) AS total FROM load_data GROUP BY geography ) SELECT * FROM yearly; - """).strip()) + """).strip() + ) report = check_grounding_drift(tmp_path, truth) assert report.is_clean, [item.detail for item in report.items] def test_queries_yaml_output_alias_is_not_flagged(tmp_path): truth = {"load_data": {"geography", "energy_mwh"}} - (tmp_path / "queries.yaml").write_text(textwrap.dedent(""" + (tmp_path / "queries.yaml").write_text( + textwrap.dedent(""" - question: "Aliased output" sql: | SELECT geography, SUM(energy_mwh) AS total_energy FROM load_data GROUP BY geography ORDER BY total_energy DESC; - """).strip()) + """).strip() + ) report = check_grounding_drift(tmp_path, truth) assert report.is_clean, [item.detail for item in report.items] @@ -144,10 +152,7 @@ def test_schema_description_md_flags_missing_column_reference(tmp_path): "# Schema\n\nThe `elec_heating` column tracks electricity heating.\n" ) report = check_grounding_drift(tmp_path, truth) - assert any( - item.claim == "elec_heating" and item.line == 3 - for item in report.items - ) + assert any(item.claim == "elec_heating" and item.line == 3 for item in report.items) def test_schema_description_md_qualified_table_column_resolves(tmp_path): @@ -181,45 +186,42 @@ def test_schema_description_md_enum_values_are_allowlisted(tmp_path): def test_time_series_yaml_missing_table_is_flagged(tmp_path): truth = {"load_data": {"geography"}} - (tmp_path / "time_series.yaml").write_text(textwrap.dedent(""" + (tmp_path / "time_series.yaml").write_text( + textwrap.dedent(""" - table: missing_table timestamp_column: ts frequency: PT1H - """).strip()) - report = check_grounding_drift(tmp_path, truth) - assert any( - item.kind == "ts_table" and item.claim == "missing_table" - for item in report.items + """).strip() ) + report = check_grounding_drift(tmp_path, truth) + assert any(item.kind == "ts_table" and item.claim == "missing_table" for item in report.items) def test_time_series_yaml_missing_timestamp_column_is_flagged(tmp_path): truth = {"load_data": {"geography"}} - (tmp_path / "time_series.yaml").write_text(textwrap.dedent(""" + (tmp_path / "time_series.yaml").write_text( + textwrap.dedent(""" - table: load_data timestamp_column: ts frequency: PT1H - """).strip()) - report = check_grounding_drift(tmp_path, truth) - assert any( - item.kind == "ts_column" and item.claim == "ts" - for item in report.items + """).strip() ) + report = check_grounding_drift(tmp_path, truth) + assert any(item.kind == "ts_column" and item.claim == "ts" for item in report.items) def test_time_series_yaml_missing_group_column_is_flagged(tmp_path): truth = {"load_data": {"geography", "ts"}} - (tmp_path / "time_series.yaml").write_text(textwrap.dedent(""" + (tmp_path / "time_series.yaml").write_text( + textwrap.dedent(""" - table: load_data timestamp_column: ts group_columns: [geography, missing_dim] frequency: PT1H - """).strip()) - report = check_grounding_drift(tmp_path, truth) - assert any( - item.kind == "ts_column" and item.claim == "missing_dim" - for item in report.items + """).strip() ) + report = check_grounding_drift(tmp_path, truth) + assert any(item.kind == "ts_column" and item.claim == "missing_dim" for item in report.items) def test_missing_grounding_files_are_silently_skipped(tmp_path): @@ -229,16 +231,25 @@ def test_missing_grounding_files_are_silently_skipped(tmp_path): def test_format_drift_report_shows_per_file_breakdown(): - report = DriftReport(items=[ - DriftItem( - file="a/queries.yaml", line=None, kind="column", - claim="foo", detail="foo not found", suggestion="bar", - ), - DriftItem( - file="a/schema_description.md", line=10, kind="column", - claim="baz", detail="baz not found", - ), - ]) + report = DriftReport( + items=[ + DriftItem( + file="a/queries.yaml", + line=None, + kind="column", + claim="foo", + detail="foo not found", + suggestion="bar", + ), + DriftItem( + file="a/schema_description.md", + line=10, + kind="column", + claim="baz", + detail="baz not found", + ), + ] + ) text = format_drift_report(report) assert "queries.yaml" in text assert "schema_description.md" in text @@ -253,11 +264,13 @@ def test_format_drift_report_clean_returns_clean_message(): def test_drift_report_by_file_groups_items(): - report = DriftReport(items=[ - DriftItem(file="x", line=None, kind="column", claim="a", detail=""), - DriftItem(file="y", line=None, kind="column", claim="b", detail=""), - DriftItem(file="x", line=None, kind="column", claim="c", detail=""), - ]) + report = DriftReport( + items=[ + DriftItem(file="x", line=None, kind="column", claim="a", detail=""), + DriftItem(file="y", line=None, kind="column", claim="b", detail=""), + DriftItem(file="x", line=None, kind="column", claim="c", detail=""), + ] + ) grouped = report.by_file() assert list(grouped.keys()) == ["x", "y"] assert len(grouped["x"]) == 2 diff --git a/tests/test_grounding_repair.py b/tests/test_grounding_repair.py index d9679bfd..48c8e1fb 100644 --- a/tests/test_grounding_repair.py +++ b/tests/test_grounding_repair.py @@ -110,11 +110,13 @@ def test_parse_repair_json_rejects_malformed(): def test_parse_repair_json_drops_unknown_keys(): - text = json.dumps({ - "queries.yaml": "x", - "comment": "ignore me", - "schema_description.md": "y", - }) + text = json.dumps( + { + "queries.yaml": "x", + "comment": "ignore me", + "schema_description.md": "y", + } + ) result = _parse_repair_json(text) assert set(result.keys()) == {"queries.yaml", "schema_description.md"} @@ -204,10 +206,12 @@ def test_format_repair_summary_with_no_changes(): def test_repair_grounding_happy_path(tmp_path): """LLM proposal validates cleanly on first try; result is well-formed.""" db_path = _long_format_db(tmp_path) - (tmp_path / "queries.yaml").write_text(textwrap.dedent(""" + (tmp_path / "queries.yaml").write_text( + textwrap.dedent(""" - question: "Old top regions" sql: SELECT * FROM load_data WHERE elec_heating > 0; - """).strip()) + """).strip() + ) new_queries = textwrap.dedent(""" - question: "Top regions" @@ -217,17 +221,31 @@ def test_repair_grounding_happy_path(tmp_path): client = _FakeLLMClient([llm_response]) run_sql = _make_run_sql(db_path) - drift = DriftReport(items=[DriftItem( - file=str(tmp_path / "queries.yaml"), line=None, kind="column", - claim="elec_heating", detail="missing", - )]) + drift = DriftReport( + items=[ + DriftItem( + file=str(tmp_path / "queries.yaml"), + line=None, + kind="column", + claim="elec_heating", + detail="missing", + ) + ] + ) old_schema = {"load_data": {"elec_heating", "geography"}} new_schema = {"load_data": {"geography", "fuel_type", "end_use", "energy_mwh"}} - result = asyncio.run(repair_grounding( - tmp_path, old_schema, new_schema, drift, - llm_client=client, model="test", run_sql=run_sql, - )) + result = asyncio.run( + repair_grounding( + tmp_path, + old_schema, + new_schema, + drift, + llm_client=client, + model="test", + run_sql=run_sql, + ) + ) assert result.overall_ok assert result.any_changes assert result.llm_retries == 0 @@ -239,28 +257,46 @@ def test_repair_grounding_happy_path(tmp_path): def test_repair_grounding_retries_on_invalid_sql(tmp_path): """First proposal fails to execute; second one validates.""" db_path = _long_format_db(tmp_path) - (tmp_path / "queries.yaml").write_text(textwrap.dedent(""" + (tmp_path / "queries.yaml").write_text( + textwrap.dedent(""" - question: "Stale" sql: SELECT foo FROM load_data; - """).strip()) + """).strip() + ) - broken = json.dumps({"queries.yaml": "- question: q\n sql: SELECT still_broken FROM load_data;"}) + broken = json.dumps( + {"queries.yaml": "- question: q\n sql: SELECT still_broken FROM load_data;"} + ) good = json.dumps({"queries.yaml": "- question: q\n sql: SELECT geography FROM load_data;"}) client = _FakeLLMClient([broken, good]) run_sql = _make_run_sql(db_path) - drift = DriftReport(items=[DriftItem( - file=str(tmp_path / "queries.yaml"), line=None, kind="column", - claim="foo", detail="missing", - )]) + drift = DriftReport( + items=[ + DriftItem( + file=str(tmp_path / "queries.yaml"), + line=None, + kind="column", + claim="foo", + detail="missing", + ) + ] + ) old_schema = {"load_data": {"foo"}} new_schema = {"load_data": {"geography", "fuel_type", "end_use", "energy_mwh"}} - result = asyncio.run(repair_grounding( - tmp_path, old_schema, new_schema, drift, - llm_client=client, model="test", run_sql=run_sql, - max_retries=2, - )) + result = asyncio.run( + repair_grounding( + tmp_path, + old_schema, + new_schema, + drift, + llm_client=client, + model="test", + run_sql=run_sql, + max_retries=2, + ) + ) assert result.overall_ok assert result.llm_retries == 1 # The client should have been called twice and the second user prompt @@ -273,10 +309,12 @@ def test_repair_grounding_retries_on_invalid_sql(tmp_path): def test_repair_grounding_gives_up_after_max_retries(tmp_path): """When every proposal is broken, the result surfaces validation errors.""" db_path = _long_format_db(tmp_path) - (tmp_path / "queries.yaml").write_text(textwrap.dedent(""" + (tmp_path / "queries.yaml").write_text( + textwrap.dedent(""" - question: "Stale" sql: SELECT foo FROM load_data; - """).strip()) + """).strip() + ) bad = json.dumps({"queries.yaml": "- question: q\n sql: SELECT nope_a FROM load_data;"}) worse = json.dumps({"queries.yaml": "- question: q\n sql: SELECT nope_b FROM load_data;"}) @@ -288,11 +326,18 @@ def test_repair_grounding_gives_up_after_max_retries(tmp_path): old_schema: dict[str, set[str]] = {} new_schema = {"load_data": {"geography"}} - result = asyncio.run(repair_grounding( - tmp_path, old_schema, new_schema, drift, - llm_client=client, model="test", run_sql=run_sql, - max_retries=2, - )) + result = asyncio.run( + repair_grounding( + tmp_path, + old_schema, + new_schema, + drift, + llm_client=client, + model="test", + run_sql=run_sql, + max_retries=2, + ) + ) assert not result.overall_ok q_file = next(f for f in result.files if f.name == "queries.yaml") assert q_file.validation_errors @@ -313,10 +358,17 @@ def test_repair_files_are_loaded_from_disk(tmp_path): drift = DriftReport(items=[]) new_schema = {"load_data": {"geography"}} - result = asyncio.run(repair_grounding( - tmp_path, {}, new_schema, drift, - llm_client=client, model="test", run_sql=run_sql, - )) + result = asyncio.run( + repair_grounding( + tmp_path, + {}, + new_schema, + drift, + llm_client=client, + model="test", + run_sql=run_sql, + ) + ) file_names = {f.name for f in result.files} assert file_names == {"queries.yaml", "schema_description.md"} assert "time_series.yaml" not in file_names diff --git a/tests/test_web_tidy.py b/tests/test_web_tidy.py index 033e7557..48a0ab28 100644 --- a/tests/test_web_tidy.py +++ b/tests/test_web_tidy.py @@ -468,8 +468,7 @@ def test_apply_returns_grounding_drift_summary(loaded_state, monkeypatch): project_dir = Path(loaded_state.project_dir) queries_path = project_dir / "queries.yaml" queries_path.write_text( - "- question: Stale wide reference\n" - " sql: SELECT region, sales_2020 FROM sales;\n", + "- question: Stale wide reference\n sql: SELECT region, sales_2020 FROM sales;\n", encoding="utf-8", ) @@ -516,13 +515,11 @@ def test_grounding_repair_endpoint_rewrites_stale_files(loaded_state, monkeypatc project_dir = Path(loaded_state.project_dir) queries_path = project_dir / "queries.yaml" queries_path.write_text( - "- question: Stale wide reference\n" - " sql: SELECT region, sales_2020 FROM sales;\n", + "- question: Stale wide reference\n sql: SELECT region, sales_2020 FROM sales;\n", encoding="utf-8", ) repaired_yaml = ( - "- question: Long-form sales\n" - " sql: SELECT region, period, sales FROM sales_long;\n" + "- question: Long-form sales\n sql: SELECT region, period, sales FROM sales_long;\n" ) async def fake_repair(project_dir, old_schema, new_schema, drift, **kwargs): @@ -548,6 +545,11 @@ async def fake_repair(project_dir, old_schema, new_schema, drift, **kwargs): proposal = _suggestion_to_proposal_dict(suggestion) with TestClient(web_app.app) as client: + # FastAPI startup clears llm_client when no API key is configured + # (CI case). Stub it inside the context so the repair endpoint + # sees a configured client. + loaded_state.llm_client = cast(Any, object()) + loaded_state.model = "stub" apply_resp = client.post( "/api/tidy/apply", json={ @@ -651,6 +653,11 @@ async def fake_repair(project_dir, old_schema, new_schema, drift, **kwargs): proposal = _suggestion_to_proposal_dict(suggestion) with TestClient(web_app.app) as client: + # FastAPI startup clears llm_client when no API key is configured + # (CI case). Stub it inside the context so the repair endpoint + # sees a configured client. + loaded_state.llm_client = cast(Any, object()) + loaded_state.model = "stub" apply_resp = client.post( "/api/tidy/apply", json={ From 81dd922278592ea37b529e08e657852e754ea25d Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 11 May 2026 09:13:31 -0600 Subject: [PATCH 09/10] Address Copilot review: lock ordering and flag combo - verify: reject --static-only + --skip-grounding-check as mutually exclusive (exit 2). The static branch is gated on `not skip_grounding_check`, so the combination silently fell through to the LLM phase, contradicting --static-only's documented semantics. - tidy_apply: move pre-apply schema snapshot (old_schema capture, state.pre_tidy_schema assignment, write_snapshot) inside state_lock, so a concurrent project switch can't desync the snapshot from the DB the apply mutates. - tidy_grounding_repair: wrap the prerequisite checks and input snapshot in `async with state.state_lock` to match what the pre-existing "before releasing the lock" comment claimed. The slow LLM call still runs outside the lock; the final mutation section already re-acquires it and guards on project_dir_snapshot. - Add a CLI test covering the new flag rejection. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/datasight/cli_commands/verify.py | 13 +++ src/datasight/web/app.py | 142 +++++++++++++++------------ tests/test_cli_commands.py | 23 +++++ 3 files changed, 113 insertions(+), 65 deletions(-) diff --git a/src/datasight/cli_commands/verify.py b/src/datasight/cli_commands/verify.py index cec0cea2..6e7fb174 100644 --- a/src/datasight/cli_commands/verify.py +++ b/src/datasight/cli_commands/verify.py @@ -89,6 +89,19 @@ def verify(project_dir, model, queries_path, static_only, skip_grounding_check): project_dir = str(Path(project_dir).resolve()) + # ``--static-only`` runs *only* the drift check; ``--skip-grounding-check`` + # asks to skip that very check. The combination has no coherent + # meaning — accepting it silently would either run nothing or fall + # through to the LLM phase (contradicting ``--static-only``). Reject + # it up front so the user fixes their invocation. + if static_only and skip_grounding_check: + click.echo( + "Error: --static-only and --skip-grounding-check are mutually " + "exclusive (one runs the static check; the other skips it).", + err=True, + ) + sys.exit(2) + # Static drift check first. Cheap, no LLM, no async — runs against a # direct DuckDB connection. For non-DuckDB backends we skip the # static check (information_schema.columns availability varies) and diff --git a/src/datasight/web/app.py b/src/datasight/web/app.py index cef5da0b..c37076e3 100644 --- a/src/datasight/web/app.py +++ b/src/datasight/web/app.py @@ -3354,28 +3354,31 @@ async def tidy_apply(request: Request, state: AppState = Depends(get_state)): # db_path = runner._database_path - # Snapshot the pre-apply schema so a follow-up grounding-repair call - # can show the LLM both old and new shapes. Without an explicit - # snapshot the repair prompt degenerates to "regenerate from scratch" - # and loses any human customizations in the grounding files. Stashed - # on AppState because grounding repair runs as a separate request - # (the LLM call is too slow to block the apply response). - old_schema: dict[str, set[str]] = { - t["name"]: {c["name"] for c in t["columns"]} for t in state.schema_info - } - state.pre_tidy_schema = old_schema - # Persist alongside the in-memory snapshot so the user can retry - # grounding repair after a server restart, or run it from the CLI - # against this same baseline. Best-effort: a write failure here - # shouldn't block the apply itself (the in-memory copy still works - # for the immediate banner flow). - if state.project_dir and not state.is_ephemeral: - try: - write_snapshot(state.project_dir, old_schema) - except OSError as exc: - logger.warning(f"grounding snapshot write failed: {exc}") - async with state.state_lock: + # Snapshot the pre-apply schema so a follow-up grounding-repair + # call can show the LLM both old and new shapes. Without an + # explicit snapshot the repair prompt degenerates to "regenerate + # from scratch" and loses any human customizations in the + # grounding files. Stashed on AppState because grounding repair + # runs as a separate request (the LLM call is too slow to block + # the apply response). Capture this under the lock so a + # concurrent project switch can't desync the snapshot from the + # DB the apply is about to mutate. + old_schema: dict[str, set[str]] = { + t["name"]: {c["name"] for c in t["columns"]} for t in state.schema_info + } + state.pre_tidy_schema = old_schema + # Persist alongside the in-memory snapshot so the user can retry + # grounding repair after a server restart, or run it from the + # CLI against this same baseline. Best-effort: a write failure + # here shouldn't block the apply itself (the in-memory copy + # still works for the immediate banner flow). + if state.project_dir and not state.is_ephemeral: + try: + write_snapshot(state.project_dir, old_schema) + except OSError as exc: + logger.warning(f"grounding snapshot write failed: {exc}") + # Cached DataFrames can keep DuckDB buffers alive and block the # read-write reopen of the same file; drop them before closing. # See /api/add-files for the same dance. @@ -3635,54 +3638,63 @@ async def tidy_grounding_repair(state: AppState = Depends(get_state)): ``grounding_drift.needs_repair`` in the apply response. Reads the pre-tidy schema snapshot stashed on AppState during the apply. """ - if not state.project_loaded or not state.project_dir: - return {"success": False, "error": "No project loaded"} - if state.is_ephemeral: - return {"success": False, "error": "Grounding repair not supported in explore sessions"} - if state.llm_client is None or not state.model: - return {"success": False, "error": "LLM client not configured"} - if state.sql_runner is None: - return {"success": False, "error": "No SQL runner available"} - # Prefer the in-memory snapshot from the current server's apply, but - # fall back to the on-disk snapshot so the user can retry repair - # after a restart (or after the LLM call timed out and nothing was - # written). The CLI's `datasight grounding repair` reads the same - # file — they're interchangeable baselines. - old_schema = state.pre_tidy_schema - if old_schema is None: - old_schema = read_snapshot(state.project_dir) - if old_schema is None: - return { - "success": False, - "error": ( - "No pre-tidy schema snapshot available. Run a tidy apply " - "first, or use `datasight grounding repair --from-csv ...` " - "to supply a baseline from a source CSV." - ), - } - from datasight.runner import DuckDBRunner - runner = state.sql_runner - if isinstance(runner, CachingSqlRunner): - runner = runner._inner - if not isinstance(runner, DuckDBRunner): - return {"success": False, "error": "Grounding repair requires a project DuckDB runner"} - db_path = runner._database_path + # Snapshot the inputs we need under ``state_lock`` so a concurrent + # project switch can't hand us a mix of fields from before and after + # the swap. ``project_dir_snapshot`` is the guard — if it differs + # when we re-acquire the lock after the LLM call, the user switched + # projects mid-flight and we must not touch the new project's state. + # The lock is released before the slow LLM/validation/write work so + # other state-level operations aren't blocked for minutes. + async with state.state_lock: + if not state.project_loaded or not state.project_dir: + return {"success": False, "error": "No project loaded"} + if state.is_ephemeral: + return { + "success": False, + "error": "Grounding repair not supported in explore sessions", + } + if state.llm_client is None or not state.model: + return {"success": False, "error": "LLM client not configured"} + if state.sql_runner is None: + return {"success": False, "error": "No SQL runner available"} + # Prefer the in-memory snapshot from the current server's apply, + # but fall back to the on-disk snapshot so the user can retry + # repair after a restart (or after the LLM call timed out and + # nothing was written). The CLI's `datasight grounding repair` + # reads the same file — they're interchangeable baselines. + old_schema = state.pre_tidy_schema + if old_schema is None: + old_schema = read_snapshot(state.project_dir) + if old_schema is None: + return { + "success": False, + "error": ( + "No pre-tidy schema snapshot available. Run a tidy apply " + "first, or use `datasight grounding repair --from-csv ...` " + "to supply a baseline from a source CSV." + ), + } - new_schema: dict[str, set[str]] = { - t["name"]: {c["name"] for c in t["columns"]} for t in state.schema_info - } + runner = state.sql_runner + if isinstance(runner, CachingSqlRunner): + runner = runner._inner + if not isinstance(runner, DuckDBRunner): + return { + "success": False, + "error": "Grounding repair requires a project DuckDB runner", + } + db_path = runner._database_path + + new_schema: dict[str, set[str]] = { + t["name"]: {c["name"] for c in t["columns"]} for t in state.schema_info + } - # Snapshot the inputs we need before releasing the lock so the LLM - # call doesn't see a half-mutated state if a concurrent project - # change races us. ``project_dir`` is the guard — if it differs when - # we re-acquire the lock, the user switched projects mid-call and - # we must not touch the new project's state. - project_dir_snapshot = state.project_dir - llm_client = state.llm_client - model = state.model - run_sql = state.sql_runner.run_sql + project_dir_snapshot = state.project_dir + llm_client = state.llm_client + model = state.model + run_sql = state.sql_runner.run_sql # Run the slow LLM + validation + write outside ``state_lock`` so # other state-level operations (project load/unload, settings diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index bd12c7a2..bcd192bd 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -731,3 +731,26 @@ def test_report_run_saved_report(tv_project_isolated): result = runner.invoke(cli, ["report", "delete", "1", "--project-dir", tv_project_isolated]) assert result.exit_code == 0 + + +def test_verify_rejects_static_only_with_skip_grounding_check(tv_project_isolated): + """`--static-only` and `--skip-grounding-check` contradict each other. + + Accepting the combination silently would fall through to the LLM + phase (the static branch is gated by `not skip_grounding_check`), + contradicting `--static-only`'s documented semantics. Make sure the + CLI rejects it with a usage error before any LLM work happens. + """ + runner = CliRunner() + result = runner.invoke( + cli, + [ + "verify", + "--project-dir", + tv_project_isolated, + "--static-only", + "--skip-grounding-check", + ], + ) + assert result.exit_code == 2, result.output + assert "mutually exclusive" in result.output From bd642292fdaaf888628f647723b694c901cfae72 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 11 May 2026 09:31:30 -0600 Subject: [PATCH 10/10] Address Copilot follow-up: schema_map sync and fenced-JSON regex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - tidy_grounding_repair: assign state.schema_info from the freshly introspected `tables` before rebuilding schema_map. Previously schema_text was built from `tables` while schema_map was built from the pre-introspection schema_info, so the two could desync if the DB schema ever drifted from schema_info at that point. Mirrors the assignment order in tidy_apply. - grounding_repair: loosen _FENCED_JSON so terse outputs like ```json\n{...}``` (no newline before the closing fence) match on the first try instead of falling through to the bare-JSON path — which previously left trailing backticks in the candidate string and forced a wasted retry. Whitespace on either side of the body is now optional. - Add regression tests covering both fenced-block edge cases. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/datasight/grounding_repair.py | 7 +++++-- src/datasight/web/app.py | 18 ++++++++++++++++++ tests/test_grounding_repair.py | 16 ++++++++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/src/datasight/grounding_repair.py b/src/datasight/grounding_repair.py index 1e025387..f671734a 100644 --- a/src/datasight/grounding_repair.py +++ b/src/datasight/grounding_repair.py @@ -155,8 +155,11 @@ def read_snapshot(project_dir: Path | str) -> dict[str, set[str]] | None: # Match an opening ```json or ``` fence, then anything (non-greedy) up # to the closing ```. Used to peel a JSON object out of an LLM response -# that wraps it in a markdown code block. -_FENCED_JSON = re.compile(r"```(?:json)?\s*\n(.*?)\n```", re.DOTALL) +# that wraps it in a markdown code block. Whitespace on either side of +# the body is optional so we also handle terse outputs like +# ``` ```json{...}``` ``` (no newline before the closing fence) without +# burning a retry. The captured body is .strip()'d at the callsite. +_FENCED_JSON = re.compile(r"```(?:json)?\s*(.*?)\s*```", re.DOTALL) @dataclass diff --git a/src/datasight/web/app.py b/src/datasight/web/app.py index c37076e3..62f80ae0 100644 --- a/src/datasight/web/app.py +++ b/src/datasight/web/app.py @@ -3722,7 +3722,25 @@ async def tidy_grounding_repair(state: AppState = Depends(get_state)): and state.project_dir == project_dir_snapshot and state.sql_runner is not None ): + # Re-introspect so the prompt + map reflect any DB-side + # changes (the repair itself only rewrites grounding files, + # so in the common case this matches the existing + # schema_info — but we don't want to rely on that + # invariant). Update schema_info first, then build the map + # from it, so schema_text and schema_map can't desync. + # Mirrors the assignment order in tidy_apply. tables = await introspect_schema(state.sql_runner.run_sql, runner=state.sql_runner) + state.schema_info = [ + { + "name": t.name, + "row_count": t.row_count, + "columns": [ + {"name": c.name, "dtype": c.dtype, "nullable": c.nullable} + for c in t.columns + ], + } + for t in tables + ] state.schema_text = format_schema_context( tables, user_description=_load_user_description(state), diff --git a/tests/test_grounding_repair.py b/tests/test_grounding_repair.py index 48c8e1fb..e2df818d 100644 --- a/tests/test_grounding_repair.py +++ b/tests/test_grounding_repair.py @@ -92,6 +92,22 @@ def test_parse_repair_json_fenced_block(): assert _parse_repair_json(text) == {"queries.yaml": "a"} +def test_parse_repair_json_fenced_block_no_trailing_newline(): + # Some models emit the closing fence directly after ``}`` with no + # newline. The regex must still match so we don't waste a retry on + # a recoverable formatting quirk. + text = '```json\n{"queries.yaml": "a"}```' + assert _parse_repair_json(text) == {"queries.yaml": "a"} + + +def test_parse_repair_json_fenced_block_no_leading_newline(): + # Symmetric: opening fence on the same line as the JSON body. The + # plain ``` (no language tag) variant covers the case where the + # LLM omitted the language hint too. + text = '```{"queries.yaml": "a"}```' + assert _parse_repair_json(text) == {"queries.yaml": "a"} + + def test_parse_repair_json_with_prose_prefix(): text = 'Here you go:\n{"queries.yaml": "a"}' assert _parse_repair_json(text) == {"queries.yaml": "a"}