Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions scripts/ablate_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
"""Retrieval-config ablation harness (rrf_k / weights / fusion / rerank sweep).

The snapshot cache bakes ``RetrievalConfig`` at first ingest and reloads the
*stale* baked config on a cache hit (dikw-core#250), so a config sweep under
``--cache read_write`` silently re-reports the first variant's numbers. This
harness therefore **forces ``--cache off``** for every variant — the documented
#250 workaround — at the cost of a fresh embed per variant. When #250 lands,
flip ``FORCED_CACHE`` to ``"rebuild"`` (one fresh snapshot per variant without
re-embedding unrelated runs) and delete this note.

Retrieval config lives in the eval base's ``dikw.yml`` ``retrieval:`` block, not
on the eval CLI, so each variant runs against a temp base carrying its overrides.

Pure helpers (override-merge, label, table) are unit-tested; the run loop shells
out to ``dikw client eval`` via ``run_eval``'s command builder.

Usage:
uv run python scripts/ablate_retrieval.py --dataset domain-bilingual-v1 \\
--variants '[{"rrf_k": 30}, {"rrf_k": 60}, {"rrf_k": 90}]' --dry-run
"""

from __future__ import annotations

import argparse
import json
import os
import subprocess
import sys
from datetime import UTC, datetime
from pathlib import Path
from typing import Any

import yaml

PROJECT_ROOT = Path(__file__).resolve().parents[1]
for _candidate in (PROJECT_ROOT, PROJECT_ROOT / "src"):
_entry = str(_candidate)
if _entry not in sys.path:
sys.path.insert(0, _entry)

from scripts.run_eval import ( # noqa: E402
DEFAULT_BASE_TEMPLATE,
build_eval_command,
ensure_base,
merge_env,
parse_eval_report,
)

from dikw_data.config import load_dotenv # noqa: E402

# dikw-core#250 workaround. The cache key omits RetrievalConfig, so only a fresh
# snapshot reflects a changed retrieval block. Flip to "rebuild" once #250 ships.
FORCED_CACHE = "off"

DEFAULT_METRIC_KEYS: tuple[str, ...] = (
"hit_at_3",
"hit_at_10",
"mrr",
"ndcg_at_10",
"recall_at_100",
)
DEFAULT_ENV = PROJECT_ROOT / ".env"
DEFAULT_REQUIRED_KEYS: tuple[str, ...] = ("MINIMAX_API_KEY", "GITEE_API_KEY")


# --- pure helpers (unit-tested) --------------------------------------------


def apply_retrieval_overrides(base_yaml: str, overrides: dict[str, Any]) -> str:
"""Merge ``overrides`` into the base ``dikw.yml``'s ``retrieval:`` block.

Keys present in ``overrides`` replace (or add) entries under ``retrieval``;
every other block (``provider``, ``storage``, …) is preserved verbatim. An
empty override returns an equivalent document (round-tripped through yaml).
"""
doc = yaml.safe_load(base_yaml) or {}
if not isinstance(doc, dict):
raise ValueError("base dikw.yml is not a mapping")
retrieval = dict(doc.get("retrieval") or {})
retrieval.update(overrides)
doc["retrieval"] = retrieval
return yaml.safe_dump(doc, sort_keys=False, allow_unicode=True)


def variant_label(overrides: dict[str, Any]) -> str:
"""Stable, filesystem-safe label for an override set, e.g. ``rrf_k=60``.

Empty overrides label as ``baseline`` (the template's own retrieval config).
"""
if not overrides:
return "baseline"
return ",".join(f"{k}={overrides[k]}" for k in sorted(overrides))


