From 5949d7b98009790b4d47e5e1c78fe18fb4979390 Mon Sep 17 00:00:00 2001 From: Ashish-dwi99 Date: Tue, 31 Mar 2026 17:42:21 +0530 Subject: [PATCH] V2.2.3 --- dhee/adapters/base.py | 71 ++++++- dhee/core/session_tracker.py | 346 ++++++++++++++++++++++++++++++++++ dhee/mcp_slim.py | 12 ++ dhee/memory/core.py | 37 +++- dhee/memory/main.py | 11 ++ dhee/simple.py | 81 +++++++- tests/test_auto_lifecycle.py | 143 ++++++++++++++ tests/test_session_tracker.py | 222 ++++++++++++++++++++++ 8 files changed, 911 insertions(+), 12 deletions(-) create mode 100644 dhee/core/session_tracker.py create mode 100644 tests/test_auto_lifecycle.py create mode 100644 tests/test_session_tracker.py diff --git a/dhee/adapters/base.py b/dhee/adapters/base.py index 60c797c..f117139 100644 --- a/dhee/adapters/base.py +++ b/dhee/adapters/base.py @@ -84,7 +84,11 @@ def __init__( buddhi_dir = str(self._engram.data_dir / "buddhi") self._buddhi = Buddhi(data_dir=buddhi_dir) - # Session tracking + # Passive session tracker — auto-context + auto-checkpoint + from dhee.core.session_tracker import SessionTracker + self._tracker = SessionTracker() + + # Session tracking (kept for backward compat with session_start/session_end) self._session_id: Optional[str] = None self._session_start_time: Optional[float] = None @@ -120,7 +124,15 @@ def remember( checks for "remember to X when Y" patterns. """ uid = user_id or self._user_id - result = self._engram.add(content, user_id=uid, infer=False, metadata=metadata) + + # Auto-tier memory content + from dhee.core.session_tracker import classify_tier + tier = classify_tier(content) + meta = dict(metadata) if metadata else {} + if tier != "smriti": + meta["tier"] = tier + + result = self._engram.add(content, user_id=uid, infer=False, metadata=meta or None) response: Dict[str, Any] = {"stored": True} memory_id = None @@ -129,6 +141,12 @@ def remember( if rs: memory_id = rs[0].get("id") response["id"] = memory_id + if tier == "shruti": + response["tier"] = "shruti" + + # Session tracking — may trigger auto-context + signals = self._tracker.on_remember(content, memory_id) + self._handle_tracker_signals(signals, uid) # Buddhi: detect intentions, record episode event, create beliefs intention = self._buddhi.on_memory_stored( @@ -152,7 +170,7 @@ def recall( """Search memory for relevant facts. 0 LLM calls. 1 embedding.""" uid = user_id or self._user_id results = self._engram.search(query, user_id=uid, limit=limit) - return [ + formatted = [ { "memory": r.get("memory", r.get("content", "")), "score": round(r.get("composite_score", r.get("score", 0.0)), 3), @@ -161,6 +179,12 @@ def recall( for r in results ] + # Session tracking + signals = self._tracker.on_recall(query, formatted) + self._handle_tracker_signals(signals, uid) + + return formatted + # ------------------------------------------------------------------ # Tool 3: context # ------------------------------------------------------------------ @@ -172,6 +196,7 @@ def context( ) -> Dict[str, Any]: """HyperAgent session bootstrap. Returns everything the agent needs.""" uid = user_id or self._user_id + self._tracker.on_context(task_description) hyper_ctx = self._buddhi.get_hyper_context( user_id=uid, task_description=task_description, @@ -219,6 +244,21 @@ def checkpoint( 8. Selective forgetting → utility-based cleanup """ uid = user_id or self._user_id + self._tracker.on_checkpoint() + + # Auto-fill task_type if not provided + if not task_type: + task_type = self._tracker.get_inferred_task_type() + if task_type == "general": + task_type = None + + # Auto-fill outcome if not provided + if outcome_score is None and self._tracker.op_count >= 3: + outcome = self._tracker.get_outcome_signals() + outcome_score = outcome.get("outcome_score") + if not what_worked: + what_worked = outcome.get("what_worked") + result: Dict[str, Any] = {} # 1. Session digest @@ -401,6 +441,31 @@ def session_end( self._session_start_time = None return result + # ------------------------------------------------------------------ + # Auto-lifecycle (driven by SessionTracker) + # ------------------------------------------------------------------ + + def _handle_tracker_signals(self, signals: Dict[str, Any], user_id: str) -> None: + """Process signals from the session tracker.""" + if not signals: + return + + # Auto-checkpoint a timed-out previous session + if signals.get("needs_auto_checkpoint"): + args = signals.get("auto_checkpoint_args", {}) + try: + self.checkpoint(user_id=user_id, **args) + except Exception: + pass + + # Auto-context for new session + if signals.get("needs_auto_context"): + task = signals.get("inferred_task") + try: + self.context(task_description=task, user_id=user_id) + except Exception: + pass + # ------------------------------------------------------------------ # Phase 3: Belief management # ------------------------------------------------------------------ diff --git a/dhee/core/session_tracker.py b/dhee/core/session_tracker.py new file mode 100644 index 0000000..22951cb --- /dev/null +++ b/dhee/core/session_tracker.py @@ -0,0 +1,346 @@ +"""Passive session observer — makes Dhee learn without ceremony. + +Tracks operations (remember/recall/context/checkpoint) and automatically: +1. Bootstraps context on first interaction (auto-context) +2. Checkpoints on session timeout (auto-checkpoint) +3. Infers task_type from query/content patterns (0 LLM) +4. Estimates outcome signals from usage patterns (0 LLM) + +The user never has to call context() or checkpoint() explicitly. +Explicit calls still work and override auto-inferred values. +""" + +from __future__ import annotations + +import logging +import re +import time +from typing import Any, Dict, List, Optional, Set + +logger = logging.getLogger(__name__) + + +# ── Memory tiers (Shruti / Smriti / Vasana) ────────────────────────── + +TIER_SHRUTI = "shruti" # Core identity — 0% decay +TIER_SMRITI = "smriti" # Episodic — normal decay +TIER_VASANA = "vasana" # Latent echo — compressed, low-priority retrieval + +# Patterns that mark a memory as core (shruti) +_SHRUTI_PATTERNS = [ + re.compile(r"\b(?:always|never|must|rule|policy|preference|principle)\b", re.I), + re.compile(r"\b(?:i am|my name|my role|i work|i prefer)\b", re.I), + re.compile(r"^(?:system|instruction|config|setting)\s*:", re.I), +] + + +def classify_tier(content: str) -> str: + """Classify memory content into a tier. 0 LLM calls.""" + for pat in _SHRUTI_PATTERNS: + if pat.search(content): + return TIER_SHRUTI + return TIER_SMRITI + + +# ── Task-type inference ────────────────────────────────────────────── + +TASK_PATTERNS: Dict[str, List[str]] = { + "bug_fix": ["fix", "bug", "error", "crash", "broken", "fail", "debug", + "issue", "traceback", "exception", "stack trace", "segfault"], + "code_review": ["review", "pr", "pull request", "approve", "nit", + "suggestion", "feedback", "lgtm"], + "feature": ["add", "implement", "create", "build", "new feature", + "endpoint", "integrate", "support"], + "refactor": ["refactor", "rename", "extract", "move", "reorganize", + "clean", "simplify", "dedup"], + "deploy": ["deploy", "release", "production", "staging", "ci/cd", + "pipeline", "rollback", "ship"], + "research": ["research", "investigate", "explore", "understand", + "how does", "why does", "what is", "learn about"], + "documentation": ["doc", "readme", "changelog", "comment", "explain", + "document", "annotate"], + "testing": ["test", "pytest", "unittest", "coverage", "assertion", + "mock", "fixture", "spec"], +} + + +def infer_task_type(texts: List[str]) -> str: + """Infer task type from a list of text snippets. 0 LLM calls. + + Scores each type by keyword overlap. Returns best match or "general". + """ + combined = " ".join(texts).lower() + if not combined.strip(): + return "general" + + best_type = "general" + best_score = 0 + + for task_type, keywords in TASK_PATTERNS.items(): + score = sum(1 for kw in keywords if kw in combined) + if score > best_score: + best_score = score + best_type = task_type + + return best_type if best_score >= 2 else "general" + + +# ── Session Tracker ────────────────────────────────────────────────── + +class SessionTracker: + """Passive session observer. Tracks operations and detects boundaries. + + All methods are pure heuristics — zero LLM calls. The tracker is + designed to be embedded in Dhee/DheePlugin and called from each + of the 4 operations. + """ + + # Configurable thresholds + SESSION_TIMEOUT_SECONDS: float = 1800.0 # 30 min inactivity = boundary + AUTO_CONTEXT: bool = True + AUTO_CHECKPOINT: bool = True + + def __init__( + self, + session_timeout: Optional[float] = None, + auto_context: bool = True, + auto_checkpoint: bool = True, + ): + self.SESSION_TIMEOUT_SECONDS = session_timeout or 1800.0 + self.AUTO_CONTEXT = auto_context + self.AUTO_CHECKPOINT = auto_checkpoint + self._reset() + + def _reset(self) -> None: + """Reset all session state.""" + self._session_active = False + self._session_start_time: float = 0.0 + self._last_activity_time: float = 0.0 + self._op_count: int = 0 + + # Content tracking + self._memories_stored: List[str] = [] # memory IDs + self._memories_stored_content: List[str] = [] # content snippets + self._recall_result_ids: List[str] = [] # IDs returned by recall + self._recall_queries: List[str] = [] # queries + self._recalled_content: Dict[str, str] = {} # id → content snippet + + # State flags + self._context_loaded: bool = False + self._checkpoint_called: bool = False + self._task_description: Optional[str] = None + + # ── Lifecycle hooks (called by Dhee/DheePlugin) ────────────── + + def on_remember(self, content: str, memory_id: Optional[str] = None) -> Dict[str, Any]: + """Called after remember(). Returns signals dict. + + Returns: + {"needs_auto_context": True} if this is the first op and + auto-context should fire. Empty dict otherwise. + """ + signals: Dict[str, Any] = {} + now = time.time() + + # Check for session timeout → auto-checkpoint previous session + timeout_signals = self._check_timeout(now) + if timeout_signals: + signals.update(timeout_signals) + + # Start session if needed + if not self._session_active: + self._start_session(now) + if self.AUTO_CONTEXT and not self._context_loaded: + signals["needs_auto_context"] = True + # Infer task from content + signals["inferred_task"] = content[:200] + + # Track + self._last_activity_time = now + self._op_count += 1 + if memory_id: + self._memories_stored.append(memory_id) + self._memories_stored_content.append(content[:300]) + + return signals + + def on_recall(self, query: str, results: List[Dict[str, Any]]) -> Dict[str, Any]: + """Called after recall(). Results should have 'id' and 'memory' keys.""" + signals: Dict[str, Any] = {} + now = time.time() + + timeout_signals = self._check_timeout(now) + if timeout_signals: + signals.update(timeout_signals) + + if not self._session_active: + self._start_session(now) + if self.AUTO_CONTEXT and not self._context_loaded: + signals["needs_auto_context"] = True + signals["inferred_task"] = query[:200] + + self._last_activity_time = now + self._op_count += 1 + self._recall_queries.append(query) + for r in results: + rid = r.get("id", "") + if rid: + self._recall_result_ids.append(rid) + self._recalled_content[rid] = r.get("memory", "")[:200] + + return signals + + def on_context(self, task_description: Optional[str] = None) -> None: + """Called when context() is explicitly invoked.""" + now = time.time() + if not self._session_active: + self._start_session(now) + self._context_loaded = True + self._last_activity_time = now + self._op_count += 1 + if task_description: + self._task_description = task_description + + def on_checkpoint(self) -> None: + """Called when checkpoint() is explicitly invoked.""" + self._checkpoint_called = True + self._last_activity_time = time.time() + self._op_count += 1 + # Don't reset — let the caller decide when to start a new session + + def finalize(self) -> Optional[Dict[str, Any]]: + """Called on shutdown/atexit. Returns auto-checkpoint args if session active.""" + if self._session_active and not self._checkpoint_called and self._op_count > 0: + return self._build_auto_checkpoint() + return None + + # ── Inference (all heuristic, 0 LLM) ───────────────────────── + + def get_inferred_task_type(self) -> str: + """Infer task type from accumulated session content.""" + texts = self._recall_queries + self._memories_stored_content[:5] + return infer_task_type(texts) + + def get_outcome_signals(self) -> Dict[str, Any]: + """Estimate outcome from usage patterns. 0 LLM calls. + + Returns: + {"outcome_score": float, "what_worked": str|None, "signals": dict} + """ + signals: Dict[str, Any] = {} + + # Signal 1: Recall utility — were recalled memories referenced later? + recalled_set = set(self._recall_result_ids) + stored_set = set(self._memories_stored) + overlap = recalled_set & stored_set + if recalled_set: + recall_utility = len(overlap) / len(recalled_set) + else: + recall_utility = 0.5 # neutral if no recalls + + # Signal 2: Session productivity — memories stored per minute + duration = max(self._last_activity_time - self._session_start_time, 60.0) + productivity = len(self._memories_stored) / (duration / 60.0) + # Normalize: 0-1 mems/min = low, 1-5 = good, 5+ = very productive + prod_score = min(1.0, productivity / 5.0) + + # Signal 3: Session engagement — how many ops total + engagement = min(1.0, self._op_count / 10.0) + + # Combine signals (weighted average) + outcome_score = ( + 0.4 * recall_utility + + 0.3 * prod_score + + 0.3 * engagement + ) + outcome_score = round(max(0.1, min(1.0, outcome_score)), 2) + + signals["recall_utility"] = round(recall_utility, 2) + signals["productivity"] = round(prod_score, 2) + signals["engagement"] = round(engagement, 2) + + # what_worked: the most-recalled memory content + what_worked = None + if self._recalled_content: + # Find the most frequently recalled memory + from collections import Counter + id_counts = Counter(self._recall_result_ids) + if id_counts: + top_id = id_counts.most_common(1)[0][0] + what_worked = self._recalled_content.get(top_id) + + return { + "outcome_score": outcome_score, + "what_worked": what_worked, + "signals": signals, + } + + # ── Internal ───────────────────────────────────────────────── + + def _start_session(self, now: float) -> None: + self._session_active = True + self._session_start_time = now + self._last_activity_time = now + + def _check_timeout(self, now: float) -> Optional[Dict[str, Any]]: + """Check if previous session timed out. Returns auto-checkpoint args.""" + if not self._session_active: + return None + if self._last_activity_time == 0: + return None + + gap = now - self._last_activity_time + if gap < self.SESSION_TIMEOUT_SECONDS: + return None + + # Session timed out — build auto-checkpoint and reset + auto_cp = None + if self.AUTO_CHECKPOINT and not self._checkpoint_called and self._op_count > 0: + auto_cp = self._build_auto_checkpoint() + + self._reset() + if auto_cp: + return {"needs_auto_checkpoint": True, "auto_checkpoint_args": auto_cp} + return None + + def _build_auto_checkpoint(self) -> Dict[str, Any]: + """Build checkpoint kwargs from inferred session signals.""" + task_type = self.get_inferred_task_type() + outcome = self.get_outcome_signals() + + # Build summary from queries and stored content + summary_parts = [] + if self._task_description: + summary_parts.append(self._task_description) + elif self._recall_queries: + summary_parts.append(f"Session focused on: {self._recall_queries[0]}") + if self._memories_stored_content: + n = len(self._memories_stored_content) + summary_parts.append(f"{n} memories stored") + summary = ". ".join(summary_parts) or "Auto-checkpointed session" + + args: Dict[str, Any] = { + "summary": summary[:500], + "task_type": task_type, + "status": "completed", + } + + # Only include outcome if we have real signals + if self._op_count >= 3: + args["outcome_score"] = outcome["outcome_score"] + if outcome.get("what_worked"): + args["what_worked"] = outcome["what_worked"] + + return args + + @property + def session_active(self) -> bool: + return self._session_active + + @property + def op_count(self) -> int: + return self._op_count + + @property + def context_loaded(self) -> bool: + return self._context_loaded diff --git a/dhee/mcp_slim.py b/dhee/mcp_slim.py index 65dc249..7d9450f 100644 --- a/dhee/mcp_slim.py +++ b/dhee/mcp_slim.py @@ -45,6 +45,18 @@ def _get_plugin(): if hasattr(memory, "config") and hasattr(memory.config, "enrichment"): memory.config.enrichment.defer_enrichment = True memory.config.enrichment.enable_unified = True + + # Auto-checkpoint on server shutdown + import atexit + def _auto_checkpoint_on_exit(): + try: + args = _plugin._tracker.finalize() + if args: + _plugin.checkpoint(**args) + except Exception: + pass + atexit.register(_auto_checkpoint_on_exit) + return _plugin diff --git a/dhee/memory/core.py b/dhee/memory/core.py index 10cb1d1..f51dc91 100644 --- a/dhee/memory/core.py +++ b/dhee/memory/core.py @@ -369,6 +369,17 @@ def apply_decay( if mem.get("immutable"): continue + # Shruti-tier memories are immune to decay + mem_meta = mem.get("metadata") or {} + if isinstance(mem_meta, str): + try: + import json + mem_meta = json.loads(mem_meta) + except Exception: + mem_meta = {} + if mem_meta.get("tier") == "shruti": + continue + new_strength = calculate_decayed_strength( current_strength=float(mem.get("strength", 1.0)), last_accessed=mem.get("last_accessed", mem.get("created_at", "")), @@ -378,15 +389,29 @@ def apply_decay( ) if should_forget(new_strength, self.fade_config): - if self.fade_config.use_tombstone_deletion: + access_count = int(mem.get("access_count", 0)) + # Vasana: memories recalled 3+ times compress instead of dying + if access_count >= 3: + content = mem.get("memory", mem.get("content", "")) + # Compress to first 100 chars + keep keywords + compressed = content[:100].rstrip() + "..." if len(content) > 100 else content + update = { + "strength": self.fade_config.forgetting_threshold + 0.01, + "memory": compressed, + "metadata": json.dumps({**mem_meta, "tier": "vasana"}), + } + self.db.update_memory(mem["id"], update) + decayed += 1 + elif self.fade_config.use_tombstone_deletion: self.db.update_memory(mem["id"], {"tombstone": 1, "strength": new_strength}) + forgotten += 1 else: self.db.delete_memory(mem["id"]) - try: - self.vector_store.delete(mem["id"]) - except Exception: - pass - forgotten += 1 + try: + self.vector_store.delete(mem["id"]) + except Exception: + pass + forgotten += 1 elif should_promote( mem.get("layer", "sml"), int(mem.get("access_count", 0)), diff --git a/dhee/memory/main.py b/dhee/memory/main.py index 2ba6b60..a2a8dad 100644 --- a/dhee/memory/main.py +++ b/dhee/memory/main.py @@ -1980,6 +1980,17 @@ def apply_decay(self, scope: Dict[str, Any] = None) -> Dict[str, Any]: if memory.get("immutable"): continue + # Shruti-tier memories are immune to decay + _tier_md = memory.get("metadata") or {} + if isinstance(_tier_md, str): + import json as _tjson + try: + _tier_md = _tjson.loads(_tier_md) + except (ValueError, TypeError): + _tier_md = {} + if _tier_md.get("tier") == "shruti": + continue + # Task-aware decay: active tasks don't decay if memory.get("memory_type") == "task": _md = memory.get("metadata") or {} diff --git a/dhee/simple.py b/dhee/simple.py index 5cc9ef0..026e673 100644 --- a/dhee/simple.py +++ b/dhee/simple.py @@ -399,6 +399,9 @@ def __init__( data_dir: Optional[Union[str, Path]] = None, user_id: str = "default", in_memory: bool = False, + auto_context: bool = True, + auto_checkpoint: bool = True, + session_timeout: Optional[float] = None, ): self._user_id = user_id self._engram = Engram( @@ -410,6 +413,14 @@ def __init__( buddhi_dir = str(self._engram.data_dir / "buddhi") self._buddhi = Buddhi(data_dir=buddhi_dir) + # Passive session tracker — auto-context + auto-checkpoint + from dhee.core.session_tracker import SessionTracker + self._tracker = SessionTracker( + session_timeout=session_timeout, + auto_context=auto_context, + auto_checkpoint=auto_checkpoint, + ) + # ------------------------------------------------------------------ # Tool 1: remember # ------------------------------------------------------------------ @@ -434,12 +445,29 @@ def remember( {"stored": True, "id": ""} """ uid = user_id or self._user_id - result = self._engram.add(content, user_id=uid, infer=False, metadata=metadata) + + # Auto-tier memory content (shruti/smriti) + from dhee.core.session_tracker import classify_tier + tier = classify_tier(content) + meta = dict(metadata) if metadata else {} + if tier != "smriti": + meta["tier"] = tier + + result = self._engram.add(content, user_id=uid, infer=False, metadata=meta or None) response: Dict[str, Any] = {"stored": True} + memory_id = None if isinstance(result, dict): rs = result.get("results", []) if rs: - response["id"] = rs[0].get("id") + memory_id = rs[0].get("id") + response["id"] = memory_id + if tier == "shruti": + response["tier"] = "shruti" + + # Session tracking — may trigger auto-context + signals = self._tracker.on_remember(content, memory_id) + self._handle_tracker_signals(signals, uid) + # Detect intentions in the content intention = self._buddhi.on_memory_stored(content=content, user_id=uid) if intention: @@ -470,7 +498,7 @@ def recall( """ uid = user_id or self._user_id results = self._engram.search(query, user_id=uid, limit=limit) - return [ + formatted = [ { "memory": r.get("memory", r.get("content", "")), "score": round(r.get("composite_score", r.get("score", 0.0)), 3), @@ -479,6 +507,12 @@ def recall( for r in results ] + # Session tracking — may trigger auto-context + signals = self._tracker.on_recall(query, formatted) + self._handle_tracker_signals(signals, uid) + + return formatted + # ------------------------------------------------------------------ # Tool 3: context # ------------------------------------------------------------------ @@ -503,6 +537,7 @@ def context( performance, memories, last_session, meta. """ uid = user_id or self._user_id + self._tracker.on_context(task_description) hyper_ctx = self._buddhi.get_hyper_context( user_id=uid, task_description=task_description, @@ -562,6 +597,21 @@ def checkpoint( insights_created, intention_stored. """ uid = user_id or self._user_id + self._tracker.on_checkpoint() + + # Auto-fill task_type if not provided + if not task_type: + task_type = self._tracker.get_inferred_task_type() + if task_type == "general": + task_type = None # don't store noise + + # Auto-fill outcome if not provided and we have enough signals + if outcome_score is None and self._tracker.op_count >= 3: + outcome = self._tracker.get_outcome_signals() + outcome_score = outcome.get("outcome_score") + if not what_worked: + what_worked = outcome.get("what_worked") + result: Dict[str, Any] = {} # 1. Session digest @@ -627,3 +677,28 @@ def checkpoint( result["intention_stored"] = intention.to_dict() return result + + # ------------------------------------------------------------------ + # Auto-lifecycle (driven by SessionTracker) + # ------------------------------------------------------------------ + + def _handle_tracker_signals(self, signals: Dict[str, Any], user_id: str) -> None: + """Process signals from the session tracker.""" + if not signals: + return + + # Auto-checkpoint a timed-out previous session + if signals.get("needs_auto_checkpoint"): + args = signals.get("auto_checkpoint_args", {}) + try: + self.checkpoint(user_id=user_id, **args) + except Exception: + pass + + # Auto-context for new session + if signals.get("needs_auto_context"): + task = signals.get("inferred_task") + try: + self.context(task_description=task, user_id=user_id) + except Exception: + pass diff --git a/tests/test_auto_lifecycle.py b/tests/test_auto_lifecycle.py new file mode 100644 index 0000000..54f6bab --- /dev/null +++ b/tests/test_auto_lifecycle.py @@ -0,0 +1,143 @@ +"""Integration tests for autonomous session lifecycle. + +Verifies that Dhee learns without explicit context()/checkpoint() calls. +""" + +import os +import time +import pytest + +from dhee.simple import Dhee + + +@pytest.fixture +def dhee(tmp_path): + """Create a Dhee instance with in-memory storage and short timeout.""" + d = Dhee( + in_memory=True, + data_dir=str(tmp_path), + session_timeout=1.0, # 1 second for testing + ) + return d + + +class TestAutoContext: + """Verify auto-context fires on first operation.""" + + def test_remember_triggers_auto_context(self, dhee): + """First remember() should auto-bootstrap context.""" + result = dhee.remember("User prefers dark mode") + assert result["stored"] is True + # Tracker should be active with context loaded + assert dhee._tracker.session_active is True + assert dhee._tracker.context_loaded is True + + def test_recall_triggers_auto_context(self, dhee): + """First recall() should auto-bootstrap context.""" + dhee.remember("Python is great") + # Reset tracker to simulate fresh session + dhee._tracker._reset() + dhee.recall("programming language") + assert dhee._tracker.context_loaded is True + + def test_explicit_context_still_works(self, dhee): + """Explicit context() should work and prevent double-bootstrap.""" + ctx = dhee.context("fixing auth bug") + assert isinstance(ctx, dict) + assert dhee._tracker.context_loaded is True + # Subsequent remember shouldn't re-trigger context + dhee.remember("found the bug in login.py") + assert dhee._tracker.op_count == 2 # context + remember + + +class TestAutoCheckpoint: + """Verify auto-checkpoint fires on session timeout.""" + + def test_timeout_auto_checkpoints(self, dhee): + """After timeout, next operation should auto-checkpoint.""" + dhee.remember("session 1 work") + dhee.remember("more session 1 work") + dhee.recall("session 1 query") + + # Simulate timeout + dhee._tracker._last_activity_time = time.time() - 2.0 + + # This should trigger auto-checkpoint of session 1 + start session 2 + dhee.remember("session 2 starts") + + # Session 2 should be active now + assert dhee._tracker.session_active is True + assert dhee._tracker.op_count == 1 # only session 2's remember + + def test_explicit_checkpoint_prevents_auto(self, dhee): + """Explicit checkpoint should prevent auto-checkpoint on timeout.""" + dhee.remember("some work") + dhee.checkpoint("done with task", task_type="bug_fix") + + # Simulate timeout + dhee._tracker._last_activity_time = time.time() - 2.0 + + # Should not trigger auto-checkpoint (already checkpointed) + dhee.remember("new session") + assert dhee._tracker.op_count == 1 + + +class TestAutoInference: + """Verify task_type and outcome are auto-inferred.""" + + def test_checkpoint_auto_fills_task_type(self, dhee): + """Checkpoint should auto-fill task_type from session content.""" + dhee.remember("fixing crash in auth module") + dhee.recall("debug error in login") + dhee.remember("found the bug, was a null pointer") + + result = dhee.checkpoint("Fixed auth crash") + # Task type should have been auto-inferred as bug_fix + # (we can't directly verify the inferred type was used, + # but we verify the checkpoint succeeds) + assert isinstance(result, dict) + + def test_checkpoint_auto_fills_outcome(self, dhee): + """Checkpoint should auto-estimate outcome from usage patterns.""" + dhee.remember("step 1") + dhee.remember("step 2") + dhee.remember("step 3") + dhee.recall("what did I do?") + + result = dhee.checkpoint("Finished the task") + assert isinstance(result, dict) + + +class TestShruti: + """Verify shruti-tier memories are auto-detected.""" + + def test_preference_tagged_shruti(self, dhee): + result = dhee.remember("I prefer tabs over spaces") + assert result.get("tier") == "shruti" + + def test_rule_tagged_shruti(self, dhee): + result = dhee.remember("Rule: always write tests first") + assert result.get("tier") == "shruti" + + def test_normal_memory_no_tier_tag(self, dhee): + result = dhee.remember("The meeting is at 3pm") + assert result.get("tier") is None # smriti doesn't get tagged in response + + +class TestDisableAuto: + """Verify auto features can be disabled.""" + + def test_disable_auto_context(self, tmp_path): + d = Dhee(in_memory=True, data_dir=str(tmp_path), auto_context=False) + d.remember("hello") + assert d._tracker.context_loaded is False + + def test_disable_auto_checkpoint(self, tmp_path): + d = Dhee( + in_memory=True, data_dir=str(tmp_path), + auto_checkpoint=False, session_timeout=1.0, + ) + d.remember("session 1") + d._tracker._last_activity_time = time.time() - 2.0 + signals = d._tracker.on_remember("session 2", "m2") + assert signals.get("needs_auto_checkpoint") is None diff --git a/tests/test_session_tracker.py b/tests/test_session_tracker.py new file mode 100644 index 0000000..ec66884 --- /dev/null +++ b/tests/test_session_tracker.py @@ -0,0 +1,222 @@ +"""Unit tests for SessionTracker — passive session observer.""" + +import time +import pytest + +from dhee.core.session_tracker import ( + SessionTracker, + classify_tier, + infer_task_type, + TIER_SHRUTI, + TIER_SMRITI, + TIER_VASANA, +) + + +# ── Tier classification ────────────────────────────────────────────── + + +class TestClassifyTier: + def test_normal_content_is_smriti(self): + assert classify_tier("The meeting is at 3pm tomorrow") == TIER_SMRITI + + def test_preference_is_shruti(self): + assert classify_tier("I prefer Python over JavaScript") == TIER_SHRUTI + + def test_rule_is_shruti(self): + assert classify_tier("Rule: always run tests before committing") == TIER_SHRUTI + + def test_never_is_shruti(self): + assert classify_tier("Never deploy on Fridays") == TIER_SHRUTI + + def test_system_instruction_is_shruti(self): + assert classify_tier("System: use dark mode for all UI") == TIER_SHRUTI + + def test_identity_is_shruti(self): + assert classify_tier("I am a backend engineer at Acme Corp") == TIER_SHRUTI + + def test_short_content_is_smriti(self): + assert classify_tier("hello") == TIER_SMRITI + + +# ── Task-type inference ────────────────────────────────────────────── + + +class TestInferTaskType: + def test_bug_fix_keywords(self): + assert infer_task_type(["fix the crash in login", "debug auth error"]) == "bug_fix" + + def test_feature_keywords(self): + assert infer_task_type(["implement new endpoint", "add user profile"]) == "feature" + + def test_refactor_keywords(self): + assert infer_task_type(["refactor the auth module", "extract helper"]) == "refactor" + + def test_testing_keywords(self): + assert infer_task_type(["write pytest for auth", "add test coverage"]) == "testing" + + def test_general_fallback(self): + assert infer_task_type(["hello world"]) == "general" + + def test_empty_is_general(self): + assert infer_task_type([]) == "general" + + def test_needs_two_keywords_minimum(self): + # Single keyword match isn't enough + assert infer_task_type(["deploy"]) == "general" + + def test_deploy_with_context(self): + assert infer_task_type(["deploy to production", "release pipeline"]) == "deploy" + + +# ── Session Tracker lifecycle ──────────────────────────────────────── + + +class TestSessionTracker: + def test_first_remember_triggers_auto_context(self): + t = SessionTracker() + signals = t.on_remember("user likes dark mode", "m1") + assert signals.get("needs_auto_context") is True + assert t.session_active is True + + def test_second_remember_no_auto_context(self): + t = SessionTracker() + t.on_remember("first", "m1") + signals = t.on_remember("second", "m2") + assert signals.get("needs_auto_context") is None + + def test_first_recall_triggers_auto_context(self): + t = SessionTracker() + signals = t.on_recall("user preferences", [{"id": "m1", "memory": "dark mode"}]) + assert signals.get("needs_auto_context") is True + + def test_explicit_context_suppresses_auto(self): + t = SessionTracker() + t.on_context("fixing auth bug") + assert t.context_loaded is True + signals = t.on_remember("some fact", "m1") + assert signals.get("needs_auto_context") is None + + def test_auto_context_disabled(self): + t = SessionTracker(auto_context=False) + signals = t.on_remember("hello", "m1") + assert signals.get("needs_auto_context") is None + + def test_checkpoint_marks_session(self): + t = SessionTracker() + t.on_remember("x", "m1") + t.on_checkpoint() + assert t._checkpoint_called is True + + def test_op_count_increments(self): + t = SessionTracker() + t.on_remember("a", "m1") + t.on_recall("b", []) + t.on_context("c") + t.on_checkpoint() + assert t.op_count == 4 + + +# ── Timeout detection ──────────────────────────────────────────────── + + +class TestSessionTimeout: + def test_timeout_triggers_auto_checkpoint(self): + t = SessionTracker(session_timeout=1.0) # 1 second timeout + t.on_remember("hello", "m1") + t.on_remember("world", "m2") + t.on_recall("test", [{"id": "m1", "memory": "hello"}]) + + # Simulate timeout + t._last_activity_time = time.time() - 2.0 + + signals = t.on_remember("new session", "m3") + assert signals.get("needs_auto_checkpoint") is True + args = signals["auto_checkpoint_args"] + assert "summary" in args + assert args["status"] == "completed" + + def test_no_timeout_within_window(self): + t = SessionTracker(session_timeout=3600.0) + t.on_remember("hello", "m1") + signals = t.on_remember("world", "m2") + assert signals.get("needs_auto_checkpoint") is None + + def test_timeout_resets_session(self): + t = SessionTracker(session_timeout=1.0) + t.on_remember("first session", "m1") + t._last_activity_time = time.time() - 2.0 + t.on_remember("second session", "m2") + # After timeout, a new session starts + assert t.op_count == 1 # reset + 1 new op + + def test_no_auto_checkpoint_after_explicit(self): + t = SessionTracker(session_timeout=1.0) + t.on_remember("hello", "m1") + t.on_checkpoint() + t._last_activity_time = time.time() - 2.0 + signals = t.on_remember("new", "m2") + # Should not trigger auto-checkpoint since explicit was called + assert signals.get("needs_auto_checkpoint") is None + + +# ── Outcome inference ──────────────────────────────────────────────── + + +class TestOutcomeInference: + def test_outcome_with_recalls(self): + t = SessionTracker() + t.on_remember("setup", "m1") + t.on_recall("test", [{"id": "m1", "memory": "setup"}]) + t.on_remember("result", "m2") + t.on_remember("more", "m3") + + outcome = t.get_outcome_signals() + assert 0.0 < outcome["outcome_score"] <= 1.0 + assert "signals" in outcome + + def test_empty_session_neutral(self): + t = SessionTracker() + outcome = t.get_outcome_signals() + assert outcome["outcome_score"] >= 0.1 + + def test_what_worked_from_top_recall(self): + t = SessionTracker() + t.on_recall("help", [{"id": "m1", "memory": "git blame first"}]) + t.on_recall("more", [{"id": "m1", "memory": "git blame first"}]) + + outcome = t.get_outcome_signals() + assert outcome["what_worked"] == "git blame first" + + +# ── Finalize (atexit) ──────────────────────────────────────────────── + + +class TestFinalize: + def test_finalize_returns_args_for_active_session(self): + t = SessionTracker() + t.on_remember("some work", "m1") + t.on_recall("query", [{"id": "m1", "memory": "some work"}]) + + args = t.finalize() + assert args is not None + assert "summary" in args + assert args["status"] == "completed" + + def test_finalize_returns_none_after_checkpoint(self): + t = SessionTracker() + t.on_remember("done", "m1") + t.on_checkpoint() + assert t.finalize() is None + + def test_finalize_returns_none_for_empty_session(self): + t = SessionTracker() + assert t.finalize() is None + + def test_inferred_task_type_in_auto_checkpoint(self): + t = SessionTracker() + t.on_remember("fix the crash in auth module", "m1") + t.on_recall("debug error in login", []) + args = t.finalize() + assert args is not None + assert args.get("task_type") == "bug_fix"