diff --git a/.env.example b/.env.example index df42afd..f7099de 100644 --- a/.env.example +++ b/.env.example @@ -1 +1,31 @@ -GROQ_API_KEY=your_groq_api_key_here +# ── MemoryLens environment variables ───────────────────────────────────────── +# Copy this file to .env and fill in the keys you want to use. +# At least ONE provider is needed for --llm / LLM eval mode. +# All providers are optional — without any key the benchmark runs in +# content-only mode (fast, free, no API needed). + +# ── Provider API keys (pick any one or more) ───────────────────────────────── + +# Groq — free tier, very fast (recommended for quick experiments) +# https://console.groq.com/keys +GROQ_API_KEY= + +# OpenAI — gpt-4o-mini by default +# https://platform.openai.com/api-keys +OPENAI_API_KEY= + +# Anthropic — claude-haiku-4-5 by default +# https://console.anthropic.com/settings/keys +ANTHROPIC_API_KEY= + +# OpenRouter — 200+ models via one key, has a free tier +# https://openrouter.ai/settings/keys +OPENROUTER_API_KEY= + +# Ollama — local LLMs, no key needed, just start the server +# Default: http://localhost:11434 Override: +# OLLAMA_HOST=http://localhost:11434 + +# ── Auto-detect override ────────────────────────────────────────────────────── +# Force a specific provider instead of auto-detecting: +# MEMORYLENS_PROVIDER=groq diff --git a/CHANGELOG.md b/CHANGELOG.md index b39cf25..988e53e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,30 @@ Format follows [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] -### Added +### Added — Multi-Provider Real LLM Evaluation (`feat/multi-provider-llm-eval`) +- `utils/providers.py` — unified LLM abstraction layer supporting **5 providers**: + - **Groq** (`GROQ_API_KEY`) — free tier, llama-3.1-8b-instant + - **OpenAI** (`OPENAI_API_KEY`) — gpt-4o-mini + - **Anthropic** (`ANTHROPIC_API_KEY`) — claude-haiku-4-5 + - **OpenRouter** (`OPENROUTER_API_KEY`) — 200+ models via one key + - **Ollama** (local, no key) — any locally running model + - Auto-detection priority: Groq → OpenAI → Anthropic → OpenRouter → Ollama +- **Two-stage LLM evaluation pipeline** in `evaluation/metrics.py`: + - `llm_recall_at_t()` — LLM answers the query; a judge LLM call verifies correctness + - `llm_temporal_drift()` — checks if LLM returns old vs new value after a fact update +- **CLI flags** in `main.py`: + - `--llm` — enable real LLM evaluation pass + - `--provider ` — force a specific provider + - `--list-providers` — print availability of all providers and exit +- **Three-table CLI output**: Content Recall, LLM Recall, and Gap (Content − LLM) +- **Dashboard updates**: + - Provider selector in sidebar (auto-detects available providers) + - Tabbed recall chart: Content Recall / LLM Recall / Gap + - KPI cards show LLM Recall with gap delta when available + - `summary` backend added to backend multiselect +- Updated `.env.example` with all five provider keys and inline documentation + +### Added — SummaryMemory Backend - `memory/summary.py`: SummaryMemory backend — rolling compression memory with dual-mode support: - **LLM mode** (when `GROQ_API_KEY` is set): Groq-powered abstractive summarisation - **Extractive fallback** (zero API cost): regex-based fact-pattern extraction diff --git a/dashboard.py b/dashboard.py index e2ac104..d12511c 100644 --- a/dashboard.py +++ b/dashboard.py @@ -24,10 +24,40 @@ """, unsafe_allow_html=True) -COLORS = {"naive": "#f38ba8", "rag": "#89b4fa", "cascading": "#a6e3a1"} +COLORS = { + "naive": "#f38ba8", + "rag": "#89b4fa", + "cascading": "#a6e3a1", + "summary": "#fab387", +} MONTHLY_QUERIES = 100_000 COST_PER_TOKEN_INR = 83 / 1_000_000 # ~$1 per 1M tokens * 83 INR/USD +_PROVIDER_KEYS = { + "groq": "GROQ_API_KEY", + "openai": "OPENAI_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "openrouter": "OPENROUTER_API_KEY", + "ollama": None, # no key needed — just a running server +} + + +def _detect_available_providers() -> List[str]: + """Return provider names whose credentials are currently present.""" + available = [] + for name, env_var in _PROVIDER_KEYS.items(): + if env_var is None: + # Ollama: try a quick ping + import urllib.request + try: + urllib.request.urlopen("http://localhost:11434/api/tags", timeout=1) + available.append(name) + except Exception: + pass + elif os.getenv(env_var): + available.append(name) + return available + # ─── Sidebar ──────────────────────────────────────────────────────────────── with st.sidebar: @@ -46,11 +76,37 @@ ) backends = st.multiselect( "Memory backends", - ["naive", "rag", "cascading"], + ["naive", "rag", "cascading", "summary"], default=["naive", "rag", "cascading"], ) st.divider() + # ── LLM Provider ────────────────────────────────────────────────────── + st.subheader("LLM Evaluation (optional)") + available_providers = _detect_available_providers() + provider_options = ["None (content-only)"] + available_providers + selected_provider_label = st.selectbox( + "Provider", + provider_options, + help=( + "Run a real answer+judge pass on top of content-based metrics. " + "Set the matching API key in your .env file to unlock a provider." + ), + ) + selected_provider = ( + None if selected_provider_label == "None (content-only)" + else selected_provider_label + ) + + if not available_providers: + st.caption( + "No provider detected. Add an API key to `.env`:\n" + "`GROQ_API_KEY`, `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, " + "`OPENROUTER_API_KEY`, or start Ollama." + ) + + st.divider() + col_run, col_demo = st.columns(2) run_btn = col_run.button("▶ Run Live", type="primary", use_container_width=True) demo_btn = col_demo.button("📊 Demo", use_container_width=True) @@ -71,19 +127,33 @@ def load_demo() -> Dict: def render_results(data: Dict, is_demo: bool = False) -> None: cps: List[int] = data["checkpoints"] - present = [b for b in ["naive", "rag", "cascading"] if b in data] + present = [b for b in ["naive", "rag", "cascading", "summary"] if b in data] + has_llm = data.get("has_llm_eval", False) if is_demo: - st.info("📊 Showing pre-computed demo results. Set GROQ_API_KEY and click ▶ Run Live for real evaluation.", icon="ℹ️") + st.info( + "📊 Showing pre-computed demo results. " + "Set an API key in `.env` and click **▶ Run Live** for real LLM evaluation.", + icon="ℹ️", + ) # ── KPI cards ────────────────────────────────────────────────────────── st.subheader("Summary at Final Checkpoint") cols = st.columns(len(present)) for col, name in zip(cols, present): d = data[name] + llm_val = d.get("llm_recall", [None])[-1] with col: st.markdown(f"#### {name.capitalize()}") st.metric("Recall@Final", f"{d['recall'][-1]*100:.1f}%") + if llm_val is not None: + gap = (d["recall"][-1] - llm_val) * 100 + st.metric( + "LLM Recall@Final", + f"{llm_val*100:.1f}%", + delta=f"{gap:+.1f}pp gap", + delta_color="inverse", + ) st.metric("Avg Tokens", f"{d['tokens'][-1]:,}") st.metric("Temporal Drift", f"{d['drift'][-1]*100:.1f}%") st.metric("Precision@K", f"{d['precision'][-1]*100:.1f}%") @@ -92,19 +162,96 @@ def render_results(data: Dict, is_demo: bool = False) -> None: # ── Recall decay ─────────────────────────────────────────────────────── st.subheader("Memory Recall Decay Over Time") - fig = go.Figure() - for name in present: - fig.add_trace(go.Scatter( - x=cps, y=[v * 100 for v in data[name]["recall"]], - name=name.capitalize(), mode="lines+markers", - line=dict(color=COLORS[name], width=3), marker=dict(size=9), - )) - fig.update_layout( - xaxis_title="Conversation Turn", yaxis_title="Recall (%)", - yaxis=dict(range=[0, 105]), template="plotly_dark", - height=360, legend=dict(orientation="h", y=1.12), - ) - st.plotly_chart(fig, use_container_width=True) + + if has_llm: + tab_content, tab_llm, tab_gap = st.tabs( + ["Content Recall", "LLM Recall", "Gap (Content − LLM)"] + ) + else: + (tab_content,) = st.tabs(["Content Recall"]) + tab_llm = tab_gap = None + + with tab_content: + fig = go.Figure() + for name in present: + color = COLORS.get(name, "#cdd6f4") + fig.add_trace(go.Scatter( + x=cps, y=[v * 100 for v in data[name]["recall"]], + name=name.capitalize(), mode="lines+markers", + line=dict(color=color, width=3), marker=dict(size=9), + )) + fig.update_layout( + xaxis_title="Conversation Turn", yaxis_title="Recall (%)", + yaxis=dict(range=[0, 105]), template="plotly_dark", + height=360, legend=dict(orientation="h", y=1.12), + ) + st.plotly_chart(fig, use_container_width=True) + st.caption( + "Content Recall: substring match on retrieved context chunks — " + "fast, reproducible, zero API cost." + ) + + if has_llm and tab_llm is not None: + with tab_llm: + fig_llm = go.Figure() + for name in present: + color = COLORS.get(name, "#cdd6f4") + llm_vals = data[name].get("llm_recall", []) + if any(v is not None for v in llm_vals): + fig_llm.add_trace(go.Scatter( + x=cps, + y=[v * 100 if v is not None else None for v in llm_vals], + name=name.capitalize(), mode="lines+markers", + line=dict(color=color, width=3, dash="dash"), + marker=dict(size=9, symbol="diamond"), + )) + fig_llm.update_layout( + xaxis_title="Conversation Turn", yaxis_title="LLM Recall (%)", + yaxis=dict(range=[0, 105]), template="plotly_dark", + height=360, legend=dict(orientation="h", y=1.12), + ) + st.plotly_chart(fig_llm, use_container_width=True) + provider_used = next( + (data[b].get("provider") for b in present if data[b].get("provider")), + "unknown", + ) + st.caption( + f"LLM Recall: two-stage answer+judge pipeline using **{provider_used}**. " + "The LLM actually answers each question; a judge call verifies correctness." + ) + + if has_llm and tab_gap is not None: + with tab_gap: + fig_gap = go.Figure() + for name in present: + color = COLORS.get(name, "#cdd6f4") + content_vals = data[name]["recall"] + llm_vals = data[name].get("llm_recall", [None] * len(content_vals)) + gaps = [ + (c - l) * 100 if l is not None else None + for c, l in zip(content_vals, llm_vals) + ] + if any(g is not None for g in gaps): + fig_gap.add_trace(go.Bar( + x=[f"T={c}" for c in cps], + y=gaps, + name=name.capitalize(), + marker_color=color, + )) + fig_gap.add_hline(y=0, line_dash="dot", line_color="#cdd6f4") + fig_gap.update_layout( + xaxis_title="Checkpoint", + yaxis_title="Content Recall − LLM Recall (pp)", + template="plotly_dark", height=360, + barmode="group", + legend=dict(orientation="h", y=1.12), + ) + st.plotly_chart(fig_gap, use_container_width=True) + st.caption( + "Positive gap means content recall *overestimates* true answer quality. " + "A large gap signals the backend retrieves the right text but the LLM " + "still fails to extract the correct answer." + ) # ── Drift + Noise ─────────────────────────────────────────────────────── c1, c2 = st.columns(2) @@ -254,9 +401,7 @@ def _latex_table(data: Dict, checkpoints: List[int], present: List[str]) -> str: st.rerun() if run_btn: - if not os.getenv("GROQ_API_KEY"): - st.error("GROQ_API_KEY not found. Add it to a `.env` file in the project root.") - elif not checkpoints: + if not checkpoints: st.warning("Select at least one checkpoint.") else: log_area = st.empty() @@ -269,17 +414,34 @@ def push_log(msg: str) -> None: with st.spinner("Running benchmark…"): from evaluation.benchmark import run_benchmark, results_to_display_dict from evaluation.logger import log_run + + # Resolve provider (None = content-only) + provider_obj = None + if selected_provider: + try: + from utils.providers import get_provider + provider_obj = get_provider(selected_provider) + push_log(f"LLM provider: {provider_obj.name}") + except Exception as e: + st.error(f"Provider error: {e}") + st.stop() + raw = run_benchmark( total_turns=total_turns, eval_checkpoints=sorted(checkpoints), backends=backends, + provider=provider_obj, progress=push_log, ) display = results_to_display_dict(raw) st.session_state.results = display st.session_state.is_demo = False - saved = log_run(display, {"total_turns": total_turns, "backends": backends}) - push_log(f"Results saved → {saved}") + saved = log_run(display, { + "total_turns": total_turns, + "backends": backends, + "provider": provider_obj.name if provider_obj else None, + }) + push_log(f"Results saved -> {saved}") log_area.empty() st.rerun() @@ -297,11 +459,12 @@ def push_log(msg: str) -> None: | Layer | What It Does | |-------|--------------| | **Memory Injection** | Injects personal facts at T=0 and queries them at T=10, 25, 50, 100 | -| **3 Backends** | Naive (full history), RAG (vector retrieval), Cascading Temporal (tiered decay) | +| **4 Backends** | Naive · RAG · Cascading Temporal · SummaryMemory | | **5 Metrics** | Recall@T · Precision@K · Temporal Drift · Memory Noise Ratio · Token Cost | -| **Dashboard** | Decay curves, cost impact, LaTeX-ready research tables | +| **LLM Eval** | Two-stage answer+judge pipeline — 5 providers (Groq, OpenAI, Anthropic, OpenRouter, Ollama) | +| **Dashboard** | Decay curves, content vs LLM recall gap, cost impact, LaTeX export | -**Click 📊 Demo** in the sidebar for instant results, or set `GROQ_API_KEY` and click **▶ Run Live**. +**Click 📊 Demo** in the sidebar for instant results, or configure a provider and click **▶ Run Live**. """) st.markdown("---") diff --git a/evaluation/benchmark.py b/evaluation/benchmark.py index 611b0a1..aeaee47 100644 --- a/evaluation/benchmark.py +++ b/evaluation/benchmark.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Dict, Callable, Optional +from typing import TYPE_CHECKING, List, Dict, Callable, Optional from simulator.facts import Fact, BENCHMARK_FACTS from simulator.conversation import generate_conversation @@ -10,53 +10,73 @@ from memory.base import BaseMemory from evaluation.metrics import ( recall_at_t, temporal_drift_score, memory_noise_ratio, precision_at_k, - cascade_efficiency, + cascade_efficiency, llm_recall_at_t, llm_temporal_drift, ) +if TYPE_CHECKING: + from utils.providers import LLMProvider + OFF_TOPIC_QUERY = "What is the best sorting algorithm for large datasets?" +_NAN = float("nan") + @dataclass class CheckpointResult: - turn: int - recall: float - precision: float - drift: float - noise: float - tokens: int - cascade_eff: float = 1.0 + turn: int + # ── Content-based (always available, fast) ─────────────────────────────── + recall: float # substring match on retrieved chunks + precision: float + drift: float + noise: float + tokens: int + cascade_eff: float = 1.0 + # ── LLM-based (available when a provider is configured) ────────────────── + llm_recall: float = _NAN # actual LLM answer judged correct/wrong + llm_drift: float = _NAN # LLM gives old vs new value after update + has_llm_eval: bool = False @dataclass class BackendResult: name: str checkpoints: List[CheckpointResult] = field(default_factory=list) - raw_recalls: List[Dict] = field(default_factory=list) + raw_recalls: List[Dict] = field(default_factory=list) + provider_name: Optional[str] = None # which LLM was used, if any def _make_memory(name: str) -> BaseMemory: if name == "naive": - # Limit to ~1,200 tokens to simulate a realistic context window budget, - # forcing oldest messages to be evicted as conversation grows. return NaiveMemory(max_context_tokens=1200) if name == "rag": return RAGMemory() if name == "cascading": return CascadingTemporalMemory() if name == "summary": - # use_llm=None → auto-detect from GROQ_API_KEY env var return SummaryMemory(window_size=20, use_llm=None) - raise ValueError(f"Unknown backend: '{name}'. Choose from: naive, rag, cascading, summary") + raise ValueError( + f"Unknown backend: '{name}'. Choose from: naive, rag, cascading, summary" + ) def run_benchmark( - total_turns: int = 100, - eval_checkpoints: Optional[List[int]] = None, - facts: Optional[List[Fact]] = None, - backends: Optional[List[str]] = None, - progress: Optional[Callable[[str], None]] = None, + total_turns: int = 100, + eval_checkpoints: Optional[List[int]] = None, + facts: Optional[List[Fact]] = None, + backends: Optional[List[str]] = None, + provider: Optional["LLMProvider"] = None, + progress: Optional[Callable[[str], None]] = None, ) -> Dict[str, BackendResult]: - + """ + Run the full MemoryLens benchmark. + + Parameters + ---------- + provider : LLMProvider | None + When supplied, a second LLM-evaluation pass runs at every checkpoint + alongside the fast content-based pass. When None, only content-based + metrics are computed. + """ if eval_checkpoints is None: eval_checkpoints = [10, 25, 50, 75, 100] if facts is None: @@ -69,24 +89,33 @@ def run_benchmark( checkpoint_set = set(eval_checkpoints) results: Dict[str, BackendResult] = {} - # Always maintain a paired naive + cascading for cascade_efficiency metric + if provider and progress: + progress(f"LLM provider: {provider.name}") + elif progress: + progress("No LLM provider — running content-only mode (fast)") + + # Shadow memories for cascade_efficiency _naive_shadow = _make_memory("naive") _cascade_shadow = _make_memory("cascading") for backend_name in backends: if progress: - progress(f"▶ Starting backend: {backend_name}") + progress(f"▶ Backend: {backend_name}") memory = _make_memory(backend_name) - result = BackendResult(name=backend_name) + result = BackendResult( + name=backend_name, + provider_name=provider.name if provider else None, + ) known_values: List[str] = [] for event in events: turn = event["turn"] - ack = "Understood." if event["is_fact"] else "I can help with that." + ack = "Understood." if event["is_fact"] else "I can help with that." + memory.add_message("user", event["content"], turn) memory.add_message("assistant", ack, turn) - # Feed shadow memories used only for cascade_efficiency + if backend_name == "naive": _naive_shadow.add_message("user", event["content"], turn) _naive_shadow.add_message("assistant", ack, turn) @@ -95,9 +124,8 @@ def run_benchmark( _cascade_shadow.add_message("assistant", ack, turn) if event["is_fact"]: - key = event["fact_key"] for f in facts: - if f.key == key: + if f.key == event["fact_key"]: val = f.current_value(turn) if val not in known_values: known_values.append(val) @@ -105,61 +133,99 @@ def run_benchmark( if (turn + 1) in checkpoint_set: cp = turn + 1 if progress: - progress(f" Evaluating {backend_name} @ T={cp} ...") + progress(f" Evaluating @ T={cp} ...") active_facts = [f for f in facts if f.injected_at <= turn] - # --- Recall@T --- - recalls = [recall_at_t(memory, f, turn) for f in active_facts] + # ── Content-based pass (always) ─────────────────────────── + recalls = [recall_at_t(memory, f, turn) for f in active_facts] avg_recall = sum(r["recalled"] for r in recalls) / max(1, len(recalls)) - avg_tokens = sum(r["tokens"] for r in recalls) / max(1, len(recalls)) + avg_tokens = sum(r["tokens"] for r in recalls) / max(1, len(recalls)) for r in recalls: result.raw_recalls.append({"turn": cp, **r}) - # --- Precision@K --- prec = precision_at_k(memory, active_facts, turn) - # --- Temporal Drift --- - drift_facts = [f for f in active_facts if f.updated_at and f.updated_at <= turn] - if drift_facts: - drifts = [temporal_drift_score(memory, f, turn)["drift"] for f in drift_facts] - avg_drift = sum(drifts) / len(drifts) - else: - avg_drift = 0.0 + drift_facts = [ + f for f in active_facts + if f.updated_at and f.updated_at <= turn + ] + avg_drift = ( + sum(temporal_drift_score(memory, f, turn)["drift"] for f in drift_facts) + / len(drift_facts) + if drift_facts else 0.0 + ) - # --- Noise Ratio --- noise = memory_noise_ratio(memory, OFF_TOPIC_QUERY, known_values, turn) - # --- Cascade Efficiency (only meaningful for cascading backend) --- eff = 1.0 if backend_name == "cascading" and "naive" in backends: - eff = cascade_efficiency(_cascade_shadow, _naive_shadow, active_facts, turn) + eff = cascade_efficiency( + _cascade_shadow, _naive_shadow, active_facts, turn + ) + + # ── LLM pass (when provider is available) ───────────────── + llm_recall_val = _NAN + llm_drift_val = _NAN + has_llm = False + + if provider: + has_llm = True + llm_results = [ + llm_recall_at_t(memory, f, turn, provider) + for f in active_facts + ] + llm_recall_val = ( + sum(r["llm_recalled"] for r in llm_results) + / max(1, len(llm_results)) + ) + + if drift_facts: + drift_llm_results = [ + llm_temporal_drift(memory, f, turn, provider) + for f in drift_facts + ] + applicable = [r for r in drift_llm_results if r["applicable"]] + llm_drift_val = ( + sum(r["llm_drift"] for r in applicable) / len(applicable) + if applicable else 0.0 + ) + else: + llm_drift_val = 0.0 result.checkpoints.append(CheckpointResult( - turn=cp, - recall=round(avg_recall, 4), - precision=round(prec, 4), - drift=round(avg_drift, 4), - noise=round(noise, 4), - tokens=int(avg_tokens), - cascade_eff=round(eff, 4), + turn = cp, + recall = round(avg_recall, 4), + precision = round(prec, 4), + drift = round(avg_drift, 4), + noise = round(noise, 4), + tokens = int(avg_tokens), + cascade_eff = round(eff, 4), + llm_recall = round(llm_recall_val, 4) if has_llm else _NAN, + llm_drift = round(llm_drift_val, 4) if has_llm else _NAN, + has_llm_eval = has_llm, )) results[backend_name] = result if progress: - progress(f" ✓ {backend_name} complete.") + progress(f" ✓ {backend_name} done.") return results def results_to_display_dict(results: Dict[str, BackendResult]) -> Dict: """Convert BackendResult objects into a JSON-serialisable dict for the dashboard.""" + import math checkpoints = sorted({cp.turn for r in results.values() for cp in r.checkpoints}) - display: Dict = {"checkpoints": checkpoints} + display: Dict = {"checkpoints": checkpoints, "has_llm_eval": False} for name, result in results.items(): cp_map = {cp.turn: cp for cp in result.checkpoints} + has_llm = any(cp.has_llm_eval for cp in result.checkpoints) + if has_llm: + display["has_llm_eval"] = True + display[name] = { "recall": [cp_map[t].recall for t in checkpoints if t in cp_map], "precision": [cp_map[t].precision for t in checkpoints if t in cp_map], @@ -167,6 +233,16 @@ def results_to_display_dict(results: Dict[str, BackendResult]) -> Dict: "noise": [cp_map[t].noise for t in checkpoints if t in cp_map], "tokens": [cp_map[t].tokens for t in checkpoints if t in cp_map], "cascade_eff": [cp_map[t].cascade_eff for t in checkpoints if t in cp_map], + # LLM metrics — None where not available + "llm_recall": [ + None if math.isnan(cp_map[t].llm_recall) else cp_map[t].llm_recall + for t in checkpoints if t in cp_map + ], + "llm_drift": [ + None if math.isnan(cp_map[t].llm_drift) else cp_map[t].llm_drift + for t in checkpoints if t in cp_map + ], + "provider": result.provider_name, } return display diff --git a/evaluation/metrics.py b/evaluation/metrics.py index 6050d74..cdf2ce9 100644 --- a/evaluation/metrics.py +++ b/evaluation/metrics.py @@ -1,7 +1,10 @@ -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from memory.base import BaseMemory from simulator.facts import Fact +if TYPE_CHECKING: + from utils.providers import LLMProvider + def recall_at_t(memory: BaseMemory, fact: Fact, current_turn: int) -> Dict: """ @@ -126,3 +129,131 @@ def _stats(mem: BaseMemory): if naive_rpt == 0: return float("inf") return round(cascading_rpt / naive_rpt, 4) + + +# ───────────────────────────────────────────────────────────────────────────── +# LLM-based metrics (require a provider — degrade gracefully without one) +# ───────────────────────────────────────────────────────────────────────────── + +_ANSWER_SYSTEM = ( + "You are a helpful assistant with access to a conversation history. " + "Answer the user's question using ONLY the information in the context. " + "Reply with the answer value only — no explanation, no extra words." +) + +_JUDGE_SYSTEM = ( + "You are a strict fact-checker. Given a question, the correct answer, " + "and a model's response, reply with ONLY 'correct' or 'wrong'." +) + + +def llm_recall_at_t( + memory: BaseMemory, + fact: Fact, + current_turn: int, + provider: "LLMProvider", +) -> Dict: + """ + LLM Recall@T — the model is actually asked the question and its answer + is judged for correctness. + + Two-stage pipeline: + 1. ANSWER — LLM answers fact.query_text() given memory context + 2. JUDGE — a second LLM call checks if the answer is correct + + Returns + ------- + dict with keys: + llm_recalled : bool — judge says the answer is correct + answer : str — what the LLM actually said + expected : str — ground-truth value + judge_verdict : str — 'correct' | 'wrong' | 'error' + tokens : int — context token estimate + """ + from utils.providers import _clean_messages + + context = memory.get_context(fact.query_text(), current_turn) + expected = fact.current_value(current_turn) + + # ── Stage 1: Answer ────────────────────────────────────────────────────── + messages = _clean_messages( + [{"role": "system", "content": _ANSWER_SYSTEM}] + + context + + [{"role": "user", "content": fact.query_text() + " Answer with just the value."}] + ) + answer = provider.chat(messages, max_tokens=60, temperature=0.0) + tokens = memory.token_count(fact.query_text(), current_turn) + + if answer.startswith("[PROVIDER_ERROR"): + return { + "llm_recalled": False, "answer": answer, + "expected": expected, "judge_verdict": "error", "tokens": tokens, + } + + # ── Stage 2: Judge ─────────────────────────────────────────────────────── + judge_prompt = ( + f"Question: {fact.query_text()}\n" + f"Correct answer: {expected}\n" + f"Model response: {answer}\n" + f"Is the model response correct? Reply with ONLY 'correct' or 'wrong'." + ) + verdict_raw = provider.chat( + [ + {"role": "system", "content": _JUDGE_SYSTEM}, + {"role": "user", "content": judge_prompt}, + ], + max_tokens=10, + temperature=0.0, + ).lower().strip() + + verdict = "correct" if "correct" in verdict_raw else "wrong" + + return { + "llm_recalled": verdict == "correct", + "answer": answer, + "expected": expected, + "judge_verdict": verdict, + "tokens": tokens, + } + + +def llm_temporal_drift( + memory: BaseMemory, + fact: Fact, + current_turn: int, + provider: "LLMProvider", +) -> Dict: + """ + LLM Temporal Drift — asks the LLM for the *current* value of an updated + fact and checks whether it returns the new or old value. + + Only meaningful after fact.updated_at has passed. + """ + from utils.providers import _clean_messages + + if not fact.updated_at or current_turn < fact.updated_at: + return {"llm_drift": 0.0, "applicable": False} + + context = memory.get_context(fact.query_text(), current_turn) + new_val = (fact.updated_value or "").lower() + old_val = fact.value.lower() + + messages = _clean_messages( + [{"role": "system", "content": _ANSWER_SYSTEM}] + + context + + [{"role": "user", "content": + f"What is my current {fact.key.replace('_', ' ')}? " + "Reply with the current value only."}] + ) + answer = provider.chat(messages, max_tokens=30, temperature=0.0).lower() + + using_old = old_val in answer and new_val not in answer + drift = 1.0 if using_old else 0.0 + + return { + "llm_drift": drift, + "answer": answer, + "expected": fact.updated_value, + "old_value": fact.value, + "applicable": True, + } diff --git a/main.py b/main.py index 2ecdb53..8c8be7d 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,22 @@ """ -MemoryLens CLI — run the benchmark from the terminal. +MemoryLens CLI -Usage: - python main.py # default settings - python main.py --turns 50 --backends rag cascading - python main.py --output my_results.json +Usage (content-only, no API key needed): + python main.py + +Usage (full LLM evaluation, auto-detects available provider): + python main.py --llm + +Usage (force a specific provider): + python main.py --llm --provider openai + python main.py --llm --provider anthropic + python main.py --llm --provider groq + python main.py --llm --provider openrouter + python main.py --llm --provider ollama + +Other options: + python main.py --turns 50 --backends naive rag --log + python main.py --list-providers """ import os @@ -18,66 +30,142 @@ def main() -> None: parser = argparse.ArgumentParser( - description="MemoryLens: LLM Memory Decay Evaluation Framework" + description="MemoryLens: End-to-end LLM Memory Decay Benchmark", + formatter_class=argparse.RawDescriptionHelpFormatter, ) - parser.add_argument("--turns", type=int, default=100) - parser.add_argument("--checkpoints", nargs="+", type=int, default=[10, 25, 50, 75, 100]) - parser.add_argument("--backends", nargs="+", default=["naive", "rag", "cascading"]) - parser.add_argument("--output", type=str, default="results.json") - parser.add_argument("--log", action="store_true", help="Save run to experiment_logs/") + parser.add_argument("--turns", type=int, default=100) + parser.add_argument("--checkpoints", nargs="+", type=int, + default=[10, 25, 50, 75, 100]) + parser.add_argument("--backends", nargs="+", + default=["naive", "rag", "cascading"], + help="naive | rag | cascading | summary") + parser.add_argument("--output", type=str, default="results.json") + parser.add_argument("--log", action="store_true", + help="Save run to experiment_logs/") + parser.add_argument("--llm", action="store_true", + help="Run real LLM evaluation pass (needs an API key or Ollama)") + parser.add_argument("--provider", type=str, default=None, + help="Force a provider: groq | openai | anthropic | openrouter | ollama") + parser.add_argument("--list-providers", action="store_true", + help="Print available providers and exit") args = parser.parse_args() - if not os.getenv("GROQ_API_KEY"): - print("ERROR: GROQ_API_KEY not set. Copy .env.example to .env and add your key.") - sys.exit(1) + # ── List providers ──────────────────────────────────────────────────────── + if args.list_providers: + from utils.providers import list_available, _REGISTRY + available = list_available() + print("\nProvider status:") + for name in _REGISTRY: + status = "available" if name in available else "not available" + print(f" {name:<15} {status}") + print() + sys.exit(0) + + # ── Resolve LLM provider ───────────────────────────────────────────────── + provider = None + if args.llm: + from utils.providers import get_provider + try: + provider = get_provider(args.provider) + except (ValueError, RuntimeError) as e: + print(f"ERROR: {e}") + sys.exit(1) + + if provider is None: + print( + "ERROR: --llm requested but no provider is available.\n" + " Set one of: GROQ_API_KEY, OPENAI_API_KEY, ANTHROPIC_API_KEY, " + "OPENROUTER_API_KEY\n" + " or start Ollama locally.\n" + " Run --list-providers to see status." + ) + sys.exit(1) + + # ── Banner ─────────────────────────────────────────────────────────────── + print("=" * 60) + print(" MemoryLens — LLM Memory Decay Benchmark") + print("=" * 60) + print(f" Turns : {args.turns}") + print(f" Checkpoints : {sorted(args.checkpoints)}") + print(f" Backends : {args.backends}") + print(f" LLM eval : {'ON (' + provider.name + ')' if provider else 'OFF (content-only)'}") + print("=" * 60) from evaluation.benchmark import run_benchmark, results_to_display_dict - print("=" * 55) - print(" MemoryLens — LLM Memory Decay Benchmark") - print("=" * 55) - print(f" Turns : {args.turns}") - print(f" Checkpoints: {args.checkpoints}") - print(f" Backends : {args.backends}") - print("=" * 55) - raw = run_benchmark( total_turns=args.turns, eval_checkpoints=sorted(args.checkpoints), backends=args.backends, + provider=provider, progress=print, ) display = results_to_display_dict(raw) checkpoints = display["checkpoints"] - print("\n" + "=" * 55) - print(" RESULTS — Recall@T") - print(" {:20s} {}".format("Backend", " ".join(f"T={c:3d}" for c in checkpoints))) - print("-" * 55) + # ── Results table ───────────────────────────────────────────────────────── + col = " ".join(f"T={c:3d}" for c in checkpoints) + sep = "-" * 60 + + print(f"\n{'CONTENT Recall@T':}") + print(f" {'Backend':<14} {col}") + print(sep) for name in args.backends: if name not in display: continue vals = " ".join(f"{v*100:5.1f}%" for v in display[name]["recall"]) - print(f" {name:20s} {vals}") - - print("\n RESULTS — Avg Tokens/Query") - print(" {:20s} {}".format("Backend", " ".join(f"T={c:3d}" for c in checkpoints))) - print("-" * 55) + print(f" {name:<14} {vals}") + + if display.get("has_llm_eval"): + print(f"\n{'LLM Recall@T (ground truth)':}") + print(f" {'Backend':<14} {col}") + print(sep) + for name in args.backends: + if name not in display: + continue + llm_vals = display[name].get("llm_recall", []) + vals = " ".join( + f"{v*100:5.1f}%" if v is not None else " N/A " + for v in llm_vals + ) + print(f" {name:<14} {vals}") + + print(f"\n Gap = Content Recall - LLM Recall") + print(f" {'Backend':<14} {col}") + print(sep) + for name in args.backends: + if name not in display: + continue + content = display[name]["recall"] + llm = display[name].get("llm_recall", [None]*len(content)) + vals = " ".join( + f"{(c - l)*100:+5.1f}%" if l is not None else " N/A " + for c, l in zip(content, llm) + ) + print(f" {name:<14} {vals}") + + print(f"\n Tokens/Query @ T={checkpoints[-1]}") + print(sep) for name in args.backends: if name not in display: continue - vals = " ".join(f"{v:6d}" for v in display[name]["tokens"]) - print(f" {name:20s} {vals}") + tok = display[name]["tokens"][-1] + print(f" {name:<14} {tok:,}") + # ── Save ───────────────────────────────────────────────────────────────── with open(args.output, "w") as fh: json.dump(display, fh, indent=2) - print(f"\nResults saved → {args.output}") + print(f"\nResults saved -> {args.output}") if args.log: from evaluation.logger import log_run - path = log_run(display, {"total_turns": args.turns, "backends": args.backends}) - print(f"Experiment logged → {path}") + path = log_run(display, { + "total_turns": args.turns, + "backends": args.backends, + "provider": provider.name if provider else None, + }) + print(f"Experiment logged -> {path}") print("Visualise: streamlit run dashboard.py") diff --git a/tests/test_imports.py b/tests/test_imports.py index 74afb26..ba75fc4 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -15,9 +15,20 @@ from evaluation.metrics import ( recall_at_t, precision_at_k, temporal_drift_score, memory_noise_ratio, cascade_efficiency, + llm_recall_at_t, llm_temporal_drift, ) from evaluation.benchmark import run_benchmark, results_to_display_dict from evaluation.logger import log_run, list_runs from evaluation.llm_judge import judge_answer +from utils.providers import get_provider, list_available, LLMProvider, _REGISTRY -print(f"All imports OK | Facts: {len(BENCHMARK_FACTS)}") +# Sanity: registry must expose all five providers +assert set(_REGISTRY.keys()) == {"groq", "openai", "anthropic", "openrouter", "ollama"}, ( + f"Provider registry mismatch: {set(_REGISTRY.keys())}" +) + +# get_provider(None) must return None or a valid LLMProvider (never raise) +p = get_provider(None) +assert p is None or isinstance(p, LLMProvider), f"Unexpected get_provider result: {p!r}" + +print(f"All imports OK | Facts: {len(BENCHMARK_FACTS)} | Providers: {list(_REGISTRY.keys())}") diff --git a/utils/providers.py b/utils/providers.py new file mode 100644 index 0000000..1918fa2 --- /dev/null +++ b/utils/providers.py @@ -0,0 +1,432 @@ +""" +utils/providers.py — Unified LLM provider abstraction for MemoryLens. + +Supports five backends: + - Groq (GROQ_API_KEY) + - OpenAI (OPENAI_API_KEY) + - Anthropic (ANTHROPIC_API_KEY) + - OpenRouter (OPENROUTER_API_KEY) — access 200+ models via one key + - Ollama (local, no key) — any locally running model + +Priority for auto-detection: + Groq → OpenAI → Anthropic → OpenRouter → Ollama → None (content-only mode) + +Usage: + from utils.providers import get_provider, list_available + + provider = get_provider() # auto-detect + provider = get_provider("openai") # force a specific one + + if provider: + answer = provider.chat([{"role": "user", "content": "Hello"}]) + print(provider.name, answer) + else: + print("No provider available — running content-only mode") +""" + +from __future__ import annotations + +import os +import time +from abc import ABC, abstractmethod +from typing import List, Dict, Optional + + +# ───────────────────────────────────────────────────────────────────────────── +# Abstract base +# ───────────────────────────────────────────────────────────────────────────── + +class LLMProvider(ABC): + """Common interface for all LLM backends.""" + + @property + @abstractmethod + def name(self) -> str: + """Human-readable provider name, e.g. 'groq/llama-3.1-8b-instant'.""" + + @property + @abstractmethod + def provider_type(self) -> str: + """Short key: 'groq' | 'openai' | 'anthropic' | 'openrouter' | 'ollama'.""" + + @abstractmethod + def chat( + self, + messages: List[Dict], + max_tokens: int = 256, + temperature: float = 0.1, + ) -> str: + """ + Send a chat request. + Returns the assistant's reply as a plain string. + Returns '[PROVIDER_ERROR: ...]' on failure — never raises. + """ + + @classmethod + @abstractmethod + def is_available(cls) -> bool: + """Return True if the required credentials/service are present.""" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}: {self.name}>" + + +# ───────────────────────────────────────────────────────────────────────────── +# Groq +# ───────────────────────────────────────────────────────────────────────────── + +class GroqProvider(LLMProvider): + DEFAULT_MODEL = "llama-3.1-8b-instant" + + def __init__(self, model: str = DEFAULT_MODEL) -> None: + self.model = model + self._client = None + + @property + def name(self) -> str: + return f"groq/{self.model}" + + @property + def provider_type(self) -> str: + return "groq" + + @classmethod + def is_available(cls) -> bool: + return bool(os.getenv("GROQ_API_KEY")) + + def _get_client(self): + if self._client is None: + from groq import Groq + self._client = Groq(api_key=os.getenv("GROQ_API_KEY")) + return self._client + + def chat(self, messages: List[Dict], max_tokens: int = 256, temperature: float = 0.1) -> str: + for attempt in range(3): + try: + resp = self._get_client().chat.completions.create( + model=self.model, + messages=_clean_messages(messages), + max_tokens=max_tokens, + temperature=temperature, + ) + return resp.choices[0].message.content.strip() + except Exception as e: + if attempt < 2: + time.sleep(2 ** attempt) + else: + return f"[PROVIDER_ERROR: groq — {e}]" + return "[PROVIDER_ERROR: groq — max retries]" + + +# ───────────────────────────────────────────────────────────────────────────── +# OpenAI +# ───────────────────────────────────────────────────────────────────────────── + +class OpenAIProvider(LLMProvider): + DEFAULT_MODEL = "gpt-4o-mini" + + def __init__(self, model: str = DEFAULT_MODEL) -> None: + self.model = model + self._client = None + + @property + def name(self) -> str: + return f"openai/{self.model}" + + @property + def provider_type(self) -> str: + return "openai" + + @classmethod + def is_available(cls) -> bool: + return bool(os.getenv("OPENAI_API_KEY")) + + def _get_client(self): + if self._client is None: + from openai import OpenAI + self._client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + return self._client + + def chat(self, messages: List[Dict], max_tokens: int = 256, temperature: float = 0.1) -> str: + for attempt in range(3): + try: + resp = self._get_client().chat.completions.create( + model=self.model, + messages=_clean_messages(messages), + max_tokens=max_tokens, + temperature=temperature, + ) + return resp.choices[0].message.content.strip() + except Exception as e: + if attempt < 2: + time.sleep(2 ** attempt) + else: + return f"[PROVIDER_ERROR: openai — {e}]" + return "[PROVIDER_ERROR: openai — max retries]" + + +# ───────────────────────────────────────────────────────────────────────────── +# Anthropic +# ───────────────────────────────────────────────────────────────────────────── + +class AnthropicProvider(LLMProvider): + DEFAULT_MODEL = "claude-haiku-4-5" + + def __init__(self, model: str = DEFAULT_MODEL) -> None: + self.model = model + self._client = None + + @property + def name(self) -> str: + return f"anthropic/{self.model}" + + @property + def provider_type(self) -> str: + return "anthropic" + + @classmethod + def is_available(cls) -> bool: + return bool(os.getenv("ANTHROPIC_API_KEY")) + + def _get_client(self): + if self._client is None: + import anthropic + self._client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) + return self._client + + def chat(self, messages: List[Dict], max_tokens: int = 256, temperature: float = 0.1) -> str: + for attempt in range(3): + try: + # Anthropic separates system messages + system_parts = [m["content"] for m in messages if m["role"] == "system"] + user_messages = [m for m in messages if m["role"] != "system"] + system_str = " ".join(system_parts) if system_parts else None + + kwargs: Dict = dict( + model=self.model, + max_tokens=max_tokens, + messages=_clean_messages(user_messages), + ) + if system_str: + kwargs["system"] = system_str + + resp = self._get_client().messages.create(**kwargs) + return resp.content[0].text.strip() + except Exception as e: + if attempt < 2: + time.sleep(2 ** attempt) + else: + return f"[PROVIDER_ERROR: anthropic — {e}]" + return "[PROVIDER_ERROR: anthropic — max retries]" + + +# ───────────────────────────────────────────────────────────────────────────── +# OpenRouter (200+ models via one endpoint) +# ───────────────────────────────────────────────────────────────────────────── + +class OpenRouterProvider(LLMProvider): + DEFAULT_MODEL = "meta-llama/llama-3.1-8b-instruct:free" + BASE_URL = "https://openrouter.ai/api/v1" + + def __init__(self, model: str = DEFAULT_MODEL) -> None: + self.model = model + self._client = None + + @property + def name(self) -> str: + return f"openrouter/{self.model}" + + @property + def provider_type(self) -> str: + return "openrouter" + + @classmethod + def is_available(cls) -> bool: + return bool(os.getenv("OPENROUTER_API_KEY")) + + def _get_client(self): + if self._client is None: + from openai import OpenAI + self._client = OpenAI( + api_key=os.getenv("OPENROUTER_API_KEY"), + base_url=self.BASE_URL, + ) + return self._client + + def chat(self, messages: List[Dict], max_tokens: int = 256, temperature: float = 0.1) -> str: + for attempt in range(3): + try: + resp = self._get_client().chat.completions.create( + model=self.model, + messages=_clean_messages(messages), + max_tokens=max_tokens, + temperature=temperature, + extra_headers={ + "HTTP-Referer": "https://github.com/Neal006/memorylens", + "X-Title": "MemoryLens Benchmark", + }, + ) + return resp.choices[0].message.content.strip() + except Exception as e: + if attempt < 2: + time.sleep(2 ** attempt) + else: + return f"[PROVIDER_ERROR: openrouter — {e}]" + return "[PROVIDER_ERROR: openrouter — max retries]" + + +# ───────────────────────────────────────────────────────────────────────────── +# Ollama (local models — no API key) +# ───────────────────────────────────────────────────────────────────────────── + +class OllamaProvider(LLMProvider): + DEFAULT_MODEL = "llama3.2" + DEFAULT_HOST = "http://localhost:11434" + + def __init__( + self, + model: str = DEFAULT_MODEL, + host: str = DEFAULT_HOST, + ) -> None: + self.model = model + self.host = os.getenv("OLLAMA_HOST", host) + + @property + def name(self) -> str: + return f"ollama/{self.model}" + + @property + def provider_type(self) -> str: + return "ollama" + + @classmethod + def is_available(cls) -> bool: + """Check if the Ollama server is reachable.""" + import urllib.request + host = os.getenv("OLLAMA_HOST", cls.DEFAULT_HOST) + try: + urllib.request.urlopen(f"{host}/api/tags", timeout=2) + return True + except Exception: + return False + + def chat(self, messages: List[Dict], max_tokens: int = 256, temperature: float = 0.1) -> str: + import json + import urllib.request + + payload = json.dumps({ + "model": self.model, + "messages": _clean_messages(messages), + "stream": False, + "options": {"num_predict": max_tokens, "temperature": temperature}, + }).encode() + + for attempt in range(3): + try: + req = urllib.request.Request( + f"{self.host}/api/chat", + data=payload, + headers={"Content-Type": "application/json"}, + ) + with urllib.request.urlopen(req, timeout=30) as resp: + body = json.loads(resp.read()) + return body["message"]["content"].strip() + except Exception as e: + if attempt < 2: + time.sleep(2 ** attempt) + else: + return f"[PROVIDER_ERROR: ollama — {e}]" + return "[PROVIDER_ERROR: ollama — max retries]" + + +# ───────────────────────────────────────────────────────────────────────────── +# Registry + auto-detection +# ───────────────────────────────────────────────────────────────────────────── + +_REGISTRY: Dict[str, type] = { + "groq": GroqProvider, + "openai": OpenAIProvider, + "anthropic": AnthropicProvider, + "openrouter": OpenRouterProvider, + "ollama": OllamaProvider, +} + +_PRIORITY = ["groq", "openai", "anthropic", "openrouter", "ollama"] + + +def get_provider(name: Optional[str] = None) -> Optional[LLMProvider]: + """ + Return a ready-to-use LLMProvider instance. + + Parameters + ---------- + name : str | None + Force a specific provider ('groq', 'openai', 'anthropic', + 'openrouter', 'ollama'). Pass None to auto-detect. + + Returns + ------- + LLMProvider | None + None if no provider is available — caller should fall back to + content-only evaluation. + """ + if name: + name = name.lower() + cls = _REGISTRY.get(name) + if cls is None: + raise ValueError(f"Unknown provider '{name}'. Choose from: {list(_REGISTRY)}") + if not cls.is_available(): + raise RuntimeError( + f"Provider '{name}' is not available. " + f"Check your environment variables / service status." + ) + return cls() + + # Auto-detect + for key in _PRIORITY: + cls = _REGISTRY[key] + if cls.is_available(): + return cls() + + return None + + +def list_available() -> List[str]: + """Return names of all currently available providers.""" + return [k for k, cls in _REGISTRY.items() if cls.is_available()] + + +def provider_from_env() -> Optional[LLMProvider]: + """ + Convenience: read MEMORYLENS_PROVIDER env var, fall back to auto-detect. + MEMORYLENS_PROVIDER=openai → forces OpenAI + MEMORYLENS_PROVIDER= → auto-detect + """ + forced = os.getenv("MEMORYLENS_PROVIDER", "").strip().lower() + return get_provider(forced if forced else None) + + +# ───────────────────────────────────────────────────────────────────────────── +# Helpers +# ───────────────────────────────────────────────────────────────────────────── + +def _clean_messages(messages: List[Dict]) -> List[Dict]: + """ + Return only {role, content} pairs with valid roles. + Merges consecutive messages from the same role to satisfy strict APIs. + """ + valid_roles = {"system", "user", "assistant"} + cleaned: List[Dict] = [] + + for m in messages: + role = m.get("role", "user") + content = m.get("content", "").strip() + if not content or role not in valid_roles: + continue + # Merge consecutive same-role messages + if cleaned and cleaned[-1]["role"] == role: + cleaned[-1]["content"] += "\n" + content + else: + cleaned.append({"role": role, "content": content}) + + return cleaned