def metrics_table(rows: list[dict[str, Any]], metric_keys: tuple[str, ...] = DEFAULT_METRIC_KEYS) -> str:
"""Markdown comparison table, one row per variant.

Each ``rows`` item is ``{"label": str, "exit_code": int, "metrics": {...}}``.
Missing metrics render as ``-`` so a partial sweep is still readable.
"""
head = "| variant | exit | " + " | ".join(metric_keys) + " |"
sep = "|" + "---|" * (2 + len(metric_keys))
lines = [head, sep]
for row in rows:
metrics = row.get("metrics") or {}
cells = " | ".join(
f"{metrics[k]:.3f}" if isinstance(metrics.get(k), (int, float)) else "-"
for k in metric_keys
)
lines.append(f"| {row['label']} | {row.get('exit_code', '')} | {cells} |")
return "\n".join(lines)


def parse_variants(spec: str) -> list[dict[str, Any]]:
"""Parse a ``--variants`` JSON list of override dicts (validates shape)."""
parsed = json.loads(spec)
if not isinstance(parsed, list) or not all(isinstance(v, dict) for v in parsed):
raise ValueError("--variants must be a JSON list of objects")
return parsed


# --- side-effecting orchestration ------------------------------------------


def _utc_stamp() -> str:
return datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ")


def run_variant(
*,
dataset: Path,
overrides: dict[str, Any],
base_template_text: str,
work_dir: Path,
retrieval: str,
run_env: dict[str, str],
) -> dict[str, Any]:
"""Materialise a temp base with ``overrides`` and run one eval (cache forced
off). Returns ``{"label", "exit_code", "metrics", "report_path"}``."""
label = variant_label(overrides)
variant_yaml = apply_retrieval_overrides(base_template_text, overrides)
template_path = work_dir / f"{label}.dikw.yml"
template_path.write_text(variant_yaml, encoding="utf-8")
base_dir = work_dir / f"base-{label}"
ensure_base(base_dir, template_path)
command = build_eval_command(
dataset=dataset,
base=base_dir,
mode="serve-and-run",
retrieval=retrieval,
cache=FORCED_CACHE,
)
result = subprocess.run(command, capture_output=True, text=True, env=run_env)
report_path = work_dir / f"{label}.ndjson"
report_path.write_text(result.stdout, encoding="utf-8")
report = parse_eval_report(result.stdout)
return {
"label": label,
"overrides": overrides,
"exit_code": result.returncode,
"metrics": report.get("metrics", {}),
"report_path": str(report_path),
}


def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("--dataset", required=True, help="dataset name (under datasets/) or absolute path")
parser.add_argument("--variants", required=True, help='JSON list of retrieval-override dicts, e.g. \'[{"rrf_k":30}]\'')
parser.add_argument("--retrieval", default="hybrid", choices=["hybrid", "bm25", "vector", "all"])
parser.add_argument("--base-template", type=Path, default=DEFAULT_BASE_TEMPLATE)
parser.add_argument("--env-file", type=Path, default=DEFAULT_ENV)
parser.add_argument("--out", type=Path, help="work dir (default: reports/<UTC-ts>-ablation)")
parser.add_argument("--dry-run", action="store_true", help="print variants + commands; no base, no API, no spend")
args = parser.parse_args(argv)

variants = parse_variants(args.variants)
dataset = Path(args.dataset)
if not dataset.is_absolute():
candidate = PROJECT_ROOT / "datasets" / args.dataset
dataset = candidate.resolve() if candidate.exists() else dataset.resolve()
base_template_text = Path(args.base_template).read_text(encoding="utf-8")

if args.dry_run:
print(f"# dry-run: {len(variants)} variant(s) over {dataset.name}, cache FORCED {FORCED_CACHE} (#250)")
for overrides in variants:
label = variant_label(overrides)
cmd = build_eval_command(
dataset=dataset, base=PROJECT_ROOT / "bases" / f"ablation-{label}",
retrieval=args.retrieval, cache=FORCED_CACHE,
)
print(f"# {label}: retrieval overrides = {overrides}")
print(" " + " ".join(cmd))
return 0

env_values = load_dotenv(args.env_file)
absent = [k for k in DEFAULT_REQUIRED_KEYS if not env_values.get(k)]
if absent:
print(f"ERROR: missing required keys in {args.env_file}: {', '.join(absent)}")
return 2
run_env = merge_env(dict(os.environ), env_values)

