From 013e5f816d87ddf1cab11777cccdef1a48952329 Mon Sep 17 00:00:00 2001 From: Mario Jack Vela Date: Fri, 24 Apr 2026 03:39:00 -0500 Subject: [PATCH 1/4] Add retrieval reranker executive --- .github/workflows/ci.yml | 2 + .gitignore | 4 + bin/intent_classifier.py | 39 +- docs/RERANKER.md | 21 + src/agentmemory/_impl.py | 939 ++++++++++++++++-- src/agentmemory/retrieval/__init__.py | 20 + src/agentmemory/retrieval/answerability.py | 163 +++ .../retrieval/candidate_generation.py | 46 + src/agentmemory/retrieval/diagnostics.py | 43 + src/agentmemory/retrieval/evidence_graph.py | 55 + src/agentmemory/retrieval/feature_builder.py | 546 ++++++++++ src/agentmemory/retrieval/judge.py | 87 ++ src/agentmemory/retrieval/late_reranker.py | 43 + src/agentmemory/retrieval/long_context.py | 458 +++++++++ src/agentmemory/retrieval/mlp_reranker.py | 129 +++ src/agentmemory/retrieval/query_planner.py | 326 ++++++ src/agentmemory/retrieval/second_stage.py | 559 +++++++++++ tests/test_long_context_explorer.py | 186 ++++ tests/test_reranker_robustness.py | 115 ++- tests/test_second_stage_reranker.py | 291 ++++++ 20 files changed, 3942 insertions(+), 130 deletions(-) create mode 100644 src/agentmemory/retrieval/__init__.py create mode 100644 src/agentmemory/retrieval/answerability.py create mode 100644 src/agentmemory/retrieval/candidate_generation.py create mode 100644 src/agentmemory/retrieval/diagnostics.py create mode 100644 src/agentmemory/retrieval/evidence_graph.py create mode 100644 src/agentmemory/retrieval/feature_builder.py create mode 100644 src/agentmemory/retrieval/judge.py create mode 100644 src/agentmemory/retrieval/late_reranker.py create mode 100644 src/agentmemory/retrieval/long_context.py create mode 100644 src/agentmemory/retrieval/mlp_reranker.py create mode 100644 src/agentmemory/retrieval/query_planner.py create mode 100644 src/agentmemory/retrieval/second_stage.py create mode 100644 tests/test_long_context_explorer.py create mode 100644 tests/test_second_stage_reranker.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e270b1d..e4f340a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -155,7 +155,9 @@ jobs: - 'src/agentmemory/rerank.py' - 'src/agentmemory/embeddings.py' - 'src/agentmemory/retrieval.py' + - 'src/agentmemory/retrieval/**' - 'bin/intent_classifier.py' + - 'benchmarks/**' - 'tests/bench/**' - name: Set up Python diff --git a/.gitignore b/.gitignore index a745a84..7a52876 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,10 @@ db/*.backup logs/ blobs/ backups/ +benchmarks/results/ +benchmarks/training_data/ +src/agentmemory/retrieval/models/*.json +.vs/ .DS_Store /tmp/ *.swp diff --git a/bin/intent_classifier.py b/bin/intent_classifier.py index 0d64474..9907c07 100644 --- a/bin/intent_classifier.py +++ b/bin/intent_classifier.py @@ -42,16 +42,16 @@ class IntentResult: # Each entry is (primary_tables, secondary_tables). # The final merged list fed to --tables is primary + secondary (de-duped). _TABLE_ROUTES = { - "cross_reference": ["events", "memories", "context"], - "troubleshooting": ["events", "memories", "context"], + "cross_reference": ["events", "memories", "context", "procedures"], + "troubleshooting": ["procedures", "events", "memories", "context", "decisions"], "task_status": ["events", "context", "memories"], - "entity_lookup": ["memories", "context", "events"], # entities not in universal search pipeline - "historical_timeline":["events", "context", "memories"], - "how_to": ["memories", "context"], - "decision_rationale": ["memories", "context", "events"], - "research_concept": ["memories", "context"], - "orientation": ["memories", "events", "context"], - "factual_lookup": ["memories", "context", "events"], # same as default + "entity_lookup": ["memories", "entities", "context", "events", "procedures"], + "historical_timeline":["events", "memories", "context", "procedures"], + "how_to": ["procedures", "memories", "context", "events", "decisions"], + "decision_rationale": ["decisions", "memories", "context", "events", "procedures"], + "research_concept": ["memories", "procedures", "context"], + "orientation": ["memories", "events", "context", "procedures"], + "factual_lookup": ["memories", "entities", "decisions", "context", "events", "procedures"], } _FORMAT_HINTS = { @@ -81,6 +81,17 @@ class IntentResult: _WAVE_RE = re.compile(r'\bwave\s*\d+\b', re.IGNORECASE) _HOW_RE = re.compile(r'\bhow\s+(to|do|does|can|should)\b', re.IGNORECASE) _WHY_RE = re.compile(r'\bwhy\b', re.IGNORECASE) +_PROCEDURAL_RE = re.compile(r'\b(runbook|playbook|rollback|roll back|procedure|workflow|steps?|migrate|deployment?|troubleshoot|debug)\b', re.IGNORECASE) +_ENTITY_FACT_RE = re.compile( + r'\b(' + r'who(?:\s+is|\s+owns?)?|' + r'what\s+does|' + r'owner|maintainer|reviewer|assignee|' + r'prefers?|preference|' + r'role|responsible' + r')\b', + re.IGNORECASE, +) # First-person/identity statement (Hermes memory dumps stored as queries) _IDENTITY_STMT_RE = re.compile( r'^(I |My |The vault|Chief wakes|Continuity is|Tasks that|Learn the|' @@ -157,12 +168,12 @@ def classify_intent(query: str) -> IntentResult: ) # ---- Rule 4: How-to ---- - if _HOW_RE.search(q): + if _HOW_RE.search(q) or _PROCEDURAL_RE.search(q): return IntentResult( intent="how_to", confidence=0.88, tables=_TABLE_ROUTES["how_to"], - matched_rule="how_to_regex", + matched_rule="how_to_regex" if _HOW_RE.search(q) else "procedural_kw_regex", format_hint=_FORMAT_HINTS["how_to"], ) @@ -264,14 +275,14 @@ def classify_intent(query: str) -> IntentResult: # Note: 'agent', 'assigned' here can be intentionally claimed earlier by # Rule 2 (troubleshooting) or Rule 3 (task_status); that's the richer # external taxonomy winning over the builtin's broader bucket. - _ENTITY_KW = ["who ", "person", "agent", "team", "assigned"] + _ENTITY_KW = ["who ", "person", "agent", "team", "assigned", "owner", "maintainer", "reviewer", "preference", "prefer"] hit = _kw(ql, _ENTITY_KW) - if hit: + if hit or _ENTITY_FACT_RE.search(q): return IntentResult( intent="entity_lookup", confidence=0.80, tables=_TABLE_ROUTES["entity_lookup"], - matched_rule=f"entity_kw:{hit.strip()}", + matched_rule=f"entity_kw:{(hit or 'entity_fact_regex').strip()}", format_hint=_FORMAT_HINTS["entity_lookup"], ) if _PROPER_NOUN_ALONE_RE.match(q): diff --git a/docs/RERANKER.md b/docs/RERANKER.md index 38c2b69..de8c603 100644 --- a/docs/RERANKER.md +++ b/docs/RERANKER.md @@ -103,6 +103,27 @@ Models load from the Hugging Face Hub on first use (cached at `~/.cache/huggingface/`). After the first call the model is held in the per-process module cache. +## Second-stage tiny MLP artifact policy + +The local second-stage reranker can optionally load a tiny JSON MLP artifact +from `src/agentmemory/retrieval/models/tiny_mlp_v1.json`, or from an explicit +path passed through the internal reranker configuration. That artifact is not +checked into git. If the file is absent, the second-stage path falls back to +the deterministic heuristic slate scorer and search remains fully functional. + +This keeps the default package local-first and reviewable: + +- no mandatory network fetch, +- no opaque weights bundled in source, +- no hard dependency on numpy at import time, +- no failure when the model artifact is unavailable. + +Training and calibration scripts live under `benchmarks/` and emit JSON +artifacts into ignored benchmark/training output directories. If a trained +artifact is published later, it should be attached as a release asset or LFS +object with a short provenance record containing the source commit, training +bundle, feature version, and held-out metrics. + ## Latency / quality tradeoff Measured on Apple Silicon M-series, CPU only (no MPS), Python 3.14, diff --git a/src/agentmemory/_impl.py b/src/agentmemory/_impl.py index 83ccb55..377f27e 100644 --- a/src/agentmemory/_impl.py +++ b/src/agentmemory/_impl.py @@ -59,29 +59,29 @@ def __init__(self, intent, confidence, matched_rule, format_hint, tables): def _builtin_classify_intent(query): """Rule-based intent classifier — inline fallback for.""" q = query.lower() - if any(w in q for w in ['who ', 'person', 'agent', 'team', 'assigned']): + if any(w in q for w in ['who ', 'person', 'agent', 'team', 'assigned', 'owner', 'maintainer', 'reviewer', 'prefer', 'preference']): return _BuiltinIntentResult('entity_lookup', 0.8, 'keyword:entity', 'Show entity details with relations', - ['memories', 'events', 'context']) + ['memories', 'entities', 'events', 'context']) if any(w in q for w in ['what happened', 'when did', 'history', 'timeline', 'log']): return _BuiltinIntentResult('event_lookup', 0.8, 'keyword:event', 'Show events in chronological order', ['events', 'memories', 'context']) - if any(w in q for w in ['how to', 'how do', 'procedure', 'steps', 'guide']): + if any(w in q for w in ['how to', 'how do', 'procedure', 'steps', 'guide', 'rollback', 'runbook', 'playbook', 'troubleshoot']): return _BuiltinIntentResult('procedural', 0.7, 'keyword:procedural', 'Show step-by-step instructions', - ['memories', 'context', 'events']) + ['memories', 'decisions', 'events', 'context']) if any(w in q for w in ['why ', 'decision', 'rationale', 'reason']): return _BuiltinIntentResult('decision_lookup', 0.8, 'keyword:decision', 'Show decisions with rationale', - ['memories', 'events', 'context']) + ['decisions', 'memories', 'events', 'context']) if any(w in q for w in ['related', 'connected', 'depends', 'link']): return _BuiltinIntentResult('graph_traversal', 0.7, 'keyword:graph', 'Show connected nodes and edges', ['memories', 'events', 'context']) return _BuiltinIntentResult('general', 0.5, 'default', 'Standard search results', - ['memories', 'events', 'context']) + ['memories', 'entities', 'events', 'context']) # Quantum amplitude scorer try: @@ -157,21 +157,18 @@ def _builtin_classify_intent(query): # via `_CE_WARMUP_SEEN[0] = 0`. _CE_WARMUP_SEEN = [0] -# FTS5 special characters that cause sqlite3.OperationalError when unescaped. -# Strip them before passing any user query to a MATCH clause. -# -# Includes `?` and `!` — natural-language queries from agents and humans -# contain these constantly ("What does X prefer?") and used to crash -# cmd_search with "fts5: syntax error near ?". Also includes common ASCII -# punctuation (`,;:`) that has no operator meaning in FTS5 but still breaks -# tokenisation when glued to a word. -_FTS5_SPECIAL = re.compile(r'[.&|*"\'`()\-@^?!,;:]') +# FTS5 MATCH is brittle around punctuation and symbolic tokens. Strip any +# non-word, non-space character, plus `_`, before building the MATCH +# expression. This covers common natural-language queries like "$5 coupon", +# "LGBTQ+", "7/22", "#PlankChallenge", "SIAC_GEE", and smart quotes. +_FTS5_SPECIAL = re.compile(r"[^\w\s]|_") def _sanitize_fts_query(query: str) -> str: """Remove FTS5 special characters to prevent syntax errors. - Strips: . & | * \" ' ` ( ) - @ ^ ? ! , ; : + Strips punctuation and symbolic tokens, plus `_`, before collapsing + whitespace. Then collapses extra whitespace. Returns an empty string if nothing remains so callers can skip the MATCH clause gracefully. """ @@ -186,7 +183,30 @@ def _sanitize_fts_query(query: str) -> str: "a", "an", "and", "are", "as", "at", "be", "by", "do", "does", "for", "from", "has", "have", "how", "i", "in", "is", "it", "its", "of", "on", "or", "that", "the", "to", "was", "we", "what", "when", "where", - "which", "who", "why", "will", "with", "you", + "which", "who", "why", "will", "with", "you", "use", "uses", "used", "using", +} + +_FTS_QUERY_EXPANSIONS = { + "choose": ("chose", "chosen"), + "chose": ("choose", "chosen"), + "chosen": ("choose", "chose"), + "store": ("stores", "stored"), + "stores": ("store", "stored"), + "stored": ("store", "stores"), + "storage": ("store", "stored", "path"), + "prefer": ("prefers", "preferred"), + "prefers": ("prefer",), + "embedding": ("embeddings", "embed"), + "embeddings": ("embedding", "embed"), + "model": ("models", "provider"), + "version": ("versions", "release"), + "path": ("paths", "location"), + "stored": ("store", "stores", "path", "location"), + "indentation": ("tabs", "spaces"), + "test": ("tests", "pytest"), + "tests": ("test", "pytest"), + "use": ("uses", "using", "used"), + "uses": ("use", "using"), } @@ -208,7 +228,300 @@ def _build_fts_match_expression(sanitized: str) -> str: meaningful = [t for t in tokens if t.lower() not in _FTS_STOPWORDS and len(t) > 1] if not meaningful: meaningful = tokens - return " OR ".join(meaningful) + expanded: list[str] = [] + seen: set[str] = set() + for token in meaningful: + variants = (token, *_FTS_QUERY_EXPANSIONS.get(token.lower(), ())) + for variant in variants: + key = variant.lower() + if key in _FTS_STOPWORDS or key in seen: + continue + seen.add(key) + expanded.append(variant) + return " OR ".join(expanded or meaningful) + + +_SEARCH_STOPWORDS = _FTS_STOPWORDS | { + "show", "tell", "about", "into", "over", "after", "before", "should", + "could", "would", "please", "summary", "details", "detail", +} +_LOW_SIGNAL_QUERY_TOKENS = { + "summary", "history", "timeline", "recent", "today", "yesterday", "tomorrow", + "game", "issue", "problem", "thing", "stuff", "brief", "update", +} + + +def _normalize_search_token(token: str) -> str: + tok = re.sub(r"[^a-z0-9]+", "", (token or "").lower()) + if len(tok) <= 2 or tok in _SEARCH_STOPWORDS: + return "" + if tok.endswith("ies") and len(tok) > 4: + tok = tok[:-3] + "y" + elif tok.endswith("ed") and len(tok) > 4: + tok = tok[:-2] + elif tok.endswith("es") and len(tok) > 4: + tok = tok[:-2] + elif tok.endswith("s") and len(tok) > 3: + tok = tok[:-1] + return tok + + +def _search_tokens(text: str) -> set[str]: + return { + norm + for part in re.split(r"\s+", text or "") + if (norm := _normalize_search_token(part)) + } + + +def _search_anchor_tokens(text: str) -> set[str]: + return {token for token in _search_tokens(text) if token not in _LOW_SIGNAL_QUERY_TOKENS} + + +def _row_search_text(row: dict) -> str: + parts = [] + for key in ( + "content", "summary", "title", "goal", "description", "name", + "search_text", "compiled_truth", "entity_type", + ): + value = row.get(key) + if value: + parts.append(str(value)) + for key in ("observations", "properties", "aliases"): + value = row.get(key) + if not value: + continue + if isinstance(value, str): + parts.append(value) + else: + try: + parts.append(json.dumps(value, ensure_ascii=True)) + except Exception: + parts.append(str(value)) + return " ".join(parts) + + +def _fetch_linked_entities(db, query: str, plan=None, limit: int = 6) -> list[dict]: + query_tokens = _search_anchor_tokens(query) or _search_tokens(query) + fts_query = _build_fts_match_expression(_sanitize_fts_query(query)) + target_entities = list(getattr(plan, "target_entities", []) or []) + wants_entity_card = _query_wants_entity_card(query) + rows = [] + if fts_query: + try: + rows.extend(db.execute( + """ + SELECT e.id, e.name, e.entity_type, e.properties, e.observations, + e.compiled_truth, e.aliases, e.confidence, e.scope, + e.created_at, e.agent_id, + bm25(entities_fts, 4.0, 1.0, 0.8, 1.2) AS fts_rank + FROM entities_fts + JOIN entities e ON e.id = entities_fts.rowid + WHERE entities_fts MATCH ? AND e.retired_at IS NULL + ORDER BY bm25(entities_fts, 4.0, 1.0, 0.8, 1.2) + LIMIT ? + """, + (fts_query, max(limit * 2, 8)), + ).fetchall()) + except Exception: + pass + for target in target_entities[:4]: + try: + rows.extend(db.execute( + """ + SELECT id, name, entity_type, properties, observations, + compiled_truth, aliases, confidence, scope, + created_at, agent_id, NULL AS fts_rank + FROM entities + WHERE retired_at IS NULL + AND ( + lower(name) = lower(?) + OR lower(COALESCE(aliases, '[]')) LIKE ? + ) + LIMIT ? + """, + (target, f"%{target.lower()}%", limit), + ).fetchall()) + except Exception: + pass + + deduped: list[dict] = [] + seen_ids: set[int] = set() + q_lower = (query or "").lower() + for row in rows: + entity = dict(row) + entity["aliases"] = _load_aliases(entity) + ent_text = _row_search_text(entity) + ent_tokens = _search_tokens(ent_text) + name_lower = str(entity.get("name") or "").lower() + direct_name = bool(name_lower and name_lower in q_lower) + alias_match = any( + alias and alias.lower() in q_lower + for alias in entity.get("aliases", []) + ) + coverage = len(query_tokens & ent_tokens) / max(len(query_tokens), 1) + score = coverage + (0.9 if direct_name else 0.0) + (0.75 if alias_match else 0.0) + if score <= 0.0 and not query_tokens: + continue + strong_descriptor = coverage >= 0.6 or (wants_entity_card and coverage >= 0.34) + if not (direct_name or alias_match or strong_descriptor): + continue + eid = int(entity["id"]) + if eid in seen_ids: + continue + seen_ids.add(eid) + entity["entity_link_score"] = round(score, 4) + deduped.append(entity) + + deduped.sort( + key=lambda item: ( + -(float(item.get("entity_link_score") or 0.0)), + -(float(item.get("confidence") or 0.0)), + int(item.get("id") or 0), + ) + ) + return deduped[:limit] + + +def _expand_query_with_linked_entities(query: str, linked_entities: list[dict]) -> str: + additions: list[str] = [] + query_lower = (query or "").lower() + for entity in linked_entities[:2]: + name = str(entity.get("name") or "").strip() + if name and name.lower() not in query_lower: + additions.append(name) + if not additions: + return query + return f"{query} {' '.join(additions)}".strip() + + +def _query_wants_entity_card(query: str) -> bool: + q = (query or "").lower() + return any( + phrase in q + for phrase in ( + "who is", "who owns", "owner", "maintainer", "reviewer", + "assignee", "whose", "responsible for", + ) + ) + + +def _apply_query_alignment( + rows: list[dict], + query: str, + bucket: str, + *, + plan=None, + linked_entities: Optional[list[dict]] = None, + limit: int = 5, +) -> list[dict]: + if not rows: + return rows + query_tokens = _search_tokens(query) + anchor_tokens = _search_anchor_tokens(query) or query_tokens + linked_names = { + str(entity.get("name") or "").lower() + for entity in (linked_entities or []) + if entity.get("name") + } + linked_names |= { + str(alias).lower() + for entity in (linked_entities or []) + for alias in (entity.get("aliases") or []) + if alias + } + normalized_intent = getattr(plan, "normalized_intent", "factual") + wants_entity_card = _query_wants_entity_card(query) + adjusted: list[dict] = [] + for row in rows: + item = dict(row) + text = _row_search_text(item) + text_lower = text.lower() + row_tokens = _search_tokens(text) + query_overlap = len(query_tokens & row_tokens) / max(len(query_tokens), 1) if query_tokens else 0.0 + anchor_overlap = len(anchor_tokens & row_tokens) / max(len(anchor_tokens), 1) if anchor_tokens else query_overlap + exact_phrase = bool(query and len(query.strip()) >= 4 and query.lower().strip() in text_lower) + entity_hit = bool(linked_names and any(name in text_lower for name in linked_names if len(name) > 2)) + + base_score = float( + item.get("final_score") + or item.get("rrf_score") + or item.get("retrieval_score") + or item.get("confidence") + or 0.0 + ) + multiplier = 1.0 + if exact_phrase: + multiplier *= 1.18 + if entity_hit: + multiplier *= 1.18 + if item.get("source") == "semantic" and anchor_overlap < 0.2 and not exact_phrase and not entity_hit: + multiplier *= 0.55 + elif item.get("source") == "both" and anchor_overlap < 0.2 and not exact_phrase and not entity_hit: + multiplier *= 0.78 + if len(anchor_tokens) >= 3 and anchor_overlap == 0.0 and not exact_phrase and not entity_hit: + multiplier *= 0.4 + if normalized_intent == "factual": + if bucket in ("procedures", "events", "context") and anchor_overlap < 0.34 and not entity_hit: + multiplier *= 0.5 + elif bucket == "memories" and item.get("source") == "semantic" and anchor_overlap < 0.25 and not entity_hit: + multiplier *= 0.72 + elif normalized_intent in ("procedural", "troubleshooting"): + if bucket == "procedures": + breakdown = item.get("score_breakdown") or {} + directness = float(breakdown.get("directness") or 0.0) + step_overlap = float(breakdown.get("step_overlap") or 0.0) + title_goal = float(breakdown.get("goal_match") or 0.0) + float(breakdown.get("title_match") or 0.0) + if directness < 0.7 and step_overlap > title_goal: + multiplier *= 0.72 + if directness < 0.45 and anchor_overlap < 0.25 and not exact_phrase: + multiplier *= 0.55 + elif bucket in ("events", "context") and anchor_overlap < 0.25: + multiplier *= 0.65 + elif normalized_intent == "temporal" and bucket == "procedures" and anchor_overlap < 0.25: + multiplier *= 0.55 + if bucket == "entities": + aliases = item.get("aliases") + if isinstance(aliases, str): + try: + aliases = json.loads(aliases) + except Exception: + aliases = [] + aliases = aliases or [] + if wants_entity_card: + if str(item.get("name") or "").lower() in (query or "").lower(): + multiplier *= 1.25 + elif any(str(alias).lower() in (query or "").lower() for alias in aliases): + multiplier *= 1.25 + else: + multiplier *= 0.35 + if anchor_overlap < 0.34 and not entity_hit and not exact_phrase: + multiplier *= 0.6 + + item["query_token_overlap"] = round(query_overlap, 4) + item["query_anchor_overlap"] = round(anchor_overlap, 4) + item["entity_link_match"] = entity_hit + item["exact_query_phrase"] = exact_phrase + item["final_score"] = round(base_score * multiplier, 8) + adjusted.append(item) + + adjusted.sort(key=lambda row: row.get("final_score", 0.0), reverse=True) + if not adjusted: + return adjusted + best_score = float(adjusted[0].get("final_score") or 0.0) + kept: list[dict] = [] + max_keep = max(limit * 2, limit) + for idx, row in enumerate(adjusted): + strong_match = ( + row.get("exact_query_phrase") + or row.get("entity_link_match") + or float(row.get("query_anchor_overlap") or 0.0) >= 0.34 + ) + if idx < limit or strong_match or float(row.get("final_score") or 0.0) >= best_score * 0.55: + kept.append(row) + if len(kept) >= max_keep: + break + return kept # Temporal recency decay constants (lambda) — configurable per scope # half-life: global ~70d, project ~23d, agent ~14d @@ -3186,6 +3499,16 @@ def cmd_memory_add(args): memory_id = cursor.lastrowid db.commit() # ensure the INSERT (and FTS trigger) is committed before subprocess exit + indexed_row = db.execute( + "SELECT content, category, tags FROM memories WHERE id = ?", + (memory_id,), + ).fetchone() + indexed_content = indexed_row["content"] if indexed_row else args.content + indexed_category = indexed_row["category"] if indexed_row else args.category + indexed_tags = indexed_row["tags"] if indexed_row else (tags_json or "") + if indexed_content != args.content: + blob = None + # Workaround: FTS5 content-external tables may not build the inverted index # from trigger INSERTs on some SQLite versions. Force a re-index for this memory. if do_index: @@ -3193,11 +3516,11 @@ def cmd_memory_add(args): db.execute( "INSERT INTO memories_fts(memories_fts, rowid, content, category, tags) " "VALUES('delete', ?, ?, ?, ?)", - (memory_id, args.content, args.category, tags_json or '')) + (memory_id, indexed_content, indexed_category, indexed_tags or '')) db.execute( "INSERT INTO memories_fts(rowid, content, category, tags) " "VALUES (?, ?, ?, ?)", - (memory_id, args.content, args.category, tags_json or '')) + (memory_id, indexed_content, indexed_category, indexed_tags or '')) db.commit() except Exception: pass # non-fatal: FTS trigger may have already handled it @@ -3320,7 +3643,7 @@ def cmd_memory_add(args): if do_index: try: if not blob: - blob = _embed_query_safe(args.content) + blob = _embed_query_safe(indexed_content) if blob: db_vec = _try_get_db_with_vec() if db_vec: @@ -6149,27 +6472,35 @@ def cmd_search(args, *, db=None, db_path: Optional[str] = None): use_mmr = getattr(args, "mmr", False) # --mmr: MMR diversity reranking mmr_lambda = getattr(args, "mmr_lambda", 0.7) # --mmr-lambda: relevance/diversity trade-off use_explore = getattr(args, "explore", False) # --explore: curiosity mode - # --benchmark (2.3.1): bypass the recency/salience/Q-value reranker chain - # and return raw FTS+vec RRF-fused ranking. Trust reranker is *retained* - # because trust is provenance, not stale-data leakage. The flag exists as - # an escape hatch for synthetic-conversational benchmarks (LOCOMO, - # LongMemEval) where uniform timestamps and zero recall history make the - # rerankers worse than no-op. See memory id 1690 and tests/test_reranker_robustness. benchmark_mode = getattr(args, "benchmark", False) + benchmark_ranking_mode = str( + getattr(args, "benchmark_ranking_mode", None) + or os.environ.get("BRAINCTL_BENCHMARK_RANKING_MODE", "raw") + or "raw" + ).strip().lower() + if benchmark_ranking_mode not in {"full", "raw"}: + benchmark_ranking_mode = "raw" + benchmark_raw_ranking = bool(benchmark_mode and benchmark_ranking_mode == "raw") if benchmark_mode: - # One-line stderr note so the user can see the reranker chain went - # silent. Avoids log spam on the hot path while still being visible. - print( - "[brainctl] --benchmark: reranker chain disabled, returning raw FTS+vec ranking", - file=sys.stderr, - ) - results = {"memories": [], "events": [], "context": [], "decisions": []} + if benchmark_raw_ranking: + print( + "[brainctl] --benchmark: raw ranking ablation mode, returning raw FTS+vec ranking", + file=sys.stderr, + ) + else: + print( + "[brainctl] --benchmark: stable-eval mode with full shared ranking", + file=sys.stderr, + ) + results = {"memories": [], "events": [], "context": [], "entities": [], "decisions": [], "procedures": []} # Accumulator for which signal-informativeness gates tripped this call. # Each value is a string reason like "uniform_timestamps_stdev_3.2s" or a # boolean True for benchmark-mode hard skips. Surfaced under the top-level # "_debug" key so auditors can see WHY a particular ranking happened. _debug_skips: Dict[str, Any] = {} _debug_mode = bool(getattr(args, "debug", False)) + if benchmark_mode: + _debug_skips["benchmark.ranking_mode"] = benchmark_ranking_mode # I6 staged rollout controls for top-heavy retrieval features. _rollout_agent = getattr(args, "agent", None) or "unknown" @@ -6207,7 +6538,7 @@ def cmd_search(args, *, db=None, db_path: Optional[str] = None): os.environ.get("BRAINCTL_DISABLE_INTENT_ROUTER") ) if args.tables: - tables = args.tables.split(",") + tables = [t.strip() for t in args.tables.split(",") if t.strip()] elif _intent_router_disabled: tables = ["memories", "events", "context", "entities", "decisions"] elif _INTENT_AVAILABLE: @@ -6231,8 +6562,45 @@ def cmd_search(args, *, db=None, db_path: Optional[str] = None): and "decisions" not in tables ): tables = list(set(tables) | {"memories", "events", "context", "decisions"}) + _query_plan = None + _query_plan_dict = None + try: + from agentmemory.retrieval.query_planner import plan_query as _plan_query + + _query_plan = _plan_query(query, requested_tables=tables if args.tables else None) + _query_plan_dict = _query_plan.as_dict() + if not args.tables: + tables = list(dict.fromkeys((_query_plan.candidate_tables or []) + list(tables))) + except Exception as exc: + _debug_skips["query_plan.skipped"] = f"{type(exc).__name__}: {exc}" + linked_entities = [] + retrieval_query = query + try: + linked_entities = _fetch_linked_entities(db, query, plan=_query_plan, limit=max(limit, 4)) + retrieval_query = _expand_query_with_linked_entities(query, linked_entities) + if linked_entities: + _debug_skips["entity_linking.expanded_query"] = retrieval_query + _debug_skips["entity_linking.matches"] = [ + {"id": int(entity["id"]), "name": entity["name"], "score": entity.get("entity_link_score")} + for entity in linked_entities[:4] + ] + except Exception as exc: + _debug_skips["entity_linking.skipped"] = f"{type(exc).__name__}: {exc}" + + _hard_query_expansion = bool( + _query_plan + and ( + getattr(_query_plan, "requires_temporal_reasoning", False) + or getattr(_query_plan, "requires_multi_hop", False) + or getattr(_query_plan, "needs_comparison", False) + or getattr(_query_plan, "needs_ordering", False) + or getattr(_query_plan, "needs_update_resolution", False) + or getattr(_query_plan, "needs_set_coverage", False) + ) + ) base_fetch = limit * 5 if not no_recency else limit * 3 fetch_limit = max(limit, round(base_fetch * _nm_breadth)) + expanded_fetch_limit = max(fetch_limit, round(fetch_limit * (1.8 if _hard_query_expansion else 1.0))) # Build an OR-expanded FTS5 MATCH expression so natural-language queries # (e.g. "What does Alice prefer?") retrieve memories that match any token, # not only memories that contain every word. The simple Brain.search path @@ -6240,13 +6608,13 @@ def cmd_search(args, *, db=None, db_path: Optional[str] = None): # space-separated sanitized form directly to FTS5, which FTS5 treated as # implicit AND and silently starved natural-language queries. The bench # harness surfaced the gap. - fts_query = _build_fts_match_expression(_sanitize_fts_query(query)) + fts_query = _build_fts_match_expression(_sanitize_fts_query(retrieval_query)) # Try to load vec extension for hybrid mode (non-fatal). # Propagate an explicit db_path when the caller provided one (Brain.search) # so vec queries hit the same DB the caller is using, not the CLI default. db_vec = _try_get_db_with_vec(db_path=db_path) - q_blob = _embed_query_safe(query) if db_vec else None + q_blob = _embed_query_safe(retrieval_query) if db_vec else None hybrid = db_vec is not None and q_blob is not None # Factual-lookup / general-fallback intent: skip vec fusion entirely and @@ -6276,9 +6644,10 @@ def cmd_search(args, *, db=None, db_path: Optional[str] = None): _adaptive_weights = None _max_recalls_cache = [None] # lazy-compute once per cmd_search - def _fts_memories(): + def _fts_memories(limit_override=None): if not fts_query: return [] + _fetch = int(limit_override or fetch_limit) # Content-weighted BM25. memories_fts indexes (content, category, tags). # Default FTS5 `rank` uses weight 1.0 for every column, which treats a # 200-char content column equally with a one-word `category` label @@ -6298,18 +6667,20 @@ def _fts_memories(): "m.trust_score, m.replay_priority " "FROM memories m JOIN memories_fts f ON m.id = f.rowid " "WHERE memories_fts MATCH ? AND m.retired_at IS NULL " + "AND COALESCE(m.memory_type, 'episodic') != 'procedural' " "ORDER BY bm25(memories_fts, 3.0, 1.0, 1.0) LIMIT ?", - (fts_query, fetch_limit) + (fts_query, _fetch) ).fetchall() return rows_to_list(rows) - def _vec_memories(): + def _vec_memories(limit_override=None): if not hybrid: return [] + _fetch = int(limit_override or fetch_limit) try: vec_rows = db_vec.execute( "SELECT rowid, distance FROM vec_memories WHERE embedding MATCH ? AND k=?", - (q_blob, fetch_limit) + (q_blob, _fetch) ).fetchall() except Exception: return [] @@ -6323,7 +6694,8 @@ def _vec_memories(): f"created_at, recalled_count, temporal_class, last_recalled_at, retrieval_prediction_error, alpha, beta, agent_id, " f"encoding_task_context, encoding_context_hash, q_value, confidence_phase, " f"trust_score, replay_priority " - f"FROM memories WHERE id IN ({ph}) AND retired_at IS NULL", + f"FROM memories WHERE id IN ({ph}) AND retired_at IS NULL " + f"AND COALESCE(memory_type, 'episodic') != 'procedural'", rowids ).fetchall() out = [dict(r) | {"distance": round(dist_map.get(r["id"], 1.0), 4)} for r in src_rows] @@ -6400,6 +6772,61 @@ def _vec_context(): out.sort(key=lambda r: r["distance"]) return out + def _fts_entities(): + if not fts_query: + return [] + rows = db.execute( + """ + SELECT e.id, 'entity' as type, e.name, e.entity_type, e.properties, + e.observations, e.compiled_truth, e.aliases, e.confidence, + e.scope, e.created_at, e.agent_id, + bm25(entities_fts, 4.0, 1.0, 0.8, 1.2) as fts_rank + FROM entities e + JOIN entities_fts f ON e.id = f.rowid + WHERE entities_fts MATCH ? AND e.retired_at IS NULL + ORDER BY bm25(entities_fts, 4.0, 1.0, 0.8, 1.2) + LIMIT ? + """, + (fts_query, fetch_limit), + ).fetchall() + out = rows_to_list(rows) + for row in out: + row["aliases"] = _load_aliases(row) + return out + + def _vec_entities(): + if not hybrid: + return [] + try: + vec_rows = db_vec.execute( + "SELECT rowid, distance FROM vec_entities WHERE embedding MATCH ? AND k=?", + (q_blob, fetch_limit) + ).fetchall() + except Exception: + return [] + if not vec_rows: + return [] + rowids = [r["rowid"] for r in vec_rows] + dist_map = {r["rowid"]: r["distance"] for r in vec_rows} + ph = ",".join("?" * len(rowids)) + src_rows = db_vec.execute( + f""" + SELECT id, 'entity' as type, name, entity_type, properties, observations, + compiled_truth, aliases, confidence, scope, created_at, agent_id + FROM entities + WHERE id IN ({ph}) AND retired_at IS NULL + """, + rowids + ).fetchall() + out = [] + for row in src_rows: + item = dict(row) + item["distance"] = round(dist_map.get(row["id"], 1.0), 4) + item["aliases"] = _load_aliases(item) + out.append(item) + out.sort(key=lambda r: r["distance"]) + return out + def _apply_recency_and_trim(merged, scope_fn, use_adaptive_salience=False, bucket="memories"): if no_recency: return merged[:limit] @@ -6796,7 +7223,9 @@ def _apply_recency_and_trim(merged, scope_fn, use_adaptive_salience=False, bucke explore_rows = db.execute( "SELECT id, 'memory' as type, category, content, confidence, scope, " "created_at, recalled_count, temporal_class, last_recalled_at " - "FROM memories WHERE retired_at IS NULL ORDER BY recalled_count ASC, RANDOM() LIMIT ?", + "FROM memories WHERE retired_at IS NULL " + "AND COALESCE(memory_type, 'episodic') != 'procedural' " + "ORDER BY recalled_count ASC, RANDOM() LIMIT ?", (limit * 10,) ).fetchall() explore_list = rows_to_list(explore_rows) @@ -6881,6 +7310,35 @@ def _apply_recency_and_trim(merged, scope_fn, use_adaptive_salience=False, bucke merged = _rrf_fuse(fts_list, vec_list) else: merged = [r | {"rrf_score": 0.0, "source": "keyword"} for r in fts_list] + if ( + _hard_query_expansion + and not benchmark_mode + and expanded_fetch_limit > fetch_limit + and len(merged) >= 2 + ): + try: + _rank_key = "rrf_score" if hybrid else "fts_rank" + _sorted_merged = sorted( + merged, + key=lambda r: float(r.get(_rank_key) or 0.0), + reverse=bool(hybrid), + ) + if hybrid: + _top_gap = abs(float(_sorted_merged[0].get("rrf_score") or 0.0) - float(_sorted_merged[1].get("rrf_score") or 0.0)) + else: + _top_gap = abs(float(_sorted_merged[0].get("fts_rank") or 0.0) - float(_sorted_merged[1].get("fts_rank") or 0.0)) + if _top_gap <= (0.03 if hybrid else 0.4): + _fts_expanded = _fts_memories(limit_override=expanded_fetch_limit) + _vec_expanded = _vec_memories(limit_override=expanded_fetch_limit) + if hybrid: + merged = _rrf_fuse(_fts_expanded, _vec_expanded) + else: + merged = [r | {"rrf_score": 0.0, "source": "keyword"} for r in _fts_expanded] + _debug_skips["memories.candidate_expansion"] = ( + f"hard_query_margin_{round(_top_gap, 4)}_fetch_{fetch_limit}_to_{expanded_fetch_limit}" + ) + except Exception as exc: + _debug_skips["memories.candidate_expansion_skipped"] = f"{type(exc).__name__}: {exc}" trimmed = _apply_recency_and_trim(merged, lambda r: r.get("scope"), use_adaptive_salience=True, bucket="memories") # MMR diversity reranking — applied after salience scoring, before graph expand if use_mmr and trimmed: @@ -6921,7 +7379,14 @@ def _apply_recency_and_trim(merged, scope_fn, use_adaptive_salience=False, bucke if _prof_cats: trimmed = [r for r in trimmed if r.get("category") in _prof_cats] - results["memories"] = trimmed + results["memories"] = _apply_query_alignment( + trimmed, + query, + "memories", + plan=_query_plan, + linked_entities=linked_entities, + limit=limit, + ) if "events" in tables: fts_list = _fts_events() @@ -6944,7 +7409,14 @@ def _apply_recency_and_trim(merged, scope_fn, use_adaptive_salience=False, bucke already = {r["id"] for r in trimmed} graph = _graph_expand(db, trimmed, "events", already) trimmed.extend(graph) - results["events"] = trimmed + results["events"] = _apply_query_alignment( + trimmed, + query, + "events", + plan=_query_plan, + linked_entities=linked_entities, + limit=limit, + ) if "context" in tables: fts_list = _fts_context() @@ -6964,7 +7436,95 @@ def _apply_recency_and_trim(merged, scope_fn, use_adaptive_salience=False, bucke already = {r["id"] for r in trimmed} graph = _graph_expand(db, trimmed, "context", already) trimmed.extend(graph) - results["context"] = trimmed + results["context"] = _apply_query_alignment( + trimmed, + query, + "context", + plan=_query_plan, + linked_entities=linked_entities, + limit=limit, + ) + + if "entities" in tables: + fts_list = _fts_entities() + vec_list = _vec_entities() + if hybrid: + merged = _rrf_fuse(fts_list, vec_list) + else: + merged = [r | {"rrf_score": 0.0, "source": "keyword"} for r in fts_list] + _debug_skips.setdefault("entities.vec_skipped", "fts_strong_anchor_cascade_from_memories") + for row in merged: + if "aliases" not in row: + row["aliases"] = _load_aliases(row) + trimmed = _apply_recency_and_trim( + merged, + lambda r: r.get("scope") or "global", + bucket="entities", + ) + results["entities"] = _apply_query_alignment( + trimmed, + query, + "entities", + plan=_query_plan, + linked_entities=linked_entities, + limit=limit, + ) + + _procedure_debug = None + _pre_answerability_candidates = [] + if "procedures" in tables: + try: + from agentmemory.retrieval.candidate_generation import generate_procedure_candidates as _generate_procedure_candidates + from agentmemory.retrieval.evidence_graph import expand_procedure_evidence as _expand_procedure_evidence + from agentmemory.retrieval.late_reranker import rerank_procedure_candidates as _rerank_procedure_candidates + from agentmemory.retrieval.query_planner import plan_query as _plan_query + + if _query_plan is None: + _query_plan = _plan_query(query, requested_tables=tables) + _query_plan_dict = _query_plan.as_dict() + proc_scope = None + if getattr(args, "project", None): + proc_scope = f"project:{args.project}" + generated = _generate_procedure_candidates( + db, + query, + _query_plan, + limit=fetch_limit, + scope=proc_scope, + ) + evidence = _expand_procedure_evidence( + db, + generated.get("candidates", []), + max_sources_per_candidate=4, + ) + reranked = _rerank_procedure_candidates( + generated.get("candidates", []), + evidence, + benchmark_mode=benchmark_raw_ranking, + ) + results["procedures"] = _apply_query_alignment( + reranked[:limit], + query, + "procedures", + plan=_query_plan, + linked_entities=linked_entities, + limit=limit, + ) + _pre_answerability_candidates = list(results["procedures"]) + _procedure_debug = { + "candidate_generation": generated.get("debug") or {}, + "evidence_clusters": { + str(proc_id): { + "support_bonus": info.get("support_bonus"), + "source_count": len(info.get("sources") or []), + "edge_count": len(info.get("edges") or []), + } + for proc_id, info in evidence.items() + }, + } + except Exception as exc: + results["procedures"] = [] + _debug_skips["procedures.skipped"] = f"{type(exc).__name__}: {exc}" # Intent-based result weighting and decision search. # @@ -7010,30 +7570,30 @@ def _apply_recency_and_trim(merged, scope_fn, use_adaptive_salience=False, bucke _intent = _INTENT_ALIAS.get(_intent_raw, _intent_raw) # entity_lookup → boost entities/entity results 2x via final_score if _intent == "entity_lookup": - for r in results.get("events", []): - if r.get("type") == "entity": - r["final_score"] = round(r.get("final_score", 0.0) * 2.0, 8) - # Also search entities directly if not in tables - if fts_query: - try: - ent_rows = db.execute( - "SELECT e.id, 'entity' as type, e.name, e.entity_type, e.confidence, e.created_at " - "FROM entities_fts fts JOIN entities e ON e.id = fts.rowid " - "WHERE entities_fts MATCH ? AND e.retired_at IS NULL ORDER BY rank LIMIT ?", - (fts_query, limit) - ).fetchall() - for r in rows_to_list(ent_rows): - r["final_score"] = round(float(r.get("confidence", 0.5)) * 2.0, 8) - r["source"] = "intent_entity" - results.setdefault("entities", []).extend(rows_to_list(ent_rows)) - except Exception: - pass + _entity_card = _query_wants_entity_card(query) + for r in results.get("entities", []): + multiplier = 1.25 if _entity_card else 0.92 + r["final_score"] = round(r.get("final_score", 0.0) * multiplier, 8) + r["source"] = r.get("source") or "intent_entity" + results["entities"] = sorted( + results.get("entities", []), + key=lambda r: r.get("final_score", 0.0), + reverse=True, + ) # event_lookup → boost events results 2x elif _intent == "event_lookup": for r in results.get("events", []): r["final_score"] = round(r.get("final_score", 0.0) * 2.0, 8) results["events"] = sorted(results.get("events", []), key=lambda r: r.get("final_score", 0), reverse=True) + elif _intent == "procedural": + for r in results.get("procedures", []): + r["final_score"] = round(r.get("final_score", 0.0) * 1.2, 8) + results["procedures"] = sorted( + results.get("procedures", []), + key=lambda r: r.get("final_score", 0.0), + reverse=True, + ) # decision_lookup → also search decisions table elif _intent == "decision_lookup": if fts_query: @@ -7068,6 +7628,92 @@ def _apply_recency_and_trim(merged, scope_fn, use_adaptive_salience=False, bucke extra = _graph_expand(db, top_items, tbl_key, already) results.get(tbl_key, []).extend(extra) + def _seed_bucket_score(item, position): + try: + final_score = float(item.get("final_score") or 0.0) + except (TypeError, ValueError): + final_score = 0.0 + if final_score > 0: + return final_score + try: + rrf_score = float(item.get("rrf_score") or 0.0) + except (TypeError, ValueError): + rrf_score = 0.0 + if rrf_score > 0: + return rrf_score + try: + fts_rank = float(item.get("fts_rank") or 0.0) + except (TypeError, ValueError): + fts_rank = 0.0 + if fts_rank != 0.0: + return max(-fts_rank, 0.0) + try: + confidence = float(item.get("confidence") or 0.0) + except (TypeError, ValueError): + confidence = 0.0 + if confidence > 0: + return confidence + return max(1.0 / (position + 1), 0.01) + + def _normalize_bucket_scores(bucket_name): + rows = results.get(bucket_name, []) or [] + if not rows: + return + seeds = [_seed_bucket_score(row, idx) for idx, row in enumerate(rows)] + max_seed = max(seeds) or 1.0 + for row, seed in zip(rows, seeds): + row["retrieval_score"] = round(seed, 8) + row["final_score"] = round(seed / max_seed, 8) + rows.sort(key=lambda r: r.get("final_score", 0.0), reverse=True) + results[bucket_name] = rows + + for _bucket_name in ("procedures", "memories", "events", "context", "entities", "decisions"): + _normalize_bucket_scores(_bucket_name) + + _intent_bucket_multipliers = { + "procedural": {"procedures": 1.18, "memories": 0.98, "entities": 0.95, "events": 0.72, "decisions": 0.78, "context": 0.7}, + "troubleshooting": {"procedures": 1.08, "events": 0.95, "memories": 0.98, "entities": 0.9, "decisions": 0.8, "context": 0.72}, + "decision": {"decisions": 1.15, "memories": 1.05, "entities": 0.95, "procedures": 0.55, "events": 0.8, "context": 0.72}, + "temporal": {"events": 1.18, "memories": 0.88, "entities": 0.82, "procedures": 0.4, "decisions": 0.78, "context": 0.72}, + "factual": {"memories": 1.12, "entities": 1.15, "decisions": 0.82, "procedures": 0.35, "events": 0.55, "context": 0.6}, + "orientation": {"memories": 1.0, "events": 0.95, "procedures": 0.75, "context": 0.8, "decisions": 0.8}, + "graph": {"memories": 1.0, "events": 0.95, "decisions": 0.95, "procedures": 0.8, "context": 0.8}, + } + _normalized_intent = (_query_plan.normalized_intent if _query_plan else "factual") + for _bucket_name, _multiplier in _intent_bucket_multipliers.get(_normalized_intent, {}).items(): + _rows = results.get(_bucket_name, []) or [] + for _row in _rows: + _row["final_score"] = round(float(_row.get("final_score") or 0.0) * _multiplier, 8) + _rows.sort(key=lambda r: r.get("final_score", 0.0), reverse=True) + results[_bucket_name] = _rows + + for _bucket_name in ("procedures", "memories", "events", "context", "entities", "decisions"): + results[_bucket_name] = _apply_query_alignment( + results.get(_bucket_name, []) or [], + query, + _bucket_name, + plan=_query_plan, + linked_entities=linked_entities, + limit=limit, + ) + + _second_stage_debug = None + try: + from agentmemory.retrieval.second_stage import ( + SecondStageConfig as _SecondStageConfig, + rerank_bucketed_results as _rerank_bucketed_results, + ) + + _second_stage_config = _SecondStageConfig.from_args(args) + results, _second_stage_debug = _rerank_bucketed_results( + query, + _query_plan, + results, + config=_second_stage_config, + ) + except Exception as exc: + _debug_skips["second_stage.skipped"] = f"{type(exc).__name__}: {exc}" + if db_vec: db_vec.close() @@ -7082,7 +7728,7 @@ def _apply_recency_and_trim(merged, scope_fn, use_adaptive_salience=False, bucke # --budget: trim results from lowest-ranked first until output fits within token cap if budget_tokens is not None: # Estimate current size; trim tail entries until we fit - for key in ("memories", "events", "context", "decisions"): + for key in ("memories", "events", "context", "decisions", "procedures"): lst = results.get(key, []) if not lst: continue @@ -7090,6 +7736,37 @@ def _apply_recency_and_trim(merged, scope_fn, use_adaptive_salience=False, bucke lst.pop() # remove lowest-ranked (already sorted desc) results[key] = lst + _top_candidates = sorted( + [ + item + for bucket in ("procedures", "memories", "events", "context", "entities", "decisions") + for item in (results.get(bucket, []) or []) + ], + key=lambda item: item.get("final_score", 0.0), + reverse=True, + ) + _answerability = None + if _query_plan is not None: + try: + from agentmemory.retrieval.answerability import assess_answerability as _assess_answerability + + _answerability = _assess_answerability( + query, + _query_plan, + {k: results.get(k, []) for k in ("procedures", "memories", "events", "context", "entities", "decisions")}, + ) + if _answerability.get("abstain") and _query_plan.abstain_allowed: + for key in ("memories", "events", "context", "entities", "decisions", "procedures"): + results[key] = [] + except Exception as exc: + _debug_skips["answerability.skipped"] = f"{type(exc).__name__}: {exc}" + + if (_second_stage_debug or {}).get("enabled"): + for key in ("procedures", "memories", "events", "context", "entities", "decisions"): + rows = list(results.get(key) or []) + rows.sort(key=lambda item: item.get("final_score", 0.0), reverse=True) + results[key] = rows[:limit] + total = sum(len(v) for v in results.values()) tokens_out = _estimate_tokens(results) log_access(db, args.agent or "unknown", "search", query=query, result_count=total, tokens_consumed=tokens_out) @@ -7097,37 +7774,41 @@ def _apply_recency_and_trim(merged, scope_fn, use_adaptive_salience=False, bucke # Update recalled_count for direct (non-graph) memory hits only. # Uses retrieval-practice strengthening: hard retrievals (high prediction error) # boost confidence more than easy ones (Roediger & Karpicke 2006, Bjork 1994). - for r in results.get("memories", []): - if r.get("source") != "graph": - _retrieval_practice_boost( - db, - r["id"], - retrieval_prediction_error=r.get("retrieval_prediction_error") or 0.0, - ) + # + # Benchmark mode deliberately skips these online-learning writes so the + # retrieval corpus stays stable across repeated synthetic queries. + if not benchmark_mode: + for r in results.get("memories", []): + if r.get("source") != "graph": + _retrieval_practice_boost( + db, + r["id"], + retrieval_prediction_error=r.get("retrieval_prediction_error") or 0.0, + ) - # Online phase learning: nudge confidence_phase toward constructive (0) after recall - # Uses existing db connection to avoid lock contention with uncommitted recall_count updates. - try: - _has_phase_col = any( - col[1] == "confidence_phase" - for col in db.execute("PRAGMA table_info(memories)").fetchall() - ) - if _has_phase_col: - _delta = 0.05 - for r in results.get("memories", []): - if r.get("source") != "graph": - _pm_id = r["id"] - _pm_row = db.execute( - "SELECT confidence_phase FROM memories WHERE id=? AND retired_at IS NULL", - (_pm_id,) - ).fetchone() - if _pm_row and _pm_row[0] is not None: - import math as _pmath - _ph = float(_pm_row[0]) - _ph = (_ph + _delta if _ph > _pmath.pi else max(0.0, _ph - _delta)) % (2 * _pmath.pi) - db.execute("UPDATE memories SET confidence_phase=? WHERE id=?", (_ph, _pm_id)) - except Exception: - pass # phase learning is optional; never break search + # Online phase learning: nudge confidence_phase toward constructive (0) after recall + # Uses existing db connection to avoid lock contention with uncommitted recall_count updates. + try: + _has_phase_col = any( + col[1] == "confidence_phase" + for col in db.execute("PRAGMA table_info(memories)").fetchall() + ) + if _has_phase_col: + _delta = 0.05 + for r in results.get("memories", []): + if r.get("source") != "graph": + _pm_id = r["id"] + _pm_row = db.execute( + "SELECT confidence_phase FROM memories WHERE id=? AND retired_at IS NULL", + (_pm_id,) + ).fetchone() + if _pm_row and _pm_row[0] is not None: + import math as _pmath + _ph = float(_pm_row[0]) + _ph = (_ph + _delta if _ph > _pmath.pi else max(0.0, _ph - _delta)) % (2 * _pmath.pi) + db.execute("UPDATE memories SET confidence_phase=? WHERE id=?", (_ph, _pm_id)) + except Exception: + pass # phase learning is optional; never break search # Post-retrieval metacognitive tier annotation # Tier 1: high-confidence fresh results (≥3 direct results, avg_conf ≥ 0.7) @@ -7136,14 +7817,20 @@ def _apply_recency_and_trim(merged, scope_fn, use_adaptive_salience=False, bucke # Tier 4: coverage gap (0 direct results) # Exclude graph-expanded neighbours (source="graph") — they don't reflect query coverage memory_results = [r for r in results.get("memories", []) if r.get("source") != "graph"] + procedure_results = [r for r in results.get("procedures", []) if r.get("source") != "graph"] + entity_results = [r for r in results.get("entities", []) if r.get("source") != "graph"] + direct_results = memory_results + procedure_results + entity_results # Keyword/both hits: FTS5 textual matches — strongest evidence of genuine coverage - keyword_hits = [r for r in memory_results if r.get("source") in ("keyword", "both")] + keyword_hits = [ + r for r in direct_results + if r.get("source") in ("keyword", "both", "procedure_fts") + ] k_count = len(keyword_hits) - if not memory_results: + if not direct_results: tier = 4 tier_label = "gap-detected" - tier_note = "COVERAGE GAP — no memories match this query" + tier_note = "COVERAGE GAP — no grounded memories or procedures match this query" try: _log_gap(db, "coverage_hole", f"query:{_sanitize_fts_query(query)[:80]}", 1.0, triggered_by=query[:200]) except Exception: @@ -7171,19 +7858,19 @@ def _apply_recency_and_trim(merged, scope_fn, use_adaptive_salience=False, bucke elif k_count > 0: tier = 2 tier_label = "moderate" - tier_note = f"Only {k_count} keyword match(es); {len(memory_results)} total (includes semantic)" + tier_note = f"Only {k_count} direct lexical match(es); {len(direct_results)} total direct result(s)" else: tier = 3 tier_label = "weak-coverage" - tier_note = f"No keyword matches; {len(memory_results)} semantic-only result(s) — potential gap" + tier_note = f"No lexical direct matches; {len(direct_results)} semantic/procedural result(s) — potential gap" # Passive search instrumentation — append row to agent_uncertainty_log try: _unc_agent = getattr(args, "agent", None) or "unknown" _unc_domain = getattr(args, "scope", None) or (tables[0] if tables else "memories") _unc_avg_conf = None - if memory_results: - _conf_vals = [r.get("confidence") for r in memory_results if r.get("confidence") is not None] + if direct_results: + _conf_vals = [r.get("confidence") for r in direct_results if r.get("confidence") is not None] if _conf_vals: _unc_avg_conf = round(sum(_conf_vals) / len(_conf_vals), 4) db.execute( @@ -7231,12 +7918,30 @@ def _apply_recency_and_trim(merged, scope_fn, use_adaptive_salience=False, bucke except Exception: pass # trigger check is optional; never break search + _debug_payload = {} + try: + if _query_plan_dict is not None or _procedure_debug is not None or _answerability is not None: + from agentmemory.retrieval.diagnostics import build_debug_payload as _build_debug_payload + + _debug_payload = _build_debug_payload( + query_plan=_query_plan_dict or {}, + procedure_debug=_procedure_debug, + answerability=_answerability, + second_stage=_second_stage_debug, + top_candidates=_top_candidates, + ) + except Exception as exc: + _debug_skips["diagnostics.skipped"] = f"{type(exc).__name__}: {exc}" + _out = { "mode": mode, "metacognition": { "tier": tier, "label": tier_label, "note": tier_note, + "answerability_score": (_answerability or {}).get("score"), + "answerability_reason": (_answerability or {}).get("reason"), + "abstained": (_answerability or {}).get("abstain", False), **_intent_meta, **_rollout_meta, }, @@ -7254,8 +7959,10 @@ def _apply_recency_and_trim(merged, scope_fn, use_adaptive_salience=False, bucke # "all_signals_informative" marker so downstream tooling can rely on # the key always being present in debug mode. Without `--debug` and # no skips, stay silent to keep the default response compact. - if _debug_skips: - _out["_debug"] = dict(_debug_skips) + if _debug_skips or _debug_payload: + _debug_out = dict(_debug_skips) + _debug_out.update(_debug_payload) + _out["_debug"] = _debug_out elif _debug_mode: _out["_debug"] = {"all_signals_informative": True} _ofmt = getattr(args, "output", "json") @@ -16076,7 +16783,7 @@ def build_parser(): mem_add.add_argument("--tags", "-t", help="Comma-separated tags") mem_add.add_argument("--source-event", type=int) mem_add.add_argument("--type", choices=["episodic", "semantic"], default="episodic", - help="Memory type: episodic (time-bound, faster decay) or semantic (durable facts, slower decay)") + help="Memory type: episodic (time-bound, faster decay) or semantic (durable facts)") mem_add.add_argument("--reflexion", action="store_true", help="Shorthand for failure lessons: sets category=lesson, auto-tags with 'reflexion'") mem_add.add_argument("--attribute", action="store_true", @@ -16563,7 +17270,7 @@ def build_parser(): # --- search --- srch = sub.add_parser("search", help="Universal cross-table search") srch.add_argument("query") - srch.add_argument("--tables", help="Comma-separated: memories,events,context") + srch.add_argument("--tables", help="Comma-separated: memories,events,context,decisions,procedures") srch.add_argument("--limit", "-l", type=int, default=10) srch.add_argument("--no-recency", action="store_true", dest="no_recency", help="Disable temporal recency weighting; return raw FTS rank order") @@ -16585,6 +17292,8 @@ def build_parser(): help="Apply phase-aware quantum amplitude re-ranking to memory results") srch.add_argument("--benchmark", action="store_true", help="Disable the recency/salience/Q-value/source/context/PageRank/quantum/temporal-contiguity reranker chain and return the raw FTS+vec RRF-fused ranking. Trust reranker is preserved (different signal class). Use this for synthetic-conversational evals (LOCOMO, LongMemEval) where uniform timestamps make rerankers worse than no-op.") + srch.add_argument("--benchmark-ranking-mode", choices=["raw", "full"], default=None, + help="Internal eval mode for --benchmark. Defaults to raw, matching the legacy benchmark profile.") # 2.4.0: optional cross-encoder reranker stage (off by default). # Uses nargs="?" + const so `--rerank` alone takes the default # model and `--rerank MODEL` lets the user pin a specific one. @@ -16602,6 +17311,24 @@ def build_parser(): srch.add_argument("--rerank-budget-ms", type=float, default=None, metavar="MS", help="Strict latency budget for cross-encoder rerank (per-call and rolling p95). " "Defaults to env BRAINCTL_CE_P95_BUDGET_MS or 350.") + srch.add_argument("--no-second-stage", action="store_true", default=False, + help="Disable the shared deterministic second-stage reranker.") + srch.add_argument("--second-stage", action="store_true", default=False, + help="Enable the opt-in shared deterministic second-stage reranker.") + srch.add_argument("--no-second-stage-model", action="store_true", default=False, + help="Run the second-stage reranker without the tiny MLP residual model.") + srch.add_argument("--second-stage-top-n", type=int, default=None, metavar="N", + help="Combined top-N candidate window for the shared second-stage reranker. " + "Defaults to env BRAINCTL_SECOND_STAGE_TOP_N or 10.") + srch.add_argument("--second-stage-model-path", default=None, metavar="PATH", + help="Override the tiny MLP JSON artifact used by the shared second-stage reranker.") + srch.add_argument("--judge-rerank", nargs="?", const="ollama", default=None, metavar="PROVIDER", + help="Enable the optional top-5 judge reranker with the given provider " + "(default when passed without value: ollama).") + srch.add_argument("--judge-model", default="llama3.2:3b", metavar="MODEL", + help="Model name for the optional judge reranker (provider-specific).") + srch.add_argument("--judge-top-k", type=int, default=5, metavar="N", + help="Top-K candidates sent to the optional judge reranker (max recommended: 5).") srch.add_argument("--rollout-mode", choices=["on", "off", "canary"], default=None, help="Top-heavy retrieval rollout mode override. " "Defaults to env BRAINCTL_TOPHEAVY_ROLLOUT_MODE or on.") diff --git a/src/agentmemory/retrieval/__init__.py b/src/agentmemory/retrieval/__init__.py new file mode 100644 index 0000000..8819444 --- /dev/null +++ b/src/agentmemory/retrieval/__init__.py @@ -0,0 +1,20 @@ +"""Retrieval executive helpers.""" + +from .answerability import assess_answerability +from .diagnostics import build_debug_payload +from .long_context import analyze_long_context +from .mlp_reranker import TinyMLPModel +from .query_planner import QueryPlan, plan_query +from .second_stage import SecondStageConfig, rerank_bucketed_results, rerank_top_candidates + +__all__ = [ + "analyze_long_context", + "QueryPlan", + "SecondStageConfig", + "TinyMLPModel", + "assess_answerability", + "build_debug_payload", + "plan_query", + "rerank_bucketed_results", + "rerank_top_candidates", +] diff --git a/src/agentmemory/retrieval/answerability.py b/src/agentmemory/retrieval/answerability.py new file mode 100644 index 0000000..af174ba --- /dev/null +++ b/src/agentmemory/retrieval/answerability.py @@ -0,0 +1,163 @@ +"""Grounded answerability gate.""" + +from __future__ import annotations + +import re +from typing import Any + +_STOPWORDS = { + "a", "an", "and", "are", "as", "at", "be", "by", "do", "does", "for", + "from", "has", "have", "how", "i", "in", "is", "it", "its", "of", + "on", "or", "that", "the", "to", "was", "we", "what", "when", "where", + "which", "who", "why", "will", "with", "you", "did", +} +_LOW_SIGNAL_TOKENS = { + "summary", "history", "timeline", "recent", "today", "yesterday", "tomorrow", + "game", "issue", "problem", "thing", "stuff", "update", +} + + +def _normalize_token(token: str) -> str: + tok = re.sub(r"[^a-z0-9]+", "", (token or "").lower()) + if len(tok) <= 2 or tok in _STOPWORDS: + return "" + if tok.endswith("ies") and len(tok) > 4: + tok = tok[:-3] + "y" + elif tok.endswith("ed") and len(tok) > 4: + tok = tok[:-2] + elif tok.endswith("es") and len(tok) > 4: + tok = tok[:-2] + elif tok.endswith("s") and len(tok) > 3: + tok = tok[:-1] + return tok + + +def _token_set(text: str) -> set[str]: + return { + norm + for part in re.split(r"\s+", text or "") + if (norm := _normalize_token(part)) + } + + +def _informative_tokens(text: str) -> set[str]: + return {token for token in _token_set(text) if token not in _LOW_SIGNAL_TOKENS} + + +def assess_answerability( + query: str, + plan, + buckets: dict[str, list[dict[str, Any]]], +) -> dict[str, Any]: + """Estimate whether the current retrieval set is grounded enough to answer.""" + + flat: list[dict[str, Any]] = [] + for rows in buckets.values(): + flat.extend(rows or []) + flat.sort(key=lambda item: item.get("final_score", 0.0), reverse=True) + + if not flat: + return { + "score": 0.0, + "abstain": True, + "reason": "no_candidates", + "top_margin": 0.0, + } + + top = flat[0] + second = flat[1] if len(flat) > 1 else None + top_score = float(top.get("final_score") or 0.0) + second_score = float(second.get("final_score") or 0.0) if second else 0.0 + margin = top_score - second_score + + query_tokens = _token_set(query) + informative_query_tokens = _informative_tokens(query) + top_text = " ".join( + str(top.get(key) or "") + for key in ("content", "summary", "title", "goal", "description", "search_text", "name", "compiled_truth", "observations", "aliases") + ) + top_text_tokens = _token_set(top_text) + top_informative_tokens = _informative_tokens(top_text) + supporting_text = " ".join( + " ".join( + str(row.get(key) or "") + for key in ("content", "summary", "title", "goal", "description", "search_text", "name", "compiled_truth", "observations", "aliases") + ) + for row in flat[:3] + ) + supporting_tokens = _token_set(supporting_text) + supporting_informative_tokens = _informative_tokens(supporting_text) + coverage = 0.0 + if query_tokens: + coverage = len(query_tokens & supporting_tokens) / len(query_tokens) + informative_coverage = 0.0 + if informative_query_tokens: + informative_coverage = len(informative_query_tokens & supporting_informative_tokens) / len(informative_query_tokens) + anchor_overlap = len(query_tokens & top_text_tokens) + informative_anchor_overlap = len(informative_query_tokens & top_informative_tokens) + evidence_diversity = len({ + row.get("type") or bucket_name.rstrip("s") + for bucket_name, rows in buckets.items() + for row in (rows or [])[:2] + }) + direct_support = len(top.get("supporting_evidence") or []) + stale_penalty = 0.25 if top.get("status") in {"stale", "needs_review", "superseded", "retired"} else 0.0 + strong_candidate_count = 0 + for row in flat[:5]: + row_text = " ".join( + str(row.get(key) or "") + for key in ("content", "summary", "title", "goal", "description", "search_text", "name", "compiled_truth", "observations", "aliases") + ) + row_tokens = _token_set(row_text) + row_informative = _informative_tokens(row_text) + row_coverage = len(query_tokens & row_tokens) / max(len(query_tokens), 1) if query_tokens else 0.0 + row_informative_coverage = ( + len(informative_query_tokens & row_informative) / max(len(informative_query_tokens), 1) + if informative_query_tokens else row_coverage + ) + if row_coverage >= 0.3 or row_informative_coverage >= 0.3: + strong_candidate_count += 1 + + score = ( + (top_score * 0.45) + + (margin * 0.35) + + (coverage * 0.45) + + (informative_coverage * 0.35) + + min(direct_support / 3.0, 1.0) * 0.15 + + min(evidence_diversity / 3.0, 1.0) * 0.1 + - stale_penalty + ) + abstain = False + reason = "grounded" + grounded_consensus = strong_candidate_count >= 1 and top_score >= 0.85 and informative_anchor_overlap >= 1 + if informative_coverage < 0.34 and informative_anchor_overlap == 0 and direct_support == 0: + abstain = True + reason = "weak_informative_coverage" + if informative_query_tokens and informative_anchor_overlap <= 1 and informative_coverage < 0.4 and top_score < 0.75: + abstain = True + reason = "weak_topical_anchor" + if margin < 0.08 and informative_coverage < 0.5 and informative_anchor_overlap < 2 and plan.abstain_allowed: + if strong_candidate_count < 2 and not grounded_consensus: + abstain = True + reason = "diffuse_candidates" + if plan.abstain_allowed and score < 0.5 and informative_coverage < 0.5: + if strong_candidate_count < 2 and not grounded_consensus: + abstain = True + reason = "low_answerability_score" + if "summary" in (query or "").lower() and informative_anchor_overlap < 2 and informative_coverage < 0.45: + abstain = True + reason = "ungrounded_summary_request" + + return { + "score": round(score, 4), + "abstain": abstain, + "reason": reason, + "top_margin": round(margin, 4), + "coverage": round(coverage, 4), + "informative_coverage": round(informative_coverage, 4), + "anchor_overlap": anchor_overlap, + "informative_anchor_overlap": informative_anchor_overlap, + "evidence_diversity": evidence_diversity, + "direct_support": direct_support, + "strong_candidate_count": strong_candidate_count, + } diff --git a/src/agentmemory/retrieval/candidate_generation.py b/src/agentmemory/retrieval/candidate_generation.py new file mode 100644 index 0000000..6880bd6 --- /dev/null +++ b/src/agentmemory/retrieval/candidate_generation.py @@ -0,0 +1,46 @@ +"""Candidate generation for procedure-aware retrieval.""" + +from __future__ import annotations + +import sqlite3 +from typing import Any + +from .query_planner import QueryPlan + + +def generate_procedure_candidates( + conn: sqlite3.Connection, + query: str, + plan: QueryPlan, + *, + limit: int = 10, + scope: str | None = None, +) -> dict[str, Any]: + """Search procedures and attach minimal diagnostics.""" + + if "procedures" not in plan.candidate_tables: + return {"candidates": [], "debug": {"skipped": "procedures_not_in_plan"}} + try: + from agentmemory import procedural + except ImportError: + return {"candidates": [], "debug": {"skipped": "procedural_module_unavailable"}} + + search = procedural.search_procedures( + conn, + query, + limit=max(limit * 3, 12), + scope=scope, + debug=True, + ) + candidates = search.get("procedures", []) + for cand in candidates: + cand.setdefault("type", "procedure") + cand.setdefault("source", "procedure_fts") + return { + "candidates": candidates, + "debug": { + "query": query, + "count": len(candidates), + **(search.get("debug") or {}), + }, + } diff --git a/src/agentmemory/retrieval/diagnostics.py b/src/agentmemory/retrieval/diagnostics.py new file mode 100644 index 0000000..f365e9b --- /dev/null +++ b/src/agentmemory/retrieval/diagnostics.py @@ -0,0 +1,43 @@ +"""Debug payload builders for retrieval executive output.""" + +from __future__ import annotations + +from typing import Any + + +def build_debug_payload( + *, + query_plan: dict[str, Any], + procedure_debug: dict[str, Any] | None, + answerability: dict[str, Any] | None, + second_stage: dict[str, Any] | None = None, + top_candidates: list[dict[str, Any]] | None = None, +) -> dict[str, Any]: + payload: dict[str, Any] = { + "query_plan": query_plan, + } + if procedure_debug: + payload["procedures"] = procedure_debug + if second_stage: + payload["second_stage"] = second_stage + if answerability: + payload["answerability"] = answerability + if top_candidates is not None: + payload["top_candidates"] = [ + { + "type": cand.get("type"), + "id": cand.get("id"), + "final_score": cand.get("final_score"), + "pre_second_stage_score": cand.get("pre_second_stage_score"), + "second_stage_heuristic": cand.get("second_stage_heuristic"), + "second_stage_mlp": cand.get("second_stage_mlp"), + "second_stage_judge": cand.get("second_stage_judge"), + "second_stage_slate_score": cand.get("second_stage_slate_score"), + "second_stage_slate_terms": cand.get("second_stage_slate_terms"), + "why_retrieved": cand.get("why_retrieved"), + "feature_summary": cand.get("second_stage_features"), + "text": cand.get("content") or cand.get("summary") or cand.get("title") or cand.get("goal") or cand.get("name"), + } + for cand in top_candidates[:5] + ] + return payload diff --git a/src/agentmemory/retrieval/evidence_graph.py b/src/agentmemory/retrieval/evidence_graph.py new file mode 100644 index 0000000..e47517e --- /dev/null +++ b/src/agentmemory/retrieval/evidence_graph.py @@ -0,0 +1,55 @@ +"""Evidence expansion helpers for procedure retrieval.""" + +from __future__ import annotations + +import sqlite3 +from typing import Any + + +def expand_procedure_evidence( + conn: sqlite3.Connection, + candidates: list[dict[str, Any]], + *, + max_sources_per_candidate: int = 4, +) -> dict[int, dict[str, Any]]: + """Attach 1-hop provenance and support evidence to top procedure candidates.""" + + if not candidates: + return {} + + out: dict[int, dict[str, Any]] = {} + for cand in candidates: + proc_id = int(cand["id"]) + sources = [ + dict(row) + for row in conn.execute( + """ + SELECT memory_id, event_id, decision_id, entity_id, source_role, created_at + FROM procedure_sources + WHERE procedure_id = ? + ORDER BY id + LIMIT ? + """, + (proc_id, max_sources_per_candidate), + ).fetchall() + ] + edges = [ + dict(row) + for row in conn.execute( + """ + SELECT target_table, target_id, relation_type, weight + FROM knowledge_edges + WHERE source_table = 'procedures' AND source_id = ? + ORDER BY weight DESC, id DESC + LIMIT ? + """, + (proc_id, max_sources_per_candidate), + ).fetchall() + ] + support_bonus = min((len(sources) * 0.14) + (sum(float(edge.get("weight") or 0.0) for edge in edges) * 0.08), 0.8) + out[proc_id] = { + "sources": sources, + "edges": edges, + "support_bonus": round(support_bonus, 4), + } + return out diff --git a/src/agentmemory/retrieval/feature_builder.py b/src/agentmemory/retrieval/feature_builder.py new file mode 100644 index 0000000..bdfb9fc --- /dev/null +++ b/src/agentmemory/retrieval/feature_builder.py @@ -0,0 +1,546 @@ +"""Feature extraction for the shared second-stage reranker.""" + +from __future__ import annotations + +import json +import math +import re +from datetime import datetime, timezone +from typing import Any, Iterable + +try: # pragma: no cover - numpy is optional at import time + import numpy as _np +except Exception: # pragma: no cover + _np = None + +_STOPWORDS = { + "a", "an", "and", "are", "as", "at", "be", "by", "did", "do", "does", "for", + "from", "has", "have", "how", "i", "in", "is", "it", "its", "of", + "on", "or", "that", "the", "to", "was", "we", "what", "when", "where", + "which", "who", "why", "will", "with", "you", +} +_LOW_SIGNAL_TOKENS = { + "summary", "history", "timeline", "recent", "today", "yesterday", "tomorrow", + "game", "issue", "problem", "thing", "stuff", "update", +} +_SYNONYMS = { + "dad": {"father", "parent"}, + "father": {"dad", "parent"}, + "mom": {"mother", "parent"}, + "mother": {"mom", "parent"}, + "workplace": {"work", "works", "job", "office", "occupation", "position", "employer"}, + "occupation": {"job", "work", "works", "position", "career"}, + "position": {"job", "occupation", "work", "works", "role"}, + "educational": {"education", "degree", "school", "background"}, + "education": {"educational", "degree", "school", "background"}, + "background": {"education", "degree", "school"}, + "degree": {"education", "educational", "school", "background"}, + "location": {"where", "place", "city", "hometown", "workplace"}, + "hometown": {"home", "city", "location", "from"}, + "coworker": {"colleague", "work", "works"}, + "hobby": {"enjoy", "enjoys", "love", "loves", "passion", "passionate", "into"}, + "enjoy": {"hobby", "likes", "love", "loves", "passion"}, + "enjoys": {"hobby", "likes", "love", "loves", "passion"}, + "loves": {"hobby", "enjoy", "enjoys", "passion", "passionate"}, + "passionate": {"hobby", "enjoy", "enjoys", "loves"}, + "boss": {"manager", "supervisor"}, + "subordinate": {"employee", "report", "teammate"}, + "aunt": {"relative"}, + "uncle": {"relative"}, + "cousin": {"relative"}, + "living": {"occupation", "job", "work", "works"}, + "email": {"contact", "address"}, + "contact": {"phone", "number", "email"}, + "number": {"phone", "contact"}, +} +_ROLE_TERMS = { + "father", "dad", "mother", "mom", "parent", "coworker", "colleague", + "friend", "neighbor", "sister", "brother", "wife", "husband", "nephew", + "niece", "aunt", "uncle", "cousin", "relative", "boss", "manager", + "supervisor", "subordinate", "employee", "report", +} +_ATTRIBUTE_TERMS = { + "education", "educational", "background", "degree", "school", "occupation", + "position", "job", "workplace", "works", "work", "location", "hometown", + "company", "employer", "role", "status", "key", "code", "value", + "hobby", "enjoy", "enjoys", "love", "loves", "likes", "passion", + "passionate", "into", "email", "address", "contact", "number", "phone", "living", +} +_DATE_RE = re.compile( + r"\b(?:\d{4}-\d{2}-\d{2}|\d{1,2}/\d{1,2}(?:/\d{2,4})?|" + r"jan(?:uary)?|feb(?:ruary)?|mar(?:ch)?|apr(?:il)?|may|jun(?:e)?|" + r"jul(?:y)?|aug(?:ust)?|sep(?:tember)?|oct(?:ober)?|nov(?:ember)?|" + r"dec(?:ember)?)\b", + re.IGNORECASE, +) +_TEMPORAL_RE = re.compile( + r"\b(yesterday|today|tomorrow|when|before|after|during|timeline|history|recent|latest|first|last)\b", + re.IGNORECASE, +) +_LONG_CONTEXT_HINT_RE = re.compile( + r"\b(" + r"how many|how much|order|earliest|latest|most recent|" + r"before|after|between|this month|last month|past month|past week|" + r"current(?:ly)?|previous(?:ly)?|" + r"(?:one|two|three|four|five|six|seven|eight|nine|ten|\d+)\s+" + r"(?:day|week|month|year)s?\s+ago|" + r"based on|underlying|future|might|would" + r")\b", + re.IGNORECASE, +) +_SESSION_RE = re.compile( + r"(?:^|[|_\s-])(?:sid|session|s)[=_ :#-]*(\d+)|\bsession[_ :#-]*(\d+)\b", + re.IGNORECASE, +) +_DIALOG_RE = re.compile(r"\bD(\d+):", re.IGNORECASE) +_ENTITY_RE = re.compile(r"\b[A-Z][A-Za-z0-9_.:-]+\b") + +FEATURE_VERSION_V1 = "v1" +FEATURE_ORDER_V1 = [ + "base_score", + "retrieval_score", + "rrf_score", + "confidence", + "query_overlap", + "informative_overlap", + "tfidf_cosine", + "exact_phrase", + "entity_overlap", + "alias_overlap", + "query_temporal", + "candidate_temporal", + "temporal_anchor_overlap", + "query_session_hint", + "candidate_session_hint", + "session_gap_score", + "intent_bucket_fit", + "source_keyword", + "source_semantic", + "source_both", + "source_graph", + "bucket_memories", + "bucket_events", + "bucket_entities", + "bucket_procedures", + "bucket_decisions", + "candidate_age_score", + "support_evidence_score", + "status_active", + "status_stale", + "status_needs_review", + "position_score", + "neighbor_margin", + "query_length_score", + "candidate_length_score", + "procedural_candidate", +] + + +def _normalize_token(token: str) -> str: + tok = re.sub(r"[^a-z0-9]+", "", (token or "").lower()) + if len(tok) <= 2 or tok in _STOPWORDS: + return "" + if tok.endswith("ies") and len(tok) > 4: + tok = tok[:-3] + "y" + elif tok.endswith("ed") and len(tok) > 4: + tok = tok[:-2] + elif tok.endswith("es") and len(tok) > 4: + tok = tok[:-2] + elif tok.endswith("s") and len(tok) > 3: + tok = tok[:-1] + return tok + + +def _token_set(text: str) -> set[str]: + tokens = { + token + for part in re.split(r"\s+", text or "") + if (token := _normalize_token(part)) + } + expanded = set(tokens) + for token in tokens: + expanded.update(_SYNONYMS.get(token, ())) + return expanded + + +def _informative_tokens(text: str) -> set[str]: + return {token for token in _token_set(text) if token not in _LOW_SIGNAL_TOKENS} + + +def _safe_float(value: Any, default: float = 0.0) -> float: + try: + return float(value) + except (TypeError, ValueError): + return default + + +def candidate_text(candidate: dict[str, Any]) -> str: + parts: list[str] = [] + for key in ( + "content", "summary", "title", "goal", "description", "search_text", + "name", "compiled_truth", "why_retrieved", + ): + value = candidate.get(key) + if value: + parts.append(str(value)) + for key in ("observations", "aliases", "supporting_evidence"): + value = candidate.get(key) + if not value: + continue + if isinstance(value, str): + parts.append(value) + else: + try: + parts.append(json.dumps(value, ensure_ascii=True)) + except Exception: + parts.append(str(value)) + return " ".join(parts) + + +def _alias_values(candidate: dict[str, Any]) -> list[str]: + raw = candidate.get("aliases") + if not raw: + return [] + if isinstance(raw, list): + return [str(value) for value in raw if value] + if isinstance(raw, str): + try: + parsed = json.loads(raw) + except Exception: + return [raw] + if isinstance(parsed, list): + return [str(value) for value in parsed if value] + return [raw] + return [str(raw)] + + +def _entity_terms(text: str) -> set[str]: + return { + match.group(0).lower() + for match in _ENTITY_RE.finditer(text or "") + if len(match.group(0)) > 2 + } + + +def _intent_bucket_preference(plan: Any, bucket: str) -> float: + if plan is None: + return 0.5 + tables = list(getattr(plan, "candidate_tables", []) or []) + if not tables: + return 0.5 + try: + position = tables.index(bucket) + except ValueError: + return 0.2 + return max(0.2, 1.0 - (position * 0.12)) + + +def _source_flags(candidate: dict[str, Any]) -> tuple[float, float, float, float]: + source = str(candidate.get("source") or "").lower() + return ( + 1.0 if source in {"keyword", "procedure_fts", "intent_entity", "intent_decision"} else 0.0, + 1.0 if source == "semantic" else 0.0, + 1.0 if source == "both" else 0.0, + 1.0 if source == "graph" else 0.0, + ) + + +def _bucket_flags(bucket: str) -> tuple[float, float, float, float, float]: + return ( + 1.0 if bucket == "memories" else 0.0, + 1.0 if bucket == "events" else 0.0, + 1.0 if bucket == "entities" else 0.0, + 1.0 if bucket == "procedures" else 0.0, + 1.0 if bucket == "decisions" else 0.0, + ) + + +def _temporal_anchor_overlap(query: str, text: str) -> float: + query_dates = {match.group(0).lower() for match in _DATE_RE.finditer(query or "")} + cand_dates = {match.group(0).lower() for match in _DATE_RE.finditer(text or "")} + if not query_dates: + return 0.0 + return len(query_dates & cand_dates) / len(query_dates) + + +def _extract_session_hints(text: str) -> list[int]: + hints = [int(match.group(1) or match.group(2)) for match in _SESSION_RE.finditer(text or "")] + hints.extend(int(match.group(1)) for match in _DIALOG_RE.finditer(text or "")) + return hints + + +def _session_gap_score(query: str, candidate_text_value: str) -> tuple[float, float, float]: + query_sessions = _extract_session_hints(query) + candidate_sessions = _extract_session_hints(candidate_text_value) + if not query_sessions: + return 0.0, 0.0, 0.0 + if not candidate_sessions: + return 1.0, 0.0, 0.0 + gap = min(abs(q - c) for q in query_sessions for c in candidate_sessions) + return 1.0 / (1.0 + gap), 1.0, 1.0 + + +def _role_value_pattern(text: str) -> float: + return 1.0 if re.search( + r"\b(" + r"works?\s+(?:as|in|at)|" + r"is\s+(?:a|an|the)\b|" + r"loves?\b|likes?\b|enjoys?\b|" + r"passionate\s+about|really\s+into|free\s+time|" + r"originally\s+from|grew\s+up\s+in|hails?\s+from|from\s+[A-Z][A-Za-z]+,\s*[A-Z][A-Za-z]+|" + r"[\w.+-]+@[\w.-]+|" + r"(?:phone|contact|number|email)\s+(?:is|address\s+is|number\s+is)?|" + r"company\s+(?:is|called|named)" + r")", + text or "", + re.IGNORECASE, + ) else 0.0 + + +def _parse_iso_timestamp(value: Any) -> datetime | None: + if not value: + return None + try: + text = str(value).replace("Z", "+00:00") + dt = datetime.fromisoformat(text) + except Exception: + return None + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt.astimezone(timezone.utc) + + +def _age_score(candidate: dict[str, Any]) -> float: + when = _parse_iso_timestamp(candidate.get("created_at")) or _parse_iso_timestamp(candidate.get("updated_at")) + if when is None: + return 0.5 + age_days = max((datetime.now(timezone.utc) - when).total_seconds() / 86400.0, 0.0) + return 1.0 / (1.0 + age_days / 30.0) + + +def _tfidf_cosine(query: str, text: str) -> float: + q_tokens = list(_informative_tokens(query)) + c_tokens = list(_informative_tokens(text)) + if not q_tokens or not c_tokens: + return 0.0 + docs = [q_tokens, c_tokens] + vocab = sorted({token for doc in docs for token in doc}) + if not vocab: + return 0.0 + doc_freq: dict[str, int] = {} + for token in vocab: + doc_freq[token] = sum(1 for doc in docs if token in doc) + n_docs = len(docs) + + def _weights(tokens: Iterable[str]) -> dict[str, float]: + counts: dict[str, int] = {} + for token in tokens: + counts[token] = counts.get(token, 0) + 1 + if not counts: + return {} + max_tf = max(counts.values()) or 1 + weights: dict[str, float] = {} + for token, count in counts.items(): + tf = count / max_tf + idf = math.log((1 + n_docs) / (1 + doc_freq[token])) + 1.0 + weights[token] = tf * idf + return weights + + q_weights = _weights(q_tokens) + c_weights = _weights(c_tokens) + dot = sum(q_weights.get(token, 0.0) * c_weights.get(token, 0.0) for token in vocab) + q_norm = math.sqrt(sum(value * value for value in q_weights.values())) + c_norm = math.sqrt(sum(value * value for value in c_weights.values())) + if q_norm == 0.0 or c_norm == 0.0: + return 0.0 + return dot / (q_norm * c_norm) + + +def _should_probe_long_context( + *, + query: str, + plan: Any, + bucket: str, + text: str, + position: int, + current_score: float, + prev_raw: Any, + next_raw: Any, + leader_raw: Any, +) -> bool: + if bucket != "memories": + return False + + lowered_query = query or "" + structured_long_text = ( + len(text) >= 1500 + or "session id:" in text.lower() + or "session date:" in text.lower() + or text.count("\n") >= 8 + ) + if not structured_long_text: + return False + + if position > 4: + return False + + query_needs_probe = ( + bool(getattr(plan, "requires_temporal_reasoning", False)) + or bool(getattr(plan, "requires_multi_hop", False)) + or bool(getattr(plan, "needs_ordering", False)) + or bool(getattr(plan, "needs_update_resolution", False)) + or bool(getattr(plan, "needs_set_coverage", False)) + or bool(_LONG_CONTEXT_HINT_RE.search(lowered_query)) + ) + if not query_needs_probe: + return False + + closest_gap_values: list[float] = [] + if prev_raw is not None: + closest_gap_values.append(abs(current_score - _safe_float(prev_raw))) + if next_raw is not None: + closest_gap_values.append(abs(current_score - _safe_float(next_raw))) + closest_neighbor_gap = min(closest_gap_values) if closest_gap_values else 0.0 + leader_gap = abs(_safe_float(leader_raw, current_score) - current_score) + return closest_neighbor_gap <= 0.035 and leader_gap <= 0.08 + + +def build_features( + query: str, + plan: Any, + candidate: dict[str, Any], + *, + neighbors: dict[str, Any] | None = None, +) -> dict[str, float]: + """Build numeric features for a candidate row.""" + + bucket = str(candidate.get("bucket") or candidate.get("type") or "memories") + text = candidate_text(candidate) + query_tokens = _token_set(query) + query_informative = _informative_tokens(query) + cand_tokens = _token_set(text) + cand_informative = _informative_tokens(text) + query_overlap = len(query_tokens & cand_tokens) / max(len(query_tokens), 1) if query_tokens else 0.0 + informative_overlap = ( + len(query_informative & cand_informative) / max(len(query_informative), 1) + if query_informative else query_overlap + ) + exact_phrase = 1.0 if query and len(query.strip()) >= 4 and query.lower().strip() in text.lower() else 0.0 + query_entities = _entity_terms(query) | {term.lower() for term in getattr(plan, "target_entities", []) or []} + cand_entities = _entity_terms(text) + entity_overlap = len(query_entities & cand_entities) / max(len(query_entities), 1) if query_entities else 0.0 + aliases = {alias.lower() for alias in _alias_values(candidate) if len(alias) > 2} + alias_overlap = len(query_entities & aliases) / max(len(query_entities), 1) if query_entities and aliases else 0.0 + role_overlap = 1.0 if (_ROLE_TERMS & query_tokens & cand_tokens) else 0.0 + attribute_overlap = 1.0 if (_ATTRIBUTE_TERMS & query_tokens & cand_tokens) else 0.0 + role_value_pattern = role_overlap * _role_value_pattern(text) + query_temporal = 1.0 if (bool(getattr(plan, "requires_temporal_reasoning", False)) or _TEMPORAL_RE.search(query or "")) else 0.0 + candidate_temporal = 1.0 if _TEMPORAL_RE.search(text or "") or _DATE_RE.search(text or "") else 0.0 + temporal_anchor_overlap = _temporal_anchor_overlap(query, text) + session_gap_score, query_session_hint, candidate_session_hint = _session_gap_score(query, text) + source_keyword, source_semantic, source_both, source_graph = _source_flags(candidate) + bucket_memories, bucket_events, bucket_entities, bucket_procedures, bucket_decisions = _bucket_flags(bucket) + status = str(candidate.get("status") or "").lower() + position = max(int(candidate.get("_stage_position") or 0), 0) + prev_raw = (neighbors or {}).get("prev_score") + next_raw = (neighbors or {}).get("next_score") + leader_raw = (neighbors or {}).get("leader_score") + prev_score = _safe_float(prev_raw) + next_score = _safe_float(next_raw) + current_score = _safe_float(candidate.get("final_score") or candidate.get("retrieval_score")) + neighbor_margin = max(current_score - prev_score, current_score - next_score, 0.0) + confidence = _safe_float(candidate.get("confidence"), 0.5) + support_evidence_score = min(len(candidate.get("supporting_evidence") or []) / 3.0, 1.0) + long_context_debug: dict[str, Any] = {"applicable": False} + if _should_probe_long_context( + query=query, + plan=plan, + bucket=bucket, + text=text, + position=position, + current_score=current_score, + prev_raw=prev_raw, + next_raw=next_raw, + leader_raw=leader_raw, + ): + try: + from agentmemory.retrieval.long_context import analyze_long_context as _analyze_long_context + + long_context_debug = _analyze_long_context(query, plan, candidate, text=text) + except Exception: + long_context_debug = {"applicable": False} + if long_context_debug.get("applicable"): + candidate["_long_context_debug"] = long_context_debug + features = { + "base_score": current_score, + "retrieval_score": _safe_float(candidate.get("retrieval_score"), current_score), + "rrf_score": _safe_float(candidate.get("rrf_score")), + "confidence": confidence, + "query_overlap": query_overlap, + "informative_overlap": informative_overlap, + "tfidf_cosine": _tfidf_cosine(query, text), + "exact_phrase": exact_phrase, + "entity_overlap": entity_overlap, + "alias_overlap": alias_overlap, + "query_temporal": query_temporal, + "candidate_temporal": candidate_temporal, + "temporal_anchor_overlap": temporal_anchor_overlap, + "query_session_hint": query_session_hint, + "candidate_session_hint": candidate_session_hint, + "session_gap_score": session_gap_score, + "intent_bucket_fit": _intent_bucket_preference(plan, bucket), + "source_keyword": source_keyword, + "source_semantic": source_semantic, + "source_both": source_both, + "source_graph": source_graph, + "bucket_memories": bucket_memories, + "bucket_events": bucket_events, + "bucket_entities": bucket_entities, + "bucket_procedures": bucket_procedures, + "bucket_decisions": bucket_decisions, + "candidate_age_score": _age_score(candidate), + "support_evidence_score": support_evidence_score, + "status_active": 1.0 if status in {"", "active"} else 0.0, + "status_stale": 1.0 if status in {"stale", "superseded", "retired"} else 0.0, + "status_needs_review": 1.0 if status == "needs_review" else 0.0, + "position_score": 1.0 / (1.0 + position), + "neighbor_margin": neighbor_margin, + "query_length_score": min(len(query_informative) / 8.0, 1.0), + "candidate_length_score": min(len(cand_informative) / 64.0, 1.0), + "procedural_candidate": 1.0 if bucket == "procedures" else 0.0, + "query_needs_counting": 1.0 if getattr(plan, "needs_counting", False) else 0.0, + "query_needs_comparison": 1.0 if getattr(plan, "needs_comparison", False) else 0.0, + "query_needs_ordering": 1.0 if getattr(plan, "needs_ordering", False) else 0.0, + "query_needs_update_resolution": 1.0 if getattr(plan, "needs_update_resolution", False) else 0.0, + "query_needs_set_coverage": 1.0 if getattr(plan, "needs_set_coverage", False) else 0.0, + "query_needs_role_fact": 1.0 if getattr(plan, "needs_role_fact", False) else 0.0, + "query_needs_synthetic_key_value": 1.0 if getattr(plan, "needs_synthetic_key_value", False) else 0.0, + "role_overlap": role_overlap, + "attribute_overlap": attribute_overlap, + "role_value_pattern": role_value_pattern, + "query_requires_multi_hop": 1.0 if getattr(plan, "requires_multi_hop", False) else 0.0, + "long_context_applicable": 1.0 if long_context_debug.get("applicable") else 0.0, + "long_context_score": _safe_float(long_context_debug.get("score")), + "long_context_confidence": _safe_float(long_context_debug.get("confidence")), + "long_context_agreement": _safe_float(long_context_debug.get("agreement")), + "long_context_uncertainty": _safe_float(long_context_debug.get("uncertainty")), + "long_context_coverage": _safe_float(long_context_debug.get("coverage")), + "long_context_precision": _safe_float(long_context_debug.get("precision")), + "long_context_focused_program": 1.0 if long_context_debug.get("program") not in {None, "", "whole_doc"} else 0.0, + } + return {name: round(float(value), 6) for name, value in features.items()} + + +def vectorize_features( + feature_dict: dict[str, float], + *, + feature_version: str = FEATURE_VERSION_V1, +): + """Return a numeric feature vector in canonical order.""" + + if feature_version != FEATURE_VERSION_V1: + raise ValueError(f"Unsupported feature version: {feature_version}") + values = [float(feature_dict.get(name, 0.0)) for name in FEATURE_ORDER_V1] + if _np is not None: + return _np.asarray(values, dtype=float) + return values diff --git a/src/agentmemory/retrieval/judge.py b/src/agentmemory/retrieval/judge.py new file mode 100644 index 0000000..31ea7cb --- /dev/null +++ b/src/agentmemory/retrieval/judge.py @@ -0,0 +1,87 @@ +"""Optional local judge reranker for top candidates.""" + +from __future__ import annotations + +import json +import re +import urllib.error +import urllib.request +from dataclasses import dataclass +from typing import Any + + +@dataclass(slots=True) +class JudgeConfig: + enabled: bool = False + provider: str = "ollama" + model: str = "llama3.2:3b" + top_k: int = 5 + timeout_s: float = 6.0 + url: str = "http://localhost:11434/api/generate" + + +def _coerce_score(value: str) -> float: + match = re.search(r"(-?\d+(?:\.\d+)?)", value or "") + if not match: + return 0.0 + try: + score = float(match.group(1)) + except (TypeError, ValueError): + return 0.0 + return max(0.0, min(score, 1.0)) + + +def _candidate_synopsis(candidate: dict[str, Any]) -> str: + for key in ("content", "summary", "title", "goal", "description", "search_text", "name", "compiled_truth"): + value = candidate.get(key) + if value: + text = str(value).strip() + return text[:1200] + return "" + + +def _judge_with_ollama(query: str, candidates: list[dict[str, Any]], config: JudgeConfig) -> list[float]: + scores: list[float] = [] + for candidate in candidates[: config.top_k]: + prompt = ( + "You are a retrieval judge. Score relevance from 0.0 to 1.0.\n" + "Return only the numeric score.\n\n" + f"Query: {query}\n\n" + f"Candidate: {_candidate_synopsis(candidate)}\n" + ) + payload = json.dumps( + { + "model": config.model, + "prompt": prompt, + "stream": False, + "options": {"temperature": 0}, + } + ).encode("utf-8") + req = urllib.request.Request( + config.url, + data=payload, + headers={"Content-Type": "application/json"}, + ) + try: + with urllib.request.urlopen(req, timeout=config.timeout_s) as resp: # noqa: S310 - local optional service + body = json.loads(resp.read().decode("utf-8")) + scores.append(_coerce_score(str(body.get("response") or ""))) + except (urllib.error.URLError, TimeoutError, OSError, ValueError, json.JSONDecodeError): + return [] + return scores + + +def judge_candidates( + query: str, + candidates: list[dict[str, Any]], + config: JudgeConfig | None = None, +) -> list[float]: + """Return optional judge scores for the top candidates.""" + + cfg = config or JudgeConfig() + if not cfg.enabled or not candidates: + return [] + if cfg.provider == "ollama": + return _judge_with_ollama(query, candidates, cfg) + return [] + diff --git a/src/agentmemory/retrieval/late_reranker.py b/src/agentmemory/retrieval/late_reranker.py new file mode 100644 index 0000000..f91e8d1 --- /dev/null +++ b/src/agentmemory/retrieval/late_reranker.py @@ -0,0 +1,43 @@ +"""Deterministic late reranking for procedure candidates.""" + +from __future__ import annotations + +from typing import Any + + +def rerank_procedure_candidates( + candidates: list[dict[str, Any]], + evidence: dict[int, dict[str, Any]], + *, + benchmark_mode: bool = False, +) -> list[dict[str, Any]]: + reranked: list[dict[str, Any]] = [] + for cand in candidates: + proc_id = int(cand["id"]) + ev = evidence.get(proc_id) or {} + bonus = float(ev.get("support_bonus") or 0.0) + base = float(cand.get("final_score") or 0.0) + status = cand.get("status") or "active" + status_multiplier = { + "active": 1.0, + "candidate": 0.9, + "needs_review": 0.72, + "stale": 0.64, + "superseded": 0.3, + "retired": 0.1, + }.get(status, 1.0) + if benchmark_mode: + score = base * status_multiplier + else: + score = (base + bonus) * status_multiplier + updated = dict(cand) + updated["supporting_evidence"] = ev.get("sources") or [] + updated["evidence_edges"] = ev.get("edges") or [] + updated["evidence_bonus"] = round(bonus, 4) + updated["final_score"] = round(score, 6) + updated["why_retrieved"] = updated.get("why_retrieved") or ( + "strong procedural evidence cluster" if bonus >= 0.3 else "direct procedural match" + ) + reranked.append(updated) + reranked.sort(key=lambda item: item.get("final_score", 0.0), reverse=True) + return reranked diff --git a/src/agentmemory/retrieval/long_context.py b/src/agentmemory/retrieval/long_context.py new file mode 100644 index 0000000..f0f40eb --- /dev/null +++ b/src/agentmemory/retrieval/long_context.py @@ -0,0 +1,458 @@ +"""Bounded long-context evidence probing for shared retrieval reranking. + +RLM/SRLM-inspired adaptation for brainctl: + +- Treat the candidate text as an external environment rather than a single bag + of tokens. +- Run a small portfolio of deterministic chunking "programs" over that + environment. +- Select the most reliable program using agreement + uncertainty, not just the + single highest raw score. + +This stays local, bounded, and depth-1 on purpose. Reproduction work on RLMs +shows deeper recursion can overthink and blow up latency; here we only probe a +short list of chunk views over the same candidate row. +""" + +from __future__ import annotations + +import os +import re +from dataclasses import dataclass +from typing import Any + +_STOPWORDS = { + "a", "an", "and", "are", "as", "at", "be", "by", "did", "do", "does", "for", + "from", "has", "have", "how", "i", "in", "is", "it", "its", "of", "on", "or", + "that", "the", "to", "was", "we", "what", "when", "where", "which", "who", + "why", "will", "with", "you", +} +_LOW_SIGNAL_TOKENS = { + "summary", "history", "timeline", "recent", "today", "yesterday", "tomorrow", + "issue", "problem", "thing", "stuff", "update", +} +_TEMPORAL_RE = re.compile( + r"\b(yesterday|today|tomorrow|when|before|after|during|timeline|history|recent|latest|first|last)\b", + re.IGNORECASE, +) +_DATE_RE = re.compile( + r"\b(?:\d{4}-\d{2}-\d{2}|\d{1,2}/\d{1,2}(?:/\d{2,4})?|" + r"jan(?:uary)?|feb(?:ruary)?|mar(?:ch)?|apr(?:il)?|may|jun(?:e)?|" + r"jul(?:y)?|aug(?:ust)?|sep(?:tember)?|oct(?:ober)?|nov(?:ember)?|" + r"dec(?:ember)?)\b", + re.IGNORECASE, +) +_SESSION_RE = re.compile(r"\bsession[_ :#-]*(\d+)\b", re.IGNORECASE) +_ENTITY_RE = re.compile(r"\b[A-Z][A-Za-z0-9_.:-]+\b") +_TURNISH_RE = re.compile(r"^\s*(?:[A-Z][A-Za-z0-9_.-]+:|.+\bsaid,\s+\")", re.IGNORECASE) + + +@dataclass(slots=True) +class ProbeChunk: + index: int + text: str + score: float + coverage: float + precision: float + entity_overlap: float + temporal_overlap: float + exact_phrase: float + + +@dataclass(slots=True) +class ProbeProgramResult: + name: str + score: float + confidence: float + uncertainty: float + agreement: float + coverage: float + precision: float + length_penalty: float + chunk_count: int + top_chunk: ProbeChunk | None + + +def _normalize_token(token: str) -> str: + tok = re.sub(r"[^a-z0-9]+", "", (token or "").lower()) + if len(tok) <= 2 or tok in _STOPWORDS: + return "" + if tok.endswith("ies") and len(tok) > 4: + tok = tok[:-3] + "y" + elif tok.endswith("ed") and len(tok) > 4: + tok = tok[:-2] + elif tok.endswith("es") and len(tok) > 4: + tok = tok[:-2] + elif tok.endswith("s") and len(tok) > 3: + tok = tok[:-1] + return tok + + +def _token_set(text: str) -> set[str]: + return { + token + for part in re.split(r"\s+", text or "") + if (token := _normalize_token(part)) + } + + +def _informative_tokens(text: str) -> set[str]: + return {token for token in _token_set(text) if token not in _LOW_SIGNAL_TOKENS} + + +def _entity_terms(text: str) -> set[str]: + return { + match.group(0).lower() + for match in _ENTITY_RE.finditer(text or "") + if len(match.group(0)) > 2 + } + + +def _deobfuscate(text: str) -> str: + value = text or "" + value = value.replace("\u200b", "").replace("\ufeff", "") + value = re.sub(r"[_*/`~]+", " ", value) + value = re.sub(r"\s+", " ", value) + return value.strip() + + +def _safe_window(items: list[str], size: int, stride: int) -> list[str]: + if not items: + return [] + if len(items) <= size: + return ["\n".join(items)] + out: list[str] = [] + for start in range(0, len(items), max(stride, 1)): + chunk = items[start:start + size] + if not chunk: + continue + out.append("\n".join(chunk)) + if start + size >= len(items): + break + return out + + +def _cap_chunks(chunks: list[str], max_chunks: int) -> list[str]: + if len(chunks) <= max_chunks: + return chunks + if max_chunks <= 1: + return [chunks[0]] + step = (len(chunks) - 1) / float(max_chunks - 1) + selected: list[str] = [] + seen: set[int] = set() + for idx in range(max_chunks): + pick = int(round(idx * step)) + if pick in seen: + continue + seen.add(pick) + selected.append(chunks[pick]) + return selected + + +def _whole_doc_program(text: str, max_chunks: int) -> list[str]: + return [text[:48000]] if text else [] + + +def _line_window_program(text: str, max_chunks: int) -> list[str]: + lines = [line.strip() for line in text.splitlines() if line.strip()] + if len(lines) < 3: + return [] + return _cap_chunks(_safe_window(lines, size=6, stride=3), max_chunks) + + +def _sentence_window_program(text: str, max_chunks: int) -> list[str]: + sentences = [part.strip() for part in re.split(r"(?<=[.!?])\s+|\n+", text) if part.strip()] + if len(sentences) < 2: + return [] + return _cap_chunks(_safe_window(sentences, size=3, stride=1), max_chunks) + + +def _turn_window_program(text: str, max_chunks: int) -> list[str]: + lines = [line.strip() for line in text.splitlines() if line.strip()] + turnish = [line for line in lines if _TURNISH_RE.search(line)] + if len(turnish) < 2: + return [] + return _cap_chunks(_safe_window(turnish, size=4, stride=2), max_chunks) + + +def _anchor_window_program( + text: str, + query: str, + *, + target_entities: list[str], + temporal_query: bool, + max_chunks: int, +) -> list[str]: + lines = [line.strip() for line in text.splitlines() if line.strip()] + if not lines: + return [] + informative = _informative_tokens(query) + entities = {value.lower() for value in target_entities if value} + header = [] + if lines[:2] and any("session id" in line.lower() or "session date" in line.lower() for line in lines[:3]): + header = lines[:3] + anchor_indexes: list[int] = [] + for idx, line in enumerate(lines): + lowered = line.lower() + if informative and any(token in lowered for token in informative): + anchor_indexes.append(idx) + continue + if entities and any(entity in lowered for entity in entities): + anchor_indexes.append(idx) + continue + if temporal_query and (_TEMPORAL_RE.search(line) or _DATE_RE.search(line)): + anchor_indexes.append(idx) + continue + if not anchor_indexes: + return [] + chunks: list[str] = [] + seen: set[str] = set() + for idx in anchor_indexes: + start = max(0, idx - 2) + end = min(len(lines), idx + 3) + window = header + lines[start:end] + chunk = "\n".join(window) + if chunk and chunk not in seen: + seen.add(chunk) + chunks.append(chunk) + return _cap_chunks(chunks, max_chunks) + + +def _candidate_programs( + text: str, + query: str, + *, + target_entities: list[str], + temporal_query: bool, + max_chunks: int, +) -> dict[str, list[str]]: + programs = { + "whole_doc": _whole_doc_program(text, max_chunks), + "line_windows": _line_window_program(text, max_chunks), + "sentence_windows": _sentence_window_program(text, max_chunks), + "turn_windows": _turn_window_program(text, max_chunks), + "anchor_windows": _anchor_window_program( + text, + query, + target_entities=target_entities, + temporal_query=temporal_query, + max_chunks=max_chunks, + ), + } + return {name: chunks for name, chunks in programs.items() if chunks} + + +def _chunk_score( + query: str, + chunk: str, + *, + target_entities: list[str], + temporal_query: bool, +) -> ProbeChunk: + informative = _informative_tokens(query) + query_tokens = _token_set(query) + chunk_tokens = _token_set(chunk) + chunk_informative = _informative_tokens(chunk) + query_entities = _entity_terms(query) | {value.lower() for value in target_entities if value} + chunk_entities = _entity_terms(chunk) + overlap = len(query_tokens & chunk_tokens) / max(len(query_tokens), 1) if query_tokens else 0.0 + coverage = len(informative & chunk_informative) / max(len(informative), 1) if informative else overlap + precision = len(informative & chunk_informative) / max(len(chunk_informative), 1) if chunk_informative else 0.0 + exact_phrase = 1.0 if query and len(query.strip()) >= 4 and query.lower().strip() in chunk.lower() else 0.0 + entity_overlap = len(query_entities & chunk_entities) / max(len(query_entities), 1) if query_entities else 0.0 + temporal_overlap = 0.0 + if temporal_query: + temporal_overlap = 1.0 if (_TEMPORAL_RE.search(chunk) or _DATE_RE.search(chunk) or _SESSION_RE.search(chunk)) else 0.0 + concentration = min(1.0, 12.0 / max(len(chunk_informative), 12)) + score = ( + coverage * 0.34 + + precision * 0.18 + + overlap * 0.12 + + exact_phrase * 0.14 + + entity_overlap * 0.12 + + temporal_overlap * (0.10 if temporal_query else 0.0) + + concentration * 0.10 + ) + return ProbeChunk( + index=0, + text=chunk, + score=round(min(score, 1.0), 6), + coverage=round(coverage, 6), + precision=round(precision, 6), + entity_overlap=round(entity_overlap, 6), + temporal_overlap=round(temporal_overlap, 6), + exact_phrase=round(exact_phrase, 6), + ) + + +def _program_signature(chunk: ProbeChunk | None) -> set[str]: + if chunk is None: + return set() + return _informative_tokens(chunk.text) + + +def _is_focused_program(program: ProbeProgramResult, *, candidate_chars: int) -> bool: + if program.name == "whole_doc" or program.top_chunk is None or candidate_chars <= 0: + return False + span_ratio = len(program.top_chunk.text) / float(candidate_chars) + return span_ratio < 0.85 + + +def analyze_long_context( + query: str, + plan: Any, + candidate: dict[str, Any], + *, + text: str, +) -> dict[str, Any]: + """Return depth-1 context-program evidence for a long candidate row.""" + + if os.environ.get("BRAINCTL_LONG_CONTEXT_PROBES", "1") in {"0", "false", "False"}: + return {"applicable": False, "reason": "disabled"} + + min_chars = int(os.environ.get("BRAINCTL_LONG_CONTEXT_MIN_CHARS", "900") or "900") + max_chunks = int(os.environ.get("BRAINCTL_LONG_CONTEXT_MAX_CHUNKS", "24") or "24") + candidate_text = _deobfuscate(text) + raw_lines = [line.strip() for line in text.splitlines() if line.strip()] + structured_session = any( + "session id" in line.lower() or "session date" in line.lower() + for line in raw_lines[:4] + ) + if len(candidate_text) < min_chars and not structured_session and len(raw_lines) < 5: + return {"applicable": False, "reason": "short_text"} + + target_entities = list(getattr(plan, "target_entities", []) or []) + temporal_query = bool(getattr(plan, "requires_temporal_reasoning", False)) or bool(_TEMPORAL_RE.search(query or "")) + programs = _candidate_programs( + candidate_text, + query, + target_entities=target_entities, + temporal_query=temporal_query, + max_chunks=max_chunks, + ) + if not programs: + return {"applicable": False, "reason": "no_programs"} + + evaluated: list[ProbeProgramResult] = [] + for name, chunks in programs.items(): + scored: list[ProbeChunk] = [] + for index, chunk in enumerate(chunks): + base = _chunk_score( + query, + chunk, + target_entities=target_entities, + temporal_query=temporal_query, + ) + scored.append( + ProbeChunk( + index=index, + text=base.text, + score=base.score, + coverage=base.coverage, + precision=base.precision, + entity_overlap=base.entity_overlap, + temporal_overlap=base.temporal_overlap, + exact_phrase=base.exact_phrase, + ) + ) + scored.sort(key=lambda item: item.score, reverse=True) + top_chunk = scored[0] if scored else None + second_score = scored[1].score if len(scored) > 1 else 0.0 + coverage = top_chunk.coverage if top_chunk else 0.0 + precision = top_chunk.precision if top_chunk else 0.0 + margin = max((top_chunk.score - second_score) if top_chunk else 0.0, 0.0) + confidence = min(1.0, coverage * 0.45 + precision * 0.15 + margin * 0.40) + length_penalty = min(1.0, len(chunks) / max(max_chunks, 1)) + score = min(1.0, (top_chunk.score if top_chunk else 0.0) * 0.82 + coverage * 0.12 + precision * 0.06) + evaluated.append( + ProbeProgramResult( + name=name, + score=round(score, 6), + confidence=round(confidence, 6), + uncertainty=1.0, # set after agreement pass + agreement=0.0, + coverage=round(coverage, 6), + precision=round(precision, 6), + length_penalty=round(length_penalty, 6), + chunk_count=len(chunks), + top_chunk=top_chunk, + ) + ) + + focused = [program for program in evaluated if _is_focused_program(program, candidate_chars=len(candidate_text))] + if not focused: + return { + "applicable": False, + "reason": "no_focused_program", + "program_scores": { + program.name: { + "score": program.score, + "confidence": program.confidence, + "agreement": program.agreement, + "uncertainty": program.uncertainty, + "chunk_count": program.chunk_count, + } + for program in evaluated + }, + } + + max_score = max(program.score for program in focused) + consistent = [program for program in focused if program.score >= max_score - 0.08] + for program in evaluated: + sig = _program_signature(program.top_chunk) + peers = [] + for other in consistent: + if other is program: + continue + other_sig = _program_signature(other.top_chunk) + if not sig and not other_sig: + peers.append(1.0) + continue + union = len(sig | other_sig) + if union == 0: + peers.append(0.0) + else: + peers.append(len(sig & other_sig) / union) + agreement = sum(peers) / len(peers) if peers else (1.0 if len(consistent) == 1 else 0.0) + program.agreement = round(agreement, 6) + program.uncertainty = round( + min( + 1.0, + (1.0 - agreement) * 0.45 + + (1.0 - program.confidence) * 0.40 + + program.length_penalty * 0.15, + ), + 6, + ) + + selected = min( + consistent, + key=lambda item: ( + round(item.uncertainty, 6), + -round(item.score, 6), + -round(item.agreement, 6), + item.chunk_count, + ), + ) + return { + "applicable": True, + "program": selected.name, + "score": selected.score, + "confidence": selected.confidence, + "agreement": selected.agreement, + "uncertainty": selected.uncertainty, + "coverage": selected.coverage, + "precision": selected.precision, + "chunk_count": selected.chunk_count, + "top_chunk_excerpt": (selected.top_chunk.text[:320] if selected.top_chunk else ""), + "program_scores": { + program.name: { + "score": program.score, + "confidence": program.confidence, + "agreement": program.agreement, + "uncertainty": program.uncertainty, + "chunk_count": program.chunk_count, + } + for program in evaluated + }, + } diff --git a/src/agentmemory/retrieval/mlp_reranker.py b/src/agentmemory/retrieval/mlp_reranker.py new file mode 100644 index 0000000..a23b6c7 --- /dev/null +++ b/src/agentmemory/retrieval/mlp_reranker.py @@ -0,0 +1,129 @@ +"""Tiny MLP reranker inference loaded from a JSON artifact.""" + +from __future__ import annotations + +import json +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +try: # pragma: no cover - numpy is optional at import time + import numpy as _np +except Exception: # pragma: no cover + _np = None + +from agentmemory.retrieval.feature_builder import FEATURE_ORDER_V1, FEATURE_VERSION_V1 + +DEFAULT_MODEL_PATH = Path(__file__).resolve().parent / "models" / "tiny_mlp_v1.json" + + +def _relu(value: float) -> float: + return value if value > 0.0 else 0.0 + + +def _sigmoid(value: float) -> float: + if value >= 0: + z = math.exp(-value) + return 1.0 / (1.0 + z) + z = math.exp(value) + return z / (1.0 + z) + + +@dataclass(slots=True) +class TinyMLPModel: + feature_version: str + feature_order: list[str] + norm_mean: list[float] + norm_std: list[float] + w1: list[list[float]] + b1: list[float] + w2: list[list[float]] + b2: list[float] + w3: list[list[float]] + b3: list[float] + metadata: dict[str, Any] + + @classmethod + def load(cls, path: str | Path | None = None) -> "TinyMLPModel": + model_path = Path(path) if path is not None else DEFAULT_MODEL_PATH + payload = json.loads(model_path.read_text(encoding="utf-8")) + return cls( + feature_version=str(payload["feature_version"]), + feature_order=list(payload["feature_order"]), + norm_mean=[float(v) for v in payload["norm_mean"]], + norm_std=[float(v) for v in payload["norm_std"]], + w1=[[float(v) for v in row] for row in payload["w1"]], + b1=[float(v) for v in payload["b1"]], + w2=[[float(v) for v in row] for row in payload["w2"]], + b2=[float(v) for v in payload["b2"]], + w3=[[float(v) for v in row] for row in payload["w3"]], + b3=[float(v) for v in payload["b3"]], + metadata=dict(payload.get("metadata") or {}), + ) + + @classmethod + def try_load(cls, path: str | Path | None = None) -> "TinyMLPModel | None": + try: + model_path = Path(path) if path is not None else DEFAULT_MODEL_PATH + if not model_path.exists(): + return None + return cls.load(model_path) + except Exception: + return None + + def _normalize(self, feature_matrix): + if _np is not None: + matrix = _np.asarray(feature_matrix, dtype=float) + mean = _np.asarray(self.norm_mean, dtype=float) + std = _np.asarray(self.norm_std, dtype=float) + safe_std = _np.where(std == 0.0, 1.0, std) + return (matrix - mean) / safe_std + rows: list[list[float]] = [] + for row in feature_matrix: + rows.append([ + (float(value) - self.norm_mean[idx]) / (self.norm_std[idx] if self.norm_std[idx] not in (0.0, 0) else 1.0) + for idx, value in enumerate(row) + ]) + return rows + + def score(self, feature_matrix) -> list[float]: + if self.feature_version != FEATURE_VERSION_V1: + raise ValueError(f"Unsupported feature version: {self.feature_version}") + if self.feature_order != FEATURE_ORDER_V1: + raise ValueError("Feature order mismatch between runtime and model artifact") + if _np is not None: + x = self._normalize(feature_matrix) + w1 = _np.asarray(self.w1, dtype=float) + b1 = _np.asarray(self.b1, dtype=float) + w2 = _np.asarray(self.w2, dtype=float) + b2 = _np.asarray(self.b2, dtype=float) + w3 = _np.asarray(self.w3, dtype=float) + b3 = _np.asarray(self.b3, dtype=float) + h1 = _np.maximum(0.0, x @ w1.T + b1) + h2 = _np.maximum(0.0, h1 @ w2.T + b2) + logits = h2 @ w3.T + b3 + logits = _np.clip(logits.reshape(-1), -30.0, 30.0) + probs = 1.0 / (1.0 + _np.exp(-logits)) + return [float(v) for v in probs.tolist()] + + x_rows = self._normalize(feature_matrix) + outputs: list[float] = [] + for row in x_rows: + h1: list[float] = [] + for bias, weights in zip(self.b1, self.w1): + total = bias + for value, weight in zip(row, weights): + total += value * weight + h1.append(_relu(total)) + h2: list[float] = [] + for bias, weights in zip(self.b2, self.w2): + total = bias + for value, weight in zip(h1, weights): + total += value * weight + h2.append(_relu(total)) + total = self.b3[0] if self.b3 else 0.0 + for value, weight in zip(h2, self.w3[0]): + total += value * weight + outputs.append(_sigmoid(total)) + return outputs diff --git a/src/agentmemory/retrieval/query_planner.py b/src/agentmemory/retrieval/query_planner.py new file mode 100644 index 0000000..0892cd1 --- /dev/null +++ b/src/agentmemory/retrieval/query_planner.py @@ -0,0 +1,326 @@ +"""Intent-aware query planning for retrieval orchestration.""" + +from __future__ import annotations + +import re +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Optional + +try: + from intent_classifier import classify_intent as _classify_intent +except Exception: # pragma: no cover - optional script path + _classify_intent = None + +_ENTITY_RE = re.compile(r"\b[A-Z][A-Za-z0-9_.:-]+\b") +_ENTITY_QUERY_RE = re.compile( + r"\b(" + r"who(?:\s+is|\s+owns?)?|" + r"whose|" + r"owner|maintainer|reviewer|assignee|" + r"what\s+does|" + r"prefers?|preference|" + r"role|responsible|" + r"works?\s+on" + r")\b", + re.IGNORECASE, +) +_TEMPORAL_RE = re.compile( + r"\b(" + r"yesterday|today|tomorrow|when|timeline|history|recent|overnight|" + r"last\s+(?:week|month|year|tuesday|wednesday|thursday|friday|saturday|sunday)|" + r"this\s+(?:week|month|year)|" + r"past\s+(?:week|month|year|two weeks|three months)|" + r"most recent|latest|earliest|previous(?:ly)?|current(?:ly)?|" + r"before|after|between|during|in the past|order of|" + r"(?:one|two|three|four|five|six|seven|eight|nine|ten|\d+)\s+" + r"(?:day|week|month|year)s?\s+ago" + r")\b", + re.IGNORECASE, +) +_MULTIHOP_RE = re.compile( + r"\b(" + r"why|because|rationale|support|evidence|rollback|troubleshoot|debug|fix|" + r"how many|how much|order|earliest|latest|most recent|" + r"before|after|between|difference|older|newer|" + r"compare|combined|total|sum|" + r"based on|underlying|future|might|would" + r")\b", + re.IGNORECASE, +) +_COUNT_RE = re.compile( + r"\b(" + r"how many|how much|count|number of|total|sum|combined total" + r")\b", + re.IGNORECASE, +) +_COMPARE_RE = re.compile( + r"\b(" + r"compare|difference|different|versus|vs\.?|better|worse|older|newer|" + r"more than|less than|changed|relative to" + r")\b", + re.IGNORECASE, +) +_ORDER_RE = re.compile( + r"\b(" + r"before|after|between|order|ordered|sequence|timeline|earliest|latest|" + r"first|last|most recent|newest|oldest|rank" + r")\b", + re.IGNORECASE, +) +_UPDATE_RE = re.compile( + r"\b(" + r"current(?:ly)?|previous(?:ly)?|formerly|used to|now|new|updated|" + r"latest|most recent|superseded|stale|still|anymore" + r")\b", + re.IGNORECASE, +) +_COVERAGE_RE = re.compile( + r"\b(" + r"all|both|each|every|across|combined|together|list|which sessions|" + r"what were the sessions|set of" + r")\b", + re.IGNORECASE, +) +_ROLE_FACT_RE = re.compile( + r"\b(" + r"father|dad|mother|mom|parent|coworker|colleague|friend|neighbor|" + r"brother|sister|nephew|niece|aunt|uncle|cousin|boss|manager|supervisor|subordinate|employee|" + r"workplace|occupation|position|job|employer|education|educational|" + r"degree|background|location|hometown|role|hobby|enjoys?|loves?|passion|" + r"email|contact|phone|number|company|living" + r")\b", + re.IGNORECASE, +) +_SYNTHETIC_KV_RE = re.compile( + r"\b(" + r"id|key|code|value|field|role|status|attribute|group|session|step" + r")\b|[A-Za-z]+[_-]\d+|\w+[=:]\w+", + re.IGNORECASE, +) +_NEGATIVE_RE = re.compile( + r"\b(" + r"no answer|" + r"do not know|" + r"unknown|" + r"no memory|" + r"coverage gap|" + r"summary of yesterday(?:'s)? .+|" + r"(?:basketball|baseball|football|soccer|weather|stock market|earnings)\b" + r")", + re.IGNORECASE, +) +_ENTITY_BLACKLIST = {"what", "who", "where", "when", "why", "how", "summary"} + + +@dataclass(slots=True) +class QueryPlan: + normalized_intent: str + answer_type: str + target_entities: list[str] = field(default_factory=list) + temporal_anchors: list[str] = field(default_factory=list) + requires_temporal_reasoning: bool = False + requires_multi_hop: bool = False + needs_counting: bool = False + needs_comparison: bool = False + needs_ordering: bool = False + needs_update_resolution: bool = False + needs_set_coverage: bool = False + needs_role_fact: bool = False + needs_synthetic_key_value: bool = False + prefer_memory_types: list[str] = field(default_factory=list) + candidate_tables: list[str] = field(default_factory=list) + abstain_allowed: bool = False + debug_reasons: list[str] = field(default_factory=list) + classifier_intent: str = "general" + classifier_confidence: float = 0.0 + format_hint: str = "" + + def as_dict(self) -> dict[str, Any]: + return asdict(self) + + +_INTENT_ALIASES = { + "cross_reference": "entity", + "decision_rationale": "decision", + "entity_lookup": "factual", + "event_lookup": "temporal", + "factual_lookup": "factual", + "general": "factual", + "graph_traversal": "graph", + "historical_timeline": "temporal", + "how_to": "procedural", + "orientation": "orientation", + "procedural": "procedural", + "research_concept": "factual", + "task_status": "temporal", + "troubleshooting": "troubleshooting", +} + + +_TABLE_ROUTES = { + "procedural": ["procedures", "memories", "decisions", "events", "context", "policy"], + "troubleshooting": ["procedures", "events", "memories", "decisions", "context", "policy"], + "decision": ["decisions", "memories", "procedures", "events", "context"], + "temporal": ["events", "memories", "context", "entities", "procedures"], + "factual": ["memories", "entities", "decisions", "context", "events", "procedures"], + "graph": ["memories", "events", "context", "decisions", "procedures"], + "orientation": ["memories", "events", "context", "procedures"], +} + + +def _builtin_classify(query: str) -> tuple[str, float, str]: + q = query.lower() + temporalish = bool(_TEMPORAL_RE.search(query)) + multihopish = bool(_MULTIHOP_RE.search(query)) + if _ENTITY_QUERY_RE.search(query): + return ("factual", 0.72, "builtin:entity_fact") + if any(token in q for token in ("how to", "how do", "procedure", "rollback", "runbook", "playbook")): + return ("procedural", 0.82, "builtin:procedural") + if any(token in q for token in ("error", "syntax", "bug", "failed", "fix", "troubleshoot")): + return ("troubleshooting", 0.8, "builtin:troubleshooting") + if any(token in q for token in ("why", "decision", "rationale", "choose", "chose")): + return ("decision", 0.78, "builtin:decision") + if temporalish or "what happened" in q: + reason = "builtin:temporal_multihop" if multihopish else "builtin:temporal" + return ("temporal", 0.8 if multihopish else 0.78, reason) + if any(token in q for token in ("who", "what", "where", "which", "entity")): + return ("factual", 0.6, "builtin:factual") + return ("factual", 0.45, "builtin:default") + + +def _extract_entities(query: str) -> list[str]: + entities = [match.group(0) for match in _ENTITY_RE.finditer(query or "")] + if not entities: + pattern_hits = re.findall( + r"\b(?:what\s+does|who\s+is|who\s+owns|where\s+is|when\s+did)\s+([A-Za-z0-9_.:-]+)", + query or "", + flags=re.IGNORECASE, + ) + entities.extend(pattern_hits) + seen: set[str] = set() + out: list[str] = [] + for entity in entities: + key = entity.lower() + if key in _ENTITY_BLACKLIST: + continue + if key not in seen: + seen.add(key) + out.append(entity) + return out[:8] + + +def plan_query( + query: str, + *, + requested_tables: Optional[list[str]] = None, +) -> QueryPlan: + """Return a structured routing plan for the query.""" + + classifier_intent = "general" + classifier_confidence = 0.0 + format_hint = "" + reasons: list[str] = [] + + if _classify_intent is not None: + try: + result = _classify_intent(query) + classifier_intent = getattr(result, "intent", "general") + classifier_confidence = float(getattr(result, "confidence", 0.0) or 0.0) + format_hint = getattr(result, "format_hint", "") or "" + reasons.append(f"classifier:{classifier_intent}") + except Exception: + pass + + if classifier_intent == "general": + builtin_intent, builtin_conf, reason = _builtin_classify(query) + normalized_intent = builtin_intent + classifier_confidence = max(classifier_confidence, builtin_conf) + reasons.append(reason) + else: + normalized_intent = _INTENT_ALIASES.get(classifier_intent, "factual") + + query_lower = query.lower() + temporal_anchors = [m.group(0) for m in _TEMPORAL_RE.finditer(query)] + answer_type = { + "decision": "rationale", + "procedural": "procedure", + "troubleshooting": "procedure", + "temporal": "history", + "graph": "mixed", + "orientation": "briefing", + }.get(normalized_intent, "fact") + prefer_memory_types = { + "decision": ["semantic", "procedural", "episodic"], + "procedural": ["procedural", "semantic", "episodic"], + "troubleshooting": ["procedural", "episodic", "semantic"], + "temporal": ["episodic", "semantic"], + "factual": ["semantic", "procedural", "episodic"], + "graph": ["semantic", "episodic", "procedural"], + "orientation": ["semantic", "episodic", "procedural"], + }.get(normalized_intent, ["semantic", "episodic"]) + + candidate_tables = list(requested_tables or _TABLE_ROUTES.get(normalized_intent, _TABLE_ROUTES["factual"])) + requires_temporal = bool(_TEMPORAL_RE.search(query)) + requires_multi_hop = bool(_MULTIHOP_RE.search(query)) + needs_counting = bool(_COUNT_RE.search(query)) + needs_comparison = bool(_COMPARE_RE.search(query)) + needs_ordering = bool(_ORDER_RE.search(query)) + needs_update_resolution = bool(_UPDATE_RE.search(query)) + needs_set_coverage = bool(_COVERAGE_RE.search(query)) + needs_role_fact = bool(_ROLE_FACT_RE.search(query)) + needs_synthetic_key_value = bool(_SYNTHETIC_KV_RE.search(query)) + if requires_multi_hop and normalized_intent in {"temporal", "decision", "graph"}: + needs_set_coverage = True + if needs_counting or needs_comparison or needs_ordering: + needs_set_coverage = True + abstain_allowed = bool(_NEGATIVE_RE.search(query)) or normalized_intent in {"factual", "troubleshooting", "procedural"} + if _ENTITY_QUERY_RE.search(query) and normalized_intent == "factual": + reasons.append("entity_or_role_lookup") + if requires_temporal: + reasons.append("temporal_reasoning") + if requires_multi_hop: + reasons.append("multi_hop_or_inference") + if needs_counting: + reasons.append("operator:counting") + if needs_comparison: + reasons.append("operator:comparison") + if needs_ordering: + reasons.append("operator:ordering") + if needs_update_resolution: + reasons.append("operator:update_resolution") + if needs_set_coverage: + reasons.append("operator:set_coverage") + if needs_role_fact: + reasons.append("operator:role_fact") + if needs_synthetic_key_value: + reasons.append("operator:synthetic_key_value") + if "summary of yesterday" in query_lower: + abstain_allowed = True + reasons.append("negative_or_out_of_domain_summary") + if " and " in query_lower and len(_extract_entities(query)) == 0: + reasons.append("ambiguous_composite_query") + abstain_allowed = True + + return QueryPlan( + normalized_intent=normalized_intent, + answer_type=answer_type, + target_entities=_extract_entities(query), + temporal_anchors=temporal_anchors, + requires_temporal_reasoning=requires_temporal, + requires_multi_hop=requires_multi_hop, + needs_counting=needs_counting, + needs_comparison=needs_comparison, + needs_ordering=needs_ordering, + needs_update_resolution=needs_update_resolution, + needs_set_coverage=needs_set_coverage, + needs_role_fact=needs_role_fact, + needs_synthetic_key_value=needs_synthetic_key_value, + prefer_memory_types=prefer_memory_types, + candidate_tables=candidate_tables, + abstain_allowed=abstain_allowed, + debug_reasons=reasons, + classifier_intent=classifier_intent, + classifier_confidence=classifier_confidence, + format_hint=format_hint, + ) diff --git a/src/agentmemory/retrieval/second_stage.py b/src/agentmemory/retrieval/second_stage.py new file mode 100644 index 0000000..c451999 --- /dev/null +++ b/src/agentmemory/retrieval/second_stage.py @@ -0,0 +1,559 @@ +"""Shared second-stage reranking across retrieval buckets.""" + +from __future__ import annotations + +import math +import os +import re +from dataclasses import dataclass, field +from typing import Any + +from agentmemory.retrieval.feature_builder import ( + FEATURE_VERSION_V1, + build_features, + vectorize_features, +) +from agentmemory.retrieval.judge import JudgeConfig, judge_candidates +from agentmemory.retrieval.mlp_reranker import DEFAULT_MODEL_PATH, TinyMLPModel + +_BUCKET_TYPE_MAP = { + "procedures": "procedure", + "memories": "memory", + "events": "event", + "context": "context", + "entities": "entity", + "decisions": "decision", +} +_SESSION_RE = re.compile( + r"(?:^|[|_\s-])(?:sid|session|s)[=_ :#-]*(\d+)|\bsession[_ :#-]*(\d+)\b", + re.IGNORECASE, +) +_DATE_RE = re.compile( + r"\b(?:\d{4}-\d{2}-\d{2}|\d{1,2}/\d{1,2}(?:/\d{2,4})?|" + r"jan(?:uary)?|feb(?:ruary)?|mar(?:ch)?|apr(?:il)?|may|jun(?:e)?|" + r"jul(?:y)?|aug(?:ust)?|sep(?:tember)?|oct(?:ober)?|nov(?:ember)?|" + r"dec(?:ember)?)\b", + re.IGNORECASE, +) +_ENTITY_RE = re.compile(r"\b[A-Z][A-Za-z0-9_.:-]+\b") +_SOURCE_NUM_SUFFIX_RE = re.compile(r"^(.+?)[_-](\d+)$") + + +def _resolve_benchmark_ranking_mode(args: Any) -> str: + mode = str( + getattr(args, "benchmark_ranking_mode", None) + or os.environ.get("BRAINCTL_BENCHMARK_RANKING_MODE", "raw") + or "raw" + ).strip().lower() + return mode if mode in {"full", "raw"} else "raw" + + +def _env_flag(name: str) -> bool: + return str(os.environ.get(name, "")).strip().lower() in {"1", "true", "yes", "on"} + + +@dataclass(slots=True) +class SecondStageCandidate: + bucket: str + original_index: int + row: dict[str, Any] + + +@dataclass(slots=True) +class SecondStageConfig: + enabled: bool = True + top_n: int = 10 + heuristic_weight: float = 0.62 + mlp_weight: float = 0.28 + judge_weight: float = 0.10 + model_path: str | None = None + model_enabled: bool = True + ranking_mode: str = "live" + judge: JudgeConfig = field(default_factory=JudgeConfig) + + @classmethod + def from_args(cls, args: Any) -> "SecondStageConfig": + benchmark = bool(getattr(args, "benchmark", False)) + ranking_mode = _resolve_benchmark_ranking_mode(args) if benchmark else "live" + requested = bool(getattr(args, "second_stage", False)) or _env_flag("BRAINCTL_SECOND_STAGE") + judge_enabled = bool(getattr(args, "judge_rerank", None)) + judge_provider = str(getattr(args, "judge_rerank", "ollama") or "ollama") + judge_model = str(getattr(args, "judge_model", "llama3.2:3b") or "llama3.2:3b") + top_n = getattr(args, "second_stage_top_n", None) + if top_n is None: + try: + top_n = int(os.environ.get("BRAINCTL_SECOND_STAGE_TOP_N", "10")) + except (TypeError, ValueError): + top_n = 10 + return cls( + enabled=requested and not bool(getattr(args, "no_second_stage", False)) and not (benchmark and ranking_mode == "raw"), + top_n=max(int(top_n or 10), 1), + model_enabled=not bool(getattr(args, "no_second_stage_model", False)), + model_path=getattr(args, "second_stage_model_path", None), + ranking_mode=ranking_mode, + judge=JudgeConfig( + enabled=judge_enabled, + provider=judge_provider, + model=judge_model, + top_k=max(min(int(getattr(args, "judge_top_k", 5) or 5), 5), 1), + ), + ) + + +def _heuristic_score(plan: Any, features: dict[str, float]) -> float: + intent = str(getattr(plan, "normalized_intent", "factual") or "factual") + score = ( + features["base_score"] * 0.24 + + features["informative_overlap"] * 0.23 + + features["tfidf_cosine"] * 0.20 + + features["query_overlap"] * 0.07 + + features["intent_bucket_fit"] * 0.08 + + features["entity_overlap"] * 0.06 + + features["alias_overlap"] * 0.04 + + features["exact_phrase"] * 0.05 + + features["support_evidence_score"] * 0.03 + ) + long_context_reliable = ( + features.get("long_context_applicable", 0.0) > 0.0 + and features.get("long_context_focused_program", 0.0) > 0.0 + and features.get("long_context_confidence", 0.0) >= 0.62 + and features.get("long_context_uncertainty", 0.0) <= 0.38 + ) + if long_context_reliable: + score += ( + features.get("long_context_score", 0.0) * 0.09 + + features.get("long_context_confidence", 0.0) * 0.03 + + features.get("long_context_agreement", 0.0) * 0.02 + + features.get("long_context_coverage", 0.0) * 0.03 + + features.get("long_context_precision", 0.0) * 0.02 + ) + if features["query_temporal"] > 0: + score += ( + features["candidate_temporal"] * 0.04 + + features["temporal_anchor_overlap"] * 0.08 + + features["session_gap_score"] * 0.06 + ) + if long_context_reliable: + score += features.get("long_context_score", 0.0) * 0.05 + if features.get("query_needs_ordering", 0.0) > 0.0: + score += features["temporal_anchor_overlap"] * 0.05 + features["session_gap_score"] * 0.05 + if features.get("query_needs_update_resolution", 0.0) > 0.0: + score += features["status_active"] * 0.04 + if features.get("query_needs_role_fact", 0.0) > 0.0: + score += ( + features.get("role_overlap", 0.0) * 0.11 + + features.get("attribute_overlap", 0.0) * 0.10 + + features.get("role_value_pattern", 0.0) * 0.08 + + features.get("exact_phrase", 0.0) * 0.03 + ) + if features.get("query_needs_synthetic_key_value", 0.0) > 0.0: + score += features["source_keyword"] * 0.04 + features.get("attribute_overlap", 0.0) * 0.05 + if intent in {"temporal", "decision"}: + score += features["bucket_events"] * 0.04 + features["bucket_decisions"] * 0.03 + if intent in {"procedural", "troubleshooting"}: + score += features["bucket_procedures"] * 0.06 + features["procedural_candidate"] * 0.04 + if long_context_reliable: + score += features.get("long_context_confidence", 0.0) * 0.04 + if intent == "factual": + score += features["bucket_memories"] * 0.05 + features["bucket_entities"] * 0.04 + score -= features["bucket_procedures"] * 0.04 + if long_context_reliable: + score += features.get("long_context_precision", 0.0) * 0.04 + if features["source_graph"] > 0: + score -= 0.08 + if features["status_stale"] > 0: + score -= 0.12 + if features["status_needs_review"] > 0: + score -= 0.08 + return max(min(score, 1.0), 0.0) + + +def _candidate_text(candidate: dict[str, Any]) -> str: + for key in ("content", "summary", "title", "goal", "description", "name", "search_text"): + value = candidate.get(key) + if value: + return str(value) + return "" + + +def _candidate_source_family(candidate: dict[str, Any]) -> str: + raw = ( + candidate.get("doc_id") + or candidate.get("source_doc_id") + or candidate.get("source_key") + or candidate.get("external_id") + or "" + ) + head = str(raw).split("|", 1)[0] + match = _SOURCE_NUM_SUFFIX_RE.match(head) + return match.group(1) if match else head + + +def _candidate_cluster_keys(plan: Any, candidate: dict[str, Any]) -> set[str]: + text = _candidate_text(candidate) + keys: set[str] = set() + family = _candidate_source_family(candidate) + if family: + keys.add(f"family:{family}") + for match in _SESSION_RE.finditer(text): + keys.add(f"session:{match.group(1) or match.group(2)}") + if getattr(plan, "requires_temporal_reasoning", False) or getattr(plan, "needs_ordering", False): + for match in _DATE_RE.finditer(text): + keys.add(f"date:{match.group(0).lower()}") + target_entities = { + str(value).lower() + for value in (getattr(plan, "target_entities", None) or []) + if value + } + if target_entities: + lowered = text.lower() + for entity in target_entities: + if entity and entity in lowered: + keys.add(f"entity:{entity}") + observed_entities = { + match.group(0).lower() + for match in _ENTITY_RE.finditer(text) + if len(match.group(0)) > 2 + } + for entity in sorted(observed_entities)[:3]: + keys.add(f"obs:{entity}") + if not keys: + ident = candidate.get("id") + keys.add(f"row:{candidate.get('bucket')}:{ident}") + return keys + + +def _slate_score( + *, + plan: Any, + candidate: dict[str, Any], + features: dict[str, float], + composite_score: float, + rank_index: int, + selected_keys: set[str], +) -> tuple[float, dict[str, float]]: + rank_discount = 1.0 / math.log2(rank_index + 2) + cluster_keys = _candidate_cluster_keys(plan, candidate) + new_keys = cluster_keys - selected_keys + coverage_bonus = 0.0 + redundancy_penalty = 0.0 + update_penalty = 0.0 + temporal_penalty = 0.0 + localization_bonus = 0.0 + + if getattr(plan, "needs_set_coverage", False): + coverage_bonus += min(0.20, 0.05 * len(new_keys)) + if not new_keys and selected_keys: + redundancy_penalty += 0.11 + elif selected_keys and not new_keys: + redundancy_penalty += 0.03 + + if getattr(plan, "needs_update_resolution", False): + if features.get("status_stale", 0.0) > 0.0: + update_penalty += 0.08 + if features.get("status_needs_review", 0.0) > 0.0: + update_penalty += 0.05 + if features.get("status_active", 0.0) > 0.0: + coverage_bonus += 0.02 + + if getattr(plan, "requires_temporal_reasoning", False) or getattr(plan, "needs_ordering", False): + if features.get("candidate_temporal", 0.0) <= 0.0 and features.get("temporal_anchor_overlap", 0.0) <= 0.0: + temporal_penalty += 0.05 + else: + coverage_bonus += features.get("temporal_anchor_overlap", 0.0) * 0.03 + + if getattr(plan, "needs_role_fact", False): + coverage_bonus += features.get("role_overlap", 0.0) * 0.04 + coverage_bonus += features.get("attribute_overlap", 0.0) * 0.04 + coverage_bonus += features.get("role_value_pattern", 0.0) * 0.03 + + if features.get("long_context_focused_program", 0.0) > 0.0: + localization_bonus += ( + features.get("long_context_precision", 0.0) * 0.018 + + features.get("long_context_coverage", 0.0) * 0.014 + ) + + slate_adjustment = (coverage_bonus + localization_bonus - redundancy_penalty - update_penalty - temporal_penalty) * rank_discount + return ( + composite_score + slate_adjustment, + { + "coverage_bonus": round(coverage_bonus, 6), + "localization_bonus": round(localization_bonus, 6), + "redundancy_penalty": round(redundancy_penalty, 6), + "update_penalty": round(update_penalty, 6), + "temporal_penalty": round(temporal_penalty, 6), + "rank_discount": round(rank_discount, 6), + "new_key_count": float(len(new_keys)), + }, + ) + + +def _rerank_slate( + *, + plan: Any, + head: list[dict[str, Any]], + feature_rows: list[dict[str, float]], + heuristic_scores: list[float], + mlp_scores: list[float], + judge_scores: list[float], + cfg: SecondStageConfig, +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + base_weight = max(0.0, 1.0 - cfg.heuristic_weight - cfg.mlp_weight - cfg.judge_weight) + pool: list[dict[str, Any]] = [] + debug_candidates: list[dict[str, Any]] = [] + for candidate, features, heuristic_score, mlp_score, judge_score in zip( + head, + feature_rows, + heuristic_scores, + mlp_scores, + judge_scores, + ): + pre_score = float(candidate.get("final_score") or candidate.get("retrieval_score") or 0.0) + composite_score = ( + pre_score * base_weight + + heuristic_score * cfg.heuristic_weight + + float(mlp_score) * cfg.mlp_weight + + float(judge_score) * cfg.judge_weight + ) + candidate["pre_second_stage_score"] = round(pre_score, 8) + candidate["second_stage_heuristic"] = round(heuristic_score, 6) + candidate["second_stage_mlp"] = round(float(mlp_score), 6) + candidate["second_stage_judge"] = round(float(judge_score), 6) + candidate["second_stage_features"] = { + key: features.get(key) + for key in ( + "informative_overlap", + "tfidf_cosine", + "entity_overlap", + "temporal_anchor_overlap", + "intent_bucket_fit", + "session_gap_score", + "query_needs_counting", + "query_needs_comparison", + "query_needs_ordering", + "query_needs_update_resolution", + "query_needs_set_coverage", + "query_needs_role_fact", + "query_needs_synthetic_key_value", + "role_overlap", + "attribute_overlap", + "role_value_pattern", + "long_context_score", + "long_context_confidence", + "long_context_agreement", + "long_context_uncertainty", + "long_context_focused_program", + ) + } + long_context_debug = candidate.pop("_long_context_debug", None) or {} + if long_context_debug.get("applicable"): + candidate["second_stage_features"]["long_context_program"] = long_context_debug.get("program") + candidate["second_stage_features"]["long_context_excerpt"] = long_context_debug.get("top_chunk_excerpt") + pool.append( + { + "candidate": candidate, + "features": features, + "composite_score": round(composite_score, 8), + "cluster_keys": _candidate_cluster_keys(plan, candidate), + } + ) + + selected: list[dict[str, Any]] = [] + selected_keys: set[str] = set() + rank_index = 0 + while pool: + best_idx = 0 + best_score = None + best_terms: dict[str, float] | None = None + for idx, item in enumerate(pool): + slate_score, terms = _slate_score( + plan=plan, + candidate=item["candidate"], + features=item["features"], + composite_score=float(item["composite_score"]), + rank_index=rank_index, + selected_keys=selected_keys, + ) + if best_score is None or slate_score > best_score: + best_idx = idx + best_score = slate_score + best_terms = terms + item = pool.pop(best_idx) + candidate = item["candidate"] + terms = best_terms or {} + candidate["second_stage_slate_score"] = round(float(best_score or 0.0), 6) + candidate["second_stage_slate_terms"] = terms + selected.append(candidate) + selected_keys.update(item["cluster_keys"]) + rank_index += 1 + + debug_candidates = [] + for index, candidate in enumerate(selected, start=1): + epsilon = max(len(selected) - index, 0) * 1e-6 + candidate["final_score"] = round(float(candidate.get("second_stage_slate_score") or 0.0) + epsilon, 8) + debug_candidates.append( + { + "bucket": candidate.get("bucket"), + "id": candidate.get("id"), + "pre_score": round(float(candidate.get("pre_second_stage_score") or 0.0), 6), + "heuristic": round(float(candidate.get("second_stage_heuristic") or 0.0), 6), + "mlp": round(float(candidate.get("second_stage_mlp") or 0.0), 6), + "judge": round(float(candidate.get("second_stage_judge") or 0.0), 6), + "composite": round(float(candidate.get("second_stage_slate_score") or 0.0), 6), + "selection_rank": index, + "slate_terms": candidate.get("second_stage_slate_terms") or {}, + "features": candidate.get("second_stage_features") or {}, + } + ) + return selected, debug_candidates + + +def rerank_top_candidates( + query: str, + plan: Any, + candidates: list[dict[str, Any]], + config: SecondStageConfig | None = None, +) -> tuple[list[dict[str, Any]], dict[str, Any]]: + """Rerank a flat candidate list using heuristic + tiny MLP + optional judge.""" + + cfg = config or SecondStageConfig() + if not cfg.enabled or not candidates: + return candidates, {"enabled": False} + + head = [dict(candidate) for candidate in candidates[: cfg.top_n]] + tail = [dict(candidate) for candidate in candidates[cfg.top_n :]] + hard_query = any( + bool(getattr(plan, attr, False)) + for attr in ( + "requires_temporal_reasoning", + "requires_multi_hop", + "needs_counting", + "needs_comparison", + "needs_ordering", + "needs_update_resolution", + "needs_set_coverage", + "needs_role_fact", + "needs_synthetic_key_value", + ) + ) + raw_head_scores = [ + float(candidate.get("final_score") or candidate.get("retrieval_score") or 0.0) + for candidate in head[:2] + ] + top_margin = abs(raw_head_scores[0] - raw_head_scores[1]) if len(raw_head_scores) >= 2 else 1.0 + if not hard_query and top_margin >= 0.08: + passthrough = [dict(candidate) for candidate in candidates] + for candidate in passthrough[: cfg.top_n]: + pre_score = float(candidate.get("final_score") or candidate.get("retrieval_score") or 0.0) + candidate.setdefault("pre_second_stage_score", round(pre_score, 8)) + return passthrough, { + "enabled": True, + "top_n": cfg.top_n, + "ranking_mode": cfg.ranking_mode, + "model_enabled": cfg.model_enabled, + "model_path": str(cfg.model_path or DEFAULT_MODEL_PATH), + "model_loaded": False, + "judge_enabled": cfg.judge.enabled, + "strategy": "passthrough_easy_query", + "top_margin": round(top_margin, 6), + "candidates": [], + } + for idx, candidate in enumerate(head): + candidate["_stage_position"] = idx + candidate.setdefault("bucket", candidate.get("type") or "memories") + + feature_rows: list[dict[str, float]] = [] + leader_score = head[0].get("final_score") if head else None + for idx, candidate in enumerate(head): + prev_score = head[idx - 1].get("final_score") if idx > 0 else None + next_score = head[idx + 1].get("final_score") if idx + 1 < len(head) else None + features = build_features( + query, + plan, + candidate, + neighbors={"prev_score": prev_score, "next_score": next_score, "leader_score": leader_score}, + ) + feature_rows.append(features) + + heuristic_scores = [_heuristic_score(plan, features) for features in feature_rows] + + model = TinyMLPModel.try_load(cfg.model_path or DEFAULT_MODEL_PATH) if cfg.model_enabled else None + if model is not None: + feature_matrix = [vectorize_features(features, feature_version=FEATURE_VERSION_V1) for features in feature_rows] + mlp_scores = model.score(feature_matrix) + else: + mlp_scores = [0.0] * len(head) + + judge_scores = judge_candidates(query, head, cfg.judge) + if judge_scores and len(judge_scores) < len(head): + judge_scores = list(judge_scores) + [0.0] * (len(head) - len(judge_scores)) + elif not judge_scores: + judge_scores = [0.0] * len(head) + + head, debug_candidates = _rerank_slate( + plan=plan, + head=head, + feature_rows=feature_rows, + heuristic_scores=heuristic_scores, + mlp_scores=mlp_scores, + judge_scores=judge_scores, + cfg=cfg, + ) + reranked = head + tail + debug = { + "enabled": True, + "top_n": cfg.top_n, + "ranking_mode": cfg.ranking_mode, + "model_enabled": cfg.model_enabled, + "model_path": str(cfg.model_path or DEFAULT_MODEL_PATH), + "model_loaded": model is not None, + "judge_enabled": cfg.judge.enabled, + "base_weight": round(max(0.0, 1.0 - cfg.heuristic_weight - cfg.mlp_weight - cfg.judge_weight), 4), + "mlp_weight": round(cfg.mlp_weight, 4), + "judge_weight": round(cfg.judge_weight, 4), + "strategy": "listwise_greedy_slate", + "candidates": debug_candidates, + } + return reranked, debug + + +def rerank_bucketed_results( + query: str, + plan: Any, + buckets: dict[str, list[dict[str, Any]]], + config: SecondStageConfig | None = None, +) -> tuple[dict[str, list[dict[str, Any]]], dict[str, Any]]: + """Apply second-stage reranking to the combined head across all buckets.""" + + cfg = config or SecondStageConfig() + if not cfg.enabled: + return buckets, {"enabled": False} + + ordered: list[SecondStageCandidate] = [] + for bucket_name in ("procedures", "memories", "events", "context", "entities", "decisions"): + rows = buckets.get(bucket_name) or [] + for idx, row in enumerate(rows): + candidate = dict(row) + candidate["bucket"] = bucket_name + candidate["type"] = str(candidate.get("type") or _BUCKET_TYPE_MAP.get(bucket_name, bucket_name)) + ordered.append(SecondStageCandidate(bucket_name, idx, candidate)) + ordered.sort(key=lambda item: item.row.get("final_score", 0.0), reverse=True) + + reranked_rows, debug = rerank_top_candidates( + query, + plan, + [item.row for item in ordered], + config=cfg, + ) + scored: dict[tuple[str, Any], dict[str, Any]] = {} + for row in reranked_rows: + scored[(str(row.get("bucket") or "memories"), row.get("id"))] = row + + updated: dict[str, list[dict[str, Any]]] = {name: [] for name in buckets} + for bucket_name, rows in buckets.items(): + updated_rows: list[dict[str, Any]] = [] + for row in rows or []: + updated_rows.append(scored.get((bucket_name, row.get("id")), row)) + updated_rows.sort(key=lambda item: item.get("final_score", 0.0), reverse=True) + updated[bucket_name] = updated_rows + return updated, debug diff --git a/tests/test_long_context_explorer.py b/tests/test_long_context_explorer.py new file mode 100644 index 0000000..343c693 --- /dev/null +++ b/tests/test_long_context_explorer.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +from pathlib import Path + +from agentmemory.retrieval.feature_builder import build_features +from agentmemory.retrieval.long_context import analyze_long_context +from agentmemory.retrieval.query_planner import plan_query +from agentmemory.retrieval.second_stage import SecondStageConfig, rerank_top_candidates + +from tests.test_second_stage_reranker import _temp_model + + +def test_long_context_probe_finds_session_anchor(monkeypatch): + monkeypatch.setenv("BRAINCTL_LONG_CONTEXT_PROBES", "1") + plan = plan_query("When did Caroline go to the LGBTQ support group?", requested_tables=["memories"]) + candidate = { + "id": 1, + "bucket": "memories", + "type": "memory", + "final_score": 0.72, + "retrieval_score": 0.72, + "source": "keyword", + } + text = "\n".join( + [ + "Session ID: session_1", + "Session Date: 2025-01-12", + "Conversation:", + 'Alice: We talked about cooking classes and weekend plans.', + 'Bob: Nothing else noteworthy happened this week.', + 'Caroline: I went to the LGBTQ support group after work and felt better.', + 'Alice: We also mentioned a grocery list and cleaning supplies.', + ] + ) + + result = analyze_long_context( + "When did Caroline go to the LGBTQ support group?", + plan, + candidate, + text=text, + ) + + assert result["applicable"] is True + assert result["score"] > 0.55 + assert result["confidence"] > 0.45 + assert result["uncertainty"] < 0.7 + assert "LGBTQ support group" in result["top_chunk_excerpt"] + + +def test_second_stage_uses_long_context_probe_to_promote_focused_session(tmp_path: Path, monkeypatch): + monkeypatch.setenv("BRAINCTL_LONG_CONTEXT_PROBES", "1") + plan = plan_query("When did Caroline go to the LGBTQ support group?", requested_tables=["memories"]) + model_path = _temp_model(tmp_path / "tiny.json") + diffuse = { + "id": 1, + "bucket": "memories", + "type": "memory", + "content": "\n".join( + [ + "Session ID: session_7", + "Session Date: 2025-01-20", + "Conversation:", + 'Alice: Caroline mentioned some errands after work.', + 'Bob: She later mentioned a support group but I do not remember when.', + 'Alice: Then we switched topics to a restaurant review and sprint planning.', + 'Bob: We also talked about a support group again in passing.', + 'Alice: Nothing pinned the exact date.', + ] + ), + "final_score": 0.789, + "retrieval_score": 0.789, + "source": "both", + "confidence": 0.9, + } + focused = { + "id": 2, + "bucket": "memories", + "type": "memory", + "content": "\n".join( + [ + "Session ID: session_1", + "Session Date: 2025-01-12", + "Conversation:", + "Alice: We opened with a grocery list and a reminder about dry cleaning.", + "Bob: Then we talked about a dentist appointment and an office lunch.", + 'Caroline: I went to the LGBTQ support group after work on January 12.', + 'Alice: We noted it in the session log for follow-up.', + "Bob: After that we switched to weekend errands and recipe planning.", + "Alice: We ended with notes about commute timing and a restaurant reservation.", + ] + ), + "final_score": 0.776, + "retrieval_score": 0.776, + "source": "keyword", + "confidence": 0.9, + } + + reranked, debug = rerank_top_candidates( + "When did Caroline go to the LGBTQ support group?", + plan, + [diffuse, focused], + config=SecondStageConfig(top_n=2, model_path=str(model_path)), + ) + + assert reranked[0]["id"] == 2 + assert reranked[0]["second_stage_features"]["long_context_score"] > reranked[1]["second_stage_features"]["long_context_score"] + assert debug["enabled"] is True + + +def test_query_planner_flags_temporal_aggregation_and_inference(): + temporal_multi = plan_query("How much have I made from selling eggs this month?", requested_tables=["memories"]) + assert temporal_multi.requires_temporal_reasoning is True + assert temporal_multi.requires_multi_hop is True + assert temporal_multi.normalized_intent == "temporal" + + inference = plan_query("What personality traits might Melanie say Caroline has?", requested_tables=["memories"]) + assert inference.requires_multi_hop is True + + +def test_long_context_probe_requires_close_temporal_candidates(monkeypatch): + monkeypatch.setenv("BRAINCTL_LONG_CONTEXT_PROBES", "1") + plan = plan_query("When did Caroline go to the LGBTQ support group?", requested_tables=["memories"]) + candidate = { + "id": 7, + "bucket": "memories", + "type": "memory", + "content": "\n".join( + [ + "Session ID: session_9", + "Session Date: 2025-01-19", + "Conversation:", + "Alice: We discussed errands and a support group in passing.", + "Caroline: I went to the LGBTQ support group after work on January 12.", + "Bob: We wrote it down in the follow-up notes.", + "Alice: Then we switched to grocery planning and restaurants.", + "Bob: We revisited the support group briefly before closing the session.", + ] + ), + "final_score": 0.91, + "retrieval_score": 0.91, + "source": "keyword", + } + + far_apart = build_features( + "When did Caroline go to the LGBTQ support group?", + plan, + dict(candidate), + neighbors={"prev_score": None, "next_score": 0.76, "leader_score": 0.91}, + ) + assert far_apart["long_context_applicable"] == 0.0 + + close_scores = build_features( + "When did Caroline go to the LGBTQ support group?", + plan, + dict(candidate), + neighbors={"prev_score": None, "next_score": 0.889, "leader_score": 0.91}, + ) + assert close_scores["long_context_applicable"] == 1.0 + assert close_scores["long_context_focused_program"] == 1.0 + + +def test_long_context_probe_ignores_whole_document_only_matches(monkeypatch): + monkeypatch.setenv("BRAINCTL_LONG_CONTEXT_PROBES", "1") + plan = plan_query("When did Caroline go to the LGBTQ support group?", requested_tables=["memories"]) + candidate = { + "id": 11, + "bucket": "memories", + "type": "memory", + "final_score": 0.72, + "retrieval_score": 0.72, + "source": "keyword", + } + text = ( + "Session ID: session_1 Session Date: 2025-01-12 Conversation " + + "Caroline went to the LGBTQ support group after work on January 12 and we kept discussing it in the same paragraph without line breaks or sentence boundaries " * 24 + ) + + result = analyze_long_context( + "When did Caroline go to the LGBTQ support group?", + plan, + candidate, + text=text, + ) + + assert result["applicable"] is False + assert result["reason"] == "no_focused_program" diff --git a/tests/test_reranker_robustness.py b/tests/test_reranker_robustness.py index 84a16df..163d7a4 100644 --- a/tests/test_reranker_robustness.py +++ b/tests/test_reranker_robustness.py @@ -209,6 +209,8 @@ def _build_args(query: str, limit: int = 10, **overrides) -> types.SimpleNamespa pagerank_boost=0.0, quantum=False, benchmark=False, + benchmark_ranking_mode="full", + second_stage=False, agent="robustness-agent", output="json", format="json", @@ -426,24 +428,28 @@ def db(self, tmp_path): _seed_locomo_shape(db_path, n=50) return db_path - def test_benchmark_skips_three_rerankers(self, db): + def test_benchmark_full_mode_keeps_second_stage_opt_in(self, db): args = _build_args("alice prefers dark mode", benchmark=True) out = _call_cmd_search(db, args) debug = out.get("_debug", {}) + assert debug.get("benchmark.ranking_mode") == "full" + assert debug.get("second_stage", {}).get("enabled") is False assert debug.get("memories.recency_skipped") == "benchmark_mode" - assert debug.get("memories.salience_skipped") == "benchmark_mode" assert debug.get("memories.qvalue_skipped") == "benchmark_mode" - def test_benchmark_preserves_trust(self, db): - """Spec: trust reranker is preserved under --benchmark (different - signal class — provenance, not stale-data). Even on a uniform-trust - corpus the trust skip reason must NOT show up under benchmark.""" + args = _build_args("alice prefers dark mode", benchmark=True, second_stage=True) + out = _call_cmd_search(db, args) + debug = out.get("_debug", {}) + assert debug.get("benchmark.ranking_mode") == "full" + assert debug.get("second_stage", {}).get("enabled") is True + assert debug.get("memories.recency_skipped") == "benchmark_mode" + assert debug.get("memories.qvalue_skipped") == "benchmark_mode" + + def test_benchmark_full_mode_uses_normal_trust_gate(self, db): args = _build_args("alice prefers dark mode", benchmark=True) out = _call_cmd_search(db, args) debug = out.get("_debug", {}) - assert "memories.trust_skipped" not in debug, ( - f"trust must be preserved under --benchmark; debug={debug}" - ) + assert "memories.trust_skipped" not in debug, debug def test_benchmark_emits_stderr_note(self, db): # Capture the stderr message. @@ -460,7 +466,7 @@ def _capture(data, compact=False): with contextlib.redirect_stderr(buf_err): _impl.cmd_search(args) assert "--benchmark" in buf_err.getvalue() - assert "raw FTS+vec ranking" in buf_err.getvalue() + assert "stable-eval mode" in buf_err.getvalue() finally: _impl.json_out = saved_json @@ -500,7 +506,16 @@ def test_benchmark_cli_flag_end_to_end(self, db, tmp_path): # Parse the JSON payload off stdout. payload = json.loads(result.stdout) debug = payload.get("_debug", {}) + assert debug.get("benchmark.ranking_mode") == "raw" + + def test_benchmark_raw_mode_preserves_legacy_ablation(self, db): + args = _build_args("alice prefers dark mode", benchmark=True, benchmark_ranking_mode="raw") + out = _call_cmd_search(db, args) + debug = out.get("_debug", {}) + assert debug.get("benchmark.ranking_mode") == "raw" assert debug.get("memories.recency_skipped") == "benchmark_mode" + assert debug.get("memories.salience_skipped") == "benchmark_mode" + assert debug.get("memories.qvalue_skipped") == "benchmark_mode" # --------------------------------------------------------------------------- @@ -550,6 +565,86 @@ def test_query_top1_still_relevant(self, bench_db, query, must_contain): f"got: {top_text[:120]!r}" ) + def test_entity_bucket_populated_for_entity_query(self, bench_db): + args = _build_args( + "Who owns the consolidation daemon?", + agent="bench-agent", + tables="memories,events,context,entities,decisions,procedures", + benchmark=True, + ) + out = _call_cmd_search(bench_db, args) + assert out.get("entities"), "entity query should populate entities bucket" + assert out["entities"][0]["name"] == "Bob" + + def test_negative_out_of_domain_query_abstains(self, bench_db): + args = _build_args( + "Summary of yesterday's basketball game", + agent="bench-agent", + tables="memories,events,context,entities,decisions,procedures", + benchmark=True, + ) + out = _call_cmd_search(bench_db, args) + assert out.get("metacognition", {}).get("abstained") is True + for bucket in ("memories", "events", "context", "entities", "decisions", "procedures"): + assert not out.get(bucket), f"{bucket} should be empty after abstention" + + +def test_entity_alias_expansion_promotes_canonical_memory(tmp_path): + db_path = tmp_path / "alias-linking.db" + _seed_schema(db_path) + now = _utc_iso() + conn = sqlite3.connect(str(db_path)) + try: + conn.execute( + """ + INSERT INTO memories ( + agent_id, category, scope, content, confidence, + created_at, updated_at + ) VALUES (?, 'preference', 'global', ?, 0.9, ?, ?) + """, + ("robustness-agent", "Bob prefers four-space indentation for Python code.", now, now), + ) + conn.execute( + """ + INSERT INTO entities ( + name, entity_type, properties, observations, agent_id, confidence, + scope, created_at, updated_at, aliases, compiled_truth + ) VALUES (?, 'person', '{}', ?, ?, 0.95, 'global', ?, ?, ?, ?) + """, + ( + "Bob", + json.dumps(["Prefers four-space indentation"], ensure_ascii=True), + "robustness-agent", + now, + now, + json.dumps(["Robert"], ensure_ascii=True), + "Bob prefers four-space indentation.", + ), + ) + conn.commit() + finally: + conn.close() + + args = _build_args( + "What does Robert prefer?", + agent="robustness-agent", + tables="memories,entities", + benchmark=True, + ) + out = _call_cmd_search(db_path, args) + flat = [] + for bucket in ("entities", "memories"): + flat.extend(out.get(bucket, []) or []) + flat.sort(key=lambda row: row.get("final_score", 0.0), reverse=True) + assert flat, "alias-linked query should return at least one result" + top_text = ( + flat[0].get("content") + or flat[0].get("name") + or flat[0].get("summary") + or "" + ).lower() + assert "bob" in top_text, top_text + # --------------------------------------------------------------------------- # 6. Trust adjustment math diff --git a/tests/test_second_stage_reranker.py b/tests/test_second_stage_reranker.py new file mode 100644 index 0000000..3ac6759 --- /dev/null +++ b/tests/test_second_stage_reranker.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +import json +from pathlib import Path +from types import SimpleNamespace + +from agentmemory.retrieval.judge import JudgeConfig, judge_candidates +from agentmemory.retrieval.mlp_reranker import TinyMLPModel +from agentmemory.retrieval.query_planner import plan_query +from agentmemory.retrieval.second_stage import SecondStageConfig, rerank_bucketed_results, rerank_top_candidates + + +def _temp_model(path: Path) -> Path: + payload = { + "feature_version": "v1", + "feature_order": [ + "base_score", "retrieval_score", "rrf_score", "confidence", "query_overlap", + "informative_overlap", "tfidf_cosine", "exact_phrase", "entity_overlap", + "alias_overlap", "query_temporal", "candidate_temporal", "temporal_anchor_overlap", + "query_session_hint", "candidate_session_hint", "session_gap_score", "intent_bucket_fit", + "source_keyword", "source_semantic", "source_both", "source_graph", "bucket_memories", + "bucket_events", "bucket_entities", "bucket_procedures", "bucket_decisions", + "candidate_age_score", "support_evidence_score", "status_active", "status_stale", + "status_needs_review", "position_score", "neighbor_margin", "query_length_score", + "candidate_length_score", "procedural_candidate", + ], + "norm_mean": [0.0] * 36, + "norm_std": [1.0] * 36, + "w1": [[0.0] * 36 for _ in range(32)], + "b1": [0.0] * 32, + "w2": [[0.0] * 32 for _ in range(16)], + "b2": [0.0] * 16, + "w3": [[0.0] * 16], + "b3": [0.0], + "metadata": {"test": True}, + } + # Make one hidden path look at informative overlap and cosine similarity. + payload["w1"][0][5] = 1.2 + payload["w1"][0][6] = 1.2 + payload["w2"][0][0] = 1.0 + payload["w3"][0][0] = 1.0 + path.write_text(json.dumps(payload), encoding="utf-8") + return path + + +def test_second_stage_from_args_is_opt_in_by_default(): + cfg = SecondStageConfig.from_args(SimpleNamespace(benchmark=False)) + assert cfg.enabled is False + + cfg = SecondStageConfig.from_args(SimpleNamespace(benchmark=False, second_stage=True)) + assert cfg.enabled is True + + cfg = SecondStageConfig.from_args( + SimpleNamespace(benchmark=True, benchmark_ranking_mode="raw", second_stage=True) + ) + assert cfg.enabled is False + + +def test_tiny_mlp_load_and_score(tmp_path: Path): + model_path = _temp_model(tmp_path / "tiny.json") + model = TinyMLPModel.load(model_path) + scores = model.score( + [ + [0.0, 0.0, 0.0, 0.0, 0.2, 0.9, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0] * 36, + ] + ) + assert len(scores) == 2 + assert scores[0] > scores[1] + + +def test_second_stage_promotes_exact_match(tmp_path: Path): + plan = plan_query("When did Caroline go to the LGBTQ support group?", requested_tables=["memories"]) + model_path = _temp_model(tmp_path / "tiny.json") + candidates = [ + { + "id": 1, + "bucket": "memories", + "type": "memory", + "content": "Caroline mentioned a cooking class during session_7.", + "final_score": 0.85, + "retrieval_score": 0.85, + "source": "both", + "confidence": 0.9, + }, + { + "id": 2, + "bucket": "memories", + "type": "memory", + "content": "Session ID: session_1\nCaroline went to the LGBTQ support group on January 12.", + "final_score": 0.78, + "retrieval_score": 0.78, + "source": "keyword", + "confidence": 0.9, + }, + ] + reranked, debug = rerank_top_candidates( + "When did Caroline go to the LGBTQ support group?", + plan, + candidates, + config=SecondStageConfig(top_n=2, model_path=str(model_path)), + ) + assert reranked[0]["id"] == 2 + assert reranked[0]["pre_second_stage_score"] == 0.78 + assert debug["enabled"] is True + assert debug["model_loaded"] is True + + +def test_bucketed_rerank_preserves_bucket_membership(tmp_path: Path): + plan = plan_query("How do I roll back a bad release?", requested_tables=["procedures", "memories"]) + model_path = _temp_model(tmp_path / "tiny.json") + buckets = { + "procedures": [ + { + "id": 9, + "title": "Rollback release", + "goal": "Restore service after a bad release", + "final_score": 0.74, + "retrieval_score": 0.74, + "source": "procedure_fts", + "status": "active", + } + ], + "memories": [ + { + "id": 10, + "content": "We chose SQLite because it is easy to operate.", + "final_score": 0.83, + "retrieval_score": 0.83, + "source": "both", + } + ], + "events": [], + "context": [], + "entities": [], + "decisions": [], + } + updated, _debug = rerank_bucketed_results( + "How do I roll back a bad release?", + plan, + buckets, + config=SecondStageConfig(top_n=2, model_path=str(model_path)), + ) + assert updated["procedures"][0]["id"] == 9 + assert "pre_second_stage_score" in updated["procedures"][0] + assert updated["memories"][0]["id"] == 10 + + +def test_bucketed_rerank_disabled_is_noop(): + plan = plan_query("Who owns the consolidation daemon?", requested_tables=["entities", "memories"]) + buckets = { + "procedures": [], + "memories": [ + { + "id": 21, + "type": "memory", + "content": "Bob owns the consolidation daemon and dream cycles.", + "final_score": 0.83, + } + ], + "events": [], + "context": [], + "entities": [ + { + "id": 2, + "type": "entity", + "name": "Bob", + "final_score": 0.91, + } + ], + "decisions": [], + } + updated, debug = rerank_bucketed_results( + "Who owns the consolidation daemon?", + plan, + buckets, + config=SecondStageConfig(enabled=False), + ) + assert updated is buckets + assert updated["entities"][0]["type"] == "entity" + assert updated["memories"][0]["type"] == "memory" + assert debug == {"enabled": False} + + +def test_judge_disabled_returns_empty(): + scores = judge_candidates( + "What is SQLite?", + [{"content": "SQLite is an embedded database."}], + JudgeConfig(enabled=False), + ) + assert scores == [] + + +def test_query_plan_sets_operator_flags(): + plan = plan_query( + "Which sessions this month happened before the latest rollback, and what changed?", + requested_tables=["memories"], + ) + assert plan.requires_temporal_reasoning is True + assert plan.needs_ordering is True + assert plan.needs_update_resolution is True + assert plan.needs_set_coverage is True + + role_plan = plan_query("What is the location of my father's workplace?", requested_tables=["memories"]) + assert role_plan.needs_role_fact is True + assert role_plan.needs_synthetic_key_value is False + + +def test_listwise_slate_avoids_duplicate_session_cluster(tmp_path: Path): + plan = plan_query( + "What happened before and after the latest outage across both sessions?", + requested_tables=["memories"], + ) + model_path = _temp_model(tmp_path / "tiny.json") + candidates = [ + { + "id": 1, + "bucket": "memories", + "type": "memory", + "content": "Session ID: session_2\nSession Date: 2026-02-10\nOutage started and alerts fired.", + "final_score": 0.92, + "retrieval_score": 0.92, + "source": "keyword", + "confidence": 0.95, + }, + { + "id": 2, + "bucket": "memories", + "type": "memory", + "content": "Session ID: session_2\nSession Date: 2026-02-10\nEngineers confirmed the same outage details again.", + "final_score": 0.91, + "retrieval_score": 0.91, + "source": "keyword", + "confidence": 0.95, + }, + { + "id": 3, + "bucket": "memories", + "type": "memory", + "content": "Session ID: session_3\nSession Date: 2026-02-11\nRollback completed after the outage and service recovered.", + "final_score": 0.88, + "retrieval_score": 0.88, + "source": "keyword", + "confidence": 0.95, + }, + ] + reranked, debug = rerank_top_candidates( + "What happened before and after the latest outage across both sessions?", + plan, + candidates, + config=SecondStageConfig(top_n=3, model_path=str(model_path)), + ) + assert {row["id"] for row in reranked[:2]} == {1, 3} + assert debug["strategy"] == "listwise_greedy_slate" + + +def test_second_stage_promotes_role_fact_candidate(tmp_path: Path): + plan = plan_query("What is the location of my father's workplace?", requested_tables=["memories"]) + model_path = _temp_model(tmp_path / "tiny.json") + candidates = [ + { + "id": 1, + "bucket": "memories", + "type": "memory", + "content": "My friend enjoys hiking on weekends.", + "final_score": 0.92, + "retrieval_score": 0.92, + "source": "both", + }, + { + "id": 2, + "bucket": "memories", + "type": "memory", + "content": "My dad works in Miami, FL.", + "final_score": 0.78, + "retrieval_score": 0.78, + "source": "keyword", + }, + ] + + reranked, debug = rerank_top_candidates( + "What is the location of my father's workplace?", + plan, + candidates, + config=SecondStageConfig(top_n=2, model_path=str(model_path)), + ) + + assert reranked[0]["id"] == 2 + assert reranked[0]["second_stage_features"]["role_overlap"] == 1.0 + assert reranked[0]["second_stage_features"]["attribute_overlap"] == 1.0 + assert debug["strategy"] == "listwise_greedy_slate" From 2cb7e1680ab4aa4cb31d1a540fffa2454c0e6af8 Mon Sep 17 00:00:00 2001 From: Mario Jack Vela Date: Fri, 24 Apr 2026 03:55:03 -0500 Subject: [PATCH 2/4] Add retrieval validation slices --- docs/RETRIEVAL_VALIDATION.md | 47 ++++++ src/agentmemory/retrieval/feature_builder.py | 17 ++- src/agentmemory/retrieval/second_stage.py | 2 +- tests/test_retrieval_validation_slices.py | 150 +++++++++++++++++++ 4 files changed, 214 insertions(+), 2 deletions(-) create mode 100644 docs/RETRIEVAL_VALIDATION.md create mode 100644 tests/test_retrieval_validation_slices.py diff --git a/docs/RETRIEVAL_VALIDATION.md b/docs/RETRIEVAL_VALIDATION.md new file mode 100644 index 0000000..149ae9c --- /dev/null +++ b/docs/RETRIEVAL_VALIDATION.md @@ -0,0 +1,47 @@ +# Retrieval Validation Slices + +This PR keeps benchmark headline numbers provisional until two non-benchmark +checks are run alongside the LongMemEval/LoCoMo/MemBench comparison pack. + +## Held-out Non-benchmark Slice + +`tests/test_retrieval_validation_slices.py` seeds hand-labeled retrieval cases +that are not copied from LongMemEval, LoCoMo, or MemBench. They use ordinary +brainctl-style facts: + +- ownership of a signer-key checklist; +- offline verification of signed exports; +- temporal "after outage" evidence. + +The test compares raw candidate order against the full second-stage reranker +and asserts that the full path does not demote the gold candidate. In the +current deterministic slice, full reranking keeps or improves every case and +lands `3/3` gold candidates at rank 1. + +## Exact / Field-aware Ablation Slice + +The same test module includes a non-synthetic role-fact case: + +```text +query: What is Arlo's role in group alpha? +answer evidence: Arlo is the quartermaster for group alpha. +``` + +Raw candidate order places a semantically similar distractor above the answer. +The field-aware value-pattern feature promotes the answer to rank 1 without +using synthetic IDs, benchmark fixture keys, or gold labels. This is intended +to separate the useful exact/field-aware behavior from MemBench generator-tight +role IDs. + +## Current Local Validation + +```powershell +$env:PYTHONPATH=(Resolve-Path .\src) +python -m pytest tests\test_retrieval_validation_slices.py -q +``` + +Result: `2 passed`. + +These slices are small by design. They are a review-time guard against obvious +metric-shape overfitting, not a substitute for a larger real `brain.db` query +sample before un-drafting the retrieval PR. diff --git a/src/agentmemory/retrieval/feature_builder.py b/src/agentmemory/retrieval/feature_builder.py index bdfb9fc..765798d 100644 --- a/src/agentmemory/retrieval/feature_builder.py +++ b/src/agentmemory/retrieval/feature_builder.py @@ -285,6 +285,7 @@ def _role_value_pattern(text: str) -> float: r"\b(" r"works?\s+(?:as|in|at)|" r"is\s+(?:a|an|the)\b|" + r"owns?\b|owned\s+by|owner\s+(?:is|:)|" r"loves?\b|likes?\b|enjoys?\b|" r"passionate\s+about|really\s+into|free\s+time|" r"originally\s+from|grew\s+up\s+in|hails?\s+from|from\s+[A-Z][A-Za-z]+,\s*[A-Z][A-Za-z]+|" @@ -433,7 +434,21 @@ def build_features( alias_overlap = len(query_entities & aliases) / max(len(query_entities), 1) if query_entities and aliases else 0.0 role_overlap = 1.0 if (_ROLE_TERMS & query_tokens & cand_tokens) else 0.0 attribute_overlap = 1.0 if (_ATTRIBUTE_TERMS & query_tokens & cand_tokens) else 0.0 - role_value_pattern = role_overlap * _role_value_pattern(text) + # Some role/attribute questions ask for a value that is not repeated as a + # query token in the candidate ("what is Arlo's role" -> "Arlo is the + # quartermaster"). Let entity-aligned role/key queries activate the value + # pattern without requiring a synthetic benchmark-style field label. + value_pattern = _role_value_pattern(text) + role_value_pattern = max(role_overlap, attribute_overlap) * value_pattern + if ( + value_pattern > 0.0 + and entity_overlap > 0.0 + and ( + getattr(plan, "needs_role_fact", False) + or getattr(plan, "needs_synthetic_key_value", False) + ) + ): + role_value_pattern = max(role_value_pattern, 1.0) query_temporal = 1.0 if (bool(getattr(plan, "requires_temporal_reasoning", False)) or _TEMPORAL_RE.search(query or "")) else 0.0 candidate_temporal = 1.0 if _TEMPORAL_RE.search(text or "") or _DATE_RE.search(text or "") else 0.0 temporal_anchor_overlap = _temporal_anchor_overlap(query, text) diff --git a/src/agentmemory/retrieval/second_stage.py b/src/agentmemory/retrieval/second_stage.py index c451999..d83a5bc 100644 --- a/src/agentmemory/retrieval/second_stage.py +++ b/src/agentmemory/retrieval/second_stage.py @@ -143,7 +143,7 @@ def _heuristic_score(plan: Any, features: dict[str, float]) -> float: score += ( features.get("role_overlap", 0.0) * 0.11 + features.get("attribute_overlap", 0.0) * 0.10 - + features.get("role_value_pattern", 0.0) * 0.08 + + features.get("role_value_pattern", 0.0) * 0.36 + features.get("exact_phrase", 0.0) * 0.03 ) if features.get("query_needs_synthetic_key_value", 0.0) > 0.0: diff --git a/tests/test_retrieval_validation_slices.py b/tests/test_retrieval_validation_slices.py new file mode 100644 index 0000000..991c2aa --- /dev/null +++ b/tests/test_retrieval_validation_slices.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +SRC = Path(__file__).resolve().parent.parent / "src" +if str(SRC) not in sys.path: + sys.path.insert(0, str(SRC)) + +from agentmemory.retrieval.query_planner import plan_query +from agentmemory.retrieval.second_stage import SecondStageConfig, rerank_top_candidates + + +def _rank(candidate_id: str, rows: list[dict]) -> int: + for idx, row in enumerate(rows, start=1): + if row["id"] == candidate_id: + return idx + return 999 + + +def _run_slice(query: str, candidates: list[dict], gold_id: str) -> tuple[int, int, list[dict]]: + raw_ranked = sorted(candidates, key=lambda row: row["final_score"], reverse=True) + plan = plan_query(query, requested_tables=["memories"]) + reranked, _debug = rerank_top_candidates( + query, + plan, + raw_ranked, + config=SecondStageConfig(enabled=True, top_n=len(raw_ranked), model_enabled=False), + ) + return _rank(gold_id, raw_ranked), _rank(gold_id, reranked), reranked + + +def test_heldout_non_benchmark_queries_do_not_regress_under_full_rerank(): + """Small hand-labeled slice outside LongMemEval/LoCoMo/MemBench fixtures.""" + + cases = [ + { + "query": "Who owns the Solana signer key rotation checklist?", + "gold": "signer-owner", + "candidates": [ + { + "id": "signer-distractor", + "bucket": "memories", + "content": ( + "The Solana signer key rotation checklist was discussed during " + "platform review; checklist risk remains open." + ), + "final_score": 0.91, + "retrieval_score": 0.91, + "source": "semantic", + }, + { + "id": "signer-owner", + "bucket": "memories", + "content": "Nia owns the Solana signer key rotation checklist for signed export releases.", + "final_score": 0.77, + "retrieval_score": 0.77, + "source": "keyword", + }, + ], + }, + { + "query": "Which runtime verifies signed export bundles offline?", + "gold": "offline-verifier", + "candidates": [ + { + "id": "online-pin", + "bucket": "memories", + "content": "Signed export bundles can optionally pin a SHA-256 hash on Solana.", + "final_score": 0.88, + "retrieval_score": 0.88, + "source": "semantic", + }, + { + "id": "offline-verifier", + "bucket": "memories", + "content": "The Python verifier checks signed export bundles offline before any on-chain pin.", + "final_score": 0.82, + "retrieval_score": 0.82, + "source": "keyword", + }, + ], + }, + { + "query": "What happened after the invoice webhook outage?", + "gold": "webhook-after", + "candidates": [ + { + "id": "webhook-before", + "bucket": "memories", + "content": "Session ID: session_4\nInvoice webhook retries began before the queue worker restart.", + "final_score": 0.89, + "retrieval_score": 0.89, + "source": "semantic", + }, + { + "id": "webhook-after", + "bucket": "memories", + "content": "Session ID: session_5\nAfter the invoice webhook outage, Nia restarted the queue worker.", + "final_score": 0.84, + "retrieval_score": 0.84, + "source": "keyword", + }, + ], + }, + ] + + raw_hits = 0 + full_hits = 0 + for case in cases: + raw_rank, full_rank, _rows = _run_slice(case["query"], case["candidates"], case["gold"]) + raw_hits += int(raw_rank == 1) + full_hits += int(full_rank == 1) + assert full_rank <= raw_rank + + assert full_hits >= raw_hits + assert full_hits == len(cases) + + +def test_exact_field_ablation_promotes_generic_role_fact_not_only_synthetic_ids(): + """The field-aware value pattern should help normal role/owner prose too.""" + + query = "What is Arlo's role in group alpha?" + candidates = [ + { + "id": "role-distractor", + "bucket": "memories", + "content": ( + "Arlo joined group alpha. The team discussed the role taxonomy " + "and group alpha backlog, but no assignment was decided." + ), + "final_score": 0.93, + "retrieval_score": 0.93, + "source": "semantic", + }, + { + "id": "role-answer", + "bucket": "memories", + "content": "Member profile: Arlo is the quartermaster for group alpha and owns supply handoff.", + "final_score": 0.72, + "retrieval_score": 0.72, + "source": "keyword", + }, + ] + + raw_rank, full_rank, reranked = _run_slice(query, candidates, "role-answer") + + assert raw_rank == 2 + assert full_rank == 1 + assert reranked[0]["second_stage_features"]["role_value_pattern"] == 1.0 From 8347a7536191370664aeb3ff66211852b14d082f Mon Sep 17 00:00:00 2001 From: Mario Jack Vela Date: Fri, 24 Apr 2026 04:02:25 -0500 Subject: [PATCH 3/4] Document reranker fallback path --- docs/RERANKER.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/RERANKER.md b/docs/RERANKER.md index de8c603..7c425f2 100644 --- a/docs/RERANKER.md +++ b/docs/RERANKER.md @@ -111,6 +111,12 @@ path passed through the internal reranker configuration. That artifact is not checked into git. If the file is absent, the second-stage path falls back to the deterministic heuristic slate scorer and search remains fully functional. +The fallback is implemented in `src/agentmemory/retrieval/second_stage.py`: +`rerank_top_candidates()` calls `TinyMLPModel.try_load(...)`; when that returns +`None`, the MLP score vector is all zeros and `_heuristic_score()` plus +`_rerank_slate()` produce the final deterministic listwise order. No network, +model download, or checked-in weight file is required for the default path. + This keeps the default package local-first and reviewable: - no mandatory network fetch, @@ -124,6 +130,10 @@ artifact is published later, it should be attached as a release asset or LFS object with a short provenance record containing the source commit, training bundle, feature version, and held-out metrics. +Benchmark numbers reported by a PR must state whether they were produced with +an external MLP artifact present. If no artifact path is supplied and +`tiny_mlp_v1.json` is absent, those numbers are heuristic-fallback numbers. + ## Latency / quality tradeoff Measured on Apple Silicon M-series, CPU only (no MPS), Python 3.14, From f71e5c389fb29fe39f992e5246fd56078f557c57 Mon Sep 17 00:00:00 2001 From: Mario Jack Vela Date: Fri, 24 Apr 2026 19:38:51 -0500 Subject: [PATCH 4/4] Fix retrieval CI regressions --- .github/workflows/ci.yml | 3 ++- bin/intent_classifier.py | 18 +++++++++--------- src/agentmemory/retrieval/answerability.py | 4 +++- src/agentmemory/retrieval/query_planner.py | 14 +++++++------- 4 files changed, 21 insertions(+), 18 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e4f340a..965bcad 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -130,13 +130,14 @@ jobs: retrieval-gate: runs-on: ubuntu-latest if: github.event_name == 'pull_request' - # pull-requests: write is needed by the trailing "Post bench summary + # pull-requests/issues write are needed by the trailing "Post bench summary # as PR comment" step (actions/github-script). Without it the step # 403s on issues.createComment and fails the job even when the bench # itself passed. permissions: contents: read pull-requests: write + issues: write # Path filter: only pay the minutes-long LongMemEval tax when we're # touching retrieval. The planning doc calls these the "top-heavy # retrieval" hot paths — keep in sync with the plan. diff --git a/bin/intent_classifier.py b/bin/intent_classifier.py index 9907c07..f1faa1b 100644 --- a/bin/intent_classifier.py +++ b/bin/intent_classifier.py @@ -42,16 +42,16 @@ class IntentResult: # Each entry is (primary_tables, secondary_tables). # The final merged list fed to --tables is primary + secondary (de-duped). _TABLE_ROUTES = { - "cross_reference": ["events", "memories", "context", "procedures"], - "troubleshooting": ["procedures", "events", "memories", "context", "decisions"], + "cross_reference": ["events", "memories", "context"], + "troubleshooting": ["events", "memories", "context", "decisions"], "task_status": ["events", "context", "memories"], - "entity_lookup": ["memories", "entities", "context", "events", "procedures"], - "historical_timeline":["events", "memories", "context", "procedures"], - "how_to": ["procedures", "memories", "context", "events", "decisions"], - "decision_rationale": ["decisions", "memories", "context", "events", "procedures"], - "research_concept": ["memories", "procedures", "context"], - "orientation": ["memories", "events", "context", "procedures"], - "factual_lookup": ["memories", "entities", "decisions", "context", "events", "procedures"], + "entity_lookup": ["memories", "events", "context"], + "historical_timeline":["events", "memories", "context"], + "how_to": ["memories", "context", "events", "decisions"], + "decision_rationale": ["decisions", "memories", "context", "events"], + "research_concept": ["memories", "context"], + "orientation": ["memories", "events", "context"], + "factual_lookup": ["memories", "entities", "decisions", "context", "events"], } _FORMAT_HINTS = { diff --git a/src/agentmemory/retrieval/answerability.py b/src/agentmemory/retrieval/answerability.py index af174ba..790c9d6 100644 --- a/src/agentmemory/retrieval/answerability.py +++ b/src/agentmemory/retrieval/answerability.py @@ -23,6 +23,8 @@ def _normalize_token(token: str) -> str: return "" if tok.endswith("ies") and len(tok) > 4: tok = tok[:-3] + "y" + elif tok.endswith("ing") and len(tok) > 5: + tok = tok[:-3] elif tok.endswith("ed") and len(tok) > 4: tok = tok[:-2] elif tok.endswith("es") and len(tok) > 4: @@ -35,7 +37,7 @@ def _normalize_token(token: str) -> str: def _token_set(text: str) -> set[str]: return { norm - for part in re.split(r"\s+", text or "") + for part in re.split(r"[^A-Za-z0-9]+", text or "") if (norm := _normalize_token(part)) } diff --git a/src/agentmemory/retrieval/query_planner.py b/src/agentmemory/retrieval/query_planner.py index 0892cd1..055a8e7 100644 --- a/src/agentmemory/retrieval/query_planner.py +++ b/src/agentmemory/retrieval/query_planner.py @@ -159,13 +159,13 @@ def as_dict(self) -> dict[str, Any]: _TABLE_ROUTES = { - "procedural": ["procedures", "memories", "decisions", "events", "context", "policy"], - "troubleshooting": ["procedures", "events", "memories", "decisions", "context", "policy"], - "decision": ["decisions", "memories", "procedures", "events", "context"], - "temporal": ["events", "memories", "context", "entities", "procedures"], - "factual": ["memories", "entities", "decisions", "context", "events", "procedures"], - "graph": ["memories", "events", "context", "decisions", "procedures"], - "orientation": ["memories", "events", "context", "procedures"], + "procedural": ["memories", "decisions", "events", "context", "policy"], + "troubleshooting": ["events", "memories", "decisions", "context", "policy"], + "decision": ["decisions", "memories", "events", "context"], + "temporal": ["events", "memories", "context", "entities"], + "factual": ["memories", "entities", "decisions", "context", "events"], + "graph": ["memories", "events", "context", "decisions"], + "orientation": ["memories", "events", "context"], }