diff --git a/scripts/ablate_retrieval.py b/scripts/ablate_retrieval.py new file mode 100644 index 0000000..4850248 --- /dev/null +++ b/scripts/ablate_retrieval.py @@ -0,0 +1,223 @@ +"""Retrieval-config ablation harness (rrf_k / weights / fusion / rerank sweep). + +The snapshot cache bakes ``RetrievalConfig`` at first ingest and reloads the +*stale* baked config on a cache hit (dikw-core#250), so a config sweep under +``--cache read_write`` silently re-reports the first variant's numbers. This +harness therefore **forces ``--cache off``** for every variant — the documented +#250 workaround — at the cost of a fresh embed per variant. When #250 lands, +flip ``FORCED_CACHE`` to ``"rebuild"`` (one fresh snapshot per variant without +re-embedding unrelated runs) and delete this note. + +Retrieval config lives in the eval base's ``dikw.yml`` ``retrieval:`` block, not +on the eval CLI, so each variant runs against a temp base carrying its overrides. + +Pure helpers (override-merge, label, table) are unit-tested; the run loop shells +out to ``dikw client eval`` via ``run_eval``'s command builder. + +Usage: + uv run python scripts/ablate_retrieval.py --dataset domain-bilingual-v1 \\ + --variants '[{"rrf_k": 30}, {"rrf_k": 60}, {"rrf_k": 90}]' --dry-run +""" + +from __future__ import annotations + +import argparse +import json +import os +import subprocess +import sys +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +import yaml + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +for _candidate in (PROJECT_ROOT, PROJECT_ROOT / "src"): + _entry = str(_candidate) + if _entry not in sys.path: + sys.path.insert(0, _entry) + +from scripts.run_eval import ( # noqa: E402 + DEFAULT_BASE_TEMPLATE, + build_eval_command, + ensure_base, + merge_env, + parse_eval_report, +) + +from dikw_data.config import load_dotenv # noqa: E402 + +# dikw-core#250 workaround. The cache key omits RetrievalConfig, so only a fresh +# snapshot reflects a changed retrieval block. Flip to "rebuild" once #250 ships. +FORCED_CACHE = "off" + +DEFAULT_METRIC_KEYS: tuple[str, ...] = ( + "hit_at_3", + "hit_at_10", + "mrr", + "ndcg_at_10", + "recall_at_100", +) +DEFAULT_ENV = PROJECT_ROOT / ".env" +DEFAULT_REQUIRED_KEYS: tuple[str, ...] = ("MINIMAX_API_KEY", "GITEE_API_KEY") + + +# --- pure helpers (unit-tested) -------------------------------------------- + + +def apply_retrieval_overrides(base_yaml: str, overrides: dict[str, Any]) -> str: + """Merge ``overrides`` into the base ``dikw.yml``'s ``retrieval:`` block. + + Keys present in ``overrides`` replace (or add) entries under ``retrieval``; + every other block (``provider``, ``storage``, …) is preserved verbatim. An + empty override returns an equivalent document (round-tripped through yaml). + """ + doc = yaml.safe_load(base_yaml) or {} + if not isinstance(doc, dict): + raise ValueError("base dikw.yml is not a mapping") + retrieval = dict(doc.get("retrieval") or {}) + retrieval.update(overrides) + doc["retrieval"] = retrieval + return yaml.safe_dump(doc, sort_keys=False, allow_unicode=True) + + +def variant_label(overrides: dict[str, Any]) -> str: + """Stable, filesystem-safe label for an override set, e.g. ``rrf_k=60``. + + Empty overrides label as ``baseline`` (the template's own retrieval config). + """ + if not overrides: + return "baseline" + return ",".join(f"{k}={overrides[k]}" for k in sorted(overrides)) + + +def metrics_table(rows: list[dict[str, Any]], metric_keys: tuple[str, ...] = DEFAULT_METRIC_KEYS) -> str: + """Markdown comparison table, one row per variant. + + Each ``rows`` item is ``{"label": str, "exit_code": int, "metrics": {...}}``. + Missing metrics render as ``-`` so a partial sweep is still readable. + """ + head = "| variant | exit | " + " | ".join(metric_keys) + " |" + sep = "|" + "---|" * (2 + len(metric_keys)) + lines = [head, sep] + for row in rows: + metrics = row.get("metrics") or {} + cells = " | ".join( + f"{metrics[k]:.3f}" if isinstance(metrics.get(k), (int, float)) else "-" + for k in metric_keys + ) + lines.append(f"| {row['label']} | {row.get('exit_code', '')} | {cells} |") + return "\n".join(lines) + + +def parse_variants(spec: str) -> list[dict[str, Any]]: + """Parse a ``--variants`` JSON list of override dicts (validates shape).""" + parsed = json.loads(spec) + if not isinstance(parsed, list) or not all(isinstance(v, dict) for v in parsed): + raise ValueError("--variants must be a JSON list of objects") + return parsed + + +# --- side-effecting orchestration ------------------------------------------ + + +def _utc_stamp() -> str: + return datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") + + +def run_variant( + *, + dataset: Path, + overrides: dict[str, Any], + base_template_text: str, + work_dir: Path, + retrieval: str, + run_env: dict[str, str], +) -> dict[str, Any]: + """Materialise a temp base with ``overrides`` and run one eval (cache forced + off). Returns ``{"label", "exit_code", "metrics", "report_path"}``.""" + label = variant_label(overrides) + variant_yaml = apply_retrieval_overrides(base_template_text, overrides) + template_path = work_dir / f"{label}.dikw.yml" + template_path.write_text(variant_yaml, encoding="utf-8") + base_dir = work_dir / f"base-{label}" + ensure_base(base_dir, template_path) + command = build_eval_command( + dataset=dataset, + base=base_dir, + mode="serve-and-run", + retrieval=retrieval, + cache=FORCED_CACHE, + ) + result = subprocess.run(command, capture_output=True, text=True, env=run_env) + report_path = work_dir / f"{label}.ndjson" + report_path.write_text(result.stdout, encoding="utf-8") + report = parse_eval_report(result.stdout) + return { + "label": label, + "overrides": overrides, + "exit_code": result.returncode, + "metrics": report.get("metrics", {}), + "report_path": str(report_path), + } + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--dataset", required=True, help="dataset name (under datasets/) or absolute path") + parser.add_argument("--variants", required=True, help='JSON list of retrieval-override dicts, e.g. \'[{"rrf_k":30}]\'') + parser.add_argument("--retrieval", default="hybrid", choices=["hybrid", "bm25", "vector", "all"]) + parser.add_argument("--base-template", type=Path, default=DEFAULT_BASE_TEMPLATE) + parser.add_argument("--env-file", type=Path, default=DEFAULT_ENV) + parser.add_argument("--out", type=Path, help="work dir (default: reports/-ablation)") + parser.add_argument("--dry-run", action="store_true", help="print variants + commands; no base, no API, no spend") + args = parser.parse_args(argv) + + variants = parse_variants(args.variants) + dataset = Path(args.dataset) + if not dataset.is_absolute(): + candidate = PROJECT_ROOT / "datasets" / args.dataset + dataset = candidate.resolve() if candidate.exists() else dataset.resolve() + base_template_text = Path(args.base_template).read_text(encoding="utf-8") + + if args.dry_run: + print(f"# dry-run: {len(variants)} variant(s) over {dataset.name}, cache FORCED {FORCED_CACHE} (#250)") + for overrides in variants: + label = variant_label(overrides) + cmd = build_eval_command( + dataset=dataset, base=PROJECT_ROOT / "bases" / f"ablation-{label}", + retrieval=args.retrieval, cache=FORCED_CACHE, + ) + print(f"# {label}: retrieval overrides = {overrides}") + print(" " + " ".join(cmd)) + return 0 + + env_values = load_dotenv(args.env_file) + absent = [k for k in DEFAULT_REQUIRED_KEYS if not env_values.get(k)] + if absent: + print(f"ERROR: missing required keys in {args.env_file}: {', '.join(absent)}") + return 2 + run_env = merge_env(dict(os.environ), env_values) + + work_dir = args.out or (PROJECT_ROOT / "reports" / f"{_utc_stamp()}-ablation") + work_dir.mkdir(parents=True, exist_ok=True) + rows: list[dict[str, Any]] = [] + for overrides in variants: + row = run_variant( + dataset=dataset, overrides=overrides, base_template_text=base_template_text, + work_dir=work_dir, retrieval=args.retrieval, run_env=run_env, + ) + print(f"==> {row['label']}: exit={row['exit_code']}") + rows.append(row) + + table = metrics_table(rows) + (work_dir / "ablation.md").write_text(table + "\n", encoding="utf-8") + (work_dir / "ablation.json").write_text(json.dumps(rows, indent=2), encoding="utf-8") + print(table) + print(f"\nwrote {work_dir / 'ablation.md'}") + return 0 if all(r["exit_code"] == 0 for r in rows) else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/test_ablate_retrieval.py b/tests/test_ablate_retrieval.py new file mode 100644 index 0000000..a454cc6 --- /dev/null +++ b/tests/test_ablate_retrieval.py @@ -0,0 +1,72 @@ +"""Unit tests for the retrieval-config ablation harness's pure helpers.""" + +from __future__ import annotations + +import pytest +import yaml +from scripts.ablate_retrieval import ( + FORCED_CACHE, + apply_retrieval_overrides, + metrics_table, + parse_variants, + variant_label, +) + +BASE_YAML = """\ +provider: + llm_model: MiniMax-M3 + embedding_dim: 1024 +retrieval: + cjk_tokenizer: jieba +storage: + backend: sqlite +""" + + +def test_forced_cache_is_off_until_250(): + # The #250 workaround: a config sweep must not reuse a stale baked snapshot. + assert FORCED_CACHE == "off" + + +def test_apply_overrides_merges_into_retrieval_only(): + out = apply_retrieval_overrides(BASE_YAML, {"rrf_k": 30, "bm25_weight": 0.5}) + doc = yaml.safe_load(out) + assert doc["retrieval"] == {"cjk_tokenizer": "jieba", "rrf_k": 30, "bm25_weight": 0.5} + # other blocks preserved verbatim + assert doc["provider"]["llm_model"] == "MiniMax-M3" + assert doc["storage"]["backend"] == "sqlite" + + +def test_apply_overrides_can_replace_existing_key(): + out = apply_retrieval_overrides(BASE_YAML, {"cjk_tokenizer": "none"}) + assert yaml.safe_load(out)["retrieval"]["cjk_tokenizer"] == "none" + + +def test_empty_overrides_round_trips(): + doc = yaml.safe_load(apply_retrieval_overrides(BASE_YAML, {})) + assert doc["retrieval"] == {"cjk_tokenizer": "jieba"} + + +def test_variant_label(): + assert variant_label({}) == "baseline" + assert variant_label({"rrf_k": 60}) == "rrf_k=60" + # sorted keys -> stable label regardless of dict order + assert variant_label({"vector_weight": 1.5, "bm25_weight": 0.3}) == "bm25_weight=0.3,vector_weight=1.5" + + +def test_metrics_table_renders_missing_as_dash(): + rows = [ + {"label": "rrf_k=30", "exit_code": 0, "metrics": {"hit_at_3": 0.7, "mrr": 0.61}}, + {"label": "rrf_k=60", "exit_code": 0, "metrics": {"hit_at_3": 0.72}}, # mrr missing + ] + table = metrics_table(rows, metric_keys=("hit_at_3", "mrr")) + assert "| rrf_k=30 | 0 | 0.700 | 0.610 |" in table + assert "| rrf_k=60 | 0 | 0.720 | - |" in table + + +def test_parse_variants_validates_shape(): + assert parse_variants('[{"rrf_k": 30}, {"rrf_k": 60}]') == [{"rrf_k": 30}, {"rrf_k": 60}] + with pytest.raises(ValueError): + parse_variants('{"rrf_k": 30}') # not a list + with pytest.raises(ValueError): + parse_variants("[1, 2, 3]") # not objects diff --git a/tests/test_negatives_separation.py b/tests/test_negatives_separation.py new file mode 100644 index 0000000..07d8a2a --- /dev/null +++ b/tests/test_negatives_separation.py @@ -0,0 +1,84 @@ +"""Unit tests for the pos-vs-neg relevance-score separation tool. + +Two schemas are exercised: the current diagnostic-only rows (no scores → graceful +degrade) and the post-#249 rows that carry a score field (under any probed name). +""" + +from __future__ import annotations + +import json + +import pytest +from tools.negatives_separation import ( + eval_report, + expect_none_satisfaction, + separation, + top1_score, +) + +# Pre-#249: rows carry only q / ranked / expect_any (no score anywhere). +PRE_POS = [ + {"id": "zh-a", "expect_any": ["doc_a"], "ranked": ["doc_a", "x"]}, + {"id": "en-b", "expect_any": ["doc_b"], "ranked": ["doc_b", "y"]}, +] +PRE_NEG = [ + {"q": "off-corpus quantum cooking", "ranked": ["doc_x", "doc_y"]}, + {"q": "off-corpus mars law", "ranked": ["doc_z"]}, +] + +# Post-#249: positives score high, negatives score low. Each side uses a +# *different* candidate key name to prove the probe is name-agnostic. +POST_POS = [ + {"id": "zh-a", "expect_any": ["doc_a"], "ranked": ["doc_a"], "scores": [0.81, 0.40]}, + {"id": "en-b", "expect_any": ["doc_b"], "ranked": ["doc_b"], "scores": [0.79, 0.31]}, +] +POST_NEG = [ + {"q": "off-corpus a", "ranked": ["doc_x"], "top1_score": 0.22}, + {"q": "off-corpus b", "ranked": ["doc_y"], "top1_score": 0.18}, +] + + +def test_top1_score_probes_names_and_ignores_missing(): + assert top1_score({"scores": [0.81, 0.4]}) == pytest.approx(0.81) + assert top1_score({"ranked_scores": [0.5]}) == pytest.approx(0.5) + assert top1_score({"top1_score": 0.22}) == pytest.approx(0.22) + assert top1_score({"ranked": ["doc_a"]}) is None # pre-#249 row + assert top1_score({"scores": []}) is None # empty list + assert top1_score({"top1_score": True}) is None # bool is not a score + + +def test_separation_degrades_without_scores(): + result = separation(PRE_POS, PRE_NEG) + assert result["scores_available"] is False + assert result["counts"] == {"positives": 2, "negatives": 2} + assert result["negative_leaks_sample"][0]["top_stem"] == "doc_x" + + +def test_separation_with_scores_computes_margin(): + result = separation(POST_POS, POST_NEG) + assert result["scores_available"] is True + assert result["positive_top1"]["mean"] == pytest.approx(0.80) + assert result["negative_top1"]["mean"] == pytest.approx(0.20) + assert result["separation_margin"] == pytest.approx(0.60) + # Default cutoff is the pos/neg midpoint (0.50); both negatives fall below it. + assert result["cutoff"] == pytest.approx(0.50) + assert result["expect_none_satisfaction"] == pytest.approx(1.0) + + +def test_expect_none_satisfaction_cutoff(): + assert expect_none_satisfaction(POST_NEG, cutoff=0.20) == pytest.approx(0.5) + assert expect_none_satisfaction(POST_NEG, cutoff=0.10) == pytest.approx(0.0) + assert expect_none_satisfaction(PRE_NEG, cutoff=0.5) is None # no scores + + +def test_eval_report_picks_richest_report_line(tmp_path): + path = tmp_path / "run.ndjson" + lines = [ + {"event": "progress"}, # noise + {"metrics": {"hit_at_3": 1.0}, "per_query": [POST_POS[0]], "negative_diagnostics": []}, + {"metrics": {"hit_at_3": 1.0}, "per_query": POST_POS, "negative_diagnostics": POST_NEG}, + ] + path.write_text("\n".join(json.dumps(obj) for obj in lines), encoding="utf-8") + report = eval_report(str(path)) + assert len(report["per_query"]) == 2 + assert len(report["negative_diagnostics"]) == 2 diff --git a/tools/negatives_separation.py b/tools/negatives_separation.py new file mode 100644 index 0000000..22e25c2 --- /dev/null +++ b/tools/negatives_separation.py @@ -0,0 +1,204 @@ +"""Offline pos-vs-neg relevance-score separation for an eval NDJSON. + +An ``expect_none=True`` (off-corpus) query is "satisfied" when the engine's top +hit for it scores *low* — a healthy engine surfaces nothing relevant. dikw-core +emits negatives as ``negative_diagnostics`` (``{q, ranked}``) and positives as +``per_query`` (``{q, id?, expect_any, ranked}``). Both are **diagnostic-only** +today: neither row carries an absolute relevance score, so "low" is not +measurable (dikw-core#249). This tool reads whatever score field #249 lands — it +probes a small list of candidate key names — and computes the pos-vs-neg top-1 +score separation plus the ``expect_none`` satisfaction at a cutoff. + +Until #249 ships it degrades gracefully: ``separation`` reports +``scores_available: false`` with a rank-only observation (counts + the leaked +top stems) so the caller knows the separation is not yet computable rather than +silently reporting a bogus zero. The moment #249 lands — under any of the probed +key names — the same tool yields the real margin with no code change. + +Mirrors ``tools/split_metrics_by_lang.py``: the pure functions (probe + compute) +are unit-tested; a thin CLI reads the NDJSON's EvalReport line. +""" + +from __future__ import annotations + +import argparse +import json +import statistics +import sys +from collections.abc import Sequence +from typing import Any + +# Candidate names for the per-hit score list #249 will align with ``ranked``, +# and for a precomputed top-1 score. Probed in order; first present wins. Listing +# several keeps the tool robust to the exact key #249 chooses. +_SCORES_KEYS = ("scores", "ranked_scores", "hit_scores", "doc_scores") +_TOP1_KEYS = ("top1_score", "top_score", "max_score") + + +def top1_score(row: dict[str, Any]) -> float | None: + """A row's top-ranked absolute relevance score, or ``None`` when the eval + output carries no score yet (pre-#249). + + Prefers an explicit precomputed top-1 score; otherwise takes the first + element of a per-hit score list aligned with ``ranked``. + """ + for key in _TOP1_KEYS: + value = row.get(key) + if isinstance(value, (int, float)) and not isinstance(value, bool): + return float(value) + for key in _SCORES_KEYS: + value = row.get(key) + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)) and value: + first = value[0] + if isinstance(first, (int, float)) and not isinstance(first, bool): + return float(first) + return None + + +def _scores(rows: list[dict[str, Any]]) -> list[float]: + """Top-1 scores of the rows that carry one (drops scoreless rows).""" + return [s for s in (top1_score(r) for r in rows) if s is not None] + + +def _summary(values: list[float]) -> dict[str, float]: + return { + "mean": statistics.fmean(values), + "median": statistics.median(values), + "min": min(values), + "max": max(values), + } + + +def expect_none_satisfaction(negatives: list[dict[str, Any]], cutoff: float) -> float | None: + """Fraction of negatives whose top-1 score falls below ``cutoff`` (a healthy + off-corpus query scores low). ``None`` when no negative carries a score. + """ + scores = _scores(negatives) + if not scores: + return None + return sum(1 for s in scores if s < cutoff) / len(scores) + + +def _leak_sample(negatives: list[dict[str, Any]], limit: int = 5) -> list[dict[str, Any]]: + """Rank-only observational fallback: the top stem each negative surfaced.""" + out: list[dict[str, Any]] = [] + for row in negatives[:limit]: + ranked = row.get("ranked") or [] + out.append({"q": row.get("q", ""), "top_stem": ranked[0] if ranked else None}) + return out + + +def separation( + per_query: list[dict[str, Any]], + negatives: list[dict[str, Any]], + *, + cutoff: float | None = None, +) -> dict[str, Any]: + """Pos-vs-neg top-1 score separation. + + ``cutoff`` for ``expect_none`` satisfaction defaults to the midpoint between + the negative and positive top-1 means — the natural decision boundary once a + separation exists. Degrades to a rank-only observation (``scores_available: + False``) when the eval output predates #249. + """ + pos = _scores(per_query) + neg = _scores(negatives) + counts = {"positives": len(per_query), "negatives": len(negatives)} + if not pos or not neg: + return { + "scores_available": False, + "counts": counts, + "note": ( + "no absolute relevance scores in eval output (pre-dikw-core#249); " + "pos-vs-neg separation is not computable. Showing rank-only " + "negative observations." + ), + "negative_leaks_sample": _leak_sample(negatives), + } + pos_s, neg_s = _summary(pos), _summary(neg) + if cutoff is None: + cutoff = (pos_s["mean"] + neg_s["mean"]) / 2 + return { + "scores_available": True, + "counts": counts, + "positive_top1": pos_s, + "negative_top1": neg_s, + "separation_margin": pos_s["mean"] - neg_s["mean"], + "cutoff": cutoff, + "expect_none_satisfaction": expect_none_satisfaction(negatives, cutoff), + } + + +def eval_report(ndjson_path: str) -> dict[str, Any]: + """The EvalReport object from an eval NDJSON stream. + + The stream carries progress events plus the final EvalReport; the report is + the dict line carrying ``metrics`` with the most ``per_query`` rows (mirrors + ``run_eval.parse_eval_report`` / ``split_metrics_by_lang.per_query_rows``). + """ + best: dict[str, Any] = {} + best_len = -1 + with open(ndjson_path, encoding="utf-8") as fh: + for raw in fh: + line = raw.strip() + if not line: + continue + try: + obj = json.loads(line) + except json.JSONDecodeError: + continue + if not isinstance(obj, dict) or "metrics" not in obj: + continue + pq = obj.get("per_query") + n = len(pq) if isinstance(pq, list) else 0 + if n >= best_len: + best, best_len = obj, n + return best + + +def format_markdown(name: str, result: dict[str, Any]) -> str: + lines = [f"### {name} — negatives separation", ""] + counts = result["counts"] + lines.append(f"- positives: {counts['positives']} negatives: {counts['negatives']}") + if not result["scores_available"]: + lines.append(f"- **scores unavailable** — {result['note']}") + if result.get("negative_leaks_sample"): + lines.append("- negative top-stems (rank-only):") + for leak in result["negative_leaks_sample"]: + lines.append(f" - `{leak['q']}` -> `{leak['top_stem']}`") + return "\n".join(lines) + pos, neg = result["positive_top1"], result["negative_top1"] + lines += [ + f"- positive top-1 score: mean {pos['mean']:.4f} (min {pos['min']:.4f})", + f"- negative top-1 score: mean {neg['mean']:.4f} (max {neg['max']:.4f})", + f"- **separation margin**: {result['separation_margin']:.4f}", + f"- expect_none satisfaction @ cutoff {result['cutoff']:.4f}: " + f"{result['expect_none_satisfaction']:.3f}", + ] + return "\n".join(lines) + + +def main(argv: list[str] | None = None) -> int: + p = argparse.ArgumentParser( + description="Pos-vs-neg relevance-score separation for an eval NDJSON " + "(degrades to rank-only until dikw-core#249 surfaces scores)." + ) + p.add_argument("ndjson", help="path to an eval NDJSON (per_query + negative_diagnostics)") + p.add_argument("--name", default="dataset") + p.add_argument("--cutoff", type=float, default=None, help="expect_none score cutoff (default: pos/neg midpoint)") + args = p.parse_args(argv) + report = eval_report(args.ndjson) + if not report: + print(f"::error::no EvalReport in {args.ndjson}", file=sys.stderr) + return 1 + result = separation( + report.get("per_query", []), + report.get("negative_diagnostics", []), + cutoff=args.cutoff, + ) + print(format_markdown(args.name, result)) + return 0 + + +if __name__ == "__main__": + sys.exit(main())