work_dir = args.out or (PROJECT_ROOT / "reports" / f"{_utc_stamp()}-ablation")
work_dir.mkdir(parents=True, exist_ok=True)
rows: list[dict[str, Any]] = []
for overrides in variants:
row = run_variant(
dataset=dataset, overrides=overrides, base_template_text=base_template_text,
work_dir=work_dir, retrieval=args.retrieval, run_env=run_env,
)
print(f"==> {row['label']}: exit={row['exit_code']}")
rows.append(row)

table = metrics_table(rows)
(work_dir / "ablation.md").write_text(table + "\n", encoding="utf-8")
(work_dir / "ablation.json").write_text(json.dumps(rows, indent=2), encoding="utf-8")
print(table)
print(f"\nwrote {work_dir / 'ablation.md'}")
return 0 if all(r["exit_code"] == 0 for r in rows) else 1


if __name__ == "__main__":
sys.exit(main())
72 changes: 72 additions & 0 deletions tests/test_ablate_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Unit tests for the retrieval-config ablation harness's pure helpers."""

from __future__ import annotations

import pytest
import yaml
from scripts.ablate_retrieval import (
FORCED_CACHE,
apply_retrieval_overrides,
metrics_table,
parse_variants,
variant_label,
)

BASE_YAML = """\
provider:
llm_model: MiniMax-M3
embedding_dim: 1024
retrieval:
cjk_tokenizer: jieba
storage:
backend: sqlite
"""


def test_forced_cache_is_off_until_250():
# The #250 workaround: a config sweep must not reuse a stale baked snapshot.
assert FORCED_CACHE == "off"


def test_apply_overrides_merges_into_retrieval_only():
out = apply_retrieval_overrides(BASE_YAML, {"rrf_k": 30, "bm25_weight": 0.5})
doc = yaml.safe_load(out)
assert doc["retrieval"] == {"cjk_tokenizer": "jieba", "rrf_k": 30, "bm25_weight": 0.5}
# other blocks preserved verbatim
assert doc["provider"]["llm_model"] == "MiniMax-M3"
assert doc["storage"]["backend"] == "sqlite"


def test_apply_overrides_can_replace_existing_key():
out = apply_retrieval_overrides(BASE_YAML, {"cjk_tokenizer": "none"})
assert yaml.safe_load(out)["retrieval"]["cjk_tokenizer"] == "none"


def test_empty_overrides_round_trips():
doc = yaml.safe_load(apply_retrieval_overrides(BASE_YAML, {}))
assert doc["retrieval"] == {"cjk_tokenizer": "jieba"}


def test_variant_label():
assert variant_label({}) == "baseline"
assert variant_label({"rrf_k": 60}) == "rrf_k=60"
# sorted keys -> stable label regardless of dict order
assert variant_label({"vector_weight": 1.5, "bm25_weight": 0.3}) == "bm25_weight=0.3,vector_weight=1.5"


def test_metrics_table_renders_missing_as_dash():
rows = [
{"label": "rrf_k=30", "exit_code": 0, "metrics": {"hit_at_3": 0.7, "mrr": 0.61}},
{"label": "rrf_k=60", "exit_code": 0, "metrics": {"hit_at_3": 0.72}}, # mrr missing
]
table = metrics_table(rows, metric_keys=("hit_at_3", "mrr"))
assert "| rrf_k=30 | 0 | 0.700 | 0.610 |" in table
assert "| rrf_k=60 | 0 | 0.720 | - |" in table


def test_parse_variants_validates_shape():
assert parse_variants('[{"rrf_k": 30}, {"rrf_k": 60}]') == [{"rrf_k": 30}, {"rrf_k": 60}]
with pytest.raises(ValueError):
parse_variants('{"rrf_k": 30}') # not a list
with pytest.raises(ValueError):
parse_variants("[1, 2, 3]") # not objects
84 changes: 84 additions & 0 deletions tests/test_negatives_separation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Unit tests for the pos-vs-neg relevance-score separation tool.

