Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,22 @@ Format follows [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

---

## [Unreleased]

### Added
- `memory/summary.py`: SummaryMemory backend — rolling compression memory with dual-mode support:
- **LLM mode** (when `GROQ_API_KEY` is set): Groq-powered abstractive summarisation
- **Extractive fallback** (zero API cost): regex-based fact-pattern extraction
- 6 new tests in `tests/test_pipeline.py` covering SummaryMemory: recall, compression, context structure, reset, token cost, and benchmark registration
- `SummaryMemory` registered as `"summary"` in `evaluation/benchmark.py`

### Results (extractive mode, 100 turns)
| Backend | Recall@100 | Tokens/Query |
|---------|:----------:|:------------:|
| SummaryMemory | 100% | 318 |

---

## [0.2.0] — 2026-05-22

### Added
Expand Down
10 changes: 8 additions & 2 deletions evaluation/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from memory.naive import NaiveMemory
from memory.rag import RAGMemory
from memory.cascading import CascadingTemporalMemory
from memory.summary import SummaryMemory
from memory.base import BaseMemory
from evaluation.metrics import (
recall_at_t, temporal_drift_score, memory_noise_ratio, precision_at_k,
Expand Down Expand Up @@ -35,12 +36,17 @@ class BackendResult:

def _make_memory(name: str) -> BaseMemory:
if name == "naive":
# Limit to ~1,500 tokens to simulate a realistic context window budget,
# Limit to ~1,200 tokens to simulate a realistic context window budget,
# forcing oldest messages to be evicted as conversation grows.
return NaiveMemory(max_context_tokens=1200)
if name == "rag":
return RAGMemory()
return CascadingTemporalMemory()
if name == "cascading":
return CascadingTemporalMemory()
if name == "summary":
# use_llm=None → auto-detect from GROQ_API_KEY env var
return SummaryMemory(window_size=20, use_llm=None)
raise ValueError(f"Unknown backend: '{name}'. Choose from: naive, rag, cascading, summary")
Comment on lines +46 to +49


def run_benchmark(
Expand Down
179 changes: 179 additions & 0 deletions memory/summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""
SummaryMemory — rolling LLM-generated compression memory backend.

Strategy:
Keep the last `window_size` messages verbatim.
Every time the buffer exceeds `window_size`, compress the overflow
into a running summary using either:
- LLM (Groq) when GROQ_API_KEY is set → high fidelity
- Extractive otherwise → zero-cost fallback

This is conceptually how long-horizon chat assistants work:
recent context stays sharp, old context becomes a compressed narrative.
"""

import os
import re
from typing import List, Dict

from .base import BaseMemory


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

_FACT_PATTERNS = re.compile(
r"(my \w[\w\s]+ is |i am |i'm |changed to |updated to |now is |"
r"name|city|age|occupation|company|hobby|language|food|score|subject)",
re.IGNORECASE,
)

_COMPRESS_SYSTEM = (
"You are a memory compressor for a conversational AI. "
"Given a batch of conversation messages, extract and preserve EVERY personal fact, "
"preference, update, and important detail. "
"Merge these with the existing summary if one is provided. "
"Output a single, compact paragraph of key facts — no filler, no opinions. "
"Always prefer the NEWER value when a fact has been updated."
)


def _extractive_compress(messages: List[Dict], existing_summary: str = "") -> str:
"""
Zero-cost fallback: keep only lines that look like personal facts.
Merges with any existing summary.
"""
kept: List[str] = []

# Re-include existing summary lines
if existing_summary:
kept.append(existing_summary)

for msg in messages:
content = msg.get("content", "")
if _FACT_PATTERNS.search(content):
kept.append(content.strip())

merged = " | ".join(kept)
return merged[:800] if merged else ""


def _llm_compress(messages: List[Dict], existing_summary: str, model: str) -> str:
"""LLM-powered compression via Groq."""
from utils.llm import chat

batch_text = "\n".join(
f"{m['role'].upper()}: {m['content']}" for m in messages
)
user_content = ""
if existing_summary:
user_content += f"Existing summary:\n{existing_summary}\n\n"
user_content += f"New messages to absorb:\n{batch_text}"

result = chat(
[
{"role": "system", "content": _COMPRESS_SYSTEM},
{"role": "user", "content": user_content},
],
model=model,
temperature=0.0,
max_tokens=200,
)
# Fallback if LLM call failed
if result.startswith("[LLM_ERROR"):
return _extractive_compress(messages, existing_summary)
return result.strip()


# ---------------------------------------------------------------------------
# SummaryMemory
# ---------------------------------------------------------------------------

class SummaryMemory(BaseMemory):
"""
Rolling-summary memory backend.

Parameters
----------
window_size : int
Number of most-recent messages kept verbatim.
use_llm : bool | None
True → always use Groq for compression.
False → always use extractive fallback.
None → auto-detect from GROQ_API_KEY env var.
model : str
Groq model name used for compression calls.
"""

name = "summary"

def __init__(
self,
window_size: int = 20,
use_llm: bool | None = None,
model: str = "llama-3.1-8b-instant",
) -> None:
self.window_size = window_size
self.model = model
self._use_llm: bool = (
bool(os.getenv("GROQ_API_KEY")) if use_llm is None else use_llm
)

self.recent: List[Dict] = []
self.summary: str = ""

# ------------------------------------------------------------------
# BaseMemory interface
# ------------------------------------------------------------------

def add_message(self, role: str, content: str, turn: int) -> None:
self.recent.append({"role": role, "content": content, "turn": turn})
# Compress whenever the verbatim buffer grows past the window
if len(self.recent) > self.window_size:
self._compress()
Comment on lines +130 to +134

def get_context(self, query: str, current_turn: int) -> List[Dict]:
context: List[Dict] = []
if self.summary:
context.append({
"role": "system",
"content": f"[Conversation summary] {self.summary}",
})
for msg in self.recent:
context.append({"role": msg["role"], "content": msg["content"]})
return context

def reset(self) -> None:
self.recent = []
self.summary = ""

# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------

def _compress(self) -> None:
"""Move the overflow (everything before the window) into the summary."""
overflow = self.recent[: len(self.recent) - self.window_size]
self.recent = self.recent[-self.window_size :]

if self._use_llm:
self.summary = _llm_compress(overflow, self.summary, self.model)
else:
self.summary = _extractive_compress(overflow, self.summary)

# ------------------------------------------------------------------
# Diagnostics
# ------------------------------------------------------------------

@property
def mode(self) -> str:
return "llm" if self._use_llm else "extractive"

def __repr__(self) -> str:
return (
f"SummaryMemory(window={self.window_size}, "
f"mode={self.mode}, "
f"recent={len(self.recent)}, "
f"summary_len={len(self.summary)})"
)
1 change: 1 addition & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from memory.naive import NaiveMemory
from memory.rag import RAGMemory
from memory.cascading import CascadingTemporalMemory
from memory.summary import SummaryMemory
from evaluation.metrics import (
recall_at_t, precision_at_k, temporal_drift_score,
memory_noise_ratio, cascade_efficiency,
Expand Down
72 changes: 72 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from memory.naive import NaiveMemory
from memory.rag import RAGMemory
from memory.cascading import CascadingTemporalMemory
from memory.summary import SummaryMemory
from evaluation.metrics import (
recall_at_t, temporal_drift_score, memory_noise_ratio, precision_at_k
)
Expand Down Expand Up @@ -114,6 +115,70 @@ def test_noise_ratio_range():
print(f"PASS: noise ratio in range ({noise:.2f})")


# ── SummaryMemory tests ────────────────────────────────────────────────────

def test_summary_extractive_fallback_recall_early():
"""SummaryMemory with extractive compression (no LLM) recalls facts at T=15."""
mem = SummaryMemory(window_size=20, use_llm=False)
_populate(mem, BENCHMARK_FACTS, 15)
active = [f for f in BENCHMARK_FACTS if f.injected_at < 15]
results = [recall_at_t(mem, f, 14) for f in active]
rate = sum(r["recalled"] for r in results) / len(results)
assert rate >= 0.75, f"Expected >=75% recall at T=15 for summary, got {rate:.0%}"
print(f"PASS: summary extractive recall early ({rate:.0%})")


def test_summary_compresses_overflow():
"""After enough messages, summary should be non-empty and recent buffer bounded."""
mem = SummaryMemory(window_size=10, use_llm=False)
_populate(mem, BENCHMARK_FACTS, 30)
assert len(mem.recent) <= mem.window_size, (
f"recent buffer {len(mem.recent)} exceeds window_size {mem.window_size}"
)
assert len(mem.summary) > 0, "summary should be non-empty after overflow"
print(f"PASS: summary compression (recent={len(mem.recent)}, summary_len={len(mem.summary)})")


def test_summary_context_contains_summary_and_recent():
"""get_context() must return the summary block followed by recent messages."""
mem = SummaryMemory(window_size=6, use_llm=False)
_populate(mem, BENCHMARK_FACTS, 20)
ctx = mem.get_context("What is my name?", 19)
roles = [m["role"] for m in ctx]
assert "system" in roles, "context should include a system summary block"
assert "user" in roles, "context should include recent user messages"
print(f"PASS: summary context structure (chunks={len(ctx)}, roles={set(roles)})")


def test_summary_reset_clears_state():
"""reset() must clear both recent buffer and summary string."""
mem = SummaryMemory(window_size=10, use_llm=False)
_populate(mem, BENCHMARK_FACTS, 30)
mem.reset()
assert len(mem.recent) == 0, "recent buffer should be empty after reset"
assert mem.summary == "", "summary should be empty string after reset"
print("PASS: summary reset clears state")


def test_summary_token_cost_bounded():
"""SummaryMemory tokens/query should stay roughly constant after compression."""
mem = SummaryMemory(window_size=20, use_llm=False)
_populate(mem, BENCHMARK_FACTS, 100)
name_fact = BENCHMARK_FACTS[0]
tokens = mem.token_count(name_fact.query_text(), 99)
# Should NOT grow linearly with history — bounded by window + summary
assert tokens < 2000, f"token cost {tokens} seems unbounded (expected < 2000)"
print(f"PASS: summary token cost bounded ({tokens} tokens at T=100)")


def test_summary_benchmark_registration():
"""'summary' backend must be resolvable from the benchmark runner."""
from evaluation.benchmark import _make_memory
mem = _make_memory("summary")
assert mem.name == "summary"
print(f"PASS: summary registered in benchmark runner ({mem!r})")


if __name__ == "__main__":
tests = [
test_conversation_generator,
Expand All @@ -124,6 +189,13 @@ def test_noise_ratio_range():
test_temporal_drift_after_update,
test_token_count_ordering,
test_noise_ratio_range,
# SummaryMemory
test_summary_extractive_fallback_recall_early,
test_summary_compresses_overflow,
test_summary_context_contains_summary_and_recent,
test_summary_reset_clears_state,
test_summary_token_cost_bounded,
test_summary_benchmark_registration,
]
failed = 0
for t in tests:
Expand Down
Loading