Two schemas are exercised: the current diagnostic-only rows (no scores → graceful
degrade) and the post-#249 rows that carry a score field (under any probed name).
"""

from __future__ import annotations

import json

import pytest
from tools.negatives_separation import (
eval_report,
expect_none_satisfaction,
separation,
top1_score,
)

# Pre-#249: rows carry only q / ranked / expect_any (no score anywhere).
PRE_POS = [
{"id": "zh-a", "expect_any": ["doc_a"], "ranked": ["doc_a", "x"]},
{"id": "en-b", "expect_any": ["doc_b"], "ranked": ["doc_b", "y"]},
]
PRE_NEG = [
{"q": "off-corpus quantum cooking", "ranked": ["doc_x", "doc_y"]},
{"q": "off-corpus mars law", "ranked": ["doc_z"]},
]

# Post-#249: positives score high, negatives score low. Each side uses a
# *different* candidate key name to prove the probe is name-agnostic.
POST_POS = [
{"id": "zh-a", "expect_any": ["doc_a"], "ranked": ["doc_a"], "scores": [0.81, 0.40]},
{"id": "en-b", "expect_any": ["doc_b"], "ranked": ["doc_b"], "scores": [0.79, 0.31]},
]
POST_NEG = [
{"q": "off-corpus a", "ranked": ["doc_x"], "top1_score": 0.22},
{"q": "off-corpus b", "ranked": ["doc_y"], "top1_score": 0.18},
]


def test_top1_score_probes_names_and_ignores_missing():
assert top1_score({"scores": [0.81, 0.4]}) == pytest.approx(0.81)
assert top1_score({"ranked_scores": [0.5]}) == pytest.approx(0.5)
assert top1_score({"top1_score": 0.22}) == pytest.approx(0.22)
assert top1_score({"ranked": ["doc_a"]}) is None # pre-#249 row
assert top1_score({"scores": []}) is None # empty list
assert top1_score({"top1_score": True}) is None # bool is not a score


def test_separation_degrades_without_scores():
result = separation(PRE_POS, PRE_NEG)
assert result["scores_available"] is False
assert result["counts"] == {"positives": 2, "negatives": 2}
assert result["negative_leaks_sample"][0]["top_stem"] == "doc_x"


def test_separation_with_scores_computes_margin():
result = separation(POST_POS, POST_NEG)
assert result["scores_available"] is True
assert result["positive_top1"]["mean"] == pytest.approx(0.80)
assert result["negative_top1"]["mean"] == pytest.approx(0.20)
assert result["separation_margin"] == pytest.approx(0.60)
# Default cutoff is the pos/neg midpoint (0.50); both negatives fall below it.
assert result["cutoff"] == pytest.approx(0.50)
assert result["expect_none_satisfaction"] == pytest.approx(1.0)


def test_expect_none_satisfaction_cutoff():
assert expect_none_satisfaction(POST_NEG, cutoff=0.20) == pytest.approx(0.5)
assert expect_none_satisfaction(POST_NEG, cutoff=0.10) == pytest.approx(0.0)
assert expect_none_satisfaction(PRE_NEG, cutoff=0.5) is None # no scores


def test_eval_report_picks_richest_report_line(tmp_path):
path = tmp_path / "run.ndjson"
lines = [
{"event": "progress"}, # noise
{"metrics": {"hit_at_3": 1.0}, "per_query": [POST_POS[0]], "negative_diagnostics": []},
{"metrics": {"hit_at_3": 1.0}, "per_query": POST_POS, "negative_diagnostics": POST_NEG},
]
path.write_text("\n".join(json.dumps(obj) for obj in lines), encoding="utf-8")
report = eval_report(str(path))
assert len(report["per_query"]) == 2
assert len(report["negative_diagnostics"]) == 2
Loading