diff --git a/dhee/__init__.py b/dhee/__init__.py
index 3172da0..f912e24 100644
--- a/dhee/__init__.py
+++ b/dhee/__init__.py
@@ -24,7 +24,7 @@
from dhee.memory.smart import SmartMemory
from dhee.memory.main import FullMemory
from dhee.simple import Dhee, Engram
-from dhee.adapters.base import DheePlugin
+from dhee.plugin import DheePlugin
from dhee.core.category import CategoryProcessor, Category, CategoryType, CategoryMatch
from dhee.core.echo import EchoProcessor, EchoDepth, EchoResult
from dhee.configs.base import MemoryConfig, FadeMemConfig, EchoMemConfig, CategoryMemConfig, ScopeConfig
@@ -32,7 +32,7 @@
# Default: CoreMemory (lightest, zero-config)
Memory = CoreMemory
-__version__ = "3.0.1"
+__version__ = "3.1.0"
__all__ = [
# Memory classes
"Engram",
diff --git a/dhee/adapters/__init__.py b/dhee/adapters/__init__.py
deleted file mode 100644
index 85a7879..0000000
--- a/dhee/adapters/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-"""Dhee adapters — universal plugin interface for any agent framework.
-
-Available adapters:
- - DheePlugin: Base universal plugin (remember/recall/context/checkpoint)
- - OpenAIToolAdapter: OpenAI function calling (tools= parameter)
- - get_dhee_tools: LangChain BaseTool wrappers
- - get_autogen_functions: AutoGen v0.2 callables + schemas
- - generate_snapshot: Frozen system prompt for non-tool-calling agents
-"""
-
-from dhee.adapters.base import DheePlugin
-
-__all__ = ["DheePlugin"]
diff --git a/dhee/adapters/autogen.py b/dhee/adapters/autogen.py
deleted file mode 100644
index 8d4b4e1..0000000
--- a/dhee/adapters/autogen.py
+++ /dev/null
@@ -1,268 +0,0 @@
-"""AutoGen adapter — wraps DheePlugin tools as AutoGen-callable functions.
-
-Supports both AutoGen v0.2 (register_for_llm/register_for_execution) and
-the newer AG2 / AutoGen 0.4+ patterns.
-
-Usage with AutoGen v0.2:
- from dhee import DheePlugin
- from dhee.adapters.autogen import get_autogen_functions, register_dhee_tools
-
- plugin = DheePlugin()
-
- # Option 1: Get callables + schemas for manual registration
- functions = get_autogen_functions(plugin)
-
- # Option 2: Auto-register on an assistant + executor pair
- register_dhee_tools(plugin, assistant=assistant, executor=user_proxy)
-
-Usage with AG2 / AutoGen 0.4+:
- from dhee.adapters.autogen import get_autogen_tool_specs
-
- specs = get_autogen_tool_specs(plugin)
- # Pass to ConversableAgent(tools=specs)
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple
-
-logger = logging.getLogger(__name__)
-
-
-# ---------------------------------------------------------------------------
-# Tool callables
-# ---------------------------------------------------------------------------
-
-def _make_callables(plugin: Any) -> Dict[str, Callable]:
- """Create plain callables wrapping DheePlugin methods."""
-
- def remember(content: str, user_id: str = "default") -> str:
- """Store a fact, preference, or observation to memory."""
- result = plugin.remember(content=content, user_id=user_id)
- return json.dumps(result, default=str)
-
- def recall(query: str, user_id: str = "default", limit: int = 5) -> str:
- """Search memory for relevant facts."""
- results = plugin.recall(query=query, user_id=user_id, limit=limit)
- return json.dumps(results, default=str)
-
- def context(
- task_description: str = "", user_id: str = "default",
- ) -> str:
- """HyperAgent session bootstrap. Returns full cognition context."""
- result = plugin.context(
- task_description=task_description or None, user_id=user_id,
- )
- return json.dumps(result, default=str)
-
- def checkpoint(
- summary: str,
- task_type: str = "",
- outcome_score: float = -1.0,
- what_worked: str = "",
- what_failed: str = "",
- remember_to: str = "",
- ) -> str:
- """Save session state and learnings."""
- kwargs: Dict[str, Any] = {"summary": summary}
- if task_type:
- kwargs["task_type"] = task_type
- if outcome_score >= 0:
- kwargs["outcome_score"] = outcome_score
- if what_worked:
- kwargs["what_worked"] = what_worked
- if what_failed:
- kwargs["what_failed"] = what_failed
- if remember_to:
- kwargs["remember_to"] = remember_to
- result = plugin.checkpoint(**kwargs)
- return json.dumps(result, default=str)
-
- return {
- "dhee_remember": remember,
- "dhee_recall": recall,
- "dhee_context": context,
- "dhee_checkpoint": checkpoint,
- }
-
-
-# ---------------------------------------------------------------------------
-# AutoGen v0.2 schemas
-# ---------------------------------------------------------------------------
-
-_AUTOGEN_SCHEMAS: List[Dict[str, Any]] = [
- {
- "name": "dhee_remember",
- "description": (
- "Store a fact, preference, or observation to memory. "
- "Zero LLM calls, one embedding call."
- ),
- "parameters": {
- "type": "object",
- "properties": {
- "content": {
- "type": "string",
- "description": "The fact to remember",
- },
- "user_id": {
- "type": "string",
- "description": "User identifier (default: 'default')",
- "default": "default",
- },
- },
- "required": ["content"],
- },
- },
- {
- "name": "dhee_recall",
- "description": (
- "Search memory for relevant facts. Returns top-K ranked by relevance."
- ),
- "parameters": {
- "type": "object",
- "properties": {
- "query": {
- "type": "string",
- "description": "What you're trying to remember",
- },
- "user_id": {
- "type": "string",
- "description": "User identifier",
- "default": "default",
- },
- "limit": {
- "type": "integer",
- "description": "Max results (default: 5)",
- "default": 5,
- },
- },
- "required": ["query"],
- },
- },
- {
- "name": "dhee_context",
- "description": (
- "HyperAgent session bootstrap. Returns performance, insights, "
- "intentions, warnings, heuristics, and memories."
- ),
- "parameters": {
- "type": "object",
- "properties": {
- "task_description": {
- "type": "string",
- "description": "What you're about to work on",
- "default": "",
- },
- "user_id": {
- "type": "string",
- "description": "User identifier",
- "default": "default",
- },
- },
- },
- },
- {
- "name": "dhee_checkpoint",
- "description": (
- "Save session state and learnings. Records outcomes, synthesizes "
- "insights, stores intentions."
- ),
- "parameters": {
- "type": "object",
- "properties": {
- "summary": {
- "type": "string",
- "description": "What you were working on",
- },
- "task_type": {
- "type": "string",
- "description": "Task category (e.g., 'bug_fix')",
- "default": "",
- },
- "outcome_score": {
- "type": "number",
- "description": "0.0-1.0 outcome score (-1 to skip)",
- "default": -1.0,
- },
- "what_worked": {
- "type": "string",
- "description": "Approach that worked",
- "default": "",
- },
- "what_failed": {
- "type": "string",
- "description": "Approach that failed",
- "default": "",
- },
- "remember_to": {
- "type": "string",
- "description": "Future intention: 'remember to X when Y'",
- "default": "",
- },
- },
- "required": ["summary"],
- },
- },
-]
-
-
-# ---------------------------------------------------------------------------
-# Public API
-# ---------------------------------------------------------------------------
-
-def get_autogen_functions(
- plugin: Any,
-) -> List[Tuple[Callable, Dict[str, Any]]]:
- """Get (callable, schema) pairs for AutoGen v0.2 registration.
-
- Returns:
- List of (function, schema_dict) tuples ready for
- register_for_llm / register_for_execution.
- """
- callables = _make_callables(plugin)
- return [
- (callables[schema["name"]], schema)
- for schema in _AUTOGEN_SCHEMAS
- ]
-
-
-def register_dhee_tools(
- plugin: Any,
- assistant: Any,
- executor: Any,
-) -> None:
- """Register Dhee tools on an AutoGen v0.2 assistant + executor pair.
-
- Args:
- plugin: A DheePlugin instance.
- assistant: An AssistantAgent (or ConversableAgent) for LLM.
- executor: A UserProxyAgent (or ConversableAgent) for execution.
- """
- callables = _make_callables(plugin)
-
- for schema in _AUTOGEN_SCHEMAS:
- name = schema["name"]
- fn = callables[name]
-
- # Register for LLM (tool definition)
- assistant.register_for_llm(
- name=name,
- description=schema["description"],
- )(fn)
-
- # Register for execution
- executor.register_for_execution(name=name)(fn)
-
-
-def get_autogen_tool_specs(plugin: Any) -> List[Dict[str, Any]]:
- """Get tool specs for AG2 / AutoGen 0.4+ ConversableAgent(tools=...).
-
- Returns a list of dicts with 'function' and 'schema' keys.
- """
- callables = _make_callables(plugin)
- return [
- {"function": callables[schema["name"]], "schema": schema}
- for schema in _AUTOGEN_SCHEMAS
- ]
diff --git a/dhee/adapters/langchain.py b/dhee/adapters/langchain.py
deleted file mode 100644
index a13a4a6..0000000
--- a/dhee/adapters/langchain.py
+++ /dev/null
@@ -1,236 +0,0 @@
-"""LangChain adapter — wraps DheePlugin tools as LangChain BaseTool instances.
-
-Usage:
- from dhee import DheePlugin
- from dhee.adapters.langchain import get_dhee_tools
-
- plugin = DheePlugin()
- tools = get_dhee_tools(plugin)
-
- # Use with any LangChain agent:
- agent = create_react_agent(llm, tools)
-
- # Or pick individual tools:
- remember_tool, recall_tool, context_tool, checkpoint_tool = tools
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-from typing import Any, Dict, List, Optional, Type
-
-logger = logging.getLogger(__name__)
-
-# Lazy import — LangChain is optional
-_HAS_LANGCHAIN = None
-
-
-def _check_langchain() -> bool:
- global _HAS_LANGCHAIN
- if _HAS_LANGCHAIN is None:
- try:
- from langchain_core.tools import BaseTool # noqa: F401
- _HAS_LANGCHAIN = True
- except ImportError:
- _HAS_LANGCHAIN = False
- return _HAS_LANGCHAIN
-
-
-def _get_base_classes():
- """Import LangChain base classes (raises ImportError if not installed)."""
- from langchain_core.tools import BaseTool
- from langchain_core.callbacks import CallbackManagerForToolRun
- try:
- from pydantic import BaseModel, Field
- except ImportError:
- from langchain_core.pydantic_v1 import BaseModel, Field
- return BaseTool, CallbackManagerForToolRun, BaseModel, Field
-
-
-# ---------------------------------------------------------------------------
-# Tool implementations
-# ---------------------------------------------------------------------------
-
-def _make_remember_tool(plugin: Any):
- BaseTool, CallbackManagerForToolRun, BaseModel, Field = _get_base_classes()
-
- class RememberInput(BaseModel):
- content: str = Field(description="The fact, preference, or observation to remember")
- user_id: Optional[str] = Field(default=None, description="User identifier")
-
- class DheeRemember(BaseTool):
- name: str = "dhee_remember"
- description: str = (
- "Store a fact, preference, or observation to memory. "
- "Zero LLM calls, one embedding call. Fast."
- )
- args_schema: Type[BaseModel] = RememberInput
- _plugin: Any = None
-
- def __init__(self, plugin: Any, **kwargs):
- super().__init__(**kwargs)
- self._plugin = plugin
-
- def _run(
- self,
- content: str,
- user_id: Optional[str] = None,
- run_manager: Optional[CallbackManagerForToolRun] = None,
- ) -> str:
- result = self._plugin.remember(content=content, user_id=user_id)
- return json.dumps(result, default=str)
-
- return DheeRemember(plugin=plugin)
-
-
-def _make_recall_tool(plugin: Any):
- BaseTool, CallbackManagerForToolRun, BaseModel, Field = _get_base_classes()
-
- class RecallInput(BaseModel):
- query: str = Field(description="What you're trying to remember")
- user_id: Optional[str] = Field(default=None, description="User identifier")
- limit: int = Field(default=5, description="Maximum results to return")
-
- class DheeRecall(BaseTool):
- name: str = "dhee_recall"
- description: str = (
- "Search memory for relevant facts. Returns top-K results "
- "ranked by relevance. Zero LLM calls, one embedding."
- )
- args_schema: Type[BaseModel] = RecallInput
- _plugin: Any = None
-
- def __init__(self, plugin: Any, **kwargs):
- super().__init__(**kwargs)
- self._plugin = plugin
-
- def _run(
- self,
- query: str,
- user_id: Optional[str] = None,
- limit: int = 5,
- run_manager: Optional[CallbackManagerForToolRun] = None,
- ) -> str:
- results = self._plugin.recall(query=query, user_id=user_id, limit=limit)
- return json.dumps(results, default=str)
-
- return DheeRecall(plugin=plugin)
-
-
-def _make_context_tool(plugin: Any):
- BaseTool, CallbackManagerForToolRun, BaseModel, Field = _get_base_classes()
-
- class ContextInput(BaseModel):
- task_description: Optional[str] = Field(
- default=None, description="What you're about to work on",
- )
- user_id: Optional[str] = Field(default=None, description="User identifier")
-
- class DheeContext(BaseTool):
- name: str = "dhee_context"
- description: str = (
- "HyperAgent session bootstrap. Returns performance snapshots, "
- "insights, intentions, warnings, heuristics, and relevant memories. "
- "Call once at the start of a task."
- )
- args_schema: Type[BaseModel] = ContextInput
- _plugin: Any = None
-
- def __init__(self, plugin: Any, **kwargs):
- super().__init__(**kwargs)
- self._plugin = plugin
-
- def _run(
- self,
- task_description: Optional[str] = None,
- user_id: Optional[str] = None,
- run_manager: Optional[CallbackManagerForToolRun] = None,
- ) -> str:
- result = self._plugin.context(
- task_description=task_description, user_id=user_id,
- )
- return json.dumps(result, default=str)
-
- return DheeContext(plugin=plugin)
-
-
-def _make_checkpoint_tool(plugin: Any):
- BaseTool, CallbackManagerForToolRun, BaseModel, Field = _get_base_classes()
-
- class CheckpointInput(BaseModel):
- summary: str = Field(description="What you were working on")
- task_type: Optional[str] = Field(
- default=None, description="Task category (e.g., 'bug_fix')",
- )
- outcome_score: Optional[float] = Field(
- default=None, description="0.0-1.0 outcome score",
- )
- what_worked: Optional[str] = Field(
- default=None, description="Approach that worked",
- )
- what_failed: Optional[str] = Field(
- default=None, description="Approach that failed",
- )
- remember_to: Optional[str] = Field(
- default=None, description="Future intention: 'remember to X when Y'",
- )
-
- class DheeCheckpoint(BaseTool):
- name: str = "dhee_checkpoint"
- description: str = (
- "Save session state and learnings. Records outcomes, "
- "synthesizes insights from what worked/failed, stores intentions."
- )
- args_schema: Type[BaseModel] = CheckpointInput
- _plugin: Any = None
-
- def __init__(self, plugin: Any, **kwargs):
- super().__init__(**kwargs)
- self._plugin = plugin
-
- def _run(
- self,
- summary: str,
- task_type: Optional[str] = None,
- outcome_score: Optional[float] = None,
- what_worked: Optional[str] = None,
- what_failed: Optional[str] = None,
- remember_to: Optional[str] = None,
- run_manager: Optional[CallbackManagerForToolRun] = None,
- ) -> str:
- result = self._plugin.checkpoint(
- summary=summary, task_type=task_type,
- outcome_score=outcome_score, what_worked=what_worked,
- what_failed=what_failed, remember_to=remember_to,
- )
- return json.dumps(result, default=str)
-
- return DheeCheckpoint(plugin=plugin)
-
-
-# ---------------------------------------------------------------------------
-# Public API
-# ---------------------------------------------------------------------------
-
-def get_dhee_tools(plugin: Any) -> List[Any]:
- """Create LangChain tool instances from a DheePlugin.
-
- Returns:
- [DheeRemember, DheeRecall, DheeContext, DheeCheckpoint]
-
- Raises:
- ImportError: If langchain-core is not installed.
- """
- if not _check_langchain():
- raise ImportError(
- "langchain-core is required for LangChain integration. "
- "Install it with: pip install langchain-core"
- )
-
- return [
- _make_remember_tool(plugin),
- _make_recall_tool(plugin),
- _make_context_tool(plugin),
- _make_checkpoint_tool(plugin),
- ]
diff --git a/dhee/adapters/openai_funcs.py b/dhee/adapters/openai_funcs.py
deleted file mode 100644
index 0702dfa..0000000
--- a/dhee/adapters/openai_funcs.py
+++ /dev/null
@@ -1,183 +0,0 @@
-"""OpenAI function calling adapter for DheePlugin.
-
-Generates tool definitions compatible with:
- - OpenAI Chat Completions API (tools parameter)
- - Any OpenAI-compatible API (Ollama, vLLM, LiteLLM, etc.)
-
-Usage with the OpenAI SDK:
- from dhee.adapters.openai_funcs import OpenAIToolAdapter
-
- adapter = OpenAIToolAdapter(plugin)
- response = client.chat.completions.create(
- model="gpt-4",
- messages=messages,
- tools=adapter.tool_definitions(),
- )
-
- # Execute the function call
- for call in response.choices[0].message.tool_calls:
- result = adapter.execute(call.function.name, json.loads(call.function.arguments))
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-from typing import Any, Callable, Dict, List, Optional
-
-logger = logging.getLogger(__name__)
-
-
-class OpenAIToolAdapter:
- """Wraps DheePlugin as OpenAI-compatible function calling tools.
-
- Provides tool_definitions() for the API request and execute() for
- dispatching tool calls from the response.
- """
-
- def __init__(self, plugin: Any):
- """
- Args:
- plugin: A DheePlugin instance.
- """
- self._plugin = plugin
- self._dispatchers: Dict[str, Callable] = {
- "remember": self._exec_remember,
- "recall": self._exec_recall,
- "context": self._exec_context,
- "checkpoint": self._exec_checkpoint,
- "session_start": self._exec_session_start,
- "session_end": self._exec_session_end,
- }
-
- def tool_definitions(self, include_session: bool = False) -> List[Dict[str, Any]]:
- """Return OpenAI-format tool definitions.
-
- Args:
- include_session: If True, also includes session_start and
- session_end as callable tools.
- """
- tools = self._plugin.as_openai_functions()
-
- if include_session:
- tools.extend([
- {
- "type": "function",
- "function": {
- "name": "session_start",
- "description": (
- "Start a Dhee cognition session. Returns a frozen context "
- "block. Call once at the beginning of a task."
- ),
- "parameters": {
- "type": "object",
- "properties": {
- "task_description": {
- "type": "string",
- "description": "What you're about to work on",
- },
- },
- },
- },
- },
- {
- "type": "function",
- "function": {
- "name": "session_end",
- "description": (
- "End the current Dhee session and save learnings."
- ),
- "parameters": {
- "type": "object",
- "properties": {
- "summary": {
- "type": "string",
- "description": "What you accomplished",
- },
- "outcome_score": {
- "type": "number",
- "description": "0.0-1.0 outcome score",
- },
- "what_worked": {
- "type": "string",
- "description": "Approach that worked",
- },
- "what_failed": {
- "type": "string",
- "description": "Approach that failed",
- },
- },
- "required": ["summary"],
- },
- },
- },
- ])
-
- return tools
-
- def execute(self, function_name: str, arguments: Dict[str, Any]) -> str:
- """Execute a tool call and return the JSON-encoded result.
-
- This is the glue between OpenAI's tool_call response and DheePlugin.
-
- Args:
- function_name: The function name from the tool call.
- arguments: Parsed JSON arguments from the tool call.
-
- Returns:
- JSON string suitable for a tool message.
- """
- dispatcher = self._dispatchers.get(function_name)
- if not dispatcher:
- return json.dumps({"error": f"Unknown function: {function_name}"})
-
- try:
- result = dispatcher(arguments)
- return json.dumps(result, default=str, ensure_ascii=False)
- except Exception as e:
- logger.warning("Tool execution failed for %s: %s", function_name, e)
- return json.dumps({"error": str(e)})
-
- def _exec_remember(self, args: Dict[str, Any]) -> Any:
- return self._plugin.remember(
- content=args["content"],
- user_id=args.get("user_id"),
- )
-
- def _exec_recall(self, args: Dict[str, Any]) -> Any:
- return self._plugin.recall(
- query=args["query"],
- user_id=args.get("user_id"),
- limit=args.get("limit", 5),
- )
-
- def _exec_context(self, args: Dict[str, Any]) -> Any:
- return self._plugin.context(
- task_description=args.get("task_description"),
- user_id=args.get("user_id"),
- )
-
- def _exec_checkpoint(self, args: Dict[str, Any]) -> Any:
- return self._plugin.checkpoint(
- summary=args["summary"],
- task_type=args.get("task_type"),
- outcome_score=args.get("outcome_score"),
- what_worked=args.get("what_worked"),
- what_failed=args.get("what_failed"),
- remember_to=args.get("remember_to"),
- trigger_keywords=args.get("trigger_keywords"),
- )
-
- def _exec_session_start(self, args: Dict[str, Any]) -> Any:
- prompt = self._plugin.session_start(
- task_description=args.get("task_description"),
- )
- return {"system_prompt": prompt}
-
- def _exec_session_end(self, args: Dict[str, Any]) -> Any:
- return self._plugin.session_end(
- summary=args["summary"],
- outcome_score=args.get("outcome_score"),
- what_worked=args.get("what_worked"),
- what_failed=args.get("what_failed"),
- )
diff --git a/dhee/adapters/system_prompt.py b/dhee/adapters/system_prompt.py
deleted file mode 100644
index 60b39f9..0000000
--- a/dhee/adapters/system_prompt.py
+++ /dev/null
@@ -1,245 +0,0 @@
-"""Frozen snapshot generator — renders DheePlugin context as a system prompt.
-
-For agents that don't support tool calling (e.g., simple prompt → completion
-workflows, humanoid robot controllers, voice assistants), this module generates
-a self-contained system prompt block that includes the full HyperContext.
-
-The "frozen snapshot" pattern (from NousResearch Hermes Agent architecture):
- 1. At session start, load HyperContext into the system prompt.
- 2. During the session, the system prompt is NEVER mutated — this preserves
- LLM KV-cache / prefix caches for fast inference.
- 3. At session end, new knowledge is written to storage for next time.
-
-Usage:
- from dhee import DheePlugin
- from dhee.adapters.system_prompt import generate_snapshot, SnapshotConfig
-
- plugin = DheePlugin()
- prompt = generate_snapshot(plugin, task="fixing auth bug")
-
- # Or with custom config:
- config = SnapshotConfig(
- include_memories=True,
- include_heuristics=True,
- max_memories=10,
- include_tool_instructions=True,
- )
- prompt = generate_snapshot(plugin, task="fixing auth bug", config=config)
-"""
-
-from __future__ import annotations
-
-import logging
-import textwrap
-from dataclasses import dataclass, field
-from typing import Any, Dict, List, Optional
-
-logger = logging.getLogger(__name__)
-
-
-@dataclass
-class SnapshotConfig:
- """Controls what goes into the frozen snapshot."""
-
- include_performance: bool = True
- include_warnings: bool = True
- include_insights: bool = True
- include_intentions: bool = True
- include_contrasts: bool = True
- include_heuristics: bool = True
- include_memories: bool = True
- include_tool_instructions: bool = False
- include_hive: bool = False
-
- max_performance: int = 5
- max_warnings: int = 5
- max_insights: int = 5
- max_intentions: int = 5
- max_contrasts: int = 3
- max_heuristics: int = 3
- max_memories: int = 10
-
- # Prefix/suffix for wrapping the snapshot
- header: str = "## Dhee Cognition Context (Frozen Snapshot)"
- footer: str = ""
-
-
-# Tool usage instructions (for agents that CAN call tools after loading snapshot)
-_TOOL_INSTRUCTIONS = """\
-### Available Memory Tools
-- **remember(content)** — store a new fact/observation
-- **recall(query)** — search memory for relevant facts
-- **context(task)** — load full HyperContext (already loaded above)
-- **checkpoint(summary, ...)** — save session state and learnings
-
-Use `remember` proactively when you learn new facts. Use `recall` before
-answering questions that may depend on stored knowledge. Use `checkpoint`
-at natural breakpoints and at session end.
-"""
-
-
-def generate_snapshot(
- plugin: Any,
- task: Optional[str] = None,
- user_id: Optional[str] = None,
- config: Optional[SnapshotConfig] = None,
- hive: Optional[Any] = None,
-) -> str:
- """Generate a frozen system prompt snapshot from DheePlugin.
-
- Args:
- plugin: A DheePlugin instance.
- task: Current task description.
- user_id: User identifier.
- config: Snapshot configuration. Defaults to include everything.
- hive: Optional HiveMemory instance for multi-agent context.
-
- Returns:
- A complete system prompt block as a string.
- """
- cfg = config or SnapshotConfig()
- ctx = plugin.context(task_description=task, user_id=user_id)
-
- parts: List[str] = [cfg.header]
-
- if task:
- parts.append(f"\n**Current task:** {task}")
-
- # Performance
- if cfg.include_performance:
- perf = ctx.get("performance", [])[:cfg.max_performance]
- if perf:
- parts.append("\n### Performance History")
- for p in perf:
- trend = p.get("trend", 0)
- direction = "improving" if trend > 0 else "declining" if trend < 0 else "stable"
- parts.append(
- f"- **{p['task_type']}**: avg={p['avg_score']:.2f}, "
- f"trend={p['trend']:+.3f} ({direction}), "
- f"attempts={p['total_attempts']}"
- )
-
- # Warnings
- if cfg.include_warnings:
- warnings = ctx.get("warnings", [])[:cfg.max_warnings]
- if warnings:
- parts.append("\n### Warnings")
- for w in warnings:
- parts.append(f"- {w}")
-
- # Insights
- if cfg.include_insights:
- insights = ctx.get("insights", [])[:cfg.max_insights]
- if insights:
- parts.append("\n### Insights from Past Work")
- for i in insights:
- parts.append(f"- [{i.get('type', 'general')}] {i['content']}")
-
- # Intentions (triggered reminders)
- if cfg.include_intentions:
- intentions = ctx.get("intentions", [])[:cfg.max_intentions]
- if intentions:
- parts.append("\n### Triggered Reminders")
- for i in intentions:
- parts.append(f"- {i['description']}")
-
- # Contrastive evidence
- if cfg.include_contrasts:
- contrasts = ctx.get("contrasts", [])[:cfg.max_contrasts]
- if contrasts:
- parts.append("\n### Contrastive Evidence (Do / Avoid)")
- for c in contrasts:
- do_text = c.get("do", "")[:200]
- avoid_text = c.get("avoid", "")[:200]
- parts.append(f"- **Do:** {do_text}")
- parts.append(f" **Avoid:** {avoid_text}")
- confidence = c.get("confidence")
- if confidence is not None:
- parts.append(f" *confidence: {confidence:.1%}*")
-
- # Heuristics
- if cfg.include_heuristics:
- heuristics = ctx.get("heuristics", [])[:cfg.max_heuristics]
- if heuristics:
- parts.append("\n### Learned Heuristics")
- for h in heuristics:
- level = h.get("level", "domain")
- text = h.get("heuristic", "")[:250]
- parts.append(f"- [{level}] {text}")
-
- # Memories
- if cfg.include_memories:
- memories = ctx.get("memories", [])[:cfg.max_memories]
- if memories:
- parts.append("\n### Relevant Memories")
- for m in memories:
- mem_text = m.get("memory", "")[:250]
- score = m.get("score", 0)
- if mem_text:
- parts.append(f"- {mem_text}")
- if score > 0:
- parts.append(f" *(relevance: {score:.2f})*")
-
- # Hive context
- if cfg.include_hive and hive:
- try:
- hive_block = hive.get_context_block(limit=3)
- hive_insights = hive_block.get("hive_insights", [])
- hive_heuristics = hive_block.get("hive_heuristics", [])
-
- if hive_insights or hive_heuristics:
- parts.append("\n### Hive Knowledge (from other agents)")
- for hi in hive_insights:
- parts.append(
- f"- [insight from {hi['source']}] "
- f"{hi['content'].get('content', '')[:150]}"
- )
- for hh in hive_heuristics:
- parts.append(
- f"- [heuristic from {hh['source']}] "
- f"{hh['content'].get('heuristic', '')[:150]}"
- )
- except Exception as e:
- logger.debug("Hive context failed: %s", e)
-
- # Tool instructions
- if cfg.include_tool_instructions:
- parts.append("\n" + _TOOL_INSTRUCTIONS.strip())
-
- # Meta
- meta = ctx.get("meta", {})
- if meta:
- meta_parts = []
- for key in ["insight_count", "intention_count", "contrast_count", "heuristic_count"]:
- val = meta.get(key, 0)
- if val > 0:
- meta_parts.append(f"{key.replace('_count', '')}s: {val}")
- if meta_parts:
- parts.append(f"\n*Loaded: {', '.join(meta_parts)}*")
-
- if cfg.footer:
- parts.append(f"\n{cfg.footer}")
-
- return "\n".join(parts)
-
-
-def generate_minimal_snapshot(
- plugin: Any,
- task: Optional[str] = None,
- user_id: Optional[str] = None,
-) -> str:
- """Generate a minimal snapshot — just warnings, intentions, and top memories.
-
- Suitable for edge/embedded agents with tight context budgets.
- """
- cfg = SnapshotConfig(
- include_performance=False,
- include_insights=False,
- include_contrasts=False,
- include_heuristics=False,
- max_warnings=3,
- max_intentions=3,
- max_memories=3,
- header="## Dhee Context (Minimal)",
- )
- return generate_snapshot(plugin, task=task, user_id=user_id, config=cfg)
diff --git a/dhee/api/__init__.py b/dhee/api/__init__.py
deleted file mode 100644
index 017dcb4..0000000
--- a/dhee/api/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-"""Engram REST API module."""
-
-from dhee.api.app import app
-from dhee.api.server import run
-
-__all__ = ["app", "run"]
diff --git a/dhee/api/app.py b/dhee/api/app.py
deleted file mode 100644
index f49f382..0000000
--- a/dhee/api/app.py
+++ /dev/null
@@ -1,198 +0,0 @@
-"""Engram core REST API — lightweight handoff endpoints (no auth required).
-
-These endpoints mirror the enterprise ``/v1/handoff/*`` routes but delegate
-directly to ``engram.core.kernel`` without session/token enforcement.
-They are intended for local development and for the ``prompt_context.py`` hook
-which fires as a subprocess with no auth context.
-"""
-
-from __future__ import annotations
-
-import logging
-import os
-from typing import Any, Dict, List, Optional
-
-from fastapi import FastAPI, Query
-from fastapi.middleware.cors import CORSMiddleware
-from pydantic import BaseModel, Field
-
-logger = logging.getLogger(__name__)
-
-app = FastAPI(
- title="Engram Core API",
- version="0.1.0",
- description="Lightweight handoff + memory endpoints.",
-)
-
-app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_methods=["*"],
- allow_headers=["*"],
-)
-
-
-# ---------------------------------------------------------------------------
-# Health
-# ---------------------------------------------------------------------------
-
-@app.get("/health")
-async def health():
- return {"status": "ok"}
-
-
-# ---------------------------------------------------------------------------
-# Request / response schemas
-# ---------------------------------------------------------------------------
-
-class CheckpointRequest(BaseModel):
- task_summary: Optional[str] = None
- event_type: str = "hook_checkpoint"
- agent_id: str = "claude-code"
- context_snapshot: Optional[str] = None
- repo_path: Optional[str] = None
- status: Optional[str] = None
- decisions_made: Optional[List[str]] = None
- files_touched: Optional[List[str]] = None
- todos_remaining: Optional[List[str]] = None
- blockers: Optional[List[str]] = None
- key_commands: Optional[List[str]] = None
- test_results: Optional[str] = None
-
-
-class RecoverRequest(BaseModel):
- repo_path: str
- agent_id: str = "claude-code"
-
-
-class SessionDigestRequest(BaseModel):
- task_summary: str
- repo: Optional[str] = None
- status: str = "active"
- agent_id: str = "claude-code"
- decisions_made: Optional[List[str]] = None
- files_touched: Optional[List[str]] = None
- todos_remaining: Optional[List[str]] = None
- blockers: Optional[List[str]] = None
- key_commands: Optional[List[str]] = None
- test_results: Optional[str] = None
-
-
-# ---------------------------------------------------------------------------
-# Handoff endpoints
-# ---------------------------------------------------------------------------
-
-@app.post("/v1/handoff/checkpoint")
-async def handoff_checkpoint(request: CheckpointRequest):
- """Receive a lightweight checkpoint from the hook or an agent.
-
- Creates a dhee-bus session (if needed) and writes a checkpoint snapshot.
- """
- from dhee.core.kernel import _get_bus
-
- bus = None
- try:
- bus = _get_bus()
-
- # Find or create a session for this agent
- session = bus.get_session(
- agent_id=request.agent_id,
- repo=request.repo_path,
- )
- if session is None:
- sid = bus.save_session(
- agent_id=request.agent_id,
- repo=request.repo_path,
- status=request.status or "active",
- task_summary=request.task_summary or "",
- )
- else:
- sid = session["id"]
- # Update task_summary if provided
- updates: Dict[str, Any] = {}
- if request.task_summary:
- updates["task_summary"] = request.task_summary
- if request.status:
- updates["status"] = request.status
- if updates:
- bus.update_session(sid, **updates)
-
- snapshot = {
- "event_type": request.event_type,
- "task_summary": request.task_summary,
- "context_snapshot": request.context_snapshot,
- "files_touched": request.files_touched or [],
- "key_commands": request.key_commands or [],
- "decisions_made": request.decisions_made or [],
- "todos_remaining": request.todos_remaining or [],
- "blockers": request.blockers or [],
- "test_results": request.test_results,
- }
- cid = bus.checkpoint(sid, request.agent_id, snapshot)
-
- return {"status": "ok", "session_id": sid, "checkpoint_id": cid}
-
- except Exception as exc:
- logger.exception("Checkpoint failed")
- return {"status": "error", "detail": str(exc)}
- finally:
- if bus is not None:
- try:
- bus.close()
- except Exception:
- pass
-
-
-@app.get("/v1/handoff/sessions/last")
-async def handoff_last_session(
- agent_id: Optional[str] = Query(default=None),
- repo: Optional[str] = Query(default=None),
- fallback_log_recovery: bool = Query(default=True),
-):
- """Get the last session, falling back to JSONL log parsing."""
- from dhee.core.kernel import get_last_session
-
- session = get_last_session(
- agent_id=agent_id or "mcp-server",
- repo=repo,
- fallback_log_recovery=fallback_log_recovery,
- )
- if session is None:
- return {"status": "no_session", "message": "No previous session found."}
- return session
-
-
-@app.post("/v1/handoff/recover")
-async def handoff_recover(request: RecoverRequest):
- """Direct log recovery — parse JSONL logs without checking bus first."""
- from dhee.core.log_parser import find_latest_log, parse_conversation_log
-
- log_path = find_latest_log(request.repo_path)
- if log_path is None:
- return {"status": "no_logs", "message": "No conversation logs found."}
-
- digest = parse_conversation_log(log_path)
- if digest.get("message_count", 0) == 0:
- return {"status": "empty_log", "message": "Log file was empty."}
-
- return digest
-
-
-@app.post("/v1/handoff/sessions/digest")
-async def save_handoff_digest(request: SessionDigestRequest):
- """Save a session digest (lightweight, no auth)."""
- from dhee.core.kernel import save_session_digest
-
- result = save_session_digest(
- task_summary=request.task_summary,
- agent_id=request.agent_id,
- repo=request.repo,
- status=request.status,
- decisions_made=request.decisions_made,
- files_touched=request.files_touched,
- todos_remaining=request.todos_remaining,
- blockers=request.blockers,
- key_commands=request.key_commands,
- test_results=request.test_results,
- )
- return result
diff --git a/dhee/api/server.py b/dhee/api/server.py
deleted file mode 100644
index adc2e34..0000000
--- a/dhee/api/server.py
+++ /dev/null
@@ -1,34 +0,0 @@
-"""Engram core API server runner.
-
-Starts the lightweight FastAPI app from ``engram.api.app`` on the configured
-host/port. This is the standalone server — the enterprise version adds auth,
-governance, and more endpoints on top.
-"""
-
-from __future__ import annotations
-
-
-def run():
- """Run the Engram core API server."""
- import argparse
- import uvicorn
-
- parser = argparse.ArgumentParser(description="Engram Core API Server")
- parser.add_argument("--host", default="127.0.0.1", help="Host to bind to")
- parser.add_argument("--port", type=int, default=8100, help="Port to listen on")
- parser.add_argument("--reload", action="store_true", help="Enable auto-reload")
- args = parser.parse_args()
-
- print(f"Starting Engram Core API on http://{args.host}:{args.port}")
- print(f"Docs at http://{args.host}:{args.port}/docs")
-
- uvicorn.run(
- "engram.api.app:app",
- host=args.host,
- port=args.port,
- reload=args.reload,
- )
-
-
-if __name__ == "__main__":
- run()
diff --git a/dhee/api/static/dashboard.html b/dhee/api/static/dashboard.html
deleted file mode 100644
index fc0bc8b..0000000
--- a/dhee/api/static/dashboard.html
+++ /dev/null
@@ -1,736 +0,0 @@
-
-
-
-
-
-Engram Memory Visualizer
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
Memory Layers
-
Top Categories
-
Decay History
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/dhee/benchmarks/hippocamp.py b/dhee/benchmarks/hippocamp.py
new file mode 100644
index 0000000..ba0d2f5
--- /dev/null
+++ b/dhee/benchmarks/hippocamp.py
@@ -0,0 +1,1272 @@
+"""HippoCamp benchmark runner for Dhee.
+
+This runner targets the official HippoCamp release artifacts on Hugging Face.
+It supports:
+
+* ``gold``: the released ``HippoCamp_Gold`` parsed-text setting. Useful for
+ ablations, but easier than the paper's default raw-file benchmark.
+* ``raw``: raw file-tree ingestion using Dhee's own deterministic parsers and
+ metadata extraction. This follows the raw-file exposure boundary, but remains
+ only partially multimodal unless augmented with OCR / ASR / vision.
+
+Usage:
+ python -m dhee.benchmarks.hippocamp \
+ --config adam_subset \
+ --mode raw \
+ --embedder-provider simple \
+ --llm-provider mock \
+ --answer-strategy extractive
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import logging
+import math
+import os
+import re
+import time
+from dataclasses import dataclass
+from pathlib import Path, PurePosixPath
+from statistics import mean
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
+
+from dhee.benchmarks.longmemeval import build_memory
+from dhee.benchmarks.raw_extractors import raw_file_to_items
+from dhee.memory.utils import strip_code_fences
+from dhee.utils.factory import LLMFactory
+
+logger = logging.getLogger("dhee.benchmarks.hippocamp")
+
+DEFAULT_REPO_ID = "MMMem-org/HippoCamp"
+NO_INFO = "No information available"
+
+_JSON_BLOCK_RE = re.compile(r"\{.*\}", re.S)
+_TOKEN_RE = re.compile(r"[a-z0-9]+")
+_WHITESPACE_RE = re.compile(r"\s+")
+
+
+def _require_huggingface_hub():
+ try:
+ from huggingface_hub import hf_hub_download, list_repo_files
+ except ImportError as exc:
+ raise ImportError(
+ "HippoCamp benchmark support requires 'huggingface_hub'. "
+ "Install it with `pip install 'dhee[benchmarks]'` or `pip install huggingface_hub`."
+ ) from exc
+ return hf_hub_download, list_repo_files
+
+
+@dataclass(frozen=True)
+class HippoCampConfigSpec:
+ name: str
+ profile: str
+ manifest_path: str
+ environment_prefix: str
+
+
+CONFIG_SPECS: Dict[str, HippoCampConfigSpec] = {
+ "adam_fullset": HippoCampConfigSpec(
+ name="adam_fullset",
+ profile="Adam",
+ manifest_path="Adam/Fullset/Adam.json",
+ environment_prefix="Adam/Fullset/Adam/",
+ ),
+ "adam_subset": HippoCampConfigSpec(
+ name="adam_subset",
+ profile="Adam",
+ manifest_path="Adam/Subset/Adam_Subset.json",
+ environment_prefix="Adam/Subset/Adam_Subset/",
+ ),
+ "bei_fullset": HippoCampConfigSpec(
+ name="bei_fullset",
+ profile="Bei",
+ manifest_path="Bei/Fullset/Bei.json",
+ environment_prefix="Bei/Fullset/Bei/",
+ ),
+ "bei_subset": HippoCampConfigSpec(
+ name="bei_subset",
+ profile="Bei",
+ manifest_path="Bei/Subset/Bei_Subset.json",
+ environment_prefix="Bei/Subset/Bei_Subset/",
+ ),
+ "victoria_fullset": HippoCampConfigSpec(
+ name="victoria_fullset",
+ profile="Victoria",
+ manifest_path="Victoria/Fullset/Victoria.json",
+ environment_prefix="Victoria/Fullset/Victoria/",
+ ),
+ "victoria_subset": HippoCampConfigSpec(
+ name="victoria_subset",
+ profile="Victoria",
+ manifest_path="Victoria/Subset/Victoria_Subset.json",
+ environment_prefix="Victoria/Subset/Victoria_Subset/",
+ ),
+}
+
+
+def _configure_logging(level: str) -> None:
+ logging.basicConfig(
+ level=getattr(logging, level.upper(), logging.INFO),
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ )
+
+
+def _persist_json(path: Path, payload: Any) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with path.open("w", encoding="utf-8") as handle:
+ json.dump(payload, handle, ensure_ascii=False, indent=2)
+
+
+def _append_jsonl(path: Path, payload: Dict[str, Any]) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with path.open("a", encoding="utf-8") as handle:
+ handle.write(json.dumps(payload, ensure_ascii=False) + "\n")
+
+
+def _load_checkpoint(checkpoint_path: Path) -> Tuple[List[Dict[str, Any]], set]:
+ """Load completed records from a checkpoint JSONL file.
+
+ Returns (records_list, set_of_completed_question_ids).
+ """
+ records: List[Dict[str, Any]] = []
+ completed_ids: set = set()
+ if not checkpoint_path.exists():
+ return records, completed_ids
+ for line in checkpoint_path.read_text(encoding="utf-8").splitlines():
+ line = line.strip()
+ if not line:
+ continue
+ try:
+ record = json.loads(line)
+ except json.JSONDecodeError:
+ continue
+ if isinstance(record, dict) and record.get("id"):
+ records.append(record)
+ completed_ids.add(str(record["id"]))
+ logger.info(
+ "Resumed from checkpoint: %d completed questions loaded from %s",
+ len(completed_ids),
+ checkpoint_path,
+ )
+ return records, completed_ids
+
+
+def _memory_has_data(history_db_path: Path, user_id: str) -> bool:
+ """Check if the SQLite memory DB already has indexed data for the user."""
+ if not history_db_path.exists():
+ return False
+ try:
+ import sqlite3 as _sqlite3
+
+ conn = _sqlite3.connect(str(history_db_path))
+ cursor = conn.execute(
+ "SELECT count(*) FROM memories WHERE user_id = ?", (user_id,)
+ )
+ count = cursor.fetchone()[0]
+ conn.close()
+ return count > 0
+ except Exception:
+ return False
+
+
+def _load_env_file(env_path: Path) -> int:
+ if not env_path.exists() or not env_path.is_file():
+ return 0
+ loaded = 0
+ for raw_line in env_path.read_text(encoding="utf-8", errors="ignore").splitlines():
+ line = raw_line.strip()
+ if not line or line.startswith("#") or "=" not in line:
+ continue
+ key, value = line.split("=", 1)
+ key = key.strip()
+ value = value.strip().strip('"').strip("'")
+ if not key or key in os.environ:
+ continue
+ os.environ[key] = value
+ loaded += 1
+ return loaded
+
+
+def _resolve_config(name: str) -> HippoCampConfigSpec:
+ try:
+ return CONFIG_SPECS[name]
+ except KeyError as exc:
+ supported = ", ".join(sorted(CONFIG_SPECS))
+ raise ValueError(f"Unsupported HippoCamp config '{name}'. Supported: {supported}") from exc
+
+
+def _normalize_answer(text: str) -> str:
+ cleaned = strip_code_fences(str(text or "")).strip()
+ cleaned = cleaned.replace("Answer:", "").replace("Final answer:", "").strip()
+ cleaned = _WHITESPACE_RE.sub(" ", cleaned)
+ return cleaned
+
+
+def _normalize_for_match(text: str) -> str:
+ lowered = _normalize_answer(text).lower()
+ lowered = re.sub(r"[^a-z0-9\s]", " ", lowered)
+ return _WHITESPACE_RE.sub(" ", lowered).strip()
+
+
+def _tokenize(text: str) -> List[str]:
+ return _TOKEN_RE.findall(_normalize_for_match(text))
+
+
+def _preview_text(text: str, *, limit: int = 160) -> str:
+ cleaned = _WHITESPACE_RE.sub(" ", str(text or "")).strip()
+ if len(cleaned) <= limit:
+ return cleaned
+ return cleaned[: max(0, limit - 3)].rstrip() + "..."
+
+
+def token_f1(prediction: str, gold: str) -> float:
+ pred_tokens = _tokenize(prediction)
+ gold_tokens = _tokenize(gold)
+ if not pred_tokens and not gold_tokens:
+ return 1.0
+ if not pred_tokens or not gold_tokens:
+ return 0.0
+
+ pred_counts: Dict[str, int] = {}
+ gold_counts: Dict[str, int] = {}
+ for token in pred_tokens:
+ pred_counts[token] = pred_counts.get(token, 0) + 1
+ for token in gold_tokens:
+ gold_counts[token] = gold_counts.get(token, 0) + 1
+
+ overlap = 0
+ for token, pred_count in pred_counts.items():
+ overlap += min(pred_count, gold_counts.get(token, 0))
+ if overlap == 0:
+ return 0.0
+
+ precision = overlap / len(pred_tokens)
+ recall = overlap / len(gold_tokens)
+ return (2 * precision * recall) / (precision + recall)
+
+
+def exact_match(prediction: str, gold: str) -> bool:
+ return _normalize_for_match(prediction) == _normalize_for_match(gold)
+
+
+def file_retrieval_metrics(predicted_paths: Sequence[str], gold_paths: Sequence[str]) -> Dict[str, float]:
+ predicted = {str(path).strip() for path in predicted_paths if str(path).strip()}
+ gold = {str(path).strip() for path in gold_paths if str(path).strip()}
+
+ if not gold and not predicted:
+ return {"precision": 1.0, "recall": 1.0, "f1": 1.0}
+ if not predicted:
+ return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
+ if not gold:
+ return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
+
+ true_positives = len(predicted & gold)
+ precision = true_positives / len(predicted)
+ recall = true_positives / len(gold)
+ if precision + recall == 0:
+ f1 = 0.0
+ else:
+ f1 = (2 * precision * recall) / (precision + recall)
+ return {"precision": precision, "recall": recall, "f1": f1}
+
+
+def _extract_json_block(text: str) -> Optional[Dict[str, Any]]:
+ cleaned = strip_code_fences(str(text or "")).strip()
+ match = _JSON_BLOCK_RE.search(cleaned)
+ if not match:
+ return None
+ try:
+ payload = json.loads(match.group(0))
+ except json.JSONDecodeError:
+ return None
+ return payload if isinstance(payload, dict) else None
+
+
+def _make_gold_repo_path(profile: str, relative_path: str) -> str:
+ rel = PurePosixPath(str(relative_path))
+ return f"HippoCamp_Gold/{profile}/{rel.with_suffix('.json').as_posix()}"
+
+
+def _relative_path_from_environment(spec: HippoCampConfigSpec, repo_path: str) -> str:
+ if not str(repo_path).startswith(spec.environment_prefix):
+ raise ValueError(f"Path '{repo_path}' is outside '{spec.environment_prefix}'")
+ return str(repo_path)[len(spec.environment_prefix):]
+
+
+def _render_file_header(
+ *,
+ profile: str,
+ config_name: str,
+ relative_path: str,
+ file_info: Dict[str, Any],
+ summary: str,
+) -> str:
+ header_lines = [
+ "[HippoCamp Gold File]",
+ f"Profile: {profile}",
+ f"Config: {config_name}",
+ f"Relative Path: {relative_path}",
+ ]
+ for key, label in (
+ ("file_modality", "Modality"),
+ ("file_type", "File Type"),
+ ("creation_date", "Created"),
+ ("modification_date", "Modified"),
+ ("location", "Location"),
+ ("latitude", "Latitude"),
+ ("longitude", "Longitude"),
+ ):
+ value = str(file_info.get(key) or "").strip()
+ if value:
+ header_lines.append(f"{label}: {value}")
+ summary_text = str(summary or "").strip()
+ if summary_text:
+ header_lines.append(f"Summary: {summary_text[:1200]}")
+ return "\n".join(header_lines)
+
+
+def _segment_label(segment: Dict[str, Any], index: int) -> str:
+ labels = [f"segment={index}"]
+ for key in ("page", "timestamp", "frame", "sheet", "row"):
+ value = segment.get(key)
+ if value not in (None, "", []):
+ labels.append(f"{key}={value}")
+ return "[" + ", ".join(labels) + "]"
+
+
+def gold_document_to_items(
+ *,
+ doc: Dict[str, Any],
+ profile: str,
+ config_name: str,
+ relative_path: str,
+ chunk_chars: int,
+) -> List[Dict[str, Any]]:
+ file_info = dict(doc.get("file_info") or {})
+ summary = str(doc.get("summary") or "").strip()
+ header = _render_file_header(
+ profile=profile,
+ config_name=config_name,
+ relative_path=relative_path,
+ file_info=file_info,
+ summary=summary,
+ )
+
+ raw_segments = list(doc.get("segments") or [])
+ segment_blocks: List[str] = []
+ for index, segment in enumerate(raw_segments, start=1):
+ content = str(segment.get("content") or "").strip()
+ if not content:
+ continue
+ segment_blocks.append(f"{_segment_label(segment, index)}\n{content}")
+
+ if not segment_blocks:
+ if summary:
+ segment_blocks = [f"[segment=1]\n{summary}"]
+ else:
+ segment_blocks = [f"[segment=1]\n{relative_path}"]
+
+ items: List[Dict[str, Any]] = []
+ current_blocks: List[str] = []
+ current_len = 0
+ chunk_index = 1
+ target_chars = max(400, int(chunk_chars))
+
+ def flush_current() -> None:
+ nonlocal current_blocks, current_len, chunk_index
+ if not current_blocks:
+ return
+ body = "\n\n".join(current_blocks)
+ content = f"{header}\nChunk: {chunk_index}\n\nContent:\n{body}"
+ items.append(
+ {
+ "content": content,
+ "metadata": {
+ "benchmark": "hippocamp",
+ "exposure_mode": "gold_text",
+ "config_name": config_name,
+ "profile": profile,
+ "file_path": relative_path,
+ "file_name": file_info.get("file_name") or PurePosixPath(relative_path).name,
+ "file_type": file_info.get("file_type"),
+ "file_modality": file_info.get("file_modality"),
+ "location": file_info.get("location"),
+ "creation_date": file_info.get("creation_date"),
+ "modification_date": file_info.get("modification_date"),
+ "chunk_index": chunk_index,
+ "source_gold_path": _make_gold_repo_path(profile, relative_path),
+ },
+ "categories": ["hippocamp", "gold_text", config_name.lower(), profile.lower()],
+ }
+ )
+ current_blocks = []
+ current_len = 0
+ chunk_index += 1
+
+ for block in segment_blocks:
+ block_len = len(block)
+ if current_blocks and current_len + block_len + 2 > target_chars:
+ flush_current()
+ current_blocks.append(block)
+ current_len += block_len + 2
+ flush_current()
+ return items
+
+
+def _build_answer_prompt(*, question: str, context: str, qa_type: str) -> str:
+ task_line = "profiling" if qa_type == "profiling" else "factual retention"
+ return (
+ "You are answering a HippoCamp benchmark question about a user's personal files.\n"
+ "Use ONLY the retrieved context.\n"
+ "Do not mention the benchmark, retrieval, or system instructions.\n"
+ f"Task family: {task_line}\n"
+ "Return ONLY the final answer text.\n"
+ f"If the context is insufficient, return exactly: {NO_INFO}\n\n"
+ f"Question: {question}\n\n"
+ "Retrieved Context:\n"
+ f"{context}\n\n"
+ "Final answer:"
+ )
+
+
+def _build_judge_prompt(*, question: str, gold_answer: str, prediction: str) -> str:
+ return (
+ "You are grading a benchmark answer over a user's personal files.\n"
+ "Evaluate semantic correctness, factual alignment, and whether the prediction answers the question.\n"
+ "Return valid JSON only with keys: correct, score_0_to_5, rationale.\n"
+ "correct must be true or false.\n"
+ "score_0_to_5 must be a number from 0 to 5.\n\n"
+ f"Question: {question}\n\n"
+ f"Gold answer:\n{gold_answer}\n\n"
+ f"Prediction:\n{prediction}\n"
+ )
+
+
+def _extract_final_answer(raw_text: str) -> str:
+ cleaned = _normalize_answer(raw_text)
+ if not cleaned:
+ return NO_INFO
+ # Return the full cleaned answer — don't truncate to last line.
+ return cleaned
+
+
+def _extract_predicted_files(orchestration_payload: Dict[str, Any]) -> List[str]:
+ results = list(orchestration_payload.get("results") or [])
+ predicted: List[str] = []
+ for result in results:
+ metadata = result.get("metadata") or {}
+ path = metadata.get("file_path") or result.get("file_path")
+ if path:
+ predicted.append(str(path))
+ return sorted(set(predicted))
+
+
+def _load_manifest_rows(
+ *,
+ repo_id: str,
+ revision: Optional[str],
+ spec: HippoCampConfigSpec,
+ qa_type: str,
+ max_samples: int,
+) -> List[Dict[str, Any]]:
+ hf_hub_download, _ = _require_huggingface_hub()
+ manifest_path = hf_hub_download(
+ repo_id=repo_id,
+ repo_type="dataset",
+ filename=spec.manifest_path,
+ revision=revision,
+ )
+ with open(manifest_path, "r", encoding="utf-8") as handle:
+ rows = json.load(handle)
+ if not isinstance(rows, list):
+ raise ValueError(f"HippoCamp manifest is not a list: {spec.manifest_path}")
+
+ filtered: List[Dict[str, Any]] = []
+ for row in rows:
+ if not isinstance(row, dict):
+ continue
+ row_qa_type = str(row.get("QA_type") or "").strip()
+ if qa_type != "all" and row_qa_type != qa_type:
+ continue
+ filtered.append(row)
+ if max_samples > 0:
+ filtered = filtered[: max_samples]
+ return filtered
+
+
+def _list_environment_repo_files(
+ *,
+ repo_id: str,
+ revision: Optional[str],
+ spec: HippoCampConfigSpec,
+) -> List[str]:
+ _, hf_list_repo_files = _require_huggingface_hub()
+ files = hf_list_repo_files(repo_id=repo_id, repo_type="dataset", revision=revision)
+ return sorted(path for path in files if str(path).startswith(spec.environment_prefix))
+
+
+def _build_environment_items(
+ *,
+ repo_id: str,
+ revision: Optional[str],
+ spec: HippoCampConfigSpec,
+ max_environment_files: int,
+ chunk_chars: int,
+) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+ hf_hub_download, _ = _require_huggingface_hub()
+ repo_files = _list_environment_repo_files(repo_id=repo_id, revision=revision, spec=spec)
+ if max_environment_files > 0:
+ repo_files = repo_files[: max_environment_files]
+
+ items: List[Dict[str, Any]] = []
+ missing_gold: List[str] = []
+
+ for repo_path in repo_files:
+ relative_path = _relative_path_from_environment(spec, repo_path)
+ gold_repo_path = _make_gold_repo_path(spec.profile, relative_path)
+ try:
+ local_path = hf_hub_download(
+ repo_id=repo_id,
+ repo_type="dataset",
+ filename=gold_repo_path,
+ revision=revision,
+ )
+ except Exception:
+ missing_gold.append(relative_path)
+ continue
+
+ with open(local_path, "r", encoding="utf-8") as handle:
+ document = json.load(handle)
+ if not isinstance(document, dict):
+ logger.warning("Skipping malformed gold file: %s", gold_repo_path)
+ continue
+ items.extend(
+ gold_document_to_items(
+ doc=document,
+ profile=spec.profile,
+ config_name=spec.name,
+ relative_path=relative_path,
+ chunk_chars=chunk_chars,
+ )
+ )
+
+ metadata = {
+ "environment_repo_files": len(repo_files),
+ "environment_index_items": len(items),
+ "missing_gold_files": missing_gold,
+ }
+ return items, metadata
+
+
+def _build_raw_environment_items(
+ *,
+ repo_id: str,
+ revision: Optional[str],
+ spec: HippoCampConfigSpec,
+ max_environment_files: int,
+ chunk_chars: int,
+) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+ hf_hub_download, _ = _require_huggingface_hub()
+ repo_files = _list_environment_repo_files(repo_id=repo_id, revision=revision, spec=spec)
+ if max_environment_files > 0:
+ repo_files = repo_files[: max_environment_files]
+
+ items: List[Dict[str, Any]] = []
+ mode_counts: Dict[str, int] = {}
+ metadata_only_files: List[str] = []
+
+ for repo_path in repo_files:
+ relative_path = _relative_path_from_environment(spec, repo_path)
+ local_path = Path(
+ hf_hub_download(
+ repo_id=repo_id,
+ repo_type="dataset",
+ filename=repo_path,
+ revision=revision,
+ )
+ )
+ result = raw_file_to_items(
+ local_path=local_path,
+ relative_path=relative_path,
+ profile=spec.profile,
+ config_name=spec.name,
+ chunk_chars=chunk_chars,
+ )
+ mode_counts[result.mode] = mode_counts.get(result.mode, 0) + 1
+ if result.mode == "metadata_only":
+ metadata_only_files.append(relative_path)
+ items.extend(result.items)
+
+ metadata = {
+ "environment_repo_files": len(repo_files),
+ "environment_index_items": len(items),
+ "raw_extraction_modes": mode_counts,
+ "metadata_only_files": metadata_only_files,
+ }
+ return items, metadata
+
+
+def _make_llm(
+ *,
+ provider: str,
+ model: Optional[str],
+ max_tokens: int,
+ timeout: int,
+ temperature: float,
+ top_p: float,
+ enable_thinking: bool = False,
+) -> Any:
+ config: Dict[str, Any] = {
+ "max_tokens": max(32, int(max_tokens)),
+ "timeout": max(1, int(timeout)),
+ "temperature": temperature,
+ "top_p": top_p,
+ "enable_thinking": bool(enable_thinking),
+ }
+ if model:
+ config["model"] = model
+ return LLMFactory.create(provider, config)
+
+
+def _maybe_judge(
+ *,
+ judge_llm: Optional[Any],
+ question: str,
+ gold_answer: str,
+ prediction: str,
+) -> Optional[Dict[str, Any]]:
+ if judge_llm is None:
+ return None
+ prompt = _build_judge_prompt(question=question, gold_answer=gold_answer, prediction=prediction)
+ try:
+ raw = str(judge_llm.generate(prompt)).strip()
+ except Exception as exc:
+ logger.warning("Judge generation failed: %s", exc)
+ return None
+ payload = _extract_json_block(raw)
+ if not payload:
+ return None
+ score = payload.get("score_0_to_5")
+ try:
+ score_value = max(0.0, min(5.0, float(score)))
+ except (TypeError, ValueError):
+ score_value = None
+ return {
+ "correct": bool(payload.get("correct")),
+ "score_0_to_5": score_value,
+ "rationale": str(payload.get("rationale") or "").strip(),
+ "raw": raw,
+ }
+
+
+def _generate_prediction(
+ *,
+ answer_llm: Optional[Any],
+ answer_strategy: str,
+ question: str,
+ qa_type: str,
+ context: str,
+ reduced_answer: str,
+) -> str:
+ strategy = answer_strategy
+ if strategy == "auto":
+ strategy = "llm" if answer_llm is not None else "extractive"
+
+ if strategy == "extractive":
+ return _normalize_answer(reduced_answer) or NO_INFO
+
+ if answer_llm is None:
+ return _normalize_answer(reduced_answer) or NO_INFO
+
+ prompt = _build_answer_prompt(question=question, context=context, qa_type=qa_type)
+ try:
+ raw = str(answer_llm.generate(prompt)).strip()
+ except Exception as exc:
+ logger.warning("Answer generation failed: %s", exc)
+ return _normalize_answer(reduced_answer) or NO_INFO
+
+ prediction = _extract_final_answer(raw)
+ if prediction == NO_INFO and reduced_answer:
+ fallback = _normalize_answer(reduced_answer)
+ if fallback:
+ return fallback
+ return prediction
+
+
+def _score_record(
+ *,
+ row: Dict[str, Any],
+ prediction: str,
+ predicted_files: Sequence[str],
+ judge_payload: Optional[Dict[str, Any]],
+) -> Dict[str, Any]:
+ gold_answer = str(row.get("answer") or "").strip()
+ gold_files = [str(path) for path in row.get("file_path") or []]
+ retrieval = file_retrieval_metrics(predicted_files, gold_files)
+ record: Dict[str, Any] = {
+ "id": str(row.get("id") or ""),
+ "qa_type": str(row.get("QA_type") or "").strip(),
+ "profiling_type": str(row.get("profiling_type") or "").strip(),
+ "question": str(row.get("question") or "").strip(),
+ "gold_answer": gold_answer,
+ "prediction": prediction,
+ "exact_match": exact_match(prediction, gold_answer),
+ "token_f1": round(token_f1(prediction, gold_answer), 4),
+ "gold_files": gold_files,
+ "predicted_files": list(predicted_files),
+ "file_precision": round(retrieval["precision"], 4),
+ "file_recall": round(retrieval["recall"], 4),
+ "file_f1": round(retrieval["f1"], 4),
+ "agent_cap": row.get("agent_cap") or {},
+ }
+ if judge_payload is not None:
+ record["judge_correct"] = bool(judge_payload.get("correct"))
+ record["judge_score_0_to_5"] = judge_payload.get("score_0_to_5")
+ record["judge_rationale"] = judge_payload.get("rationale")
+ return record
+
+
+def _family_accuracy(records: Sequence[Dict[str, Any]], family: str) -> Optional[float]:
+ by_subcategory: Dict[str, List[bool]] = {}
+ for record in records:
+ if "judge_correct" not in record:
+ continue
+ labels = ((record.get("agent_cap") or {}).get(family) or [])
+ for label in labels:
+ by_subcategory.setdefault(str(label), []).append(bool(record["judge_correct"]))
+ if not by_subcategory:
+ return None
+ sub_scores = [sum(values) / len(values) for values in by_subcategory.values() if values]
+ return mean(sub_scores) if sub_scores else None
+
+
+def _summarize_records(records: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
+ if not records:
+ return {"count": 0}
+
+ summary: Dict[str, Any] = {
+ "count": len(records),
+ "exact_match": round(mean(1.0 if record["exact_match"] else 0.0 for record in records), 4),
+ "token_f1": round(mean(float(record["token_f1"]) for record in records), 4),
+ "file_precision": round(mean(float(record["file_precision"]) for record in records), 4),
+ "file_recall": round(mean(float(record["file_recall"]) for record in records), 4),
+ "file_f1": round(mean(float(record["file_f1"]) for record in records), 4),
+ }
+ if any("judge_correct" in record for record in records):
+ judged = [record for record in records if "judge_correct" in record]
+ summary["judge_accuracy"] = round(
+ mean(1.0 if record["judge_correct"] else 0.0 for record in judged),
+ 4,
+ )
+ judge_scores = [
+ float(record["judge_score_0_to_5"])
+ for record in judged
+ if record.get("judge_score_0_to_5") is not None
+ ]
+ if judge_scores:
+ summary["judge_avg_score_0_to_5"] = round(mean(judge_scores), 4)
+ summary["judge_avg_score_0_to_10"] = round(mean(score * 2.0 for score in judge_scores), 4)
+
+ capability_breakdown = {}
+ for family in ("search", "evidence_perception", "reasoning"):
+ family_score = _family_accuracy(judged, family)
+ if family_score is not None:
+ capability_breakdown[family] = round(family_score, 4)
+ if capability_breakdown:
+ summary["capability_accuracy"] = capability_breakdown
+ return summary
+
+
+def _emit_progress(progress_path: Optional[Path], event: str, **payload: Any) -> None:
+ if progress_path is None:
+ return
+ row = {
+ "ts": round(time.time(), 3),
+ "event": event,
+ **payload,
+ }
+ try:
+ _append_jsonl(progress_path, row)
+ except Exception as exc:
+ logger.warning("Failed to append progress event '%s' to %s: %s", event, progress_path, exc)
+
+
+def _run_config(args: argparse.Namespace, spec: HippoCampConfigSpec) -> Dict[str, Any]:
+ output_root = Path(args.output_dir)
+ history_db_path = output_root / f"{spec.name}.sqlite"
+ progress_path = output_root / f"{spec.name}.progress.jsonl"
+ checkpoint_path = output_root / f"{spec.name}.checkpoint.jsonl"
+ user_id = f"hippocamp_{spec.name}"
+ # Don't nuke progress — append to it for resumed runs.
+ # Load checkpoint of previously completed questions.
+ resumed_records, completed_ids = _load_checkpoint(checkpoint_path)
+
+ memory = build_memory(
+ llm_provider=args.llm_provider,
+ embedder_provider=args.embedder_provider,
+ vector_store_provider=args.vector_store_provider,
+ embedding_dims=args.embedding_dims,
+ history_db_path=str(history_db_path),
+ llm_model=args.llm_model,
+ llm_timeout=args.llm_timeout,
+ embedder_model=args.embedder_model,
+ full_potential=not args.minimal,
+ defer_enrichment=args.defer_enrichment,
+ enable_rerank=args.enable_rerank,
+ rerank_model=args.rerank_model,
+ rerank_config=None,
+ enable_episodic_index=not args.disable_episodic_index,
+ enable_hierarchical_retrieval=not args.disable_hierarchical_retrieval,
+ enable_orchestrated_search=not args.disable_orchestrated_search,
+ cost_guardrail_strict=not args.no_cost_guardrail_strict,
+ )
+
+ answer_llm = None
+ if args.answer_provider not in {"", "none", "mock"}:
+ answer_llm = _make_llm(
+ provider=args.answer_provider,
+ model=args.answer_model,
+ max_tokens=args.answer_max_tokens,
+ timeout=args.answer_timeout,
+ temperature=args.answer_temperature,
+ top_p=args.answer_top_p,
+ enable_thinking=args.answer_enable_thinking,
+ )
+
+ judge_llm = None
+ if args.judge_provider:
+ judge_llm = _make_llm(
+ provider=args.judge_provider,
+ model=args.judge_model,
+ max_tokens=args.judge_max_tokens,
+ timeout=args.judge_timeout,
+ temperature=args.judge_temperature,
+ top_p=args.judge_top_p,
+ enable_thinking=args.judge_enable_thinking,
+ )
+
+ rows = _load_manifest_rows(
+ repo_id=args.repo_id,
+ revision=args.revision,
+ spec=spec,
+ qa_type=args.qa_type,
+ max_samples=args.max_samples,
+ )
+
+ # When resuming with an existing index, skip the expensive environment download/parse.
+ can_skip_env = bool(completed_ids) and _memory_has_data(history_db_path, user_id)
+ if can_skip_env:
+ environment_items: List[Dict[str, Any]] = []
+ environment_meta: Dict[str, Any] = {"environment_repo_files": 0, "environment_index_items": 0, "resumed": True}
+ logger.info("RESUME: Skipping environment download — memory DB already populated for %s", spec.name)
+ elif args.mode == "gold":
+ environment_items, environment_meta = _build_environment_items(
+ repo_id=args.repo_id,
+ revision=args.revision,
+ spec=spec,
+ max_environment_files=args.max_environment_files,
+ chunk_chars=args.chunk_chars,
+ )
+ elif args.mode == "raw":
+ environment_items, environment_meta = _build_raw_environment_items(
+ repo_id=args.repo_id,
+ revision=args.revision,
+ spec=spec,
+ max_environment_files=args.max_environment_files,
+ chunk_chars=args.chunk_chars,
+ )
+ else:
+ raise ValueError(f"Unsupported mode: {args.mode}")
+
+ logger.info("Live progress %s -> %s", spec.name, progress_path)
+ _emit_progress(
+ progress_path,
+ "config_start",
+ config=spec.name,
+ profile=spec.profile,
+ mode=args.mode,
+ qa_type=args.qa_type,
+ question_count=len(rows),
+ environment_repo_files=environment_meta["environment_repo_files"],
+ environment_index_items=environment_meta["environment_index_items"],
+ )
+ # Skip re-indexing if the memory DB already has data for this user (resume mode).
+ skip_indexing = bool(completed_ids) and _memory_has_data(history_db_path, user_id)
+ if skip_indexing:
+ logger.info(
+ "RESUME: Skipping indexing for %s — memory DB exists with data, %d questions already completed",
+ spec.name,
+ len(completed_ids),
+ )
+ index_seconds = 0.0
+ else:
+ logger.info(
+ "Indexing %s: %d files -> %d memory items",
+ spec.name,
+ environment_meta["environment_repo_files"],
+ environment_meta["environment_index_items"],
+ )
+ _emit_progress(
+ progress_path,
+ "index_start",
+ config=spec.name,
+ environment_repo_files=environment_meta["environment_repo_files"],
+ environment_index_items=environment_meta["environment_index_items"],
+ )
+ t0 = time.time()
+ memory.delete_all(user_id=user_id)
+ memory.add_batch(items=environment_items, user_id=user_id)
+ index_seconds = time.time() - t0
+ _emit_progress(
+ progress_path,
+ "index_done",
+ config=spec.name,
+ index_seconds=round(index_seconds, 4),
+ environment=environment_meta,
+ )
+
+ if args.enrich_after_ingest:
+ try:
+ memory.enrich_pending(user_id=user_id, batch_size=args.enrich_batch_size, max_batches=args.enrich_max_batches)
+ _emit_progress(
+ progress_path,
+ "enrichment_done",
+ config=spec.name,
+ batch_size=args.enrich_batch_size,
+ max_batches=args.enrich_max_batches,
+ )
+ except Exception as exc:
+ logger.warning("Post-ingest enrichment failed for %s: %s", spec.name, exc)
+ _emit_progress(
+ progress_path,
+ "enrichment_failed",
+ config=spec.name,
+ error=str(exc),
+ )
+
+ records: List[Dict[str, Any]] = list(resumed_records)
+ skipped_count = 0
+ total_rows = len(rows)
+ for index, row in enumerate(rows, start=1):
+ row_id = str(row.get("id") or "")
+ if row_id and row_id in completed_ids:
+ skipped_count += 1
+ continue
+ question = str(row.get("question") or "").strip()
+ qa_type = str(row.get("QA_type") or "").strip() or "factual_retention"
+ _emit_progress(
+ progress_path,
+ "question_start",
+ config=spec.name,
+ question_index=index,
+ question_count=total_rows,
+ qa_type=qa_type,
+ question=_preview_text(question, limit=240),
+ )
+ payload = memory.search_orchestrated(
+ query=question,
+ user_id=user_id,
+ question_type=f"hippocamp-{qa_type}",
+ question_date="",
+ limit=args.top_k,
+ orchestration_mode=args.answer_orchestration_mode,
+ base_search_limit=args.top_k,
+ base_context_limit=args.answer_context_top_k,
+ search_cap=args.search_cap,
+ context_cap=args.context_cap,
+ map_max_candidates=args.map_max_candidates,
+ map_max_chars=args.map_max_chars,
+ keyword_search=True,
+ hybrid_alpha=0.7,
+ include_evidence=True,
+ evidence_strategy=args.evidence_strategy,
+ evidence_max_chars=args.evidence_max_chars,
+ evidence_context_lines=args.evidence_context_lines,
+ max_context_chars=args.max_context_chars,
+ rerank=args.enable_rerank,
+ orchestrator_llm=memory.llm if args.answer_orchestration_mode != "off" else None,
+ reflection_max_hops=1,
+ )
+ context = str(payload.get("context") or "").strip()
+ reduced_answer = str(payload.get("reduced_answer") or "").strip()
+ predicted_files = _extract_predicted_files(payload)
+ _emit_progress(
+ progress_path,
+ "search_done",
+ config=spec.name,
+ question_index=index,
+ question_count=total_rows,
+ search_result_count=len(payload.get("results") or []),
+ predicted_files=predicted_files,
+ reduced_answer_preview=_preview_text(reduced_answer),
+ )
+ _emit_progress(
+ progress_path,
+ "answer_start",
+ config=spec.name,
+ question_index=index,
+ question_count=total_rows,
+ strategy=args.answer_strategy,
+ provider=args.answer_provider or None,
+ model=args.answer_model or None,
+ )
+ prediction = _generate_prediction(
+ answer_llm=answer_llm,
+ answer_strategy=args.answer_strategy,
+ question=question,
+ qa_type=qa_type,
+ context=context,
+ reduced_answer=reduced_answer,
+ )
+ _emit_progress(
+ progress_path,
+ "answer_done",
+ config=spec.name,
+ question_index=index,
+ question_count=total_rows,
+ prediction_preview=_preview_text(prediction),
+ )
+ if judge_llm is not None:
+ _emit_progress(
+ progress_path,
+ "judge_start",
+ config=spec.name,
+ question_index=index,
+ question_count=total_rows,
+ provider=args.judge_provider or None,
+ model=args.judge_model or None,
+ )
+ judge_payload = _maybe_judge(
+ judge_llm=judge_llm,
+ question=question,
+ gold_answer=str(row.get("answer") or ""),
+ prediction=prediction,
+ )
+ if judge_llm is not None:
+ _emit_progress(
+ progress_path,
+ "judge_done",
+ config=spec.name,
+ question_index=index,
+ question_count=total_rows,
+ judge_correct=None if judge_payload is None else bool(judge_payload.get("correct")),
+ judge_score_0_to_5=None if judge_payload is None else judge_payload.get("score_0_to_5"),
+ judge_rationale_preview="" if judge_payload is None else _preview_text(judge_payload.get("rationale") or ""),
+ )
+ record = _score_record(
+ row=row,
+ prediction=prediction,
+ predicted_files=predicted_files,
+ judge_payload=judge_payload,
+ )
+ record["search_result_count"] = len(payload.get("results") or [])
+ record["question_index"] = index
+ records.append(record)
+ # Checkpoint: persist each completed record so we can resume.
+ _append_jsonl(checkpoint_path, record)
+ _emit_progress(
+ progress_path,
+ "question_done",
+ config=spec.name,
+ question_index=index,
+ question_count=total_rows,
+ exact_match=bool(record["exact_match"]),
+ token_f1=float(record["token_f1"]),
+ file_f1=float(record["file_f1"]),
+ judge_correct=record.get("judge_correct"),
+ judge_score_0_to_5=record.get("judge_score_0_to_5"),
+ )
+
+ if args.print_every > 0 and index % args.print_every == 0:
+ logger.info("Progress %s: %d/%d", spec.name, index, len(rows))
+
+ if skipped_count > 0:
+ logger.info(
+ "RESUME: Skipped %d already-completed questions, processed %d new for %s",
+ skipped_count,
+ len(records) - len(resumed_records),
+ spec.name,
+ )
+ overall = _summarize_records(records)
+ by_type: Dict[str, Dict[str, Any]] = {}
+ for qa_type in ("profiling", "factual_retention"):
+ subset = [record for record in records if record.get("qa_type") == qa_type]
+ if subset:
+ by_type[qa_type] = _summarize_records(subset)
+
+ _emit_progress(
+ progress_path,
+ "config_done",
+ config=spec.name,
+ overall=overall,
+ by_type=by_type,
+ progress_jsonl=str(progress_path),
+ )
+
+ return {
+ "config": spec.name,
+ "profile": spec.profile,
+ "mode": args.mode,
+ "repo_id": args.repo_id,
+ "revision": args.revision,
+ "qa_type_filter": args.qa_type,
+ "answer_strategy": args.answer_strategy,
+ "llm_provider": args.llm_provider,
+ "llm_model": args.llm_model,
+ "answer_provider": args.answer_provider or None,
+ "answer_model": args.answer_model or None,
+ "judge_provider": args.judge_provider or None,
+ "judge_model": args.judge_model or None,
+ "index_seconds": round(index_seconds, 4),
+ "environment": environment_meta,
+ "progress_jsonl": str(progress_path),
+ "overall": overall,
+ "by_type": by_type,
+ "records": records,
+ }
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="Run Dhee on the HippoCamp benchmark release.")
+ parser.add_argument("--repo-id", default=DEFAULT_REPO_ID)
+ parser.add_argument("--revision", default=None)
+ parser.add_argument(
+ "--config",
+ action="append",
+ dest="configs",
+ choices=sorted(CONFIG_SPECS),
+ help="HippoCamp config(s) to run. May be passed multiple times.",
+ )
+ parser.add_argument("--mode", choices=["gold", "raw"], default="raw")
+ parser.add_argument("--qa-type", choices=["all", "profiling", "factual_retention"], default="all")
+ parser.add_argument("--max-samples", type=int, default=-1)
+ parser.add_argument("--max-environment-files", type=int, default=-1)
+ parser.add_argument("--chunk-chars", type=int, default=2600)
+ parser.add_argument("--output-dir", default="runs/hippocamp")
+ parser.add_argument("--print-every", type=int, default=10)
+
+ # --- Full-power Dhee defaults: NVIDIA embedder + reranker, all features ON ---
+ parser.add_argument("--llm-provider", default="nvidia")
+ parser.add_argument("--llm-model", default="meta/llama-3.3-70b-instruct")
+ parser.add_argument("--llm-timeout", type=int, default=240)
+ parser.add_argument("--embedder-provider", default="nvidia")
+ parser.add_argument("--embedder-model", default="nvidia/llama-nemotron-embed-vl-1b-v2")
+ parser.add_argument("--embedding-dims", type=int, default=2048)
+ parser.add_argument("--vector-store-provider", choices=["memory", "sqlite_vec"], default="memory")
+
+ parser.add_argument("--minimal", action="store_true", default=False)
+ # defer_enrichment=True: fast 0-LLM ingestion; batch enrichment runs after ingest via enrich_after_ingest
+ parser.add_argument("--defer-enrichment", dest="defer_enrichment", action="store_true", default=True)
+ parser.add_argument("--no-defer-enrichment", dest="defer_enrichment", action="store_false")
+ parser.add_argument("--enrich-after-ingest", dest="enrich_after_ingest", action="store_true", default=True)
+ parser.add_argument("--no-enrich-after-ingest", dest="enrich_after_ingest", action="store_false")
+ parser.add_argument("--enrich-batch-size", type=int, default=10)
+ parser.add_argument("--enrich-max-batches", type=int, default=200)
+ parser.add_argument("--enable-rerank", dest="enable_rerank", action="store_true", default=True)
+ parser.add_argument("--disable-rerank", dest="enable_rerank", action="store_false")
+ parser.add_argument("--rerank-model", default="nvidia/llama-3.2-nv-rerankqa-1b-v2")
+ parser.add_argument("--disable-episodic-index", action="store_true", default=False)
+ parser.add_argument("--disable-hierarchical-retrieval", action="store_true", default=False)
+ parser.add_argument("--disable-orchestrated-search", action="store_true", default=False)
+ parser.add_argument("--no-cost-guardrail-strict", action="store_true", default=False)
+
+ parser.add_argument("--top-k", type=int, default=12)
+ parser.add_argument("--answer-context-top-k", type=int, default=8)
+ parser.add_argument("--search-cap", type=int, default=30)
+ parser.add_argument("--context-cap", type=int, default=20)
+ parser.add_argument("--map-max-candidates", type=int, default=8)
+ parser.add_argument("--map-max-chars", type=int, default=1400)
+ parser.add_argument("--max-context-chars", type=int, default=28000)
+ parser.add_argument("--evidence-strategy", choices=["full", "vector_text", "snippet"], default="snippet")
+ parser.add_argument("--evidence-max-chars", type=int, default=3500)
+ parser.add_argument("--evidence-context-lines", type=int, default=1)
+ parser.add_argument("--answer-orchestration-mode", choices=["off", "hybrid", "strict"], default="hybrid")
+
+ parser.add_argument("--answer-strategy", choices=["auto", "llm", "extractive"], default="auto")
+ parser.add_argument("--answer-provider", default="nvidia")
+ parser.add_argument("--answer-model", default="meta/llama-3.3-70b-instruct")
+ parser.add_argument("--answer-timeout", type=int, default=240)
+ parser.add_argument("--answer-max-tokens", type=int, default=1024)
+ parser.add_argument("--answer-temperature", type=float, default=0.2)
+ parser.add_argument("--answer-top-p", type=float, default=0.7)
+ parser.add_argument("--answer-enable-thinking", dest="answer_enable_thinking", action="store_true", default=False)
+ parser.add_argument("--answer-disable-thinking", dest="answer_enable_thinking", action="store_false")
+
+ parser.add_argument("--judge-provider", default="nvidia")
+ parser.add_argument("--judge-model", default="deepseek-ai/deepseek-v3.1-terminus")
+ parser.add_argument("--judge-timeout", type=int, default=60)
+ parser.add_argument("--judge-max-tokens", type=int, default=2048)
+ parser.add_argument("--judge-temperature", type=float, default=0.2)
+ parser.add_argument("--judge-top-p", type=float, default=0.7)
+ parser.add_argument("--judge-enable-thinking", dest="judge_enable_thinking", action="store_true", default=False)
+ parser.add_argument("--judge-disable-thinking", dest="judge_enable_thinking", action="store_false")
+
+ parser.add_argument("--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"])
+ args = parser.parse_args()
+ if not args.configs:
+ args.configs = ["adam_subset"]
+ return args
+
+
+def main() -> None:
+ args = parse_args()
+ _configure_logging(args.log_level)
+ _load_env_file(Path.cwd() / ".env")
+ _load_env_file(Path(__file__).resolve().parents[2] / ".env")
+
+ output_root = Path(args.output_dir)
+ output_root.mkdir(parents=True, exist_ok=True)
+
+ started = time.time()
+ config_results = []
+ for config_name in args.configs:
+ spec = _resolve_config(config_name)
+ logger.info("Running HippoCamp config=%s mode=%s qa_type=%s", spec.name, args.mode, args.qa_type)
+ config_results.append(_run_config(args, spec))
+
+ summary = {
+ "runner": "dhee.benchmarks.hippocamp",
+ "repo_id": args.repo_id,
+ "revision": args.revision,
+ "configs": args.configs,
+ "mode": args.mode,
+ "qa_type": args.qa_type,
+ "elapsed_seconds": round(time.time() - started, 4),
+ "config_results": config_results,
+ }
+
+ summary_path = output_root / "summary.json"
+ predictions_path = output_root / "predictions.json"
+ flat_records = []
+ for config_result in config_results:
+ for record in config_result.get("records", []):
+ flat_records.append(
+ {
+ "config": config_result["config"],
+ "profile": config_result["profile"],
+ **record,
+ }
+ )
+
+ _persist_json(summary_path, summary)
+ _persist_json(predictions_path, flat_records)
+
+ print(
+ json.dumps(
+ {
+ "summary_json": str(summary_path),
+ "predictions_json": str(predictions_path),
+ "configs": args.configs,
+ "mode": args.mode,
+ "qa_type": args.qa_type,
+ "records": len(flat_records),
+ },
+ indent=2,
+ )
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/dhee/benchmarks/longmemeval.py b/dhee/benchmarks/longmemeval.py
index e4a7984..35947d4 100644
--- a/dhee/benchmarks/longmemeval.py
+++ b/dhee/benchmarks/longmemeval.py
@@ -38,7 +38,6 @@
from dhee.core.answer_orchestration import (
is_low_confidence_answer,
)
-from dhee.core.code_exec_counter import refine_count_with_code_exec
logger = logging.getLogger(__name__)
@@ -1771,25 +1770,7 @@ def run_longmemeval(args: argparse.Namespace) -> Dict[str, Any]:
except Exception as reg_exc:
logger.debug("Entity registry lookup failed for %s: %s", question_id, reg_exc)
- # 2) Code-exec counting (1 LLM call + deterministic exec)
code_exec_answer = None
- try:
- ce_llm = _code_exec_llm or _answer_llm or memory.llm
- code_exec_answer = refine_count_with_code_exec(
- llm=ce_llm,
- question=query,
- question_type=question_type,
- retrieved_context=context,
- draft_answer=hypothesis,
- question_date=question_date,
- )
- if code_exec_answer:
- logger.info(
- "Code-exec answer for %s: %r",
- question_id, code_exec_answer,
- )
- except Exception as ce_exc:
- logger.debug("Code-exec counting failed for %s: %s", question_id, ce_exc)
# 3) Pick best: code_exec > registry > existing refinement
if code_exec_answer and not _is_refusal(code_exec_answer):
diff --git a/dhee/benchmarks/raw_extractors.py b/dhee/benchmarks/raw_extractors.py
new file mode 100644
index 0000000..2435dd4
--- /dev/null
+++ b/dhee/benchmarks/raw_extractors.py
@@ -0,0 +1,505 @@
+"""Raw-file extraction helpers for HippoCamp-style benchmarks.
+
+These extractors operate only on released raw files. They intentionally avoid
+benchmark gold text and QA annotations.
+"""
+
+from __future__ import annotations
+
+import csv
+import json
+import logging
+import sqlite3
+import subprocess
+import zipfile
+from dataclasses import dataclass
+from email import policy
+from email.parser import BytesParser
+from pathlib import Path, PurePosixPath
+from typing import Any, Dict, List, Optional, Sequence
+from xml.etree import ElementTree as ET
+
+logger = logging.getLogger(__name__)
+
+
+TEXT_EXTENSIONS = {"txt", "md", "py", "log", "ics", "json", "csv"}
+IMAGE_EXTENSIONS = {"png", "jpg", "jpeg", "gif"}
+VIDEO_EXTENSIONS = {"mp4", "mkv"}
+AUDIO_EXTENSIONS = {"mp3", "wav", "m4a", "aac"}
+
+
+@dataclass
+class RawExtractionResult:
+ items: List[Dict[str, Any]]
+ mode: str
+ notes: List[str]
+
+
+def _run_command(args: Sequence[str]) -> str:
+ proc = subprocess.run(
+ list(args),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ check=False,
+ )
+ if proc.returncode != 0:
+ raise RuntimeError(proc.stderr.strip() or f"Command failed: {' '.join(args)}")
+ return proc.stdout
+
+
+def _chunk_blocks(
+ *,
+ header: str,
+ blocks: Sequence[str],
+ chunk_chars: int,
+ metadata: Dict[str, Any],
+ categories: Sequence[str],
+) -> List[Dict[str, Any]]:
+ items: List[Dict[str, Any]] = []
+ target_chars = max(400, int(chunk_chars))
+ current_blocks: List[str] = []
+ current_len = 0
+ chunk_index = 1
+
+ def flush() -> None:
+ nonlocal current_blocks, current_len, chunk_index
+ if not current_blocks:
+ return
+ body = "\n\n".join(current_blocks)
+ content = f"{header}\nChunk: {chunk_index}\n\nContent:\n{body}"
+ item_meta = dict(metadata)
+ item_meta["chunk_index"] = chunk_index
+ items.append({"content": content, "metadata": item_meta, "categories": list(categories)})
+ current_blocks = []
+ current_len = 0
+ chunk_index += 1
+
+ for raw_block in blocks:
+ block = str(raw_block or "").strip()
+ if not block:
+ continue
+ block_len = len(block)
+ if current_blocks and current_len + block_len + 2 > target_chars:
+ flush()
+ current_blocks.append(block)
+ current_len += block_len + 2
+ flush()
+ return items
+
+
+def _read_text_file(path: Path) -> str:
+ for encoding in ("utf-8", "utf-16", "latin-1"):
+ try:
+ return path.read_text(encoding=encoding)
+ except Exception:
+ continue
+ return path.read_text(encoding="utf-8", errors="ignore")
+
+
+def _extract_email_text(path: Path) -> str:
+ with path.open("rb") as handle:
+ message = BytesParser(policy=policy.default).parse(handle)
+ parts: List[str] = []
+ for header in ("subject", "from", "to", "cc", "date"):
+ value = message.get(header)
+ if value:
+ parts.append(f"{header.title()}: {value}")
+ body = message.get_body(preferencelist=("plain", "html"))
+ if body is not None:
+ try:
+ body_text = body.get_content()
+ except Exception:
+ body_text = ""
+ if body_text:
+ parts.extend(["", str(body_text)])
+ return "\n".join(parts).strip()
+
+
+def _extract_pdf_pages(path: Path) -> List[str]:
+ output = _run_command(["pdftotext", str(path), "-"])
+ return [page.strip() for page in output.split("\f") if page.strip()]
+
+
+def _extract_docx_text(path: Path) -> str:
+ try:
+ text = _run_command(["textutil", "-convert", "txt", "-stdout", str(path)]).strip()
+ if text:
+ return text
+ except Exception:
+ pass
+ return _extract_docx_ooxml_text(path)
+
+
+def _extract_docx_ooxml_text(path: Path) -> str:
+ with zipfile.ZipFile(path) as archive:
+ names = [name for name in archive.namelist() if name.startswith("word/") and name.endswith(".xml")]
+ parts: List[str] = []
+ for name in sorted(names):
+ if not any(key in name for key in ("document.xml", "header", "footer", "footnotes", "endnotes")):
+ continue
+ root = ET.fromstring(archive.read(name))
+ texts = [elem.text for elem in root.iter() if elem.tag.endswith("}t") and elem.text]
+ if texts:
+ parts.append("\n".join(texts))
+ return "\n\n".join(parts).strip()
+
+
+def _extract_pptx_ooxml_text(path: Path) -> str:
+ with zipfile.ZipFile(path) as archive:
+ slides = []
+ for name in sorted(archive.namelist()):
+ if not name.startswith("ppt/slides/slide") or not name.endswith(".xml"):
+ continue
+ root = ET.fromstring(archive.read(name))
+ texts = [elem.text for elem in root.iter() if elem.tag.endswith("}t") and elem.text]
+ if texts:
+ slides.append("\n".join(texts))
+ return "\n\n".join(f"[slide={idx + 1}]\n{text}" for idx, text in enumerate(slides)).strip()
+
+
+def _col_ref_to_index(col_ref: str) -> int:
+ result = 0
+ for ch in col_ref:
+ if not ch.isalpha():
+ break
+ result = result * 26 + (ord(ch.upper()) - ord("A") + 1)
+ return max(1, result)
+
+
+def _extract_xlsx_ooxml_text(path: Path) -> str:
+ with zipfile.ZipFile(path) as archive:
+ shared_strings: List[str] = []
+ if "xl/sharedStrings.xml" in archive.namelist():
+ root = ET.fromstring(archive.read("xl/sharedStrings.xml"))
+ for elem in root.iter():
+ if elem.tag.endswith("}t") and elem.text is not None:
+ shared_strings.append(elem.text)
+
+ sheet_names: Dict[str, str] = {}
+ if "xl/workbook.xml" in archive.namelist():
+ root = ET.fromstring(archive.read("xl/workbook.xml"))
+ for elem in root.iter():
+ if elem.tag.endswith("}sheet"):
+ rel_id = elem.attrib.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id", "")
+ if rel_id:
+ sheet_names[rel_id] = elem.attrib.get("name", rel_id)
+
+ rel_map: Dict[str, str] = {}
+ if "xl/_rels/workbook.xml.rels" in archive.namelist():
+ root = ET.fromstring(archive.read("xl/_rels/workbook.xml.rels"))
+ for elem in root.iter():
+ if elem.tag.endswith("}Relationship"):
+ rel_id = elem.attrib.get("Id", "")
+ target = elem.attrib.get("Target", "")
+ if rel_id and target:
+ rel_map[target] = sheet_names.get(rel_id, PurePosixPath(target).stem)
+
+ sheets_out: List[str] = []
+ for name in sorted(archive.namelist()):
+ if not name.startswith("xl/worksheets/") or not name.endswith(".xml"):
+ continue
+ root = ET.fromstring(archive.read(name))
+ rows_out: List[str] = []
+ for row in root.iter():
+ if not row.tag.endswith("}row"):
+ continue
+ values: Dict[int, str] = {}
+ for cell in row:
+ if not cell.tag.endswith("}c"):
+ continue
+ ref = cell.attrib.get("r", "")
+ col_idx = _col_ref_to_index(ref) if ref else (len(values) + 1)
+ cell_type = cell.attrib.get("t", "")
+ value_text = ""
+ value_elem = None
+ for child in cell:
+ if child.tag.endswith("}v"):
+ value_elem = child
+ break
+ if value_elem is not None and value_elem.text is not None:
+ raw_value = value_elem.text
+ if cell_type == "s":
+ try:
+ value_text = shared_strings[int(raw_value)]
+ except Exception:
+ value_text = raw_value
+ else:
+ value_text = raw_value
+ if value_text:
+ values[col_idx] = value_text
+ if values:
+ rows_out.append("\t".join(values[idx] for idx in sorted(values)))
+ if rows_out:
+ sheet_label = rel_map.get(name.replace("xl/", ""), Path(name).stem)
+ sheets_out.append(f"[sheet={sheet_label}]\n" + "\n".join(rows_out))
+ return "\n\n".join(sheets_out).strip()
+
+
+def _extract_ipynb_text(path: Path) -> str:
+ with path.open("r", encoding="utf-8") as handle:
+ notebook = json.load(handle)
+ parts: List[str] = []
+ for idx, cell in enumerate(notebook.get("cells") or [], start=1):
+ source = cell.get("source") or []
+ if isinstance(source, list):
+ text = "".join(str(part) for part in source)
+ else:
+ text = str(source)
+ text = text.strip()
+ if text:
+ parts.append(f"[{cell.get('cell_type', 'cell')} {idx}]\n{text}")
+ return "\n\n".join(parts).strip()
+
+
+def _extract_sqlite_text(path: Path, max_tables: int = 6, max_rows: int = 10) -> str:
+ conn = sqlite3.connect(str(path))
+ conn.row_factory = sqlite3.Row
+ try:
+ rows = conn.execute(
+ "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name"
+ ).fetchall()
+ parts: List[str] = []
+ for row in rows[:max_tables]:
+ table = str(row["name"])
+ escaped_table = table.replace("'", "''")
+ parts.append(f"[table={table}]")
+ schema = conn.execute(f"PRAGMA table_info('{escaped_table}')").fetchall()
+ if schema:
+ parts.append("columns: " + ", ".join(f"{col['name']}:{col['type']}" for col in schema))
+ try:
+ sample_rows = conn.execute(
+ f"SELECT * FROM '{escaped_table}' LIMIT {int(max_rows)}"
+ ).fetchall()
+ except Exception:
+ sample_rows = []
+ for sample in sample_rows:
+ payload = {key: sample[key] for key in sample.keys()}
+ parts.append(json.dumps(payload, ensure_ascii=False))
+ parts.append("")
+ return "\n".join(parts).strip()
+ finally:
+ conn.close()
+
+
+def _extract_csv_text(path: Path, max_rows: int = 100) -> str:
+ with path.open("r", encoding="utf-8", errors="ignore", newline="") as handle:
+ reader = csv.reader(handle)
+ rows = []
+ for idx, row in enumerate(reader):
+ if idx >= max_rows:
+ break
+ rows.append("\t".join(str(cell) for cell in row))
+ return "\n".join(rows).strip()
+
+
+def _extract_media_metadata(path: Path) -> str:
+ try:
+ output = _run_command(
+ [
+ "ffprobe",
+ "-v",
+ "error",
+ "-show_entries",
+ "format=filename,format_name,duration,size:stream=index,codec_type,codec_name,width,height,sample_rate,channels:stream_tags=language,title",
+ "-of",
+ "json",
+ str(path),
+ ]
+ )
+ return json.dumps(json.loads(output), ensure_ascii=False, indent=2)
+ except Exception:
+ return ""
+
+
+def _extract_subtitles(path: Path) -> str:
+ try:
+ output = _run_command(["ffmpeg", "-v", "error", "-i", str(path), "-map", "0:s:0", "-f", "srt", "-"])
+ except Exception:
+ return ""
+ lines: List[str] = []
+ for raw_line in output.splitlines():
+ line = raw_line.strip()
+ if not line or line.isdigit() or "-->" in line:
+ continue
+ lines.append(line)
+ return "\n".join(lines).strip()
+
+
+def _extract_image_metadata(path: Path) -> str:
+ parts: List[str] = []
+ try:
+ from PIL import Image
+
+ with Image.open(path) as image:
+ parts.append(f"format: {image.format}")
+ parts.append(f"size: {image.size[0]}x{image.size[1]}")
+ parts.append(f"mode: {image.mode}")
+ except Exception:
+ pass
+ return "\n".join(parts).strip()
+
+
+def _raw_header(*, profile: str, config_name: str, relative_path: str, local_path: Path, mode: str) -> str:
+ stat = local_path.stat()
+ ext = local_path.suffix.lower().lstrip(".")
+ return "\n".join(
+ [
+ "[HippoCamp Raw File]",
+ f"Profile: {profile}",
+ f"Config: {config_name}",
+ f"Relative Path: {relative_path}",
+ f"File Type: {ext}",
+ f"Extraction Mode: {mode}",
+ f"File Size Bytes: {stat.st_size}",
+ ]
+ )
+
+
+def _extract_audio_transcript(local_path: Path) -> Optional[str]:
+ """Transcribe audio using OpenAI whisper (local model). Returns None if not available."""
+ try:
+ import whisper # type: ignore[import-untyped]
+ except ImportError:
+ logger.debug("whisper not installed — skipping audio transcription for %s", local_path.name)
+ return None
+ try:
+ model = whisper.load_model("base")
+ result = model.transcribe(str(local_path), fp16=False)
+ text = str(result.get("text", "")).strip()
+ if not text:
+ return None
+ return f"[audio_transcript]\n{text}"
+ except Exception as exc:
+ logger.warning("Whisper transcription failed for %s: %s", local_path.name, exc)
+ return None
+
+
+def _extract_image_ocr(local_path: Path) -> Optional[str]:
+ """Extract text from image using pytesseract OCR. Returns None if not available."""
+ try:
+ from PIL import Image # type: ignore[import-untyped]
+ import pytesseract # type: ignore[import-untyped]
+ except ImportError:
+ logger.debug("pytesseract/Pillow not installed — skipping OCR for %s", local_path.name)
+ return None
+ try:
+ img = Image.open(local_path)
+ text = pytesseract.image_to_string(img).strip()
+ if not text or len(text) < 3:
+ return None
+ return f"[ocr_text]\n{text}"
+ except Exception as exc:
+ logger.warning("OCR failed for %s: %s", local_path.name, exc)
+ return None
+
+
+def raw_file_to_items(
+ *,
+ local_path: Path,
+ relative_path: str,
+ profile: str,
+ config_name: str,
+ chunk_chars: int,
+) -> RawExtractionResult:
+ ext = local_path.suffix.lower().lstrip(".")
+ mode = "text"
+ notes: List[str] = []
+ blocks: List[str] = []
+
+ try:
+ if ext in TEXT_EXTENSIONS:
+ blocks = [_extract_csv_text(local_path) if ext == "csv" else _read_text_file(local_path)]
+ elif ext == "eml":
+ blocks = [_extract_email_text(local_path)]
+ elif ext == "pdf":
+ blocks = [f"[page={idx + 1}]\n{page}" for idx, page in enumerate(_extract_pdf_pages(local_path))]
+ elif ext == "docx":
+ blocks = [_extract_docx_text(local_path)]
+ elif ext == "pptx":
+ blocks = [_extract_pptx_ooxml_text(local_path)]
+ elif ext == "xlsx":
+ blocks = [_extract_xlsx_ooxml_text(local_path)]
+ elif ext == "ipynb":
+ blocks = [_extract_ipynb_text(local_path)]
+ elif ext == "sqlite":
+ blocks = [_extract_sqlite_text(local_path)]
+ elif ext in VIDEO_EXTENSIONS:
+ subtitles = _extract_subtitles(local_path)
+ metadata = _extract_media_metadata(local_path)
+ if subtitles:
+ mode = "subtitle_text"
+ blocks = [subtitles]
+ if metadata:
+ blocks.append("[media_metadata]\n" + metadata)
+ else:
+ mode = "metadata_only"
+ blocks = ["[media_metadata]\n" + metadata] if metadata else [f"Raw {ext} video file."]
+ notes.append("No subtitle text extracted.")
+ elif ext in AUDIO_EXTENSIONS:
+ transcript = _extract_audio_transcript(local_path)
+ if transcript:
+ mode = "transcript"
+ blocks = [transcript]
+ metadata_text = _extract_media_metadata(local_path)
+ if metadata_text:
+ blocks.append("[media_metadata]\n" + metadata_text)
+ else:
+ mode = "metadata_only"
+ metadata_text = _extract_media_metadata(local_path)
+ blocks = ["[media_metadata]\n" + metadata_text] if metadata_text else [f"Raw {ext} audio file."]
+ notes.append("No speech-to-text available (whisper not installed or failed).")
+ elif ext in IMAGE_EXTENSIONS:
+ ocr_text = _extract_image_ocr(local_path)
+ if ocr_text:
+ mode = "ocr_text"
+ blocks = [ocr_text]
+ img_metadata = _extract_image_metadata(local_path)
+ if img_metadata:
+ blocks.append("[image_metadata]\n" + img_metadata)
+ else:
+ mode = "metadata_only"
+ img_metadata = _extract_image_metadata(local_path)
+ blocks = ["[image_metadata]\n" + img_metadata] if img_metadata else [f"Raw {ext} image file."]
+ notes.append("No OCR text extracted (pytesseract not installed or no text found).")
+ else:
+ mode = "metadata_only"
+ blocks = [f"Unsupported raw file type for text extraction: {ext or 'unknown'}"]
+ notes.append(f"Unsupported file extension: {ext or 'unknown'}")
+ except Exception as exc:
+ mode = "metadata_only"
+ blocks = [f"Extraction failed for {relative_path}: {exc}"]
+ notes.append(str(exc))
+
+ blocks = [block for block in blocks if str(block or "").strip()]
+ if not blocks:
+ blocks = [f"Empty extracted content for {relative_path}"]
+ notes.append("Empty extracted content.")
+
+ header = _raw_header(
+ profile=profile,
+ config_name=config_name,
+ relative_path=relative_path,
+ local_path=local_path,
+ mode=mode,
+ )
+ metadata = {
+ "benchmark": "hippocamp",
+ "exposure_mode": "raw_files_only",
+ "config_name": config_name,
+ "profile": profile,
+ "file_path": relative_path,
+ "file_name": local_path.name,
+ "file_type": ext or None,
+ "raw_extraction_mode": mode,
+ "raw_extraction_notes": notes,
+ }
+ categories = ["hippocamp", "raw_files_only", config_name.lower(), profile.lower()]
+ items = _chunk_blocks(
+ header=header,
+ blocks=blocks,
+ chunk_chars=chunk_chars,
+ metadata=metadata,
+ categories=categories,
+ )
+ return RawExtractionResult(items=items, mode=mode, notes=notes)
diff --git a/dhee/configs/base.py b/dhee/configs/base.py
index c017ee9..e3e80ad 100644
--- a/dhee/configs/base.py
+++ b/dhee/configs/base.py
@@ -591,12 +591,6 @@ def _positive_int(cls, v: int) -> int:
return max(1, int(v))
-def _get_teaching_config():
- """Lazy import to avoid circular dependency."""
- from dhee.teaching.config import TeachingConfig
- return TeachingConfig()
-
-
class MemoryConfig(BaseModel):
vector_store: VectorStoreConfig = Field(default_factory=VectorStoreConfig)
llm: LLMConfig = Field(default_factory=LLMConfig)
@@ -631,7 +625,6 @@ class MemoryConfig(BaseModel):
cost_guardrail: CostGuardrailConfig = Field(default_factory=CostGuardrailConfig)
skill: SkillConfig = Field(default_factory=SkillConfig)
task: TaskConfig = Field(default_factory=TaskConfig)
- teaching: "TeachingConfig" = Field(default_factory=lambda: _get_teaching_config())
metamemory: MetamemoryInlineConfig = Field(default_factory=MetamemoryInlineConfig)
prospective: ProspectiveInlineConfig = Field(default_factory=ProspectiveInlineConfig)
procedural: ProceduralInlineConfig = Field(default_factory=ProceduralInlineConfig)
@@ -675,12 +668,3 @@ def full(cls) -> "MemoryConfig":
return full_config()
-# Resolve forward reference for TeachingConfig
-def _rebuild_memory_config():
- try:
- from dhee.teaching.config import TeachingConfig
- MemoryConfig.model_rebuild()
- except ImportError:
- pass
-
-_rebuild_memory_config()
diff --git a/dhee/core/agi_loop.py b/dhee/core/agi_loop.py
deleted file mode 100644
index 12ea159..0000000
--- a/dhee/core/agi_loop.py
+++ /dev/null
@@ -1,145 +0,0 @@
-"""Dhee v3 — Cognitive Maintenance Cycle.
-
-Replaces the phantom AGI loop with honest, real maintenance operations.
-
-v2.2 had 8 steps, 6 of which imported non-existent engram_* packages.
-v3 runs only what actually exists:
- 1. Consolidation (active → passive, via safe consolidation engine)
- 2. Decay (forgetting curves)
-
-Planned but not yet implemented (will be added as real Job classes):
- - Anchor candidate resolution
- - Distillation promotion
- - Conflict scanning
- - Stale intention cleanup
-
-The old API surface (run_agi_cycle, get_system_health) is preserved
-for backward compatibility with existing callers.
-"""
-
-from __future__ import annotations
-
-import logging
-from datetime import datetime, timezone
-from typing import Any, Dict, Optional
-
-logger = logging.getLogger(__name__)
-
-
-def run_agi_cycle(
- memory: Any,
- user_id: str = "default",
- context: Optional[str] = None,
-) -> Dict[str, Any]:
- """Run one maintenance cycle. Only executes real subsystems.
-
- Args:
- memory: Dhee Memory instance
- user_id: User identifier for scoped operations
- context: Optional current context (reserved for future use)
-
- Returns:
- Dict with status of each step
- """
- now = datetime.now(timezone.utc).isoformat()
- results: Dict[str, Any] = {"timestamp": now, "user_id": user_id}
-
- # Step 1: Consolidation — run distillation (episodic → semantic)
- try:
- if hasattr(memory, "_kernel") and memory._kernel:
- consolidation = memory._kernel.sleep_cycle(user_id=user_id)
- results["consolidation"] = {"status": "ok", "result": consolidation}
- else:
- results["consolidation"] = {"status": "skipped", "reason": "no kernel"}
- except Exception as e:
- results["consolidation"] = {"status": "error", "error": str(e)}
-
- # Step 2: Decay — apply forgetting curves
- try:
- decay_result = memory.apply_decay(scope={"user_id": user_id})
- results["decay"] = {"status": "ok", "result": decay_result}
- except Exception as e:
- results["decay"] = {"status": "error", "error": str(e)}
-
- # Compute summary
- statuses = [
- v.get("status", "unknown")
- for v in results.values()
- if isinstance(v, dict) and "status" in v
- ]
- ok_count = statuses.count("ok")
- error_count = statuses.count("error")
- skipped_count = statuses.count("skipped")
-
- results["summary"] = {
- "ok": ok_count,
- "errors": error_count,
- "skipped": skipped_count,
- "total_subsystems": len(statuses),
- }
-
- return results
-
-
-def get_system_health(memory: Any, user_id: str = "default") -> Dict[str, Any]:
- """Report health status across real cognitive subsystems.
-
- Only reports subsystems that actually exist — no phantom package checks.
- """
- now = datetime.now(timezone.utc).isoformat()
- systems: Dict[str, Dict] = {}
-
- # Core memory
- try:
- stats = memory.get_stats(user_id=user_id)
- systems["core_memory"] = {"available": True, "stats": stats}
- except Exception as e:
- systems["core_memory"] = {"available": False, "error": str(e)}
-
- # Knowledge graph
- systems["knowledge_graph"] = {
- "available": (
- hasattr(memory, "knowledge_graph")
- and memory.knowledge_graph is not None
- ),
- }
- if systems["knowledge_graph"]["available"]:
- try:
- systems["knowledge_graph"]["stats"] = memory.knowledge_graph.stats()
- except Exception:
- pass
-
- # Cognition kernel
- has_kernel = hasattr(memory, "_kernel") and memory._kernel is not None
- systems["cognition_kernel"] = {"available": has_kernel}
- if has_kernel:
- try:
- systems["cognition_kernel"]["stats"] = memory._kernel.cognition_health(
- user_id=user_id
- )
- except Exception:
- pass
-
- # Active memory / consolidation
- systems["consolidation"] = {
- "available": (
- hasattr(memory, "_consolidation_engine")
- and memory._consolidation_engine is not None
- ),
- }
-
- # v3 stores (if wired)
- systems["v3_event_store"] = {
- "available": hasattr(memory, "_event_store") and memory._event_store is not None,
- }
-
- available = sum(1 for s in systems.values() if s.get("available"))
- total = len(systems)
-
- return {
- "timestamp": now,
- "systems": systems,
- "available": available,
- "total": total,
- "health_pct": round(available / total * 100, 1) if total else 0,
- }
diff --git a/dhee/core/anchor_resolver.py b/dhee/core/anchor_resolver.py
deleted file mode 100644
index 0cf1dad..0000000
--- a/dhee/core/anchor_resolver.py
+++ /dev/null
@@ -1,330 +0,0 @@
-"""Dhee v3 — Anchor Resolver: per-field candidates + confidence-weighted resolution.
-
-Makes context extraction fallible, revisable, and auditable.
-
-Instead of a single ContextAnchor with one confidence score, each field
-(era, place, time_absolute, activity, etc.) gets competing candidates.
-Resolution picks the best candidate per field. Re-anchoring is safe
-because raw events are never touched.
-
-Design contract:
- - Extraction produces candidates, not final truth
- - Same memory can hold alternate candidate anchors
- - Anchor correction does not mutate raw event history
- - Resolution is deterministic: highest confidence per field wins
- - Zero LLM calls — rule-based extraction + confidence scoring
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-import sqlite3
-import threading
-import uuid
-from contextlib import contextmanager
-from dataclasses import dataclass, field
-from datetime import datetime, timezone
-from typing import Any, Dict, List, Optional
-
-logger = logging.getLogger(__name__)
-
-
-def _utcnow_iso() -> str:
- return datetime.now(timezone.utc).isoformat()
-
-
-# Fields that can have competing candidates
-ANCHOR_FIELDS = frozenset({
- "era", "place", "place_type", "place_detail",
- "time_absolute", "time_range_start", "time_range_end",
- "time_derivation", "activity",
-})
-
-
-@dataclass
-class AnchorCandidate:
- """A proposed value for a single anchor field."""
-
- candidate_id: str
- anchor_id: str
- field_name: str
- field_value: str
- confidence: float = 0.5
- extractor_source: str = "default"
- source_event_ids: List[str] = field(default_factory=list)
- derivation_version: int = 1
- status: str = "pending" # pending | accepted | rejected | superseded
-
-
-class AnchorCandidateStore:
- """Manages per-field anchor candidates in the database."""
-
- def __init__(self, conn: sqlite3.Connection, lock: threading.RLock):
- self._conn = conn
- self._lock = lock
-
- @contextmanager
- def _tx(self):
- with self._lock:
- try:
- yield self._conn
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
-
- def submit(
- self,
- anchor_id: str,
- field_name: str,
- field_value: str,
- *,
- confidence: float = 0.5,
- extractor_source: str = "default",
- source_event_ids: Optional[List[str]] = None,
- ) -> str:
- """Submit a candidate for an anchor field. Returns candidate_id."""
- if field_name not in ANCHOR_FIELDS:
- raise ValueError(f"Invalid anchor field: {field_name}")
-
- cid = str(uuid.uuid4())
- with self._tx() as conn:
- conn.execute(
- """INSERT INTO anchor_candidates
- (candidate_id, anchor_id, field_name, field_value,
- confidence, extractor_source, source_event_ids,
- derivation_version, status, created_at)
- VALUES (?, ?, ?, ?, ?, ?, ?, 1, 'pending', ?)""",
- (
- cid, anchor_id, field_name, field_value,
- confidence, extractor_source,
- json.dumps(source_event_ids or []),
- _utcnow_iso(),
- ),
- )
- return cid
-
- def get_candidates(
- self,
- anchor_id: str,
- field_name: Optional[str] = None,
- *,
- status: Optional[str] = None,
- ) -> List[Dict[str, Any]]:
- """Get candidates for an anchor, optionally filtered by field and status."""
- query = "SELECT * FROM anchor_candidates WHERE anchor_id = ?"
- params: list = [anchor_id]
- if field_name:
- query += " AND field_name = ?"
- params.append(field_name)
- if status:
- query += " AND status = ?"
- params.append(status)
- query += " ORDER BY confidence DESC"
-
- with self._lock:
- rows = self._conn.execute(query, params).fetchall()
- return [self._row_to_dict(r) for r in rows]
-
- def set_status(self, candidate_id: str, status: str) -> bool:
- with self._lock:
- try:
- result = self._conn.execute(
- "UPDATE anchor_candidates SET status = ? WHERE candidate_id = ?",
- (status, candidate_id),
- )
- self._conn.commit()
- return result.rowcount > 0
- except Exception:
- self._conn.rollback()
- raise
-
- @staticmethod
- def _row_to_dict(row: sqlite3.Row) -> Dict[str, Any]:
- source_ids = row["source_event_ids"]
- if isinstance(source_ids, str):
- try:
- source_ids = json.loads(source_ids)
- except (json.JSONDecodeError, TypeError):
- source_ids = []
- return {
- "candidate_id": row["candidate_id"],
- "anchor_id": row["anchor_id"],
- "field_name": row["field_name"],
- "field_value": row["field_value"],
- "confidence": row["confidence"],
- "extractor_source": row["extractor_source"],
- "source_event_ids": source_ids,
- "derivation_version": row["derivation_version"],
- "status": row["status"],
- "created_at": row["created_at"],
- }
-
-
-class AnchorResolver:
- """Resolves anchor fields from competing candidates.
-
- Resolution strategy:
- 1. For each anchor field, find all pending/accepted candidates
- 2. Pick the highest-confidence candidate per field
- 3. Mark winner as 'accepted', losers as 'superseded'
- 4. Update the anchor row with resolved values
-
- Re-resolution: when new candidates arrive (correction, new evidence),
- call resolve() again — it re-evaluates all candidates.
- """
-
- def __init__(
- self,
- candidate_store: AnchorCandidateStore,
- anchor_store: "AnchorStore",
- ):
- self.candidates = candidate_store
- self.anchors = anchor_store
-
- def resolve(self, anchor_id: str) -> Dict[str, Any]:
- """Resolve all fields for an anchor. Returns the resolved field values.
-
- Steps:
- 1. Get all non-rejected candidates grouped by field
- 2. For each field, pick highest confidence
- 3. Mark winners as accepted, others as superseded
- 4. Update anchor with resolved values
- """
- resolved: Dict[str, str] = {}
- resolution_details: Dict[str, Dict[str, Any]] = {}
-
- for field_name in ANCHOR_FIELDS:
- candidates = self.candidates.get_candidates(
- anchor_id, field_name
- )
- # Filter to only pending/accepted (not rejected)
- active = [
- c for c in candidates
- if c["status"] in ("pending", "accepted")
- ]
-
- if not active:
- continue
-
- # Sort by confidence descending, then by created_at ascending (earlier = better tiebreak)
- active.sort(key=lambda c: (-c["confidence"], c["created_at"]))
- winner = active[0]
-
- resolved[field_name] = winner["field_value"]
- resolution_details[field_name] = {
- "value": winner["field_value"],
- "confidence": winner["confidence"],
- "source": winner["extractor_source"],
- "candidate_id": winner["candidate_id"],
- "competing_count": len(active),
- }
-
- # Mark winner accepted, others superseded
- for c in active:
- if c["candidate_id"] == winner["candidate_id"]:
- self.candidates.set_status(c["candidate_id"], "accepted")
- else:
- self.candidates.set_status(c["candidate_id"], "superseded")
-
- # Update anchor with resolved values
- if resolved:
- self.anchors.update_fields(anchor_id, **resolved)
-
- return {
- "anchor_id": anchor_id,
- "resolved_fields": resolved,
- "details": resolution_details,
- }
-
- def re_anchor(
- self,
- anchor_id: str,
- field_name: str,
- new_value: str,
- *,
- confidence: float = 0.9,
- source: str = "user_correction",
- source_event_ids: Optional[List[str]] = None,
- ) -> Dict[str, Any]:
- """Submit a correction for a specific field and re-resolve.
-
- The user says "no, the place was Bengaluru, not Ghazipur."
- This submits a high-confidence candidate and re-runs resolution.
- """
- # Submit the correction as a new candidate
- cid = self.candidates.submit(
- anchor_id=anchor_id,
- field_name=field_name,
- field_value=new_value,
- confidence=confidence,
- extractor_source=source,
- source_event_ids=source_event_ids,
- )
-
- # Re-resolve just this field (and all others for consistency)
- result = self.resolve(anchor_id)
- result["correction_candidate_id"] = cid
- return result
-
- def extract_and_submit(
- self,
- anchor_id: str,
- content: str,
- *,
- source_event_ids: Optional[List[str]] = None,
- ) -> List[str]:
- """Rule-based extraction of anchor candidates from content.
-
- Extracts candidates for each field it can identify. Returns
- list of candidate_ids.
-
- Zero LLM calls — keyword/pattern matching only.
- """
- candidates_created: List[str] = []
- eids = source_event_ids or []
- lower = content.lower()
-
- # Activity detection
- activity_keywords = {
- "coding": ["coding", "programming", "debug", "commit", "deploy", "refactor"],
- "meeting": ["meeting", "standup", "call", "sync", "discussion"],
- "research": ["research", "reading", "paper", "study", "learn"],
- "travel": ["travel", "flight", "airport", "train", "driving"],
- "writing": ["writing", "blog", "document", "email", "report"],
- }
- for activity, keywords in activity_keywords.items():
- matches = sum(1 for kw in keywords if kw in lower)
- if matches >= 1:
- confidence = min(0.3 + 0.15 * matches, 0.85)
- cid = self.candidates.submit(
- anchor_id=anchor_id,
- field_name="activity",
- field_value=activity,
- confidence=confidence,
- extractor_source="keyword_activity",
- source_event_ids=eids,
- )
- candidates_created.append(cid)
-
- # Place type detection
- place_types = {
- "office": ["office", "workplace", "desk", "cubicle"],
- "home": ["home", "house", "apartment", "flat"],
- "school": ["school", "university", "college", "campus", "class"],
- "travel": ["airport", "station", "hotel", "flight"],
- }
- for ptype, keywords in place_types.items():
- if any(kw in lower for kw in keywords):
- cid = self.candidates.submit(
- anchor_id=anchor_id,
- field_name="place_type",
- field_value=ptype,
- confidence=0.6,
- extractor_source="keyword_place_type",
- source_event_ids=eids,
- )
- candidates_created.append(cid)
-
- return candidates_created
diff --git a/dhee/core/answer_orchestration.py b/dhee/core/answer_orchestration.py
index 1fff68b..095c064 100644
--- a/dhee/core/answer_orchestration.py
+++ b/dhee/core/answer_orchestration.py
@@ -1,41 +1,24 @@
-"""Answer-time orchestration utilities for memory-heavy QA.
+"""Query planning utilities for orchestrated search.
-This module is benchmark-agnostic and can be reused by runtime APIs.
-It provides:
-- lightweight query-intent routing
-- optional query rewriting for retrieval
-- map stage (atomic fact extraction)
-- deterministic reducers for high-leverage question types
+Provides intent classification and query rewriting to improve retrieval.
+No answer synthesis — Dhee retrieves context, the agent answers.
"""
from __future__ import annotations
-import json
-import logging
import re
from dataclasses import dataclass
-from datetime import date, datetime, timezone
from enum import Enum
-from typing import Any, Dict, List, Optional, Sequence, Tuple
+from typing import Optional
-logger = logging.getLogger(__name__)
-
-_SESSION_ID_RE = re.compile(r"^Session ID:\s*(?P\S+)\s*$", re.MULTILINE)
_RECENT_QUERY_RE = re.compile(r"\b(latest|most recent(?:ly)?|currently|current|recent(?:ly)?|as of|last)\b", re.I)
-# Superlative patterns: "the most", "the least", "the first", "the last"
-# These require ARGMAX/ARGMIN — not just listing set members.
_SUPERLATIVE_RE = re.compile(
r"\b(?:the\s+)?(?:most|least|fewest|highest|lowest|biggest|smallest|first|last"
r"|(?:fly|flew|visit|use|eat|watch|play|buy|read|drive|travel)\w*\s+(?:the\s+)?most)\b",
re.I,
)
_LOW_CONFIDENCE_RE = re.compile(
- r"\b(i\s+don['’]?t\s+know|not\s+enough\s+information|insufficient\s+information|unknown|cannot\s+determine)\b",
- re.I,
-)
-_MONEY_RE = re.compile(r"[-+]?\$?\s*(\d{1,3}(?:,\d{3})*|\d+)(?:\.(\d+))?")
-_DURATION_RE = re.compile(
- r"([-+]?\d+(?:\.\d+)?)\s*(years?|months?|weeks?|days?|hours?|minutes?)",
+ r"\b(i\s+don['']?t\s+know|not\s+enough\s+information|insufficient\s+information|unknown|cannot\s+determine)\b",
re.I,
)
@@ -46,19 +29,17 @@ class AnswerIntent(str, Enum):
DURATION = "duration"
LATEST = "latest"
SET_MEMBERS = "set_members"
+ ANALYSIS = "analysis"
FREEFORM = "freeform"
-_NUMERIC_INTENTS = {AnswerIntent.COUNT, AnswerIntent.MONEY_SUM, AnswerIntent.DURATION}
-
-
@dataclass
class QueryPlan:
intent: AnswerIntent
rewritten_query: str
search_limit: int
context_limit: int
- should_map_reduce: bool
+ should_map_reduce: bool # kept for API compat; always False in new path
def classify_answer_intent(question: str, question_type: str = "") -> AnswerIntent:
@@ -68,36 +49,23 @@ def classify_answer_intent(question: str, question_type: str = "") -> AnswerInte
if not q:
return AnswerIntent.FREEFORM
- # DURATION must be checked BEFORE money — "how much time did I spend" is duration, not money.
if re.search(r"\b(how long|duration|elapsed|time spent|total years?|total months?)\b", q):
return AnswerIntent.DURATION
if re.search(r"\bhow much time\b", q):
return AnswerIntent.DURATION
-
- # "How many days/months/weeks ago" or "how many days between X and Y"
- # are temporal-duration questions, not counting questions.
if re.search(r"\bhow many\s+(days?|weeks?|months?|years?|hours?|minutes?)\b", q):
return AnswerIntent.DURATION
- # Money: strict signals only. "spend/spent" + time words is DURATION (caught above).
- # Exclude "days/hours spent" — that's DURATION, not money.
money_signals = bool(
re.search(r"\b(money|dollars?|usd|spent|spend|cost|price)\b", q)
)
if money_signals and re.search(r"\b(how much|total|sum|spent|cost)\b", q):
- # "total number of days spent" is DURATION, not money
if re.search(r"\b(days?|weeks?|months?|years?|hours?|minutes?)\s+(spent|in)\b", q):
return AnswerIntent.DURATION
- # "what percentage" is FREEFORM, not money
if re.search(r"\bpercentage\b", q):
return AnswerIntent.FREEFORM
return AnswerIntent.MONEY_SUM
- # "How many [quantity-unit]" asks for a numeric VALUE, not a COUNT of distinct items.
- # COUNT = enumerate distinct items (cities, books, sports, doctors)
- # FREEFORM = read/compute a numeric value (points, pages, followers, views, copies)
- # Allow up to 3 modifier words between "how many" and the unit:
- # "how many Instagram followers" / "how many rare items" / "how many completed videos"
_QUANTITY_UNITS = (
r"points?|dollars?|credits?|tokens?|calories?|miles?|steps?|pounds?"
r"|kilograms?|grams?|liters?|gallons?|servings?|reps?|sets?"
@@ -112,48 +80,61 @@ def classify_answer_intent(question: str, question_type: str = "") -> AnswerInte
)
if _QUANTITY_UNITS_RE.search(q):
return AnswerIntent.FREEFORM
-
- # "[noun] count" pattern: "page count", "word count", "calorie count", "step count"
if re.search(r"\b(page|word|calorie|step|follower|subscriber|view|video|item)\s+count\b", q):
return AnswerIntent.FREEFORM
- # "What is the total number of [quantity]" needs arithmetic (sum), not item counting.
_TOTAL_QUANTITY_RE = re.compile(
r"\btotal\s+number\s+of\s+(?:\w+\s+){0,2}(" + _QUANTITY_UNITS + r")\b"
)
if _TOTAL_QUANTITY_RE.search(q):
return AnswerIntent.FREEFORM
- # Knowledge-update questions need LATEST intent — must check BEFORE "how many"
- # to prevent "How many times did X change?" from being routed to COUNT
if "knowledge-update" in qtype:
return AnswerIntent.LATEST
-
- # "How much" alone (without money signals) is a value question, not a count.
- # "How much is the painting worth?" → FREEFORM
- # "How much will I save?" → FREEFORM
if re.search(r"\bhow much\b", q):
return AnswerIntent.FREEFORM
-
if re.search(r"\b(how many|number of|count|total number)\b", q):
return AnswerIntent.COUNT
- # Superlative questions FIRST: "which X the most/least/first/last"
- # Must come before generic LATEST check — "the most last month" is COUNT, not LATEST.
if _SUPERLATIVE_RE.search(q):
- # Frequency superlatives → COUNT (need argmax)
if re.search(r"\b(the most|most often|most frequent)\b", q, re.I):
return AnswerIntent.COUNT
- # Temporal superlatives → LATEST (ordering by date)
if re.search(r"\b(most recent|first|earliest|latest|newest|oldest)\b", q, re.I):
return AnswerIntent.LATEST
if _RECENT_QUERY_RE.search(q):
return AnswerIntent.LATEST
-
if re.search(r"\b(which|what are|list|name all)\b", q):
return AnswerIntent.SET_MEMBERS
+ if re.search(
+ r"\b(analy[sz]e|clarify|summarize|explain\b.*\b(reasoning|basis|context|approach))"
+ r"|\b(legal opinion|legal basis|legal aid)"
+ r"|\b(comprehensive|in[- ]?depth)\b",
+ q,
+ ):
+ return AnswerIntent.ANALYSIS
+ if re.search(r"\bhow should (i|we)\s+(reply|respond|draft|write|prepare)\b", q):
+ return AnswerIntent.ANALYSIS
+ if re.search(
+ r"\b(check\s+(my|the|our)\s+\w+|flag\s+any|missing\s+or\s+incorrect"
+ r"|inconsisten|verify|validate|cross[- ]?check)\b",
+ q,
+ ):
+ return AnswerIntent.ANALYSIS
+ if re.search(
+ r"\b(my\s+(routine|usual|typical|approach|process|habit|workflow))"
+ r"|\b(how\s+(do|did)\s+i\s+usually)\b"
+ r"|\b(what\s+(is|are)\s+my\s+\w*\s*(like|routine|habit|approach))\b"
+ r"|\b(can you\s+(take a look|check|help me))\b",
+ q,
+ ):
+ return AnswerIntent.ANALYSIS
+ if re.search(r"\b(categorize|categori[sz]e|filter|organize|sort|group)\b", q):
+ return AnswerIntent.ANALYSIS
+ if "hippocamp" in qtype and "profiling" in qtype:
+ return AnswerIntent.ANALYSIS
+
return AnswerIntent.FREEFORM
@@ -163,7 +144,6 @@ def rewrite_query_for_intent(question: str, intent: AnswerIntent) -> str:
return q
if intent == AnswerIntent.COUNT:
- # If the original question asks "which X the most", the count is for argmax
if _SUPERLATIVE_RE.search(q):
return (
f"{q}\nList each occurrence of each distinct item across ALL sessions. "
@@ -182,7 +162,12 @@ def rewrite_query_for_intent(question: str, intent: AnswerIntent) -> str:
)
if intent == AnswerIntent.SET_MEMBERS:
return f"{q}\nList all distinct relevant items with deduplication."
- # FREEFORM: add derivation instruction for computation questions
+ if intent == AnswerIntent.ANALYSIS:
+ return (
+ f"{q}\nExtract all relevant facts, details, and context from the retrieved documents. "
+ f"Synthesize a comprehensive, grounded answer using ONLY information found in the evidence. "
+ f"Cite specific details (names, dates, amounts, file references) where available."
+ )
if re.search(r"\bwhat time\b", q, re.I):
return (
f"{q}\nIf the exact answer is not stated directly, COMPUTE it from the "
@@ -213,16 +198,14 @@ def build_query_plan(
AnswerIntent.DURATION,
AnswerIntent.LATEST,
AnswerIntent.SET_MEMBERS,
+ AnswerIntent.ANALYSIS,
}
- # Also expand FREEFORM questions that need multi-fact derivation.
- # "how many" questions may be classified FREEFORM (e.g., "how many items"
- # hits the quantity-unit filter) but still need map-reduce to aggregate
- # across multiple sessions. Same for multi-session aggregation patterns.
if not should_expand and intent == AnswerIntent.FREEFORM:
q_lower = question.lower()
if re.search(
r"\b(what time|what day|what date|at what age|how many|how much"
- r"|total number|in total|all the|list all|what are all)\b",
+ r"|total number|in total|all the|list all|what are all"
+ r"|based on|according to|please help|can you help)\b",
q_lower,
):
should_expand = True
@@ -241,702 +224,5 @@ def build_query_plan(
)
-def build_map_candidates(
- results: Sequence[Dict[str, Any]],
- *,
- max_candidates: int,
- per_candidate_max_chars: int,
-) -> List[Dict[str, str]]:
- out: List[Dict[str, str]] = []
- for row in list(results)[: max(1, int(max_candidates))]:
- metadata = row.get("metadata") or {}
- session_id = str(metadata.get("session_id") or "").strip()
- memory_text = str(row.get("memory") or "")
- if not session_id and memory_text:
- match = _SESSION_ID_RE.search(memory_text)
- if match:
- session_id = match.group("session_id")
- session_date = str(
- metadata.get("event_time")
- or metadata.get("session_date")
- or metadata.get("event_date")
- or ""
- ).strip()
- evidence = str(row.get("evidence_text") or "").strip() or memory_text
- if not evidence.strip():
- continue
-
- out.append(
- {
- "session_id": session_id or "unknown",
- "session_date": session_date,
- "text": _truncate_text(evidence, per_candidate_max_chars),
- }
- )
- return out
-
-
-def _truncate_text(text: str, max_chars: int) -> str:
- if max_chars <= 0:
- return ""
- text = str(text or "")
- if len(text) <= max_chars:
- return text
- return text[: max_chars - 3].rstrip() + "..."
-
-
-def _extract_json_payload(raw: str) -> Optional[Any]:
- if not raw:
- return None
- raw = raw.strip()
-
- try:
- return json.loads(raw)
- except json.JSONDecodeError:
- pass
-
- obj_match = re.search(r"\{[\s\S]*\}", raw)
- if obj_match:
- candidate = obj_match.group(0)
- try:
- return json.loads(candidate)
- except json.JSONDecodeError:
- pass
-
- arr_match = re.search(r"\[[\s\S]*\]", raw)
- if arr_match:
- candidate = arr_match.group(0)
- try:
- return json.loads(candidate)
- except json.JSONDecodeError:
- pass
-
- return None
-
-
-def _normalize_bool(value: Any) -> bool:
- if isinstance(value, bool):
- return value
- if isinstance(value, (int, float)):
- return bool(value)
- if isinstance(value, str):
- return value.strip().lower() in {"1", "true", "yes", "y"}
- return False
-
-
-def _to_float(value: Any) -> Optional[float]:
- if value is None:
- return None
- if isinstance(value, (int, float)):
- return float(value)
- text = str(value).strip()
- if not text:
- return None
- text = text.replace(",", "")
- if text.startswith("$"):
- text = text[1:]
- try:
- return float(text)
- except ValueError:
- return None
-
-
-def _parse_event_datetime(value: Any) -> Optional[datetime]:
- if value is None:
- return None
- if isinstance(value, datetime):
- dt = value
- elif isinstance(value, date):
- dt = datetime.combine(value, datetime.min.time())
- else:
- text = str(value).strip()
- if not text:
- return None
- if text.endswith("Z"):
- text = text[:-1] + "+00:00"
- try:
- dt = datetime.fromisoformat(text)
- except ValueError:
- date_match = re.match(r"^(\d{4}-\d{2}-\d{2})", text)
- if not date_match:
- return None
- try:
- d = date.fromisoformat(date_match.group(1))
- except ValueError:
- return None
- dt = datetime.combine(d, datetime.min.time())
-
- if dt.tzinfo is None:
- dt = dt.replace(tzinfo=timezone.utc)
- else:
- dt = dt.astimezone(timezone.utc)
- return dt
-
-
-def extract_atomic_facts(
- *,
- llm: Any,
- question: str,
- question_type: str,
- question_date: str,
- candidates: Sequence[Dict[str, str]],
-) -> List[Dict[str, Any]]:
- if not candidates:
- return []
-
- candidate_blocks = []
- for idx, c in enumerate(candidates, start=1):
- candidate_blocks.append(
- "\n".join(
- [
- f"[Candidate {idx}] session_id={c.get('session_id', 'unknown')} date={c.get('session_date', '')}",
- c.get("text", ""),
- ]
- )
- )
-
- prompt = (
- "You are a fact extraction engine for memory QA.\n"
- "Extract only facts relevant to answering the question.\n"
- "IMPORTANT: Deduplicate facts. If the same item/event appears in "
- "multiple sessions, emit it ONCE with the canonical_key set.\n"
- "canonical_key = a short lowercase identifier for the unique item "
- "(e.g. 'boots_zara', 'blazer_dry_cleaning', 'project_alpha'). "
- "Same real-world item across sessions MUST share the same canonical_key.\n"
- "Return STRICT JSON only, no markdown.\n\n"
- f"Question: {question}\n"
- f"Question Type: {question_type or 'unknown'}\n"
- f"Question Date: {question_date or 'unknown'}\n\n"
- "Candidate Context:\n"
- + "\n\n".join(candidate_blocks)
- + "\n\n"
- "Required JSON schema:\n"
- "{\"facts\":["
- "{"
- "\"session_id\":\"string\","
- "\"event_date\":\"YYYY-MM-DD or empty\","
- "\"subject\":\"string\","
- "\"predicate\":\"string\","
- "\"value\":\"string\","
- "\"numeric_value\":0,"
- "\"unit\":\"string\","
- "\"currency\":\"string\","
- "\"canonical_key\":\"unique_item_id (REQUIRED, lowercase, e.g. boots_zara)\","
- "\"relevant\":true"
- "}"
- "]}\n"
- "Return an empty list if nothing relevant: {\"facts\":[]}"
- )
-
- try:
- raw = str(llm.generate(prompt)).strip()
- except Exception as exc:
- logger.warning("Map-stage fact extraction failed: %s", exc)
- return []
-
- logger.info("Map-stage raw LLM response (first 800 chars): %s", raw[:800])
- payload = _extract_json_payload(raw)
- if payload is None:
- logger.warning("Map-stage payload parse failed. Raw response (first 500 chars): %s", raw[:500])
- return []
-
- if isinstance(payload, list):
- facts_raw = payload
- elif isinstance(payload, dict):
- facts_raw = payload.get("facts")
- else:
- facts_raw = None
-
- if not isinstance(facts_raw, list):
- logger.warning("Map-stage facts_raw is not a list: %s", type(facts_raw))
- return []
-
- logger.info("Map-stage extracted %d raw facts from LLM", len(facts_raw))
- facts: List[Dict[str, Any]] = []
- for row in facts_raw:
- if not isinstance(row, dict):
- continue
- value = str(row.get("value") or "").strip()
- subject = str(row.get("subject") or "").strip()
- predicate = str(row.get("predicate") or "").strip()
- if not value and not subject and not predicate:
- continue
-
- facts.append(
- {
- "session_id": str(row.get("session_id") or "").strip(),
- "event_date": str(row.get("event_date") or "").strip(),
- "subject": subject,
- "predicate": predicate,
- "value": value,
- "numeric_value": _to_float(row.get("numeric_value")),
- "unit": str(row.get("unit") or "").strip().lower(),
- "currency": str(row.get("currency") or "").strip().upper(),
- "canonical_key": str(row.get("canonical_key") or "").strip(),
- "relevant": _normalize_bool(row.get("relevant", True)),
- }
- )
-
- logger.info("Map-stage final facts: %d (relevant=%d)", len(facts), sum(1 for f in facts if _normalize_bool(f.get("relevant", True))))
- if facts:
- for i, f in enumerate(facts[:5]):
- logger.info(" fact[%d]: subject=%s predicate=%s value=%s canonical_key=%s relevant=%s",
- i, f.get("subject", ""), f.get("predicate", ""), f.get("value", ""), f.get("canonical_key", ""), f.get("relevant"))
- return facts
-
-
-def _extract_money_value(text: str) -> Optional[float]:
- if not text:
- return None
- m = _MONEY_RE.search(text)
- if not m:
- return None
- whole = m.group(1).replace(",", "")
- frac = m.group(2)
- try:
- if frac:
- return float(f"{whole}.{frac}")
- return float(whole)
- except ValueError:
- return None
-
-
-def _extract_duration_value(text: str) -> Optional[Tuple[float, str]]:
- if not text:
- return None
- m = _DURATION_RE.search(text)
- if not m:
- return None
- try:
- return float(m.group(1)), m.group(2).lower()
- except ValueError:
- return None
-
-
-def _normalize_unit(unit: str) -> str:
- unit = str(unit or "").strip().lower()
- if unit.endswith("s"):
- unit = unit[:-1]
- aliases = {
- "yr": "year",
- "yrs": "year",
- "hr": "hour",
- "hrs": "hour",
- "min": "minute",
- "mins": "minute",
- }
- return aliases.get(unit, unit)
-
-
-def _duration_target_unit(question: str) -> str:
- q = str(question or "").lower()
- for unit in ("year", "month", "week", "day", "hour", "minute"):
- if re.search(rf"\b{unit}s?\b", q):
- return unit
- return "day"
-
-
-def _convert_duration(value: float, src_unit: str, dst_unit: str) -> Optional[float]:
- src = _normalize_unit(src_unit)
- dst = _normalize_unit(dst_unit)
- to_days = {
- "year": 365.0,
- "month": 30.0,
- "week": 7.0,
- "day": 1.0,
- "hour": 1.0 / 24.0,
- "minute": 1.0 / 1440.0,
- }
- if src not in to_days or dst not in to_days:
- return None
- return (value * to_days[src]) / to_days[dst]
-
-
-def reduce_atomic_facts(
- *,
- question: str,
- intent: AnswerIntent,
- facts: Sequence[Dict[str, Any]],
-) -> Tuple[Optional[str], Dict[str, Any]]:
- relevant = [f for f in facts if _normalize_bool(f.get("relevant", True))]
- meta: Dict[str, Any] = {
- "fact_count": len(facts),
- "relevant_fact_count": len(relevant),
- "intent": intent.value,
- }
- if not relevant:
- return None, meta
-
- if intent == AnswerIntent.COUNT:
- # Superlative COUNT (argmax): "which X the most" → return the most frequent VALUE
- is_argmax = bool(_SUPERLATIVE_RE.search(question))
- if is_argmax:
- # Count occurrences of each distinct value
- value_counts: Dict[str, int] = {}
- value_display: Dict[str, str] = {} # lowercase → original case
- for f in relevant:
- val = str(f.get("value") or "").strip()
- if not val:
- continue
- key = val.lower()
- value_counts[key] = value_counts.get(key, 0) + 1
- if key not in value_display:
- value_display[key] = val
- if not value_counts:
- return None, meta
- # Find the value with highest count
- best_key = max(value_counts, key=value_counts.get)
- best_count = value_counts[best_key]
- meta["argmax_value"] = value_display[best_key]
- meta["argmax_count"] = best_count
- meta["value_distribution"] = {
- value_display[k]: c for k, c in value_counts.items()
- }
- return value_display[best_key], meta
-
- # Standard COUNT: count unique items
- keys = set()
- for f in relevant:
- key = str(f.get("canonical_key") or "").strip().lower()
- if not key:
- # Build key from predicate + normalized value.
- # Include predicate so "return boots" != "pick up boots".
- val = str(f.get("value") or "").strip().lower()
- pred = str(f.get("predicate") or "").strip().lower()
- # Strip common prefixes that don't change identity
- for prefix in ("new ", "a ", "an ", "the ", "my ", "pair of ", "some "):
- if val.startswith(prefix):
- val = val[len(prefix):]
- parts = [p for p in (pred, val) if p]
- key = " | ".join(parts) if parts else ""
- if key:
- keys.add(key)
- if not keys:
- return None, meta
- meta["reduced_unique_keys"] = len(keys)
- return str(len(keys)), meta
-
- if intent == AnswerIntent.MONEY_SUM:
- values: List[float] = []
- for f in relevant:
- amount = _to_float(f.get("numeric_value"))
- if amount is None:
- amount = _extract_money_value(str(f.get("value") or ""))
- if amount is None:
- continue
- values.append(amount)
- if not values:
- return None, meta
- total = sum(values)
- meta["money_terms"] = len(values)
- if abs(total - round(total)) < 1e-9:
- return f"${int(round(total)):,}", meta
- return f"${total:,.2f}", meta
-
- if intent == AnswerIntent.DURATION:
- target = _duration_target_unit(question)
- values: List[float] = []
- for f in relevant:
- numeric = _to_float(f.get("numeric_value"))
- unit = _normalize_unit(str(f.get("unit") or ""))
- if numeric is None or not unit:
- parsed = _extract_duration_value(str(f.get("value") or ""))
- if parsed:
- numeric, unit = parsed
- unit = _normalize_unit(unit)
- if numeric is None or not unit:
- continue
- converted = _convert_duration(float(numeric), unit, target)
- if converted is None:
- continue
- values.append(converted)
- if not values:
- return None, meta
- total = sum(values)
- meta["duration_terms"] = len(values)
- rounded = round(total, 2)
- if abs(rounded - round(rounded)) < 1e-9:
- rounded = int(round(rounded))
- unit_out = target if rounded == 1 else f"{target}s"
- return f"{rounded} {unit_out}", meta
-
- if intent == AnswerIntent.LATEST:
- dated: List[Tuple[datetime, Dict[str, Any]]] = []
- for f in relevant:
- dt = _parse_event_datetime(f.get("event_date"))
- if dt is not None:
- dated.append((dt, f))
- if dated:
- dated.sort(key=lambda x: x[0], reverse=True)
- best = dated[0][1]
- answer = str(best.get("value") or "").strip()
- if answer:
- return answer, meta
- # fallback: first relevant value
- for f in relevant:
- answer = str(f.get("value") or "").strip()
- if answer:
- return answer, meta
- return None, meta
-
- if intent == AnswerIntent.SET_MEMBERS:
- values = []
- seen = set()
- for f in relevant:
- val = str(f.get("value") or "").strip()
- if not val:
- continue
- key = val.lower()
- if key in seen:
- continue
- seen.add(key)
- values.append(val)
- if values:
- return ", ".join(values), meta
- return None, meta
-
- return None, meta
-
-
-def _extract_numeric_mentions(text: str) -> List[float]:
- values: List[float] = []
- if not text:
- return values
- for match in _MONEY_RE.finditer(str(text)):
- whole = (match.group(1) or "").replace(",", "")
- frac = match.group(2)
- if not whole:
- continue
- try:
- number = float(f"{whole}.{frac}") if frac else float(whole)
- except ValueError:
- continue
- values.append(number)
- return values
-
-
-def deterministic_inconsistency_check(
- *,
- question: str,
- intent: AnswerIntent,
- results: Sequence[Dict[str, Any]],
- coverage: Optional[Dict[str, Any]] = None,
-) -> Dict[str, Any]:
- """Detect deterministic evidence inconsistencies before map/reduce.
-
- This is intentionally cheap and LLM-free.
- """
- reasons: List[str] = []
- coverage_payload = dict(coverage or {})
- coverage_sufficient = bool(coverage_payload.get("sufficient"))
- if not coverage_sufficient:
- reasons.append("coverage_insufficient")
-
- if not results:
- reasons.append("no_results")
- return {"inconsistent": True, "reasons": reasons}
-
- intent_value = intent.value if isinstance(intent, AnswerIntent) else str(intent or "")
- top_rows = list(results)[:12]
-
- if intent_value in {"count", "set_members"}:
- numeric_candidates = set()
- for row in top_rows:
- text = str(row.get("evidence_text") or row.get("memory") or "")
- for value in _extract_numeric_mentions(text):
- if value >= 0:
- numeric_candidates.add(round(float(value), 3))
- if len(numeric_candidates) >= 2:
- reasons.append("count_numeric_conflict")
-
- if intent_value == "money_sum":
- amounts = set()
- for row in top_rows:
- text = str(row.get("evidence_text") or row.get("memory") or "")
- for value in _extract_numeric_mentions(text):
- amounts.add(round(float(value), 2))
- if len(amounts) >= 2:
- reasons.append("money_terms_multiple")
-
- if intent_value == "duration":
- units = set()
- for row in top_rows:
- text = str(row.get("evidence_text") or row.get("memory") or "")
- parsed = _extract_duration_value(text)
- if parsed:
- units.add(_normalize_unit(parsed[1]))
- if len(units) >= 2:
- reasons.append("duration_unit_mixed")
-
- if intent_value == "latest":
- dated_hits = int(coverage_payload.get("dated_fact_count", 0) or 0)
- if dated_hits <= 0:
- # fallback scan for explicit dates in evidence
- has_date_like = False
- for row in top_rows:
- text = str(row.get("evidence_text") or row.get("memory") or "")
- if re.search(r"\b\d{4}-\d{2}-\d{2}\b", text):
- has_date_like = True
- break
- if not has_date_like:
- reasons.append("latest_missing_dated_evidence")
-
- return {"inconsistent": bool(reasons), "reasons": reasons}
-
-
-def render_fact_context(facts: Sequence[Dict[str, Any]], max_facts: int = 20) -> str:
- lines: List[str] = []
- for f in list(facts)[: max(1, int(max_facts))]:
- if not _normalize_bool(f.get("relevant", True)):
- continue
- parts = []
- if f.get("event_date"):
- parts.append(f"date={f['event_date']}")
- if f.get("session_id"):
- parts.append(f"session={f['session_id']}")
- label = " ".join(parts)
- value = str(f.get("value") or "").strip()
- subj = str(f.get("subject") or "").strip()
- pred = str(f.get("predicate") or "").strip()
- body = " | ".join(x for x in [subj, pred, value] if x)
- if not body:
- continue
- if label:
- lines.append(f"- [{label}] {body}")
- else:
- lines.append(f"- {body}")
- return "\n".join(lines)
-
-
def is_low_confidence_answer(answer: str) -> bool:
return bool(_LOW_CONFIDENCE_RE.search(str(answer or "").strip()))
-
-
-def should_override_with_reducer(intent: AnswerIntent) -> bool:
- return intent in _NUMERIC_INTENTS or intent == AnswerIntent.LATEST
-
-
-# ---------------------------------------------------------------------------
-# Event-first reducer — zero LLM cost
-# ---------------------------------------------------------------------------
-
-
-def reduce_from_episodic_events(
- *,
- question: str,
- intent: AnswerIntent,
- events: Sequence[Dict[str, Any]],
-) -> Tuple[Optional[str], Dict[str, Any]]:
- """Reduce episodic events into an answer — zero LLM cost.
-
- Works directly from event dicts produced by episodic_index rather than
- LLM-extracted atomic facts. Uses the same deterministic logic as
- reduce_atomic_facts() but adapted for the event schema.
- """
- meta: Dict[str, Any] = {
- "event_count": len(events),
- "intent": intent.value,
- "source": "episodic_events",
- }
- if not events:
- return None, meta
-
- if intent == AnswerIntent.COUNT:
- keys: set = set()
- for ev in events:
- key = str(ev.get("canonical_key") or "").strip().lower()
- if not key:
- value = str(ev.get("value_text") or "").strip().lower()
- if value:
- key = value
- if key:
- keys.add(key)
- if not keys:
- return None, meta
- meta["reduced_unique_keys"] = len(keys)
- return str(len(keys)), meta
-
- if intent == AnswerIntent.MONEY_SUM:
- values: List[float] = []
- for ev in events:
- if str(ev.get("event_type") or "").lower() != "money":
- continue
- amount = _to_float(ev.get("value_num"))
- if amount is None:
- amount = _extract_money_value(str(ev.get("value_text") or ""))
- if amount is not None:
- values.append(amount)
- if not values:
- return None, meta
- total = sum(values)
- meta["money_terms"] = len(values)
- if abs(total - round(total)) < 1e-9:
- return f"${int(round(total)):,}", meta
- return f"${total:,.2f}", meta
-
- if intent == AnswerIntent.DURATION:
- target = _duration_target_unit(question)
- values = []
- for ev in events:
- if str(ev.get("event_type") or "").lower() != "duration":
- continue
- numeric = _to_float(ev.get("value_num"))
- unit = _normalize_unit(str(ev.get("value_unit") or ""))
- if numeric is None or not unit:
- parsed = _extract_duration_value(str(ev.get("value_text") or ""))
- if parsed:
- numeric, unit = parsed
- unit = _normalize_unit(unit)
- if numeric is None or not unit:
- continue
- converted = _convert_duration(float(numeric), unit, target)
- if converted is not None:
- values.append(converted)
- if not values:
- return None, meta
- total = sum(values)
- meta["duration_terms"] = len(values)
- rounded = round(total, 2)
- if abs(rounded - round(rounded)) < 1e-9:
- rounded = int(round(rounded))
- unit_out = target if rounded == 1 else f"{target}s"
- return f"{rounded} {unit_out}", meta
-
- if intent == AnswerIntent.LATEST:
- dated: List[Tuple[datetime, Dict[str, Any]]] = []
- for ev in events:
- dt = _parse_event_datetime(
- ev.get("normalized_time_start") or ev.get("event_time")
- )
- if dt is not None:
- dated.append((dt, ev))
- if dated:
- dated.sort(key=lambda x: x[0], reverse=True)
- best = dated[0][1]
- answer = str(best.get("value_text") or "").strip()
- if answer:
- return answer, meta
- # fallback: first event value
- for ev in events:
- answer = str(ev.get("value_text") or "").strip()
- if answer:
- return answer, meta
- return None, meta
-
- if intent == AnswerIntent.SET_MEMBERS:
- values_list: List[str] = []
- seen: set = set()
- for ev in events:
- val = str(ev.get("value_text") or "").strip()
- if not val:
- continue
- key = val.lower()
- if key in seen:
- continue
- seen.add(key)
- values_list.append(val)
- if values_list:
- return ", ".join(values_list), meta
- return None, meta
-
- return None, meta
diff --git a/dhee/core/code_exec_counter.py b/dhee/core/code_exec_counter.py
deleted file mode 100644
index 16f1a52..0000000
--- a/dhee/core/code_exec_counter.py
+++ /dev/null
@@ -1,266 +0,0 @@
-"""Sandboxed code-exec counting for deterministic aggregation.
-
-Instead of asking the LLM to count/sum items across multiple sessions,
-this module has the LLM emit Python code that enumerates items, then
-executes it in a restricted sandbox to produce a deterministic answer.
-"""
-
-from __future__ import annotations
-
-import io
-import logging
-import re
-import threading
-from datetime import date, datetime, timedelta
-from typing import Any, Dict, List, Optional
-
-logger = logging.getLogger(__name__)
-
-# Allowed builtins for the sandbox — no file I/O, no imports, no exec/eval.
-# Includes datetime types for date arithmetic (e.g., days between two dates).
-_SAFE_BUILTINS = {
- "len": len,
- "sum": sum,
- "max": max,
- "min": min,
- "sorted": sorted,
- "set": set,
- "list": list,
- "dict": dict,
- "int": int,
- "float": float,
- "str": str,
- "range": range,
- "enumerate": enumerate,
- "zip": zip,
- "print": print, # will be redirected to StringIO
- "abs": abs,
- "round": round,
- "tuple": tuple,
- "True": True,
- "False": False,
- "None": None,
- # Date/time types — safe, no I/O, needed for temporal arithmetic
- "datetime": datetime,
- "date": date,
- "timedelta": timedelta,
-}
-
-# Safe import lines that can be stripped before the blocked-pattern check.
-# LLMs habitually emit these even when the types are already available.
-_SAFE_IMPORT_RE = re.compile(
- r"^\s*(?:from\s+datetime\s+import\s+[\w\s,]+|import\s+datetime)\s*$",
- re.MULTILINE,
-)
-
-# Patterns that indicate dangerous code.
-# Only match import/from at statement-start (^) to avoid false positives in
-# comments like "# data from session 1" or strings like "trip from NYC".
-_BLOCKED_PATTERNS = re.compile(
- r"\b(__\w+__|exec|eval|compile|globals|locals|getattr|setattr|delattr"
- r"|subprocess|shutil|pathlib)\b"
- r"|^import\s+(?!datetime\b)" # block `import X` unless X is datetime
- r"|^from\s+(?!datetime\b)\w" # block `from X` unless X is datetime
- r"|\bos\.\w" # block os.anything
- r"|\bsys\.\w" # block sys.anything
- r"|\bopen\s*\(", # block open() calls, not the word "open"
- re.MULTILINE,
-)
-
-_CODE_BLOCK_RE = re.compile(r"```(?:python)?\s*\n(.*?)```", re.DOTALL)
-_ANSWER_RE = re.compile(r"^ANSWER:\s*(.+)$", re.MULTILINE)
-_ITEMS_RE = re.compile(r"^ITEMS:\s*(.+)$", re.MULTILINE)
-
-
-_MAX_CODE_EXEC_CONTEXT_CHARS = 12000
-
-
-def build_code_counting_prompt(
- question: str,
- retrieved_context: str,
- question_date: str = "",
-) -> str:
- """Build a prompt that asks the LLM to generate Python counting code."""
- # Truncate context to fit smaller models' context windows
- ctx = retrieved_context[:_MAX_CODE_EXEC_CONTEXT_CHARS] if len(retrieved_context) > _MAX_CODE_EXEC_CONTEXT_CHARS else retrieved_context
- date_line = f"\nQuestion Date (today): {question_date}" if question_date else ""
- return f"""You are a precise counting and date-arithmetic assistant. Read the question and context below,
-then write a short Python script following this EXACT pattern:
-
-```python
-# Step 1: Define ONE list of (label, value) tuples — one entry per distinct item
-items = [
- ("item description from session X", numeric_value),
- ("item description from session Y", numeric_value),
-]
-
-# Step 2: Compute answer from the items list ONLY
-answer = sum(v for _, v in items) # for sum/duration questions
-# OR: answer = len(items) # for counting questions
-
-# Step 3: Print results — the answer MUST be derived from the items list above
-print(f"ANSWER: {{answer}}")
-print(f"ITEMS: {{[label for label, _ in items]}}")
-```
-
-CRITICAL RULES:
-- Create exactly ONE list called `items` with ALL relevant entries from ALL sessions
-- For "how many hours/days/weeks" questions: each tuple is ("description", hours_or_days_number), answer = sum()
-- For "how many times/items" questions: each tuple is ("description", 1), answer = len()
-- For "how much money" questions: each tuple is ("description", dollar_amount), answer = sum()
-- For "how many days/months between X and Y" questions: use datetime(year, month, day) to compute exact differences
-- For "how many months/days ago" questions: compute from the Question Date using datetime arithmetic
-- The ANSWER line MUST be computed from the `items` list or datetime math — never hardcode it
-- Available without import: datetime, date, timedelta, len, sum, min, max, abs, round
-- Read EVERY session in the context — missing one entry means a wrong answer
-- If the same event appears in multiple sessions, include it only ONCE (deduplicate)
-- Add the unit to the ANSWER (e.g., "8 days", "140 hours", "$5850", "30 days")
-
-Question: {question}{date_line}
-
-Context:
-{ctx}
-
-Write ONLY the Python code inside a code block:
-```python
-"""
-
-
-def execute_counting_code(code: str, timeout: float = 5.0) -> Optional[Dict[str, Any]]:
- """Execute counting code in a restricted sandbox.
-
- Returns {{"answer": str, "items": list}} or None on failure.
- """
- # Strip safe import lines (datetime) before the blocked-pattern check.
- # LLMs emit these habitually; the types are already in _SAFE_BUILTINS.
- code = _SAFE_IMPORT_RE.sub("", code)
-
- # Validate: block dangerous patterns
- if _BLOCKED_PATTERNS.search(code):
- logger.warning("Code-exec blocked: dangerous pattern detected in: %s", code[:200])
- return None
-
- # Capture stdout
- captured = io.StringIO()
-
- restricted_globals = {"__builtins__": {}}
- for name, obj in _SAFE_BUILTINS.items():
- restricted_globals[name] = obj
-
- # Redirect print to captured output
- def safe_print(*args, **kwargs):
- kwargs["file"] = captured
- print(*args, **kwargs)
-
- restricted_globals["print"] = safe_print
-
- result = {"completed": False, "error": None}
-
- def _run():
- try:
- exec(code, restricted_globals) # noqa: S102
- result["completed"] = True
- except Exception as e:
- result["error"] = str(e)
-
- thread = threading.Thread(target=_run, daemon=True)
- thread.start()
- thread.join(timeout=timeout)
-
- if not result["completed"]:
- if result["error"]:
- logger.warning("Code-exec error: %s", result["error"])
- else:
- logger.warning("Code-exec timeout after %.1fs", timeout)
- return None
-
- output = captured.getvalue()
-
- # Parse ANSWER line
- answer_match = _ANSWER_RE.search(output)
- if not answer_match:
- logger.debug("Code-exec: no ANSWER line in output: %r", output[:200])
- return None
-
- answer = answer_match.group(1).strip()
-
- # Parse ITEMS line (optional)
- items: List[str] = []
- items_match = _ITEMS_RE.search(output)
- if items_match:
- items_raw = items_match.group(1).strip()
- # Parse list by splitting on commas (avoid eval for safety)
- inner = items_raw.strip("[]")
- if inner:
- items = [x.strip().strip("'\"") for x in inner.split(",") if x.strip().strip("'\"")]
-
- return {"answer": answer, "items": items}
-
-
-def refine_count_with_code_exec(
- *,
- llm: Any,
- question: str,
- question_type: str,
- retrieved_context: str,
- draft_answer: str,
- question_date: str = "",
-) -> Optional[str]:
- """Full pipeline: prompt LLM to generate code -> exec -> parse answer.
-
- Returns the refined answer string, or None if code-exec fails.
- """
- prompt = build_code_counting_prompt(question, retrieved_context, question_date=question_date)
-
- try:
- raw_response = str(llm.generate(prompt)).strip()
- except Exception as e:
- logger.warning("Code-exec LLM call failed: %s", e)
- return None
-
- # Extract code block
- code_match = _CODE_BLOCK_RE.search(raw_response)
- if code_match:
- code = code_match.group(1).strip()
- else:
- # Try treating the entire response as code if it looks like Python
- lines = raw_response.strip().splitlines()
- code_lines = [ln for ln in lines if not ln.startswith("```")]
- if any(ln.strip().startswith(("items", "total", "count", "print", "#", "result")) for ln in code_lines):
- code = "\n".join(code_lines)
- else:
- logger.debug("Code-exec: no code block found in LLM response")
- return None
-
- result = execute_counting_code(code)
- if not result:
- return None
-
- answer = result["answer"]
- items = result.get("items", [])
-
- logger.info(
- "Code-exec result: answer=%r, items_count=%d, draft=%r",
- answer, len(items), draft_answer,
- )
-
- # Cross-check: if we have items and an answer, verify consistency
- if items and answer:
- try:
- answer_num = float(re.sub(r"[^\d.]", "", answer.split()[0]))
- # For pure counting, items list length should match
- q_lower = question.lower()
- if any(w in q_lower for w in ("how many times", "how many", "number of")):
- if abs(answer_num - len(items)) > 0.5 and not any(
- w in q_lower for w in ("hours", "days", "weeks", "months", "minutes")
- ):
- # Items count is more trustworthy for pure counting
- logger.debug(
- "Code-exec cross-check: stated %s but %d items; using items count",
- answer, len(items),
- )
- answer = str(len(items))
- except (ValueError, IndexError):
- pass
-
- return answer if answer else None
diff --git a/dhee/core/conflicts.py b/dhee/core/conflicts.py
deleted file mode 100644
index a8c7ca9..0000000
--- a/dhee/core/conflicts.py
+++ /dev/null
@@ -1,238 +0,0 @@
-"""Dhee v3 — Cognitive Conflict Store + Auto-Resolution.
-
-Tracks contradictions and disagreements between derived objects.
-Conflicts are explicit rows, not silent resolution.
-
-Auto-resolution: if one side has confidence > 0.8 and the other < 0.3,
-auto-resolve in favor of the high-confidence side. Otherwise, flag
-for manual resolution.
-
-Conflict types:
- - belief_contradiction: two beliefs claim opposing things
- - anchor_disagreement: anchor candidate disagrees with resolved anchor
- - distillation_conflict: candidate conflicts with promoted truth
- - invalidation_dispute: partial invalidation verification disagrees
-
-Design contract:
- - Every contradiction gets an explicit row
- - Auto-resolution only when confidence gap > 0.5
- - Zero LLM calls — confidence comparison only
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-import sqlite3
-import threading
-import uuid
-from contextlib import contextmanager
-from datetime import datetime, timezone
-from typing import Any, Dict, List, Optional
-
-logger = logging.getLogger(__name__)
-
-
-def _utcnow_iso() -> str:
- return datetime.now(timezone.utc).isoformat()
-
-
-# Auto-resolution thresholds
-AUTO_RESOLVE_HIGH = 0.8
-AUTO_RESOLVE_LOW = 0.3
-AUTO_RESOLVE_GAP = 0.5
-
-
-class ConflictStore:
- """Manages cognitive conflicts in the database."""
-
- def __init__(self, conn: sqlite3.Connection, lock: threading.RLock):
- self._conn = conn
- self._lock = lock
-
- @contextmanager
- def _tx(self):
- with self._lock:
- try:
- yield self._conn
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
-
- def create(
- self,
- conflict_type: str,
- side_a_type: str,
- side_a_id: str,
- side_b_type: str,
- side_b_id: str,
- *,
- side_a_confidence: Optional[float] = None,
- side_b_confidence: Optional[float] = None,
- ) -> Dict[str, Any]:
- """Create a conflict. Attempts auto-resolution if confidence gap is clear.
-
- Returns dict with conflict_id, resolution_status, and auto_resolution details.
- """
- cid = str(uuid.uuid4())
- now = _utcnow_iso()
-
- # Attempt auto-resolution
- resolution_status = "open"
- resolution_json = None
- auto_confidence = None
-
- if side_a_confidence is not None and side_b_confidence is not None:
- gap = abs(side_a_confidence - side_b_confidence)
- if gap >= AUTO_RESOLVE_GAP:
- if (side_a_confidence >= AUTO_RESOLVE_HIGH
- and side_b_confidence <= AUTO_RESOLVE_LOW):
- resolution_status = "auto_resolved"
- auto_confidence = side_a_confidence
- resolution_json = json.dumps({
- "winner": "side_a",
- "winner_type": side_a_type,
- "winner_id": side_a_id,
- "reason": f"confidence gap: {side_a_confidence:.2f} vs {side_b_confidence:.2f}",
- })
- elif (side_b_confidence >= AUTO_RESOLVE_HIGH
- and side_a_confidence <= AUTO_RESOLVE_LOW):
- resolution_status = "auto_resolved"
- auto_confidence = side_b_confidence
- resolution_json = json.dumps({
- "winner": "side_b",
- "winner_type": side_b_type,
- "winner_id": side_b_id,
- "reason": f"confidence gap: {side_b_confidence:.2f} vs {side_a_confidence:.2f}",
- })
-
- with self._tx() as conn:
- conn.execute(
- """INSERT INTO cognitive_conflicts
- (conflict_id, conflict_type, side_a_type, side_a_id,
- side_b_type, side_b_id, detected_at,
- resolution_status, resolution_json,
- auto_resolution_confidence)
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
- (
- cid, conflict_type, side_a_type, side_a_id,
- side_b_type, side_b_id, now,
- resolution_status, resolution_json, auto_confidence,
- ),
- )
-
- result = {
- "conflict_id": cid,
- "conflict_type": conflict_type,
- "resolution_status": resolution_status,
- }
- if resolution_status == "auto_resolved":
- result["auto_resolution"] = json.loads(resolution_json)
- return result
-
- def resolve(
- self,
- conflict_id: str,
- resolution: Dict[str, Any],
- *,
- by: str = "user",
- ) -> bool:
- """Manually resolve a conflict. Returns True if updated."""
- status = "user_resolved" if by == "user" else "auto_resolved"
- with self._tx() as conn:
- result = conn.execute(
- """UPDATE cognitive_conflicts
- SET resolution_status = ?, resolution_json = ?
- WHERE conflict_id = ? AND resolution_status = 'open'""",
- (status, json.dumps(resolution), conflict_id),
- )
- return result.rowcount > 0
-
- def defer(self, conflict_id: str) -> bool:
- """Defer a conflict for later resolution."""
- with self._tx() as conn:
- result = conn.execute(
- """UPDATE cognitive_conflicts
- SET resolution_status = 'deferred'
- WHERE conflict_id = ? AND resolution_status = 'open'""",
- (conflict_id,),
- )
- return result.rowcount > 0
-
- def get(self, conflict_id: str) -> Optional[Dict[str, Any]]:
- with self._lock:
- row = self._conn.execute(
- "SELECT * FROM cognitive_conflicts WHERE conflict_id = ?",
- (conflict_id,),
- ).fetchone()
- return self._row_to_dict(row) if row else None
-
- def list_open(self, *, limit: int = 50) -> List[Dict[str, Any]]:
- """Get all open (unresolved) conflicts."""
- with self._lock:
- rows = self._conn.execute(
- """SELECT * FROM cognitive_conflicts
- WHERE resolution_status = 'open'
- ORDER BY detected_at DESC
- LIMIT ?""",
- (limit,),
- ).fetchall()
- return [self._row_to_dict(r) for r in rows]
-
- def list_for_object(
- self, object_type: str, object_id: str
- ) -> List[Dict[str, Any]]:
- """Get all conflicts involving a specific object."""
- with self._lock:
- rows = self._conn.execute(
- """SELECT * FROM cognitive_conflicts
- WHERE (side_a_type = ? AND side_a_id = ?)
- OR (side_b_type = ? AND side_b_id = ?)
- ORDER BY detected_at DESC""",
- (object_type, object_id, object_type, object_id),
- ).fetchall()
- return [self._row_to_dict(r) for r in rows]
-
- def count_open(self) -> int:
- with self._lock:
- row = self._conn.execute(
- "SELECT COUNT(*) FROM cognitive_conflicts WHERE resolution_status = 'open'"
- ).fetchone()
- return row[0] if row else 0
-
- def has_open_conflicts(
- self, object_type: str, object_id: str
- ) -> bool:
- """Check if an object has any open conflicts (for retrieval penalty)."""
- with self._lock:
- row = self._conn.execute(
- """SELECT 1 FROM cognitive_conflicts
- WHERE resolution_status = 'open'
- AND ((side_a_type = ? AND side_a_id = ?)
- OR (side_b_type = ? AND side_b_id = ?))
- LIMIT 1""",
- (object_type, object_id, object_type, object_id),
- ).fetchone()
- return row is not None
-
- @staticmethod
- def _row_to_dict(row: sqlite3.Row) -> Dict[str, Any]:
- resolution = row["resolution_json"]
- if isinstance(resolution, str):
- try:
- resolution = json.loads(resolution)
- except (json.JSONDecodeError, TypeError):
- resolution = None
- return {
- "conflict_id": row["conflict_id"],
- "conflict_type": row["conflict_type"],
- "side_a_type": row["side_a_type"],
- "side_a_id": row["side_a_id"],
- "side_b_type": row["side_b_type"],
- "side_b_id": row["side_b_id"],
- "detected_at": row["detected_at"],
- "resolution_status": row["resolution_status"],
- "resolution": resolution,
- "auto_resolution_confidence": row["auto_resolution_confidence"],
- }
diff --git a/dhee/core/consolidation.py b/dhee/core/consolidation.py
deleted file mode 100644
index 7c07add..0000000
--- a/dhee/core/consolidation.py
+++ /dev/null
@@ -1,162 +0,0 @@
-"""
-Consolidation Engine — promotes important active signals to passive memory.
-
-v3 FIX: Breaks the feedback loop identified in the architecture critique.
-
-Old behavior (DANGEROUS):
- _promote_to_passive() called memory.add() → triggered full enrichment
- pipeline → could create new active signals → infinite consolidation loop.
-
-New behavior (SAFE):
- _promote_to_passive() calls memory.add() with infer=False AND tags
- promoted memories with source="consolidated" provenance metadata.
- _should_promote() rejects signals that were already consolidated
- (prevents re-consolidation of promoted content).
-
-The enrichment pipeline is explicitly skipped for consolidated memories
-because the content was already enriched when it entered active memory.
-"""
-
-import logging
-from typing import Any, Dict, Protocol, TYPE_CHECKING
-
-from dhee.configs.active import ActiveMemoryConfig
-
-if TYPE_CHECKING:
- from dhee.memory.main import FullMemory
-
-logger = logging.getLogger(__name__)
-
-
-class ActiveMemoryStore(Protocol):
- """Structural contract for consolidation's active-memory dependency."""
-
- def get_consolidation_candidates(
- self,
- *,
- min_age_seconds: int,
- min_reads: int,
- ) -> list[Dict[str, Any]]:
- ...
-
- def mark_consolidated(self, signal_ids: list[str]) -> None:
- ...
-
-
-class ConsolidationEngine:
- """Promotes qualifying active signals into passive (Engram) memory."""
-
- def __init__(
- self,
- active_store: ActiveMemoryStore,
- memory: "FullMemory",
- config: ActiveMemoryConfig,
- ):
- self.active = active_store
- self.memory = memory
- self.config = config
- self.consolidation = config.consolidation
-
- def run_cycle(self) -> Dict[str, Any]:
- """Run one consolidation cycle. Returns promotion stats."""
- candidates = self.active.get_consolidation_candidates(
- min_age_seconds=self.config.consolidation_min_age_seconds,
- min_reads=self.config.consolidation_min_reads,
- )
-
- promoted = []
- skipped = 0
- errors = 0
- feedback_loop_blocked = 0
-
- for signal in candidates:
- if not self._should_promote(signal):
- skipped += 1
- continue
- try:
- self._promote_to_passive(signal)
- promoted.append(signal["id"])
- except Exception:
- logger.exception("Failed to promote signal %s", signal["id"])
- errors += 1
-
- if promoted:
- self.active.mark_consolidated(promoted)
-
- return {
- "promoted": len(promoted),
- "checked": len(candidates),
- "skipped": skipped,
- "errors": errors,
- "feedback_loop_blocked": feedback_loop_blocked,
- }
-
- def _should_promote(self, signal: Dict[str, Any]) -> bool:
- """Determine if a signal qualifies for promotion to passive memory."""
- signal_type = signal.get("signal_type", "")
- ttl_tier = signal.get("ttl_tier", "")
- read_count = signal.get("read_count", 0)
-
- # v3 FIX: Block re-consolidation of already-consolidated content.
- # This breaks the feedback loop where promoted content generates
- # new active signals that get re-consolidated infinitely.
- signal_metadata = signal.get("metadata", {})
- if isinstance(signal_metadata, dict):
- if signal_metadata.get("source") == "consolidated":
- return False
- if signal_metadata.get("consolidated_from"):
- return False
-
- # Also check the value field for consolidation markers
- value = signal.get("value", "")
- if isinstance(value, str) and "[consolidated]" in value.lower():
- return False
-
- # Directives always promote
- if signal_type == "directive" and self.consolidation.directive_to_passive:
- return True
-
- # Critical tier promotes
- if ttl_tier == "critical" and self.consolidation.promote_critical:
- return True
-
- # High-read signals promote
- if (
- self.consolidation.promote_high_read
- and read_count >= self.consolidation.promote_read_threshold
- ):
- return True
-
- return False
-
- def _promote_to_passive(self, signal: Dict[str, Any]) -> None:
- """Add a signal's content to passive memory.
-
- v3 FIX: Uses infer=False to skip the LLM enrichment pipeline.
- Tags with source="consolidated" to prevent re-consolidation.
- """
- signal_type = signal.get("signal_type", "event")
- user_id = signal.get("user_id", "default")
- key = signal.get("key", "")
- value = signal.get("value", "")
-
- # Build content string
- content = f"[{key}] {value}" if key else value
-
- self.memory.add(
- messages=content,
- user_id=user_id,
- metadata={
- # Provenance: identifies this as consolidated content
- "source": "consolidated",
- "consolidated_from": signal.get("id"),
- "signal_key": key,
- "signal_type": signal_type,
- },
- immutable=(signal_type == "directive"),
- initial_layer="lml" if signal_type == "directive" else "sml",
- # v3 FIX: Skip enrichment pipeline entirely.
- # Content was already enriched when it entered active memory.
- # Re-enrichment would generate divergent facts/entities.
- infer=False,
- )
diff --git a/dhee/core/derived_store.py b/dhee/core/derived_store.py
deleted file mode 100644
index 390492c..0000000
--- a/dhee/core/derived_store.py
+++ /dev/null
@@ -1,1145 +0,0 @@
-"""Dhee v3 — Type-specific derived cognition stores.
-
-Each derived type gets its own store class because they have different:
-- Lifecycle rules (beliefs: Bayesian; policies: win-rate; insights: strength)
-- Indexing needs (anchors: era/place; policies: granularity/utility)
-- Invalidation behavior (beliefs: retract; policies: deprecate; anchors: re-resolve)
-- Conflict semantics (beliefs: contradiction pairs; policies: approach conflicts)
-
-All stores share a common database connection (from RawEventStore) and
-the derived_lineage table for traceability.
-
-Design contract:
- - Every derived object has derivation_version + lineage_fingerprint
- - Invalidation statuses (stale, suspect, invalidated) are orthogonal to
- type-specific lifecycle statuses
- - Zero LLM calls — pure storage and state transitions
-"""
-
-from __future__ import annotations
-
-import hashlib
-import json
-import logging
-import sqlite3
-import threading
-import uuid
-from contextlib import contextmanager
-from dataclasses import dataclass, field
-from datetime import datetime, timezone
-from enum import Enum
-from typing import Any, Dict, List, Optional, Tuple
-
-logger = logging.getLogger(__name__)
-
-
-def _utcnow_iso() -> str:
- return datetime.now(timezone.utc).isoformat()
-
-
-def _parse_json(value: Any, default: Any = None) -> Any:
- if value is None:
- return default if default is not None else []
- if isinstance(value, (dict, list)):
- return value
- try:
- return json.loads(value)
- except (json.JSONDecodeError, TypeError):
- return default if default is not None else []
-
-
-def _compute_lineage_fingerprint(source_event_ids: List[str], version: int) -> str:
- """Deterministic fingerprint from sorted source IDs + version."""
- payload = "|".join(sorted(source_event_ids)) + f"|v{version}"
- return hashlib.sha256(payload.encode()).hexdigest()[:16]
-
-
-# =========================================================================
-# Enums
-# =========================================================================
-
-class BeliefStatus(str, Enum):
- PROPOSED = "proposed"
- HELD = "held"
- CHALLENGED = "challenged"
- REVISED = "revised"
- RETRACTED = "retracted"
- # Invalidation statuses (from three-tier model)
- STALE = "stale"
- SUSPECT = "suspect"
- INVALIDATED = "invalidated"
-
-
-class PolicyStatus(str, Enum):
- PROPOSED = "proposed"
- ACTIVE = "active"
- VALIDATED = "validated"
- DEPRECATED = "deprecated"
- STALE = "stale"
- SUSPECT = "suspect"
- INVALIDATED = "invalidated"
-
-
-class PolicyGranularity(str, Enum):
- TASK = "task"
- STEP = "step"
-
-
-class InsightType(str, Enum):
- CAUSAL = "causal"
- WARNING = "warning"
- STRATEGY = "strategy"
- PATTERN = "pattern"
-
-
-class AbstractionLevel(str, Enum):
- SPECIFIC = "specific"
- DOMAIN = "domain"
- UNIVERSAL = "universal"
-
-
-class DerivedType(str, Enum):
- BELIEF = "belief"
- POLICY = "policy"
- ANCHOR = "anchor"
- INSIGHT = "insight"
- HEURISTIC = "heuristic"
-
-
-class DerivedStatus(str, Enum):
- """Common invalidation statuses across all derived types."""
- ACTIVE = "active"
- STALE = "stale"
- SUSPECT = "suspect"
- INVALIDATED = "invalidated"
-
-
-# =========================================================================
-# Base store with shared connection management
-# =========================================================================
-
-class _DerivedStoreBase:
- """Shared connection management for all derived stores.
-
- All stores share a single SQLite connection. The connection is
- created externally (by RawEventStore or a coordinator) and passed in.
- """
-
- def __init__(self, conn: sqlite3.Connection, lock: threading.RLock):
- self._conn = conn
- self._lock = lock
-
- @contextmanager
- def _tx(self):
- with self._lock:
- try:
- yield self._conn
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
-
-
-# =========================================================================
-# BeliefStore
-# =========================================================================
-
-class BeliefStore(_DerivedStoreBase):
- """Confidence-tracked claims with Bayesian updates and contradiction detection."""
-
- def add(
- self,
- user_id: str,
- claim: str,
- *,
- domain: str = "general",
- confidence: float = 0.5,
- source_memory_ids: Optional[List[str]] = None,
- source_episode_ids: Optional[List[str]] = None,
- tags: Optional[List[str]] = None,
- belief_id: Optional[str] = None,
- ) -> str:
- bid = belief_id or str(uuid.uuid4())
- now = _utcnow_iso()
- smids = source_memory_ids or []
- seids = source_episode_ids or []
- fp = _compute_lineage_fingerprint(smids, 1)
-
- with self._tx() as conn:
- conn.execute(
- """INSERT INTO beliefs
- (belief_id, user_id, claim, domain, status, confidence,
- source_memory_ids, source_episode_ids, derivation_version,
- lineage_fingerprint, created_at, updated_at, tags_json)
- VALUES (?, ?, ?, ?, 'proposed', ?, ?, ?, 1, ?, ?, ?, ?)""",
- (
- bid, user_id, claim, domain, confidence,
- json.dumps(smids), json.dumps(seids), fp,
- now, now, json.dumps(tags or []),
- ),
- )
- return bid
-
- def get(self, belief_id: str) -> Optional[Dict[str, Any]]:
- with self._lock:
- row = self._conn.execute(
- "SELECT * FROM beliefs WHERE belief_id = ?",
- (belief_id,),
- ).fetchone()
- return self._row_to_dict(row) if row else None
-
- def list_by_user(
- self,
- user_id: str,
- *,
- domain: Optional[str] = None,
- status: Optional[str] = None,
- min_confidence: float = 0.0,
- limit: int = 50,
- ) -> List[Dict[str, Any]]:
- query = "SELECT * FROM beliefs WHERE user_id = ? AND confidence >= ?"
- params: list = [user_id, min_confidence]
- if domain:
- query += " AND domain = ?"
- params.append(domain)
- if status:
- query += " AND status = ?"
- params.append(status)
- query += " ORDER BY confidence DESC LIMIT ?"
- params.append(limit)
-
- with self._lock:
- rows = self._conn.execute(query, params).fetchall()
- return [self._row_to_dict(r) for r in rows]
-
- def update_confidence(
- self,
- belief_id: str,
- new_confidence: float,
- *,
- new_status: Optional[str] = None,
- evidence: Optional[Dict[str, Any]] = None,
- revision_reason: Optional[str] = None,
- ) -> bool:
- """Update belief confidence with optional evidence and revision tracking."""
- now = _utcnow_iso()
- with self._tx() as conn:
- row = conn.execute(
- "SELECT confidence, status, evidence_json, revisions_json FROM beliefs WHERE belief_id = ?",
- (belief_id,),
- ).fetchone()
- if not row:
- return False
-
- old_conf = row["confidence"]
- old_status = row["status"]
-
- # Append evidence
- evidence_list = _parse_json(row["evidence_json"], [])
- if evidence:
- evidence_list.append(evidence)
-
- # Append revision
- revisions = _parse_json(row["revisions_json"], [])
- revisions.append({
- "timestamp": now,
- "old_confidence": old_conf,
- "new_confidence": new_confidence,
- "old_status": old_status,
- "new_status": new_status or old_status,
- "reason": revision_reason or "confidence_update",
- })
-
- # Auto-derive status if not explicitly set
- status = new_status
- if not status:
- if new_confidence >= 0.7:
- status = BeliefStatus.HELD.value
- elif new_confidence <= 0.1:
- status = BeliefStatus.RETRACTED.value
- else:
- status = old_status
-
- conn.execute(
- """UPDATE beliefs
- SET confidence = ?, status = ?, evidence_json = ?,
- revisions_json = ?, updated_at = ?
- WHERE belief_id = ?""",
- (
- new_confidence, status, json.dumps(evidence_list),
- json.dumps(revisions), now, belief_id,
- ),
- )
- return True
-
- def add_contradiction(self, belief_a_id: str, belief_b_id: str) -> None:
- """Link two beliefs as contradicting each other."""
- now = _utcnow_iso()
- with self._tx() as conn:
- for bid, other_id in [(belief_a_id, belief_b_id), (belief_b_id, belief_a_id)]:
- row = conn.execute(
- "SELECT contradicts_ids FROM beliefs WHERE belief_id = ?",
- (bid,),
- ).fetchone()
- if row:
- ids = _parse_json(row["contradicts_ids"], [])
- if other_id not in ids:
- ids.append(other_id)
- conn.execute(
- "UPDATE beliefs SET contradicts_ids = ?, status = 'challenged', updated_at = ? WHERE belief_id = ?",
- (json.dumps(ids), now, bid),
- )
-
- def set_status(self, belief_id: str, status: str) -> bool:
- with self._tx() as conn:
- result = conn.execute(
- "UPDATE beliefs SET status = ?, updated_at = ? WHERE belief_id = ?",
- (status, _utcnow_iso(), belief_id),
- )
- return result.rowcount > 0
-
- def get_by_invalidation_status(
- self, status: str, *, limit: int = 100
- ) -> List[Dict[str, Any]]:
- """Get beliefs in stale/suspect/invalidated status for repair jobs."""
- with self._lock:
- rows = self._conn.execute(
- "SELECT * FROM beliefs WHERE status = ? LIMIT ?",
- (status, limit),
- ).fetchall()
- return [self._row_to_dict(r) for r in rows]
-
- @staticmethod
- def _row_to_dict(row: sqlite3.Row) -> Dict[str, Any]:
- return {
- "belief_id": row["belief_id"],
- "user_id": row["user_id"],
- "claim": row["claim"],
- "domain": row["domain"],
- "status": row["status"],
- "confidence": row["confidence"],
- "evidence": _parse_json(row["evidence_json"], []),
- "revisions": _parse_json(row["revisions_json"], []),
- "contradicts_ids": _parse_json(row["contradicts_ids"], []),
- "source_memory_ids": _parse_json(row["source_memory_ids"], []),
- "source_episode_ids": _parse_json(row["source_episode_ids"], []),
- "derivation_version": row["derivation_version"],
- "lineage_fingerprint": row["lineage_fingerprint"],
- "created_at": row["created_at"],
- "updated_at": row["updated_at"],
- "tags": _parse_json(row["tags_json"], []),
- }
-
-
-# =========================================================================
-# PolicyStore
-# =========================================================================
-
-class PolicyStore(_DerivedStoreBase):
- """Condition→action rules with utility tracking (D2Skill dual-granularity)."""
-
- def add(
- self,
- user_id: str,
- name: str,
- condition: Dict[str, Any],
- action: Dict[str, Any],
- *,
- granularity: str = "task",
- source_task_ids: Optional[List[str]] = None,
- source_episode_ids: Optional[List[str]] = None,
- tags: Optional[List[str]] = None,
- policy_id: Optional[str] = None,
- ) -> str:
- pid = policy_id or str(uuid.uuid4())
- now = _utcnow_iso()
- stids = source_task_ids or []
- seids = source_episode_ids or []
- fp = _compute_lineage_fingerprint(stids, 1)
-
- with self._tx() as conn:
- conn.execute(
- """INSERT INTO policies
- (policy_id, user_id, name, granularity, status,
- condition_json, action_json, source_task_ids,
- source_episode_ids, derivation_version,
- lineage_fingerprint, created_at, updated_at, tags_json)
- VALUES (?, ?, ?, ?, 'proposed', ?, ?, ?, ?, 1, ?, ?, ?, ?)""",
- (
- pid, user_id, name, granularity,
- json.dumps(condition), json.dumps(action),
- json.dumps(stids), json.dumps(seids), fp,
- now, now, json.dumps(tags or []),
- ),
- )
- return pid
-
- def get(self, policy_id: str) -> Optional[Dict[str, Any]]:
- with self._lock:
- row = self._conn.execute(
- "SELECT * FROM policies WHERE policy_id = ?",
- (policy_id,),
- ).fetchone()
- return self._row_to_dict(row) if row else None
-
- def list_by_user(
- self,
- user_id: str,
- *,
- granularity: Optional[str] = None,
- status: Optional[str] = None,
- limit: int = 50,
- ) -> List[Dict[str, Any]]:
- query = "SELECT * FROM policies WHERE user_id = ?"
- params: list = [user_id]
- if granularity:
- query += " AND granularity = ?"
- params.append(granularity)
- if status:
- query += " AND status = ?"
- params.append(status)
- query += " ORDER BY utility DESC LIMIT ?"
- params.append(limit)
-
- with self._lock:
- rows = self._conn.execute(query, params).fetchall()
- return [self._row_to_dict(r) for r in rows]
-
- def record_outcome(
- self,
- policy_id: str,
- success: bool,
- *,
- baseline_score: Optional[float] = None,
- actual_score: Optional[float] = None,
- ) -> Optional[float]:
- """Record an outcome for a policy. Returns the delta if scores provided.
-
- Updates apply_count, success/failure counts, and utility EMA.
- Auto-transitions status: validated (win_rate >= 0.6 after 5+) or
- deprecated (win_rate < 0.4 after 5+).
- """
- now = _utcnow_iso()
- with self._tx() as conn:
- row = conn.execute(
- """SELECT apply_count, success_count, failure_count,
- utility, cumulative_delta, status
- FROM policies WHERE policy_id = ?""",
- (policy_id,),
- ).fetchone()
- if not row:
- return None
-
- apply_count = row["apply_count"] + 1
- success_count = row["success_count"] + (1 if success else 0)
- failure_count = row["failure_count"] + (0 if success else 1)
- utility = row["utility"]
- cumulative = row["cumulative_delta"]
- status = row["status"]
-
- delta = 0.0
- if baseline_score is not None and actual_score is not None:
- delta = actual_score - baseline_score
- utility = 0.3 * delta + 0.7 * utility # EMA alpha=0.3
- cumulative += delta
-
- # Auto-transition after enough data
- if apply_count >= 5 and status not in ("stale", "suspect", "invalidated"):
- win_rate = (success_count + 1) / (apply_count + 2) # Laplace
- if win_rate >= 0.6:
- status = PolicyStatus.VALIDATED.value
- elif win_rate < 0.4:
- status = PolicyStatus.DEPRECATED.value
- elif status == PolicyStatus.PROPOSED.value:
- status = PolicyStatus.ACTIVE.value
-
- conn.execute(
- """UPDATE policies
- SET apply_count = ?, success_count = ?, failure_count = ?,
- utility = ?, last_delta = ?, cumulative_delta = ?,
- status = ?, updated_at = ?
- WHERE policy_id = ?""",
- (
- apply_count, success_count, failure_count,
- utility, delta, cumulative, status, now, policy_id,
- ),
- )
- return delta
-
- def set_status(self, policy_id: str, status: str) -> bool:
- with self._tx() as conn:
- result = conn.execute(
- "UPDATE policies SET status = ?, updated_at = ? WHERE policy_id = ?",
- (status, _utcnow_iso(), policy_id),
- )
- return result.rowcount > 0
-
- def get_by_invalidation_status(
- self, status: str, *, limit: int = 100
- ) -> List[Dict[str, Any]]:
- with self._lock:
- rows = self._conn.execute(
- "SELECT * FROM policies WHERE status = ? LIMIT ?",
- (status, limit),
- ).fetchall()
- return [self._row_to_dict(r) for r in rows]
-
- @staticmethod
- def _row_to_dict(row: sqlite3.Row) -> Dict[str, Any]:
- return {
- "policy_id": row["policy_id"],
- "user_id": row["user_id"],
- "name": row["name"],
- "granularity": row["granularity"],
- "status": row["status"],
- "condition": _parse_json(row["condition_json"], {}),
- "action": _parse_json(row["action_json"], {}),
- "apply_count": row["apply_count"],
- "success_count": row["success_count"],
- "failure_count": row["failure_count"],
- "utility": row["utility"],
- "last_delta": row["last_delta"],
- "cumulative_delta": row["cumulative_delta"],
- "source_task_ids": _parse_json(row["source_task_ids"], []),
- "source_episode_ids": _parse_json(row["source_episode_ids"], []),
- "derivation_version": row["derivation_version"],
- "lineage_fingerprint": row["lineage_fingerprint"],
- "created_at": row["created_at"],
- "updated_at": row["updated_at"],
- "tags": _parse_json(row["tags_json"], []),
- }
-
-
-# =========================================================================
-# AnchorStore
-# =========================================================================
-
-class AnchorStore(_DerivedStoreBase):
- """Hierarchical context anchors (era/place/time/activity)."""
-
- def add(
- self,
- user_id: str,
- *,
- memory_event_id: Optional[str] = None,
- era: Optional[str] = None,
- place: Optional[str] = None,
- place_type: Optional[str] = None,
- place_detail: Optional[str] = None,
- time_absolute: Optional[str] = None,
- time_markers: Optional[List[str]] = None,
- time_range_start: Optional[str] = None,
- time_range_end: Optional[str] = None,
- time_derivation: Optional[str] = None,
- activity: Optional[str] = None,
- session_id: Optional[str] = None,
- session_position: int = 0,
- anchor_id: Optional[str] = None,
- ) -> str:
- aid = anchor_id or str(uuid.uuid4())
- now = _utcnow_iso()
- source_ids = [memory_event_id] if memory_event_id else []
- fp = _compute_lineage_fingerprint(source_ids, 1)
-
- with self._tx() as conn:
- conn.execute(
- """INSERT INTO anchors
- (anchor_id, user_id, memory_event_id, era, place,
- place_type, place_detail, time_absolute,
- time_markers_json, time_range_start, time_range_end,
- time_derivation, activity, session_id, session_position,
- derivation_version, lineage_fingerprint, created_at)
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 1, ?, ?)""",
- (
- aid, user_id, memory_event_id, era, place,
- place_type, place_detail, time_absolute,
- json.dumps(time_markers or []),
- time_range_start, time_range_end, time_derivation,
- activity, session_id, session_position, fp, now,
- ),
- )
- return aid
-
- def get(self, anchor_id: str) -> Optional[Dict[str, Any]]:
- with self._lock:
- row = self._conn.execute(
- "SELECT * FROM anchors WHERE anchor_id = ?",
- (anchor_id,),
- ).fetchone()
- return self._row_to_dict(row) if row else None
-
- def get_by_event(self, memory_event_id: str) -> Optional[Dict[str, Any]]:
- with self._lock:
- row = self._conn.execute(
- "SELECT * FROM anchors WHERE memory_event_id = ?",
- (memory_event_id,),
- ).fetchone()
- return self._row_to_dict(row) if row else None
-
- def list_by_user(
- self,
- user_id: str,
- *,
- era: Optional[str] = None,
- place: Optional[str] = None,
- activity: Optional[str] = None,
- limit: int = 50,
- ) -> List[Dict[str, Any]]:
- query = "SELECT * FROM anchors WHERE user_id = ?"
- params: list = [user_id]
- if era:
- query += " AND era = ?"
- params.append(era)
- if place:
- query += " AND place = ?"
- params.append(place)
- if activity:
- query += " AND activity = ?"
- params.append(activity)
- query += " ORDER BY created_at DESC LIMIT ?"
- params.append(limit)
-
- with self._lock:
- rows = self._conn.execute(query, params).fetchall()
- return [self._row_to_dict(r) for r in rows]
-
- def update_fields(self, anchor_id: str, **fields: Any) -> bool:
- """Update specific anchor fields. Only allows known anchor columns."""
- allowed = {
- "era", "place", "place_type", "place_detail",
- "time_absolute", "time_markers_json", "time_range_start",
- "time_range_end", "time_derivation", "activity",
- }
- updates = {k: v for k, v in fields.items() if k in allowed}
- if not updates:
- return False
-
- set_clause = ", ".join(f"{k} = ?" for k in updates)
- values = list(updates.values()) + [anchor_id]
-
- with self._tx() as conn:
- result = conn.execute(
- f"UPDATE anchors SET {set_clause} WHERE anchor_id = ?",
- values,
- )
- return result.rowcount > 0
-
- @staticmethod
- def _row_to_dict(row: sqlite3.Row) -> Dict[str, Any]:
- return {
- "anchor_id": row["anchor_id"],
- "user_id": row["user_id"],
- "memory_event_id": row["memory_event_id"],
- "era": row["era"],
- "place": row["place"],
- "place_type": row["place_type"],
- "place_detail": row["place_detail"],
- "time_absolute": row["time_absolute"],
- "time_markers": _parse_json(row["time_markers_json"], []),
- "time_range_start": row["time_range_start"],
- "time_range_end": row["time_range_end"],
- "time_derivation": row["time_derivation"],
- "activity": row["activity"],
- "session_id": row["session_id"],
- "session_position": row["session_position"],
- "derivation_version": row["derivation_version"],
- "lineage_fingerprint": row["lineage_fingerprint"],
- "created_at": row["created_at"],
- }
-
-
-# =========================================================================
-# InsightStore
-# =========================================================================
-
-class InsightStore(_DerivedStoreBase):
- """Synthesized causal hypotheses with strength tracking."""
-
- def add(
- self,
- user_id: str,
- content: str,
- *,
- insight_type: str = "pattern",
- source_task_types: Optional[List[str]] = None,
- confidence: float = 0.5,
- tags: Optional[List[str]] = None,
- insight_id: Optional[str] = None,
- ) -> str:
- iid = insight_id or str(uuid.uuid4())
- now = _utcnow_iso()
-
- with self._tx() as conn:
- conn.execute(
- """INSERT INTO insights
- (insight_id, user_id, content, insight_type,
- source_task_types_json, confidence,
- derivation_version, lineage_fingerprint,
- created_at, tags_json)
- VALUES (?, ?, ?, ?, ?, ?, 1, '', ?, ?)""",
- (
- iid, user_id, content, insight_type,
- json.dumps(source_task_types or []),
- confidence, now, json.dumps(tags or []),
- ),
- )
- return iid
-
- def get(self, insight_id: str) -> Optional[Dict[str, Any]]:
- with self._lock:
- row = self._conn.execute(
- "SELECT * FROM insights WHERE insight_id = ?",
- (insight_id,),
- ).fetchone()
- return self._row_to_dict(row) if row else None
-
- def list_by_user(
- self,
- user_id: str,
- *,
- insight_type: Optional[str] = None,
- min_confidence: float = 0.0,
- status: Optional[str] = None,
- limit: int = 50,
- ) -> List[Dict[str, Any]]:
- query = "SELECT * FROM insights WHERE user_id = ? AND confidence >= ?"
- params: list = [user_id, min_confidence]
- if insight_type:
- query += " AND insight_type = ?"
- params.append(insight_type)
- if status:
- query += " AND status = ?"
- params.append(status)
- query += " ORDER BY confidence DESC LIMIT ?"
- params.append(limit)
-
- with self._lock:
- rows = self._conn.execute(query, params).fetchall()
- return [self._row_to_dict(r) for r in rows]
-
- def record_outcome(
- self,
- insight_id: str,
- success: bool,
- *,
- baseline_score: Optional[float] = None,
- actual_score: Optional[float] = None,
- ) -> bool:
- """Record validation/invalidation outcome. Updates confidence + utility."""
- now = _utcnow_iso()
- with self._tx() as conn:
- row = conn.execute(
- """SELECT confidence, validation_count, invalidation_count,
- utility, apply_count, status
- FROM insights WHERE insight_id = ?""",
- (insight_id,),
- ).fetchone()
- if not row:
- return False
-
- conf = row["confidence"]
- v_count = row["validation_count"]
- i_count = row["invalidation_count"]
- utility = row["utility"]
- apply_count = row["apply_count"] + 1
-
- if success:
- v_count += 1
- conf = min(1.0, conf + 0.05)
- else:
- i_count += 1
- conf = max(0.0, conf - 0.1)
-
- if baseline_score is not None and actual_score is not None:
- delta = actual_score - baseline_score
- utility = 0.3 * delta + 0.7 * utility
-
- conn.execute(
- """UPDATE insights
- SET confidence = ?, validation_count = ?,
- invalidation_count = ?, utility = ?,
- apply_count = ?, last_validated = ?,
- status = ?
- WHERE insight_id = ?""",
- (
- conf, v_count, i_count, utility, apply_count, now,
- row["status"], # preserve current status
- insight_id,
- ),
- )
- return True
-
- def set_status(self, insight_id: str, status: str) -> bool:
- with self._tx() as conn:
- result = conn.execute(
- "UPDATE insights SET status = ? WHERE insight_id = ?",
- (status, insight_id),
- )
- return result.rowcount > 0
-
- def get_by_invalidation_status(
- self, status: str, *, limit: int = 100
- ) -> List[Dict[str, Any]]:
- with self._lock:
- rows = self._conn.execute(
- "SELECT * FROM insights WHERE status = ? LIMIT ?",
- (status, limit),
- ).fetchall()
- return [self._row_to_dict(r) for r in rows]
-
- @staticmethod
- def _row_to_dict(row: sqlite3.Row) -> Dict[str, Any]:
- return {
- "insight_id": row["insight_id"],
- "user_id": row["user_id"],
- "content": row["content"],
- "insight_type": row["insight_type"],
- "source_task_types": _parse_json(row["source_task_types_json"], []),
- "confidence": row["confidence"],
- "validation_count": row["validation_count"],
- "invalidation_count": row["invalidation_count"],
- "utility": row["utility"],
- "apply_count": row["apply_count"],
- "derivation_version": row["derivation_version"],
- "lineage_fingerprint": row["lineage_fingerprint"],
- "created_at": row["created_at"],
- "last_validated": row["last_validated"],
- "tags": _parse_json(row["tags_json"], []),
- "status": row["status"],
- }
-
-
-# =========================================================================
-# HeuristicStore
-# =========================================================================
-
-class HeuristicStore(_DerivedStoreBase):
- """Transferable reasoning patterns (ERL, 3 abstraction levels)."""
-
- def add(
- self,
- user_id: str,
- content: str,
- *,
- abstraction_level: str = "specific",
- source_task_types: Optional[List[str]] = None,
- confidence: float = 0.5,
- tags: Optional[List[str]] = None,
- heuristic_id: Optional[str] = None,
- ) -> str:
- hid = heuristic_id or str(uuid.uuid4())
- now = _utcnow_iso()
-
- with self._tx() as conn:
- conn.execute(
- """INSERT INTO heuristics
- (heuristic_id, user_id, content, abstraction_level,
- source_task_types_json, confidence,
- derivation_version, lineage_fingerprint,
- created_at, tags_json)
- VALUES (?, ?, ?, ?, ?, ?, 1, '', ?, ?)""",
- (
- hid, user_id, content, abstraction_level,
- json.dumps(source_task_types or []),
- confidence, now, json.dumps(tags or []),
- ),
- )
- return hid
-
- def get(self, heuristic_id: str) -> Optional[Dict[str, Any]]:
- with self._lock:
- row = self._conn.execute(
- "SELECT * FROM heuristics WHERE heuristic_id = ?",
- (heuristic_id,),
- ).fetchone()
- return self._row_to_dict(row) if row else None
-
- def list_by_user(
- self,
- user_id: str,
- *,
- abstraction_level: Optional[str] = None,
- min_confidence: float = 0.0,
- status: Optional[str] = None,
- limit: int = 50,
- ) -> List[Dict[str, Any]]:
- query = "SELECT * FROM heuristics WHERE user_id = ? AND confidence >= ?"
- params: list = [user_id, min_confidence]
- if abstraction_level:
- query += " AND abstraction_level = ?"
- params.append(abstraction_level)
- if status:
- query += " AND status = ?"
- params.append(status)
- query += " ORDER BY confidence DESC LIMIT ?"
- params.append(limit)
-
- with self._lock:
- rows = self._conn.execute(query, params).fetchall()
- return [self._row_to_dict(r) for r in rows]
-
- def record_outcome(
- self,
- heuristic_id: str,
- success: bool,
- *,
- baseline_score: Optional[float] = None,
- actual_score: Optional[float] = None,
- ) -> bool:
- now = _utcnow_iso()
- with self._tx() as conn:
- row = conn.execute(
- """SELECT confidence, validation_count, invalidation_count,
- utility, last_delta, apply_count, status
- FROM heuristics WHERE heuristic_id = ?""",
- (heuristic_id,),
- ).fetchone()
- if not row:
- return False
-
- conf = row["confidence"]
- v_count = row["validation_count"]
- i_count = row["invalidation_count"]
- utility = row["utility"]
- apply_count = row["apply_count"] + 1
- delta = 0.0
-
- if success:
- v_count += 1
- conf = min(1.0, conf + 0.05)
- else:
- i_count += 1
- conf = max(0.0, conf - 0.1)
-
- if baseline_score is not None and actual_score is not None:
- delta = actual_score - baseline_score
- utility = 0.3 * delta + 0.7 * utility
-
- conn.execute(
- """UPDATE heuristics
- SET confidence = ?, validation_count = ?,
- invalidation_count = ?, utility = ?,
- last_delta = ?, apply_count = ?, status = ?
- WHERE heuristic_id = ?""",
- (
- conf, v_count, i_count, utility,
- delta, apply_count, row["status"],
- heuristic_id,
- ),
- )
- return True
-
- def set_status(self, heuristic_id: str, status: str) -> bool:
- with self._tx() as conn:
- result = conn.execute(
- "UPDATE heuristics SET status = ? WHERE heuristic_id = ?",
- (status, heuristic_id),
- )
- return result.rowcount > 0
-
- def get_by_invalidation_status(
- self, status: str, *, limit: int = 100
- ) -> List[Dict[str, Any]]:
- with self._lock:
- rows = self._conn.execute(
- "SELECT * FROM heuristics WHERE status = ? LIMIT ?",
- (status, limit),
- ).fetchall()
- return [self._row_to_dict(r) for r in rows]
-
- @staticmethod
- def _row_to_dict(row: sqlite3.Row) -> Dict[str, Any]:
- return {
- "heuristic_id": row["heuristic_id"],
- "user_id": row["user_id"],
- "content": row["content"],
- "abstraction_level": row["abstraction_level"],
- "source_task_types": _parse_json(row["source_task_types_json"], []),
- "confidence": row["confidence"],
- "validation_count": row["validation_count"],
- "invalidation_count": row["invalidation_count"],
- "utility": row["utility"],
- "last_delta": row["last_delta"],
- "apply_count": row["apply_count"],
- "derivation_version": row["derivation_version"],
- "lineage_fingerprint": row["lineage_fingerprint"],
- "created_at": row["created_at"],
- "tags": _parse_json(row["tags_json"], []),
- "status": row["status"],
- }
-
-
-# =========================================================================
-# DerivedLineageStore
-# =========================================================================
-
-class DerivedLineageStore(_DerivedStoreBase):
- """Links derived objects to source raw events for traceability.
-
- Supports the three-tier invalidation model:
- - Given a source event, find all derived objects that depend on it
- - Given a derived object, find all source events it was built from
- - Contribution weight enables partial invalidation decisions
- """
-
- def add(
- self,
- derived_type: str,
- derived_id: str,
- source_event_id: str,
- *,
- contribution_weight: float = 1.0,
- lineage_id: Optional[str] = None,
- ) -> str:
- lid = lineage_id or str(uuid.uuid4())
- with self._tx() as conn:
- conn.execute(
- """INSERT INTO derived_lineage
- (lineage_id, derived_type, derived_id,
- source_event_id, contribution_weight)
- VALUES (?, ?, ?, ?, ?)""",
- (lid, derived_type, derived_id, source_event_id, contribution_weight),
- )
- return lid
-
- def add_batch(
- self,
- derived_type: str,
- derived_id: str,
- source_event_ids: List[str],
- *,
- weights: Optional[List[float]] = None,
- ) -> List[str]:
- """Add multiple lineage links at once."""
- w = weights or [1.0] * len(source_event_ids)
- ids = []
- with self._tx() as conn:
- for eid, weight in zip(source_event_ids, w):
- lid = str(uuid.uuid4())
- conn.execute(
- """INSERT INTO derived_lineage
- (lineage_id, derived_type, derived_id,
- source_event_id, contribution_weight)
- VALUES (?, ?, ?, ?, ?)""",
- (lid, derived_type, derived_id, eid, weight),
- )
- ids.append(lid)
- return ids
-
- def get_sources(
- self, derived_type: str, derived_id: str
- ) -> List[Dict[str, Any]]:
- """Get all source events for a derived object."""
- with self._lock:
- rows = self._conn.execute(
- """SELECT lineage_id, source_event_id, contribution_weight, created_at
- FROM derived_lineage
- WHERE derived_type = ? AND derived_id = ?""",
- (derived_type, derived_id),
- ).fetchall()
- return [
- {
- "lineage_id": r["lineage_id"],
- "source_event_id": r["source_event_id"],
- "contribution_weight": r["contribution_weight"],
- "created_at": r["created_at"],
- }
- for r in rows
- ]
-
- def get_dependents(
- self, source_event_id: str
- ) -> List[Dict[str, Any]]:
- """Get all derived objects that depend on a source event.
-
- This is the key query for invalidation cascades.
- """
- with self._lock:
- rows = self._conn.execute(
- """SELECT lineage_id, derived_type, derived_id,
- contribution_weight, created_at
- FROM derived_lineage
- WHERE source_event_id = ?""",
- (source_event_id,),
- ).fetchall()
- return [
- {
- "lineage_id": r["lineage_id"],
- "derived_type": r["derived_type"],
- "derived_id": r["derived_id"],
- "contribution_weight": r["contribution_weight"],
- "created_at": r["created_at"],
- }
- for r in rows
- ]
-
- def get_source_count(self, derived_type: str, derived_id: str) -> int:
- """Count source events for a derived object."""
- with self._lock:
- row = self._conn.execute(
- """SELECT COUNT(*) FROM derived_lineage
- WHERE derived_type = ? AND derived_id = ?""",
- (derived_type, derived_id),
- ).fetchone()
- return row[0] if row else 0
-
- def get_contribution_weight(
- self, derived_type: str, derived_id: str, source_event_id: str
- ) -> Optional[float]:
- """Get the contribution weight of a specific source to a derived object.
-
- Used by partial invalidation to decide severity.
- """
- with self._lock:
- row = self._conn.execute(
- """SELECT contribution_weight FROM derived_lineage
- WHERE derived_type = ? AND derived_id = ? AND source_event_id = ?""",
- (derived_type, derived_id, source_event_id),
- ).fetchone()
- return row["contribution_weight"] if row else None
-
- def delete_for_derived(self, derived_type: str, derived_id: str) -> int:
- """Remove all lineage links for a derived object (e.g., before re-deriving)."""
- with self._tx() as conn:
- result = conn.execute(
- "DELETE FROM derived_lineage WHERE derived_type = ? AND derived_id = ?",
- (derived_type, derived_id),
- )
- return result.rowcount
-
-
-# =========================================================================
-# CognitionStore — Coordinator that holds all sub-stores
-# =========================================================================
-
-class CognitionStore:
- """Unified access to all v3 stores sharing a single SQLite connection.
-
- Usage:
- store = CognitionStore() # or CognitionStore(db_path="...")
- store.events.add(content="...", user_id="...")
- store.beliefs.add(user_id="...", claim="...")
- store.lineage.add("belief", bid, event_id)
- """
-
- def __init__(self, db_path: Optional[str] = None):
- from dhee.core.events import RawEventStore, _default_db_path
-
- self.db_path = db_path or _default_db_path()
-
- # RawEventStore owns the connection and schema initialization
- self.events = RawEventStore(self.db_path)
-
- # All derived stores share the same connection + lock
- conn = self.events._conn
- lock = self.events._lock
-
- self.beliefs = BeliefStore(conn, lock)
- self.policies = PolicyStore(conn, lock)
- self.anchors = AnchorStore(conn, lock)
- self.insights = InsightStore(conn, lock)
- self.heuristics = HeuristicStore(conn, lock)
- self.lineage = DerivedLineageStore(conn, lock)
-
- def close(self) -> None:
- self.events.close()
diff --git a/dhee/core/events.py b/dhee/core/events.py
deleted file mode 100644
index 5bd12b7..0000000
--- a/dhee/core/events.py
+++ /dev/null
@@ -1,443 +0,0 @@
-"""Dhee v3 — RawEventStore: immutable source-of-truth memory events.
-
-Every call to remember() writes an immutable raw event. Corrections create
-new events with supersedes_event_id pointing to the original. Deletions
-mark events as 'deleted' (soft delete — never physically removed).
-
-Design contract:
- - Raw events are never mutated after creation
- - Content-hash dedup prevents duplicate storage of identical content
- - Corrections/deletions change status of the OLD event and create a NEW event
- - All derived cognition traces back to raw events via derived_lineage
- - Zero LLM calls — this is a pure storage layer
-"""
-
-from __future__ import annotations
-
-import hashlib
-import json
-import logging
-import os
-import sqlite3
-import threading
-import uuid
-from contextlib import contextmanager
-from datetime import datetime, timezone
-from dataclasses import dataclass, field
-from enum import Enum
-from typing import Any, Dict, List, Optional
-
-from dhee.core.storage import initialize_schema
-
-logger = logging.getLogger(__name__)
-
-
-class EventStatus(str, Enum):
- ACTIVE = "active"
- CORRECTED = "corrected"
- DELETED = "deleted"
-
-
-@dataclass
-class RawMemoryEvent:
- """In-memory representation of a raw memory event."""
-
- event_id: str
- user_id: str
- content: str
- content_hash: str
- status: EventStatus = EventStatus.ACTIVE
- session_id: Optional[str] = None
- source: str = "user"
- supersedes_event_id: Optional[str] = None
- metadata: Dict[str, Any] = field(default_factory=dict)
- created_at: Optional[str] = None
-
- @staticmethod
- def compute_hash(content: str) -> str:
- """SHA-256 content hash for dedup."""
- return hashlib.sha256(content.encode("utf-8")).hexdigest()
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "event_id": self.event_id,
- "user_id": self.user_id,
- "content": self.content,
- "content_hash": self.content_hash,
- "status": self.status.value,
- "session_id": self.session_id,
- "source": self.source,
- "supersedes_event_id": self.supersedes_event_id,
- "metadata": self.metadata,
- "created_at": self.created_at,
- }
-
-
-def _utcnow_iso() -> str:
- return datetime.now(timezone.utc).isoformat()
-
-
-def _default_db_path() -> str:
- data_dir = os.environ.get("DHEE_DATA_DIR") or os.path.join(
- os.path.expanduser("~"), ".dhee"
- )
- return os.path.join(data_dir, "v3.db")
-
-
-class RawEventStore:
- """Immutable raw event storage backed by SQLite.
-
- Thread-safe via RLock. Follows the same connection pattern as
- dhee/db/sqlite.py (_SQLiteBase).
- """
-
- def __init__(self, db_path: Optional[str] = None):
- self.db_path = db_path or _default_db_path()
- db_dir = os.path.dirname(self.db_path)
- if db_dir:
- os.makedirs(db_dir, exist_ok=True)
-
- self._conn = sqlite3.connect(self.db_path, check_same_thread=False)
- self._conn.execute("PRAGMA journal_mode=WAL")
- self._conn.execute("PRAGMA busy_timeout=5000")
- self._conn.execute("PRAGMA synchronous=FULL")
- self._conn.execute("PRAGMA cache_size=-8000")
- self._conn.execute("PRAGMA temp_store=MEMORY")
- self._conn.row_factory = sqlite3.Row
- self._lock = threading.RLock()
-
- # Initialize all v3 tables
- initialize_schema(self._conn)
-
- def close(self) -> None:
- with self._lock:
- if self._conn:
- try:
- self._conn.close()
- except Exception:
- pass
- self._conn = None # type: ignore[assignment]
-
- @contextmanager
- def _tx(self):
- """Yield connection under lock with commit/rollback."""
- with self._lock:
- try:
- yield self._conn
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
-
- # ------------------------------------------------------------------
- # Write operations
- # ------------------------------------------------------------------
-
- def add(
- self,
- content: str,
- user_id: str,
- *,
- session_id: Optional[str] = None,
- source: str = "user",
- metadata: Optional[Dict[str, Any]] = None,
- event_id: Optional[str] = None,
- ) -> RawMemoryEvent:
- """Store a new raw memory event. Returns the event (existing if dedup hit).
-
- Content-hash dedup: if identical content already exists for this user
- and is active, returns the existing event instead of creating a duplicate.
- """
- content_hash = RawMemoryEvent.compute_hash(content)
- eid = event_id or str(uuid.uuid4())
- meta = metadata or {}
- now = _utcnow_iso()
-
- with self._tx() as conn:
- # Dedup check — same content, same user, still active
- existing = conn.execute(
- """SELECT event_id, user_id, content, content_hash, status,
- session_id, source, supersedes_event_id,
- metadata_json, created_at
- FROM raw_memory_events
- WHERE content_hash = ? AND user_id = ? AND status = 'active'
- LIMIT 1""",
- (content_hash, user_id),
- ).fetchone()
-
- if existing:
- return self._row_to_event(existing)
-
- conn.execute(
- """INSERT INTO raw_memory_events
- (event_id, user_id, session_id, created_at, content,
- content_hash, source, status, metadata_json)
- VALUES (?, ?, ?, ?, ?, ?, ?, 'active', ?)""",
- (
- eid, user_id, session_id, now, content,
- content_hash, source, json.dumps(meta),
- ),
- )
-
- return RawMemoryEvent(
- event_id=eid,
- user_id=user_id,
- content=content,
- content_hash=content_hash,
- status=EventStatus.ACTIVE,
- session_id=session_id,
- source=source,
- metadata=meta,
- created_at=now,
- )
-
- def correct(
- self,
- original_event_id: str,
- new_content: str,
- *,
- source: str = "user_correction",
- metadata: Optional[Dict[str, Any]] = None,
- ) -> RawMemoryEvent:
- """Correct an existing event.
-
- 1. Marks the original event as 'corrected'
- 2. Creates a new event with supersedes_event_id pointing to original
- 3. Returns the new event
-
- Raises ValueError if original event not found or not active.
- """
- with self._tx() as conn:
- original = conn.execute(
- """SELECT event_id, user_id, session_id, status
- FROM raw_memory_events WHERE event_id = ?""",
- (original_event_id,),
- ).fetchone()
-
- if not original:
- raise ValueError(f"Event not found: {original_event_id}")
- if original["status"] != "active":
- raise ValueError(
- f"Cannot correct event with status '{original['status']}': "
- f"{original_event_id}"
- )
-
- # Mark original as corrected
- conn.execute(
- "UPDATE raw_memory_events SET status = 'corrected' WHERE event_id = ?",
- (original_event_id,),
- )
-
- # Create correction event
- new_id = str(uuid.uuid4())
- content_hash = RawMemoryEvent.compute_hash(new_content)
- meta = metadata or {}
- now = _utcnow_iso()
-
- conn.execute(
- """INSERT INTO raw_memory_events
- (event_id, user_id, session_id, created_at, content,
- content_hash, source, status, supersedes_event_id, metadata_json)
- VALUES (?, ?, ?, ?, ?, ?, ?, 'active', ?, ?)""",
- (
- new_id, original["user_id"], original["session_id"],
- now, new_content, content_hash, source,
- original_event_id, json.dumps(meta),
- ),
- )
-
- return RawMemoryEvent(
- event_id=new_id,
- user_id=original["user_id"],
- content=new_content,
- content_hash=content_hash,
- status=EventStatus.ACTIVE,
- session_id=original["session_id"],
- source=source,
- supersedes_event_id=original_event_id,
- metadata=meta,
- created_at=now,
- )
-
- def delete(self, event_id: str) -> bool:
- """Soft-delete a raw event. Returns True if status changed.
-
- Marks the event as 'deleted'. Does NOT physically remove it.
- Raises ValueError if event not found.
- """
- with self._tx() as conn:
- row = conn.execute(
- "SELECT status FROM raw_memory_events WHERE event_id = ?",
- (event_id,),
- ).fetchone()
-
- if not row:
- raise ValueError(f"Event not found: {event_id}")
- if row["status"] == "deleted":
- return False
-
- conn.execute(
- "UPDATE raw_memory_events SET status = 'deleted' WHERE event_id = ?",
- (event_id,),
- )
- return True
-
- # ------------------------------------------------------------------
- # Read operations
- # ------------------------------------------------------------------
-
- def get(self, event_id: str) -> Optional[RawMemoryEvent]:
- """Get a single event by ID."""
- with self._lock:
- row = self._conn.execute(
- """SELECT event_id, user_id, content, content_hash, status,
- session_id, source, supersedes_event_id,
- metadata_json, created_at
- FROM raw_memory_events WHERE event_id = ?""",
- (event_id,),
- ).fetchone()
- return self._row_to_event(row) if row else None
-
- def get_by_hash(
- self, content_hash: str, user_id: str
- ) -> Optional[RawMemoryEvent]:
- """Get active event by content hash + user."""
- with self._lock:
- row = self._conn.execute(
- """SELECT event_id, user_id, content, content_hash, status,
- session_id, source, supersedes_event_id,
- metadata_json, created_at
- FROM raw_memory_events
- WHERE content_hash = ? AND user_id = ? AND status = 'active'
- LIMIT 1""",
- (content_hash, user_id),
- ).fetchone()
- return self._row_to_event(row) if row else None
-
- def list_by_user(
- self,
- user_id: str,
- *,
- status: Optional[EventStatus] = None,
- limit: int = 100,
- offset: int = 0,
- ) -> List[RawMemoryEvent]:
- """List events for a user, newest first."""
- with self._lock:
- if status:
- rows = self._conn.execute(
- """SELECT event_id, user_id, content, content_hash, status,
- session_id, source, supersedes_event_id,
- metadata_json, created_at
- FROM raw_memory_events
- WHERE user_id = ? AND status = ?
- ORDER BY created_at DESC
- LIMIT ? OFFSET ?""",
- (user_id, status.value, limit, offset),
- ).fetchall()
- else:
- rows = self._conn.execute(
- """SELECT event_id, user_id, content, content_hash, status,
- session_id, source, supersedes_event_id,
- metadata_json, created_at
- FROM raw_memory_events
- WHERE user_id = ?
- ORDER BY created_at DESC
- LIMIT ? OFFSET ?""",
- (user_id, limit, offset),
- ).fetchall()
- return [self._row_to_event(r) for r in rows]
-
- def get_supersedes_chain(self, event_id: str) -> List[RawMemoryEvent]:
- """Walk the supersedes chain from newest to oldest.
-
- Given an event that supersedes another, returns the full chain:
- [newest_correction, ..., original_event]
- """
- chain: List[RawMemoryEvent] = []
- seen: set = set()
- current_id: Optional[str] = event_id
-
- while current_id and current_id not in seen:
- seen.add(current_id)
- event = self.get(current_id)
- if not event:
- break
- chain.append(event)
- current_id = event.supersedes_event_id
-
- return chain
-
- def count(
- self, user_id: str, *, status: Optional[EventStatus] = None
- ) -> int:
- """Count events for a user."""
- with self._lock:
- if status:
- row = self._conn.execute(
- "SELECT COUNT(*) FROM raw_memory_events WHERE user_id = ? AND status = ?",
- (user_id, status.value),
- ).fetchone()
- else:
- row = self._conn.execute(
- "SELECT COUNT(*) FROM raw_memory_events WHERE user_id = ?",
- (user_id,),
- ).fetchone()
- return row[0] if row else 0
-
- def get_events_since(
- self, user_id: str, since_iso: str, *, status: Optional[EventStatus] = None
- ) -> List[RawMemoryEvent]:
- """Get events created after a given ISO timestamp."""
- with self._lock:
- if status:
- rows = self._conn.execute(
- """SELECT event_id, user_id, content, content_hash, status,
- session_id, source, supersedes_event_id,
- metadata_json, created_at
- FROM raw_memory_events
- WHERE user_id = ? AND created_at > ? AND status = ?
- ORDER BY created_at ASC""",
- (user_id, since_iso, status.value),
- ).fetchall()
- else:
- rows = self._conn.execute(
- """SELECT event_id, user_id, content, content_hash, status,
- session_id, source, supersedes_event_id,
- metadata_json, created_at
- FROM raw_memory_events
- WHERE user_id = ? AND created_at > ?
- ORDER BY created_at ASC""",
- (user_id, since_iso),
- ).fetchall()
- return [self._row_to_event(r) for r in rows]
-
- # ------------------------------------------------------------------
- # Internal
- # ------------------------------------------------------------------
-
- @staticmethod
- def _row_to_event(row: sqlite3.Row) -> RawMemoryEvent:
- meta_raw = row["metadata_json"]
- if isinstance(meta_raw, str):
- try:
- meta = json.loads(meta_raw)
- except (json.JSONDecodeError, TypeError):
- meta = {}
- elif isinstance(meta_raw, dict):
- meta = meta_raw
- else:
- meta = {}
-
- return RawMemoryEvent(
- event_id=row["event_id"],
- user_id=row["user_id"],
- content=row["content"],
- content_hash=row["content_hash"],
- status=EventStatus(row["status"]),
- session_id=row["session_id"],
- source=row["source"],
- supersedes_event_id=row["supersedes_event_id"],
- metadata=meta,
- created_at=row["created_at"],
- )
diff --git a/dhee/core/fusion_v3.py b/dhee/core/fusion_v3.py
deleted file mode 100644
index 03732c7..0000000
--- a/dhee/core/fusion_v3.py
+++ /dev/null
@@ -1,303 +0,0 @@
-"""Dhee v3 — 5-Stage Weighted Reciprocal Rank Fusion Pipeline.
-
-Explicit ranking contract (zero LLM on hot path):
-
-Stage 1: Per-index retrieval (parallel, 0 LLM)
-Stage 2: Score normalization (min-max within each index)
-Stage 3: Weighted RRF (k=60, configurable weights per index)
-Stage 4: Post-fusion adjustments (recency, confidence, staleness, conflicts)
-Stage 5: Final ranking + dedup
-
-No reranker stage. If retrieval quality is insufficient, fix embeddings
-or distillation, not the hot path.
-"""
-
-from __future__ import annotations
-
-import logging
-import math
-from dataclasses import dataclass, field
-from datetime import datetime, timezone
-from typing import Any, Dict, List, Optional, Tuple
-
-logger = logging.getLogger(__name__)
-
-
-@dataclass
-class FusionConfig:
- """Configuration for the 5-stage fusion pipeline."""
-
- # Stage 1: Per-index top-K
- raw_top_k: int = 20
- distilled_top_k: int = 15
- episodic_top_k: int = 10
-
- # Stage 3: RRF weights
- rrf_k: int = 60 # standard RRF constant
- weight_distilled: float = 1.0
- weight_episodic: float = 0.7
- weight_raw: float = 0.5
-
- # Stage 4: Adjustment parameters
- recency_boost_max: float = 0.3 # max 30% boost for fresh raw
- recency_decay_hours: float = 24.0
- confidence_floor: float = 0.5 # score *= 0.5 + 0.5 * confidence
- staleness_penalty: float = 0.3 # stale/suspect get 70% penalty
- contradiction_penalty: float = 0.5 # open conflicts get 50% penalty
-
- # Stage 5: Final output
- final_top_n: int = 10
-
-
-@dataclass
-class FusionCandidate:
- """A candidate passing through the fusion pipeline."""
-
- row_id: str
- source_kind: str # raw | distilled | episodic
- source_type: str # event | belief | policy | insight | heuristic
- source_id: str
- retrieval_text: str
- raw_score: float = 0.0 # cosine similarity from index
- normalized_score: float = 0.0 # after min-max normalization
- rrf_score: float = 0.0 # after weighted RRF
- adjusted_score: float = 0.0 # after post-fusion adjustments
- confidence: float = 1.0
- utility: float = 0.0
- status: str = "active"
- created_at: Optional[str] = None
- has_open_conflicts: bool = False
- # Lineage for dedup
- lineage_event_ids: Optional[List[str]] = None
-
-
-@dataclass
-class FusionBreakdown:
- """Loggable breakdown of how fusion produced its results."""
-
- query: str
- config: Dict[str, Any]
- per_index_counts: Dict[str, int]
- pre_adjustment_top5: List[Dict[str, Any]]
- post_adjustment_top5: List[Dict[str, Any]]
- dedup_removed: int
- final_count: int
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "query": self.query[:100],
- "per_index_counts": self.per_index_counts,
- "pre_adjustment_top5": self.pre_adjustment_top5,
- "post_adjustment_top5": self.post_adjustment_top5,
- "dedup_removed": self.dedup_removed,
- "final_count": self.final_count,
- }
-
-
-def _parse_iso(iso_str: Optional[str]) -> Optional[datetime]:
- if not iso_str:
- return None
- try:
- return datetime.fromisoformat(iso_str.replace("Z", "+00:00"))
- except (ValueError, AttributeError):
- return None
-
-
-class RRFFusion:
- """5-stage Weighted Reciprocal Rank Fusion pipeline.
-
- Usage:
- fusion = RRFFusion(config)
- results, breakdown = fusion.fuse(
- raw_candidates=[...],
- distilled_candidates=[...],
- episodic_candidates=[...],
- conflict_checker=lambda type, id: bool,
- )
- """
-
- def __init__(self, config: Optional[FusionConfig] = None):
- self.config = config or FusionConfig()
-
- def fuse(
- self,
- raw_candidates: List[FusionCandidate],
- distilled_candidates: List[FusionCandidate],
- episodic_candidates: Optional[List[FusionCandidate]] = None,
- *,
- conflict_checker: Optional[Any] = None,
- query: str = "",
- ) -> Tuple[List[FusionCandidate], FusionBreakdown]:
- """Run the full 5-stage fusion pipeline.
-
- Args:
- raw_candidates: Candidates from raw_index with raw_score set
- distilled_candidates: Candidates from distilled_index
- episodic_candidates: Optional candidates from episodic_index
- conflict_checker: callable(source_type, source_id) -> bool
- query: The original query string (for logging)
-
- Returns:
- (ranked_results, breakdown)
- """
- cfg = self.config
- episodic = episodic_candidates or []
-
- # Stage 1: Trim to per-index top-K
- raw = sorted(raw_candidates, key=lambda c: -c.raw_score)[:cfg.raw_top_k]
- dist = sorted(distilled_candidates, key=lambda c: -c.raw_score)[:cfg.distilled_top_k]
- epi = sorted(episodic, key=lambda c: -c.raw_score)[:cfg.episodic_top_k]
-
- per_index_counts = {
- "raw": len(raw), "distilled": len(dist), "episodic": len(epi),
- }
-
- # Stage 2: Min-max normalization within each index
- self._normalize(raw)
- self._normalize(dist)
- self._normalize(epi)
-
- # Stage 3: Weighted RRF
- # Build a combined dict: row_id → candidate, accumulating RRF score
- combined: Dict[str, FusionCandidate] = {}
-
- for rank, c in enumerate(raw):
- rrf = cfg.weight_raw / (cfg.rrf_k + rank + 1)
- if c.row_id in combined:
- combined[c.row_id].rrf_score += rrf
- else:
- c.rrf_score = rrf
- combined[c.row_id] = c
-
- for rank, c in enumerate(dist):
- rrf = cfg.weight_distilled / (cfg.rrf_k + rank + 1)
- if c.row_id in combined:
- combined[c.row_id].rrf_score += rrf
- else:
- c.rrf_score = rrf
- combined[c.row_id] = c
-
- for rank, c in enumerate(epi):
- rrf = cfg.weight_episodic / (cfg.rrf_k + rank + 1)
- if c.row_id in combined:
- combined[c.row_id].rrf_score += rrf
- else:
- c.rrf_score = rrf
- combined[c.row_id] = c
-
- # Pre-adjustment snapshot
- pre_sorted = sorted(combined.values(), key=lambda c: -c.rrf_score)
- pre_top5 = [
- {"row_id": c.row_id, "kind": c.source_kind, "rrf": round(c.rrf_score, 6)}
- for c in pre_sorted[:5]
- ]
-
- # Stage 4: Post-fusion adjustments
- now = datetime.now(timezone.utc)
- for c in combined.values():
- score = c.rrf_score
-
- # Recency boost (raw only)
- if c.source_kind == "raw" and c.created_at:
- created = _parse_iso(c.created_at)
- if created:
- age_hours = max(0, (now - created).total_seconds() / 3600)
- boost = 1.0 + cfg.recency_boost_max * math.exp(
- -age_hours / cfg.recency_decay_hours
- )
- score *= boost
-
- # Confidence normalization
- score *= cfg.confidence_floor + (1.0 - cfg.confidence_floor) * c.confidence
-
- # Staleness penalty
- if c.status in ("stale", "suspect"):
- score *= cfg.staleness_penalty
-
- # Hard invalidation exclusion
- if c.status == "invalidated":
- score = 0.0
-
- # Contradiction penalty
- if conflict_checker and c.source_type and c.source_id:
- try:
- if conflict_checker(c.source_type, c.source_id):
- c.has_open_conflicts = True
- score *= cfg.contradiction_penalty
- except Exception:
- pass
-
- c.adjusted_score = score
-
- # Stage 5: Final ranking + dedup
- ranked = sorted(combined.values(), key=lambda c: -c.adjusted_score)
-
- # Dedup: if raw and distilled of same content via lineage, keep distilled
- seen_source_ids: Dict[str, FusionCandidate] = {}
- deduped: List[FusionCandidate] = []
- dedup_removed = 0
-
- for c in ranked:
- sid = c.source_id
- if sid in seen_source_ids:
- existing = seen_source_ids[sid]
- # Keep the distilled version
- if c.source_kind == "distilled" and existing.source_kind == "raw":
- deduped = [x for x in deduped if x.source_id != sid]
- deduped.append(c)
- seen_source_ids[sid] = c
- dedup_removed += 1
- else:
- dedup_removed += 1
- else:
- seen_source_ids[sid] = c
- deduped.append(c)
-
- final = deduped[:cfg.final_top_n]
-
- # Post-adjustment snapshot
- post_top5 = [
- {
- "row_id": c.row_id, "kind": c.source_kind,
- "adjusted": round(c.adjusted_score, 6),
- "conflicts": c.has_open_conflicts,
- }
- for c in final[:5]
- ]
-
- breakdown = FusionBreakdown(
- query=query,
- config={
- "rrf_k": cfg.rrf_k,
- "weights": {
- "raw": cfg.weight_raw,
- "distilled": cfg.weight_distilled,
- "episodic": cfg.weight_episodic,
- },
- },
- per_index_counts=per_index_counts,
- pre_adjustment_top5=pre_top5,
- post_adjustment_top5=post_top5,
- dedup_removed=dedup_removed,
- final_count=len(final),
- )
-
- return final, breakdown
-
- @staticmethod
- def _normalize(candidates: List[FusionCandidate]) -> None:
- """Min-max normalize raw_score within a candidate list."""
- if not candidates:
- return
-
- scores = [c.raw_score for c in candidates]
- min_s = min(scores)
- max_s = max(scores)
- spread = max_s - min_s
-
- if spread < 1e-9:
- for c in candidates:
- c.normalized_score = 1.0 if max_s > 0 else 0.0
- else:
- for c in candidates:
- c.normalized_score = (c.raw_score - min_s) / spread
diff --git a/dhee/core/graph_evolution.py b/dhee/core/graph_evolution.py
deleted file mode 100644
index 3192c53..0000000
--- a/dhee/core/graph_evolution.py
+++ /dev/null
@@ -1,600 +0,0 @@
-"""Evolving knowledge graph — versioned entities, personalized PageRank, schema-free extraction.
-
-Extends KnowledgeGraph (graph.py) with three capabilities:
-
-1. **Entity versioning**: Every entity mutation is stored as a version snapshot.
- Queries can ask "what was X at time T?" and diffs show how entities evolve.
-
-2. **Personalized PageRank**: Per-user / per-agent importance scores over
- the entity–memory graph. Guides retrieval toward what matters *to this user*.
-
-3. **Schema-free extraction**: Uses BuddhiMini (or any LLM) to discover
- entity types at runtime, stored as EntityType.DYNAMIC with a type_label.
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-import os
-import time
-from collections import defaultdict
-from dataclasses import dataclass, field
-from datetime import datetime, timezone
-from typing import Any, Dict, List, Optional, Set, Tuple
-
-from dhee.core.graph import (
- Entity,
- EntityType,
- KnowledgeGraph,
- Relationship,
- RelationType,
-)
-
-logger = logging.getLogger(__name__)
-
-
-# ---------------------------------------------------------------------------
-# Entity Versioning
-# ---------------------------------------------------------------------------
-
-@dataclass
-class EntityVersion:
- """A point-in-time snapshot of an entity's state."""
-
- entity_name: str
- version: int
- timestamp: str # ISO-8601
- entity_type: str
- type_label: Optional[str] = None # For DYNAMIC entities
- aliases: List[str] = field(default_factory=list)
- memory_ids: List[str] = field(default_factory=list)
- metadata: Dict[str, Any] = field(default_factory=dict)
- change_reason: str = "" # What triggered this version
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "entity_name": self.entity_name,
- "version": self.version,
- "timestamp": self.timestamp,
- "entity_type": self.entity_type,
- "type_label": self.type_label,
- "aliases": self.aliases,
- "memory_ids": self.memory_ids,
- "metadata": self.metadata,
- "change_reason": self.change_reason,
- }
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "EntityVersion":
- return cls(
- entity_name=data["entity_name"],
- version=data["version"],
- timestamp=data["timestamp"],
- entity_type=data["entity_type"],
- type_label=data.get("type_label"),
- aliases=data.get("aliases", []),
- memory_ids=data.get("memory_ids", []),
- metadata=data.get("metadata", {}),
- change_reason=data.get("change_reason", ""),
- )
-
-
-def _now_iso() -> str:
- return datetime.now(timezone.utc).isoformat()
-
-
-class EntityVersionStore:
- """Append-only version log for entity snapshots.
-
- Stored as JSONL on disk — one line per version, O(1) append.
- """
-
- def __init__(self, path: str):
- self._path = path
- self._versions: Dict[str, List[EntityVersion]] = defaultdict(list)
- self._load()
-
- def _load(self) -> None:
- if not os.path.exists(self._path):
- return
- try:
- with open(self._path, "r", encoding="utf-8") as f:
- for line in f:
- line = line.strip()
- if not line:
- continue
- v = EntityVersion.from_dict(json.loads(line))
- self._versions[v.entity_name].append(v)
- except (OSError, json.JSONDecodeError) as e:
- logger.warning("Failed to load entity versions: %s", e)
-
- def record(self, entity: Entity, reason: str = "") -> EntityVersion:
- """Snapshot the current state of an entity."""
- history = self._versions[entity.name]
- version_num = (history[-1].version + 1) if history else 1
-
- v = EntityVersion(
- entity_name=entity.name,
- version=version_num,
- timestamp=_now_iso(),
- entity_type=entity.entity_type.value,
- type_label=entity.metadata.get("type_label"),
- aliases=sorted(entity.aliases),
- memory_ids=sorted(entity.memory_ids),
- metadata=dict(entity.metadata),
- change_reason=reason,
- )
-
- history.append(v)
- self._append(v)
- return v
-
- def _append(self, v: EntityVersion) -> None:
- try:
- os.makedirs(os.path.dirname(self._path) or ".", exist_ok=True)
- with open(self._path, "a", encoding="utf-8") as f:
- f.write(json.dumps(v.to_dict(), ensure_ascii=False) + "\n")
- except OSError as e:
- logger.warning("Failed to persist entity version: %s", e)
-
- def get_history(self, entity_name: str) -> List[EntityVersion]:
- """All versions of an entity, oldest first."""
- return list(self._versions.get(entity_name, []))
-
- def get_at_time(self, entity_name: str, iso_time: str) -> Optional[EntityVersion]:
- """Get the entity version that was current at a given time."""
- history = self._versions.get(entity_name, [])
- result = None
- for v in history:
- if v.timestamp <= iso_time:
- result = v
- else:
- break
- return result
-
- def diff(self, entity_name: str, v1: int, v2: int) -> Dict[str, Any]:
- """Compute the difference between two versions of an entity."""
- history = self._versions.get(entity_name, [])
- ver_map = {v.version: v for v in history}
- old = ver_map.get(v1)
- new = ver_map.get(v2)
- if not old or not new:
- return {"error": "version not found"}
-
- changes: Dict[str, Any] = {}
- if old.entity_type != new.entity_type:
- changes["entity_type"] = {"old": old.entity_type, "new": new.entity_type}
- if old.type_label != new.type_label:
- changes["type_label"] = {"old": old.type_label, "new": new.type_label}
-
- old_aliases = set(old.aliases)
- new_aliases = set(new.aliases)
- if old_aliases != new_aliases:
- changes["aliases_added"] = sorted(new_aliases - old_aliases)
- changes["aliases_removed"] = sorted(old_aliases - new_aliases)
-
- old_mids = set(old.memory_ids)
- new_mids = set(new.memory_ids)
- if old_mids != new_mids:
- changes["memories_added"] = sorted(new_mids - old_mids)
- changes["memories_removed"] = sorted(old_mids - new_mids)
-
- # Metadata diff (shallow)
- for key in set(old.metadata) | set(new.metadata):
- old_val = old.metadata.get(key)
- new_val = new.metadata.get(key)
- if old_val != new_val:
- changes.setdefault("metadata", {})[key] = {
- "old": old_val, "new": new_val,
- }
-
- return {
- "entity": entity_name,
- "from_version": v1,
- "to_version": v2,
- "changes": changes,
- }
-
- @property
- def entity_count(self) -> int:
- return len(self._versions)
-
-
-# ---------------------------------------------------------------------------
-# Personalized PageRank
-# ---------------------------------------------------------------------------
-
-class PersonalizedPageRank:
- """Per-user importance ranking over the entity–memory graph.
-
- Runs a standard power-iteration PageRank seeded from a user's
- interaction history (memories they wrote, entities they mentioned).
- Results are cached and refreshed when the graph changes.
- """
-
- def __init__(
- self,
- damping: float = 0.85,
- iterations: int = 20,
- tolerance: float = 1e-6,
- ):
- self.damping = damping
- self.iterations = iterations
- self.tolerance = tolerance
- self._cache: Dict[str, Dict[str, float]] = {} # user_id -> {node -> score}
-
- def compute(
- self,
- graph: KnowledgeGraph,
- seed_memory_ids: Optional[Set[str]] = None,
- user_id: str = "default",
- ) -> Dict[str, float]:
- """Compute personalized PageRank for a user.
-
- Args:
- graph: The knowledge graph.
- seed_memory_ids: Memories that form the user's personalization
- vector. If None, uses all memories.
- user_id: Cache key.
-
- Returns:
- Dict mapping node IDs (memory IDs and "entity:name") to scores.
- """
- # Build adjacency from relationships
- adj: Dict[str, Set[str]] = defaultdict(set)
- all_nodes: Set[str] = set()
-
- for rel in graph.relationships:
- adj[rel.source_id].add(rel.target_id)
- adj[rel.target_id].add(rel.source_id)
- all_nodes.add(rel.source_id)
- all_nodes.add(rel.target_id)
-
- # Add entity nodes
- for entity_name, entity in graph.entities.items():
- enode = f"entity:{entity_name}"
- all_nodes.add(enode)
- for mid in entity.memory_ids:
- adj[mid].add(enode)
- adj[enode].add(mid)
- all_nodes.add(mid)
-
- if not all_nodes:
- return {}
-
- n = len(all_nodes)
- node_list = sorted(all_nodes)
- node_idx = {node: i for i, node in enumerate(node_list)}
-
- # Personalization vector: uniform over seeds, zero elsewhere
- personalization = [0.0] * n
- if seed_memory_ids:
- seeds_in_graph = [
- node_idx[mid] for mid in seed_memory_ids if mid in node_idx
- ]
- if seeds_in_graph:
- weight = 1.0 / len(seeds_in_graph)
- for idx in seeds_in_graph:
- personalization[idx] = weight
- else:
- personalization = [1.0 / n] * n
- else:
- personalization = [1.0 / n] * n
-
- # Power iteration
- scores = [1.0 / n] * n
- for _ in range(self.iterations):
- new_scores = [0.0] * n
- for i, node in enumerate(node_list):
- neighbors = adj.get(node, set())
- if not neighbors:
- # Dangling node — distribute uniformly
- share = scores[i] / n
- for j in range(n):
- new_scores[j] += share
- else:
- share = scores[i] / len(neighbors)
- for nb in neighbors:
- if nb in node_idx:
- new_scores[node_idx[nb]] += share
-
- # Apply damping + personalization
- for i in range(n):
- new_scores[i] = (
- (1 - self.damping) * personalization[i]
- + self.damping * new_scores[i]
- )
-
- # Check convergence
- delta = sum(abs(new_scores[i] - scores[i]) for i in range(n))
- scores = new_scores
- if delta < self.tolerance:
- break
-
- result = {node_list[i]: scores[i] for i in range(n)}
- self._cache[user_id] = result
- return result
-
- def get_top_entities(
- self,
- graph: KnowledgeGraph,
- user_id: str = "default",
- seed_memory_ids: Optional[Set[str]] = None,
- limit: int = 20,
- ) -> List[Tuple[str, float]]:
- """Get top-ranked entities for a user."""
- if user_id not in self._cache:
- self.compute(graph, seed_memory_ids=seed_memory_ids, user_id=user_id)
-
- scores = self._cache.get(user_id, {})
- entity_scores = [
- (name.replace("entity:", ""), score)
- for name, score in scores.items()
- if name.startswith("entity:")
- ]
- entity_scores.sort(key=lambda x: x[1], reverse=True)
- return entity_scores[:limit]
-
- def boost_retrieval(
- self,
- memory_ids: List[str],
- user_id: str = "default",
- ) -> Dict[str, float]:
- """Get PageRank boost factors for a set of candidate memory IDs."""
- scores = self._cache.get(user_id, {})
- if not scores:
- return {}
- return {mid: scores.get(mid, 0.0) for mid in memory_ids}
-
- def invalidate(self, user_id: Optional[str] = None) -> None:
- """Clear cached scores. Call when graph changes."""
- if user_id:
- self._cache.pop(user_id, None)
- else:
- self._cache.clear()
-
-
-# ---------------------------------------------------------------------------
-# Schema-Free Entity Extraction
-# ---------------------------------------------------------------------------
-
-_SCHEMA_FREE_PROMPT = """Extract entities from the following text. For each entity, provide:
-- name: The entity name
-- type: A descriptive type (e.g., "person", "technology", "framework", "metric",
- "emotion", "event", "disease", "recipe" — any type that fits, not limited to a fixed set)
-- relevance: How important this entity is to the text (0.0 to 1.0)
-
-Text: {content}
-
-Return a JSON array. Example:
-[{{"name": "FastAPI", "type": "framework", "relevance": 0.9}}]
-
-Return ONLY the JSON array:"""
-
-
-def extract_entities_schema_free(
- content: str,
- memory_id: str,
- graph: KnowledgeGraph,
- llm: Any = None,
- min_relevance: float = 0.3,
-) -> List[Entity]:
- """Extract entities without a fixed type schema.
-
- Uses LLM to discover entity types at runtime. Discovered types are
- stored as EntityType.DYNAMIC with a ``type_label`` in metadata.
-
- Falls back to graph.extract_entities() regex path if no LLM.
- """
- if not llm:
- return graph.extract_entities(content, memory_id, use_llm=False)
-
- prompt = _SCHEMA_FREE_PROMPT.format(content=content[:2000])
-
- try:
- response = llm.generate(prompt)
- arr_start = response.find("[")
- if arr_start < 0:
- return graph.extract_entities(content, memory_id, use_llm=False)
-
- items, _ = json.JSONDecoder().raw_decode(response, arr_start)
- except Exception as e:
- logger.debug("Schema-free extraction failed (%s), falling back to regex", e)
- return graph.extract_entities(content, memory_id, use_llm=False)
-
- # Known EntityType values (lowercase)
- _known_types = {t.value for t in EntityType}
-
- entities: List[Entity] = []
- for item in items:
- name = item.get("name", "").strip()
- if not name:
- continue
-
- relevance = float(item.get("relevance", 0.5))
- if relevance < min_relevance:
- continue
-
- raw_type = item.get("type", "unknown").strip().lower()
-
- # Map to existing enum if possible; otherwise DYNAMIC
- if raw_type in _known_types:
- entity_type = EntityType(raw_type)
- type_label = None
- else:
- entity_type = EntityType.DYNAMIC
- type_label = raw_type
-
- entity = graph._get_or_create_entity(name, entity_type)
- entity.memory_ids.add(memory_id)
- entity.metadata["relevance"] = max(
- entity.metadata.get("relevance", 0.0), relevance,
- )
- if type_label:
- entity.metadata["type_label"] = type_label
-
- entities.append(entity)
-
- graph.memory_entities[memory_id] = {e.name for e in entities}
- return entities
-
-
-# ---------------------------------------------------------------------------
-# EvolvingGraph — wraps it all together
-# ---------------------------------------------------------------------------
-
-class EvolvingGraph:
- """Knowledge graph with entity versioning, PageRank, and schema-free extraction.
-
- Drop-in extension of KnowledgeGraph — delegates core graph operations
- and adds evolution capabilities on top.
- """
-
- def __init__(
- self,
- data_dir: Optional[str] = None,
- llm: Any = None,
- damping: float = 0.85,
- ):
- self._data_dir = data_dir or os.path.join(
- os.path.expanduser("~"), ".dhee", "graph",
- )
- os.makedirs(self._data_dir, exist_ok=True)
-
- graph_path = os.path.join(self._data_dir, "graph.json")
- self.graph = KnowledgeGraph.load(graph_path, llm=llm)
-
- self._versions = EntityVersionStore(
- os.path.join(self._data_dir, "entity_versions.jsonl"),
- )
- self._pagerank = PersonalizedPageRank(damping=damping)
- self._llm = llm
-
- # ── Entity operations (versioned) ──
-
- def extract_and_version(
- self,
- content: str,
- memory_id: str,
- reason: str = "new_memory",
- schema_free: bool = True,
- ) -> List[Entity]:
- """Extract entities from content and record versions for any changes."""
- if schema_free and self._llm:
- entities = extract_entities_schema_free(
- content, memory_id, self.graph, llm=self._llm,
- )
- else:
- entities = self.graph.extract_entities(content, memory_id)
-
- # Record version for each entity that was touched
- for entity in entities:
- self._versions.record(entity, reason=reason)
-
- # Invalidate PageRank caches (graph changed)
- self._pagerank.invalidate()
-
- return entities
-
- def update_entity(
- self,
- entity_name: str,
- updates: Dict[str, Any],
- reason: str = "update",
- ) -> Optional[Entity]:
- """Update an entity's fields and record the version."""
- entity = self.graph.entities.get(entity_name)
- if not entity:
- return None
-
- if "entity_type" in updates:
- entity.entity_type = EntityType(updates["entity_type"])
- if "aliases" in updates:
- entity.aliases.update(updates["aliases"])
- if "metadata" in updates:
- entity.metadata.update(updates["metadata"])
-
- self._versions.record(entity, reason=reason)
- self._pagerank.invalidate()
- return entity
-
- def get_entity_history(self, entity_name: str) -> List[EntityVersion]:
- return self._versions.get_history(entity_name)
-
- def get_entity_at_time(
- self, entity_name: str, iso_time: str,
- ) -> Optional[EntityVersion]:
- return self._versions.get_at_time(entity_name, iso_time)
-
- def entity_diff(
- self, entity_name: str, v1: int, v2: int,
- ) -> Dict[str, Any]:
- return self._versions.diff(entity_name, v1, v2)
-
- # ── PageRank ──
-
- def compute_pagerank(
- self,
- user_id: str = "default",
- seed_memory_ids: Optional[Set[str]] = None,
- ) -> Dict[str, float]:
- return self._pagerank.compute(
- self.graph, seed_memory_ids=seed_memory_ids, user_id=user_id,
- )
-
- def get_important_entities(
- self,
- user_id: str = "default",
- seed_memory_ids: Optional[Set[str]] = None,
- limit: int = 20,
- ) -> List[Tuple[str, float]]:
- return self._pagerank.get_top_entities(
- self.graph, user_id=user_id,
- seed_memory_ids=seed_memory_ids, limit=limit,
- )
-
- def pagerank_boost(
- self,
- memory_ids: List[str],
- user_id: str = "default",
- ) -> Dict[str, float]:
- return self._pagerank.boost_retrieval(memory_ids, user_id=user_id)
-
- # ── Graph delegation ──
-
- def add_relationship(self, *args, **kwargs) -> Relationship:
- rel = self.graph.add_relationship(*args, **kwargs)
- self._pagerank.invalidate()
- return rel
-
- def link_by_shared_entities(self, memory_id: str) -> List[Relationship]:
- rels = self.graph.link_by_shared_entities(memory_id)
- if rels:
- self._pagerank.invalidate()
- return rels
-
- def get_related_memories(self, *args, **kwargs):
- return self.graph.get_related_memories(*args, **kwargs)
-
- def get_causal_chain(self, *args, **kwargs):
- return self.graph.get_causal_chain(*args, **kwargs)
-
- def get_memory_graph(self, memory_id: str) -> Dict[str, Any]:
- return self.graph.get_memory_graph(memory_id)
-
- # ── Persistence ──
-
- def save(self) -> None:
- """Persist graph to disk. Version store auto-persists on append."""
- graph_path = os.path.join(self._data_dir, "graph.json")
- self.graph.save(graph_path)
-
- def stats(self) -> Dict[str, Any]:
- base = self.graph.stats()
- base["versioned_entities"] = self._versions.entity_count
- base["dynamic_entities"] = sum(
- 1 for e in self.graph.entities.values()
- if e.entity_type == EntityType.DYNAMIC
- )
- return base
diff --git a/dhee/core/invalidation.py b/dhee/core/invalidation.py
deleted file mode 100644
index 2b64316..0000000
--- a/dhee/core/invalidation.py
+++ /dev/null
@@ -1,248 +0,0 @@
-"""Dhee v3 — Three-Tier Invalidation Engine.
-
-Graduated invalidation based on what happened to the source:
-
-1. Hard invalidation: source deleted → derived tombstoned
-2. Soft invalidation: source corrected → derived marked stale, re-eval queued
-3. Partial invalidation: one of N sources changed, contribution < 30%
- → derived marked suspect with confidence penalty
-
-Design contract:
- - Invalidation is async — marks status + enqueues repair jobs
- - Never synchronously rewrites derived objects
- - Type-aware: each derived type has its own invalidation response
- - All cascades are traceable via maintenance_jobs
- - Zero LLM calls
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-import uuid
-from datetime import datetime, timezone
-from typing import Any, Dict, List, Optional
-
-logger = logging.getLogger(__name__)
-
-
-def _utcnow_iso() -> str:
- return datetime.now(timezone.utc).isoformat()
-
-
-# Threshold: if a changed source contributed >= this fraction,
-# escalate from partial to soft invalidation
-PARTIAL_ESCALATION_THRESHOLD = 0.30
-
-
-class InvalidationEngine:
- """Cascades invalidation from raw events to derived objects.
-
- Usage:
- engine = InvalidationEngine(lineage, stores_map, job_enqueuer)
- engine.on_event_corrected(event_id) # soft + partial
- engine.on_event_deleted(event_id) # hard + partial
- """
-
- def __init__(
- self,
- lineage: "DerivedLineageStore",
- stores: Dict[str, Any],
- conn: "sqlite3.Connection",
- lock: "threading.RLock",
- ):
- """
- Args:
- lineage: DerivedLineageStore for tracing dependencies
- stores: Map of derived_type → store instance, e.g.
- {"belief": belief_store, "policy": policy_store, ...}
- conn: Shared SQLite connection for enqueuing jobs
- lock: Shared threading lock
- """
- self.lineage = lineage
- self.stores = stores
- self._conn = conn
- self._lock = lock
-
- def on_event_corrected(self, event_id: str) -> Dict[str, Any]:
- """Handle a raw event being corrected (superseded by new event).
-
- For each dependent derived object:
- - If sole source → soft invalidation (stale)
- - If one of many → check contribution weight:
- - >= 30% → soft invalidation
- - < 30% → partial invalidation (suspect + confidence penalty)
- """
- return self._cascade(event_id, mode="corrected")
-
- def on_event_deleted(self, event_id: str) -> Dict[str, Any]:
- """Handle a raw event being deleted.
-
- For each dependent derived object:
- - If sole source → hard invalidation (tombstone)
- - If one of many → check contribution weight:
- - >= 30% → soft invalidation (stale + re-eval)
- - < 30% → partial invalidation (suspect + confidence penalty)
- """
- return self._cascade(event_id, mode="deleted")
-
- def _cascade(self, event_id: str, mode: str) -> Dict[str, Any]:
- """Core cascade logic for both correction and deletion."""
- dependents = self.lineage.get_dependents(event_id)
-
- result = {
- "event_id": event_id,
- "mode": mode,
- "hard_invalidated": [],
- "soft_invalidated": [],
- "partial_invalidated": [],
- "jobs_enqueued": [],
- "errors": [],
- }
-
- for dep in dependents:
- dtype = dep["derived_type"]
- did = dep["derived_id"]
- weight = dep["contribution_weight"]
-
- try:
- # How many total sources does this derived object have?
- total_sources = self.lineage.get_source_count(dtype, did)
-
- if total_sources <= 1:
- # Sole source — hard or soft depending on mode
- if mode == "deleted":
- self._hard_invalidate(dtype, did)
- result["hard_invalidated"].append(
- {"type": dtype, "id": did}
- )
- else: # corrected
- self._soft_invalidate(dtype, did)
- job_id = self._enqueue_repair(
- dtype, did, "repair_stale_derived"
- )
- result["soft_invalidated"].append(
- {"type": dtype, "id": did}
- )
- result["jobs_enqueued"].append(job_id)
-
- elif weight >= PARTIAL_ESCALATION_THRESHOLD:
- # Major contributor — treat as soft invalidation
- self._soft_invalidate(dtype, did)
- job_id = self._enqueue_repair(
- dtype, did, "repair_stale_derived"
- )
- result["soft_invalidated"].append(
- {"type": dtype, "id": did, "weight": weight}
- )
- result["jobs_enqueued"].append(job_id)
-
- else:
- # Minor contributor — partial invalidation
- self._partial_invalidate(dtype, did, weight)
- job_id = self._enqueue_repair(
- dtype, did, "verify_suspect_derived"
- )
- result["partial_invalidated"].append(
- {"type": dtype, "id": did, "weight": weight}
- )
- result["jobs_enqueued"].append(job_id)
-
- except Exception as e:
- logger.exception(
- "Invalidation failed for %s:%s from event %s",
- dtype, did, event_id,
- )
- result["errors"].append({
- "type": dtype, "id": did, "error": str(e)
- })
-
- return result
-
- # ------------------------------------------------------------------
- # Invalidation tier implementations
- # ------------------------------------------------------------------
-
- def _hard_invalidate(self, dtype: str, did: str) -> None:
- """Source gone, child unusable. Mark as tombstone."""
- store = self.stores.get(dtype)
- if store and hasattr(store, "set_status"):
- store.set_status(did, "invalidated")
- logger.info("Hard invalidated %s:%s", dtype, did)
-
- def _soft_invalidate(self, dtype: str, did: str) -> None:
- """Source changed, child needs re-evaluation."""
- store = self.stores.get(dtype)
- if store and hasattr(store, "set_status"):
- store.set_status(did, "stale")
- logger.info("Soft invalidated %s:%s → stale", dtype, did)
-
- def _partial_invalidate(
- self, dtype: str, did: str, weight: float
- ) -> None:
- """Minor source change. Mark suspect + confidence penalty."""
- store = self.stores.get(dtype)
- if not store:
- return
-
- # Apply confidence penalty proportional to contribution weight
- if hasattr(store, "get") and hasattr(store, "update_confidence"):
- obj = store.get(did)
- if obj and "confidence" in obj:
- penalty = weight * 0.5 # half the contribution weight
- new_conf = max(0.05, obj["confidence"] - penalty)
- store.update_confidence(
- did, new_conf,
- new_status="suspect",
- revision_reason=f"partial_invalidation (weight={weight:.2f})",
- )
- elif hasattr(store, "set_status"):
- store.set_status(did, "suspect")
-
- logger.info(
- "Partial invalidated %s:%s → suspect (weight=%.2f)",
- dtype, did, weight,
- )
-
- # ------------------------------------------------------------------
- # Job enqueuing
- # ------------------------------------------------------------------
-
- def _enqueue_repair(
- self, dtype: str, did: str, job_name: str
- ) -> str:
- """Enqueue a repair job for a derived object."""
- job_id = str(uuid.uuid4())
- now = _utcnow_iso()
- payload = json.dumps({
- "derived_type": dtype,
- "derived_id": did,
- })
- idem_key = f"{job_name}:{dtype}:{did}"
-
- with self._lock:
- try:
- # Check idempotency — don't enqueue if already pending/running
- existing = self._conn.execute(
- """SELECT job_id FROM maintenance_jobs
- WHERE idempotency_key = ?
- AND status IN ('pending', 'running')
- LIMIT 1""",
- (idem_key,),
- ).fetchone()
-
- if existing:
- return existing["job_id"]
-
- self._conn.execute(
- """INSERT INTO maintenance_jobs
- (job_id, job_name, status, payload_json,
- created_at, idempotency_key)
- VALUES (?, ?, 'pending', ?, ?, ?)""",
- (job_id, job_name, payload, now, idem_key),
- )
- self._conn.commit()
- return job_id
- except Exception:
- self._conn.rollback()
- raise
diff --git a/dhee/core/jobs.py b/dhee/core/jobs.py
deleted file mode 100644
index 69f7e5f..0000000
--- a/dhee/core/jobs.py
+++ /dev/null
@@ -1,477 +0,0 @@
-"""Dhee v3 — Job Registry: named, idempotent, observable maintenance jobs.
-
-Replaces agi_loop.py's phantom subsystems with real, independently testable jobs.
-
-Each job:
- - Has a unique name (e.g., "distill_episodic_to_semantic")
- - Is idempotent: same input → same output, safe to retry
- - Is observable: status, timing, retry count tracked in maintenance_jobs table
- - Is leasable: acquires a lock before running (via LeaseManager)
- - Returns a structured result dict
-
-Design contract:
- - Jobs NEVER call memory.add() or any write-path that triggers enrichment
- - Jobs write to derived stores + lineage only
- - Jobs are cold-path: called by heartbeat/cron, never by hot-path remember/recall
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-import traceback
-import uuid
-from abc import ABC, abstractmethod
-from datetime import datetime, timezone
-from typing import Any, Dict, List, Optional, Type
-
-from dhee.core.lease_manager import LeaseManager
-
-logger = logging.getLogger(__name__)
-
-
-def _utcnow_iso() -> str:
- return datetime.now(timezone.utc).isoformat()
-
-
-class Job(ABC):
- """Base class for all maintenance jobs."""
-
- # Subclasses MUST set this
- name: str = ""
-
- def __init__(self):
- if not self.name:
- raise ValueError(f"{self.__class__.__name__} must set 'name'")
-
- @abstractmethod
- def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
- """Run the job. Returns a result dict.
-
- Args:
- payload: Job-specific input parameters
-
- Returns:
- Dict with job results (stored in maintenance_jobs.result_json)
-
- Raises:
- Exception on failure (caught by JobRegistry, stored as error)
- """
- ...
-
- def make_idempotency_key(self, payload: Dict[str, Any]) -> Optional[str]:
- """Generate an idempotency key for dedup. Override if needed.
-
- Returns None to skip idempotency checking (every run creates a new job).
- """
- return None
-
-
-class JobRegistry:
- """Manages job registration, scheduling, execution, and observability.
-
- Usage:
- registry = JobRegistry(conn, lock, lease_manager)
- registry.register(ApplyForgettingJob)
- registry.register(DistillEpisodicJob)
-
- # Run a specific job
- result = registry.run("apply_forgetting", payload={"user_id": "u1"})
-
- # Run all due jobs (heartbeat)
- results = registry.run_all(owner_id="worker-1")
- """
-
- def __init__(
- self,
- conn: "sqlite3.Connection",
- lock: "threading.RLock",
- lease_manager: LeaseManager,
- ):
- import sqlite3
- import threading
-
- self._conn = conn
- self._lock = lock
- self._lease = lease_manager
- self._jobs: Dict[str, Job] = {}
-
- def register(self, job_class: Type[Job]) -> None:
- """Register a job class. Instantiates it."""
- job = job_class()
- if job.name in self._jobs:
- logger.warning("Job %s already registered, replacing", job.name)
- self._jobs[job.name] = job
-
- def list_registered(self) -> List[str]:
- """List all registered job names."""
- return list(self._jobs.keys())
-
- def run(
- self,
- job_name: str,
- *,
- payload: Optional[Dict[str, Any]] = None,
- owner_id: str = "default-worker",
- ) -> Dict[str, Any]:
- """Run a single job by name. Acquires lease, executes, records result.
-
- Returns:
- Dict with: job_id, job_name, status, result/error, timing
- """
- if job_name not in self._jobs:
- return {
- "job_name": job_name,
- "status": "error",
- "error": f"Unknown job: {job_name}",
- }
-
- job = self._jobs[job_name]
- payload = payload or {}
-
- # Idempotency check
- idem_key = job.make_idempotency_key(payload)
- if idem_key:
- with self._lock:
- existing = self._conn.execute(
- """SELECT job_id, status, result_json FROM maintenance_jobs
- WHERE idempotency_key = ? AND status IN ('completed', 'running')
- LIMIT 1""",
- (idem_key,),
- ).fetchone()
- if existing:
- return {
- "job_id": existing["job_id"],
- "job_name": job_name,
- "status": "skipped_idempotent",
- "existing_status": existing["status"],
- }
-
- # Acquire lease
- lock_id = f"job:{job_name}"
- if not self._lease.acquire(lock_id, owner_id):
- return {
- "job_name": job_name,
- "status": "skipped_locked",
- "holder": self._lease.get_holder(lock_id),
- }
-
- # Create job record
- job_id = str(uuid.uuid4())
- now = _utcnow_iso()
-
- try:
- with self._lock:
- self._conn.execute(
- """INSERT INTO maintenance_jobs
- (job_id, job_name, status, payload_json,
- created_at, started_at, idempotency_key)
- VALUES (?, ?, 'running', ?, ?, ?, ?)""",
- (
- job_id, job_name, json.dumps(payload),
- now, now, idem_key,
- ),
- )
- self._conn.commit()
-
- # Execute
- result = job.execute(payload)
- completed_at = _utcnow_iso()
-
- with self._lock:
- self._conn.execute(
- """UPDATE maintenance_jobs
- SET status = 'completed', result_json = ?,
- completed_at = ?
- WHERE job_id = ?""",
- (json.dumps(result, default=str), completed_at, job_id),
- )
- self._conn.commit()
-
- return {
- "job_id": job_id,
- "job_name": job_name,
- "status": "completed",
- "result": result,
- "started_at": now,
- "completed_at": completed_at,
- }
-
- except Exception as e:
- error_msg = f"{type(e).__name__}: {e}"
- logger.exception("Job %s failed: %s", job_name, error_msg)
-
- with self._lock:
- self._conn.execute(
- """UPDATE maintenance_jobs
- SET status = 'failed', error_message = ?,
- completed_at = ?,
- retry_count = retry_count + 1
- WHERE job_id = ?""",
- (error_msg, _utcnow_iso(), job_id),
- )
- self._conn.commit()
-
- return {
- "job_id": job_id,
- "job_name": job_name,
- "status": "failed",
- "error": error_msg,
- }
-
- finally:
- self._lease.release(lock_id, owner_id)
-
- def run_all(
- self, *, owner_id: str = "default-worker"
- ) -> List[Dict[str, Any]]:
- """Run all registered jobs. Returns list of results."""
- results = []
- for name in self._jobs:
- result = self.run(name, owner_id=owner_id)
- results.append(result)
- return results
-
- def get_job_history(
- self,
- job_name: str,
- *,
- limit: int = 10,
- ) -> List[Dict[str, Any]]:
- """Get recent execution history for a job."""
- with self._lock:
- rows = self._conn.execute(
- """SELECT job_id, job_name, status, payload_json,
- result_json, error_message,
- created_at, started_at, completed_at,
- retry_count
- FROM maintenance_jobs
- WHERE job_name = ?
- ORDER BY created_at DESC
- LIMIT ?""",
- (job_name, limit),
- ).fetchall()
-
- return [
- {
- "job_id": r["job_id"],
- "job_name": r["job_name"],
- "status": r["status"],
- "payload": json.loads(r["payload_json"] or "{}"),
- "result": json.loads(r["result_json"] or "null"),
- "error": r["error_message"],
- "created_at": r["created_at"],
- "started_at": r["started_at"],
- "completed_at": r["completed_at"],
- "retry_count": r["retry_count"],
- }
- for r in rows
- ]
-
- def get_health(self) -> Dict[str, Any]:
- """Get health summary across all registered jobs."""
- health: Dict[str, Any] = {
- "registered_jobs": list(self._jobs.keys()),
- "total_registered": len(self._jobs),
- "job_status": {},
- }
-
- for name in self._jobs:
- with self._lock:
- row = self._conn.execute(
- """SELECT status, completed_at, error_message
- FROM maintenance_jobs
- WHERE job_name = ?
- ORDER BY created_at DESC
- LIMIT 1""",
- (name,),
- ).fetchone()
-
- if row:
- health["job_status"][name] = {
- "last_status": row["status"],
- "last_completed": row["completed_at"],
- "last_error": row["error_message"],
- }
- else:
- health["job_status"][name] = {"last_status": "never_run"}
-
- return health
-
-
-# =========================================================================
-# Concrete Job Implementations
-# =========================================================================
-
-class ApplyForgettingJob(Job):
- """Apply decay/forgetting curves to memory strengths.
-
- Replaces agi_loop step 2 (decay).
- """
-
- name = "apply_forgetting"
-
- def __init__(self):
- super().__init__()
- self._memory = None # Set by caller via set_context()
-
- def set_context(self, memory: Any) -> "ApplyForgettingJob":
- self._memory = memory
- return self
-
- def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
- if not self._memory:
- return {"status": "skipped", "reason": "no memory instance"}
-
- user_id = payload.get("user_id", "default")
- try:
- result = self._memory.apply_decay(scope={"user_id": user_id})
- return {"status": "ok", "decay_result": result}
- except Exception as e:
- return {"status": "error", "error": str(e)}
-
-
-class RunConsolidationJob(Job):
- """Run the cognition kernel's sleep_cycle (distillation).
-
- Replaces agi_loop step 1 (consolidate).
- """
-
- name = "run_consolidation"
-
- def __init__(self):
- super().__init__()
- self._kernel = None
-
- def set_context(self, kernel: Any) -> "RunConsolidationJob":
- self._kernel = kernel
- return self
-
- def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
- if not self._kernel:
- return {"status": "skipped", "reason": "no kernel instance"}
-
- user_id = payload.get("user_id", "default")
- try:
- result = self._kernel.sleep_cycle(user_id=user_id)
- return {"status": "ok", "consolidation_result": result}
- except Exception as e:
- return {"status": "error", "error": str(e)}
-
-
-class ExtractStepPoliciesJob(Job):
- """Extract step-level policies from completed tasks.
-
- Replaces the inline policy extraction in record_learning_outcomes.
- """
-
- name = "extract_step_policies"
-
- def __init__(self):
- super().__init__()
- self._kernel = None
-
- def set_context(self, kernel: Any) -> "ExtractStepPoliciesJob":
- self._kernel = kernel
- return self
-
- def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
- if not self._kernel:
- return {"status": "skipped", "reason": "no kernel instance"}
-
- user_id = payload.get("user_id", "default")
- task_id = payload.get("task_id")
-
- if not task_id:
- return {"status": "skipped", "reason": "no task_id in payload"}
-
- try:
- # Look up the task
- task = self._kernel.task_manager.get(task_id)
- if not task:
- return {"status": "skipped", "reason": f"task {task_id} not found"}
-
- # Use existing policy extraction
- policies_created = 0
- if hasattr(self._kernel, 'policy_manager') and self._kernel.policy_manager:
- from dhee.core.task_state import TaskStatus
- if task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED):
- # Extract via existing mechanisms
- policies_created = self._kernel.policy_manager.extract_from_task(
- task, user_id=user_id
- ) if hasattr(self._kernel.policy_manager, 'extract_from_task') else 0
-
- return {"status": "ok", "policies_created": policies_created}
- except Exception as e:
- return {"status": "error", "error": str(e)}
-
- def make_idempotency_key(self, payload: Dict[str, Any]) -> Optional[str]:
- task_id = payload.get("task_id", "")
- return f"step_policies:{task_id}" if task_id else None
-
-
-class DetectFailurePatternsJob(Job):
- """Run the FailurePatternDetector on terminal tasks.
-
- Replaces the inline pattern detection in record_learning_outcomes.
- """
-
- name = "detect_failure_patterns"
-
- def __init__(self):
- super().__init__()
- self._kernel = None
-
- def set_context(self, kernel: Any) -> "DetectFailurePatternsJob":
- self._kernel = kernel
- return self
-
- def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
- if not self._kernel:
- return {"status": "skipped", "reason": "no kernel instance"}
-
- user_id = payload.get("user_id", "default")
- try:
- from dhee.core.pattern_detector import (
- FailurePatternDetector, extract_features,
- )
- from dhee.core.task_state import TaskStatus
-
- # Get terminal tasks
- all_tasks = self._kernel.task_manager.list_tasks(
- user_id=user_id, limit=200
- )
- terminal = [
- t for t in all_tasks
- if t.status in (TaskStatus.COMPLETED, TaskStatus.FAILED)
- ]
-
- if len(terminal) < 10:
- return {
- "status": "ok", "patterns_found": 0,
- "reason": f"only {len(terminal)} terminal tasks (need 10+)",
- }
-
- features = extract_features(terminal)
- detector = FailurePatternDetector()
- patterns = detector.detect_and_describe(features)
-
- stored = 0
- for pattern in patterns:
- if hasattr(self._kernel, '_store_pattern_as_policy'):
- policy = self._kernel._store_pattern_as_policy(
- user_id, "detected", pattern,
- )
- if policy:
- stored += 1
-
- return {
- "status": "ok",
- "terminal_tasks": len(terminal),
- "patterns_found": len(patterns),
- "patterns_stored": stored,
- }
- except ImportError:
- return {"status": "skipped", "reason": "pattern_detector not available"}
- except Exception as e:
- return {"status": "error", "error": str(e)}
diff --git a/dhee/core/lease_manager.py b/dhee/core/lease_manager.py
deleted file mode 100644
index aeb70b6..0000000
--- a/dhee/core/lease_manager.py
+++ /dev/null
@@ -1,256 +0,0 @@
-"""Dhee v3 — SQLite Lease Manager for job concurrency control.
-
-Ensures that only one runner can execute a given maintenance job at a time.
-Uses SQLite's BEGIN IMMEDIATE for atomic lease acquisition.
-
-Design contract:
- - Leases are time-bounded (default 300s)
- - Expired leases are automatically stolen
- - Renew extends lease while holding it
- - Release is explicit; stale leases cleaned on next acquire
- - Zero external dependencies — pure SQLite
-"""
-
-from __future__ import annotations
-
-import logging
-import sqlite3
-import threading
-import uuid
-from contextlib import contextmanager
-from datetime import datetime, timezone, timedelta
-from typing import Optional
-
-logger = logging.getLogger(__name__)
-
-DEFAULT_LEASE_DURATION_SECONDS = 300
-
-
-def _utcnow() -> datetime:
- return datetime.now(timezone.utc)
-
-
-def _utcnow_iso() -> str:
- return _utcnow().isoformat()
-
-
-class LeaseManager:
- """SQLite-based distributed lease manager.
-
- Each lock_id represents a named resource (e.g., a job name).
- Only one owner can hold a lease at a time. Expired leases are
- automatically reclaimed.
-
- Usage:
- lm = LeaseManager(conn, lock)
- acquired = lm.acquire("distill_batch", owner_id="worker-1")
- if acquired:
- try:
- # do work
- lm.renew("distill_batch", "worker-1") # extend if long
- finally:
- lm.release("distill_batch", "worker-1")
- """
-
- def __init__(
- self,
- conn: sqlite3.Connection,
- lock: threading.RLock,
- *,
- default_duration_seconds: int = DEFAULT_LEASE_DURATION_SECONDS,
- ):
- self._conn = conn
- self._lock = lock
- self.default_duration = default_duration_seconds
-
- @contextmanager
- def _tx(self):
- with self._lock:
- try:
- yield self._conn
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
-
- def acquire(
- self,
- lock_id: str,
- owner_id: str,
- *,
- duration_seconds: Optional[int] = None,
- ) -> bool:
- """Try to acquire a lease. Returns True if acquired.
-
- If the lock is held by another owner and not expired, returns False.
- If the lock is expired, steals it (atomic via BEGIN IMMEDIATE).
- If the lock is held by the same owner, renews it.
- """
- duration = duration_seconds or self.default_duration
- now = _utcnow()
- expires = (now + timedelta(seconds=duration)).isoformat()
- now_iso = now.isoformat()
-
- with self._tx() as conn:
- row = conn.execute(
- "SELECT owner_id, lease_expires_at FROM locks WHERE lock_id = ?",
- (lock_id,),
- ).fetchone()
-
- if row is None:
- # No lock exists — create it
- conn.execute(
- """INSERT INTO locks (lock_id, owner_id, lease_expires_at, updated_at)
- VALUES (?, ?, ?, ?)""",
- (lock_id, owner_id, expires, now_iso),
- )
- return True
-
- existing_owner = row["owner_id"]
- existing_expires = row["lease_expires_at"]
-
- # Same owner — renew
- if existing_owner == owner_id:
- conn.execute(
- "UPDATE locks SET lease_expires_at = ?, updated_at = ? WHERE lock_id = ?",
- (expires, now_iso, lock_id),
- )
- return True
-
- # Different owner — check if expired
- try:
- exp_dt = datetime.fromisoformat(existing_expires.replace("Z", "+00:00"))
- except (ValueError, AttributeError):
- exp_dt = _utcnow() # treat unparseable as expired
-
- if now >= exp_dt:
- # Expired — steal the lease
- conn.execute(
- "UPDATE locks SET owner_id = ?, lease_expires_at = ?, updated_at = ? WHERE lock_id = ?",
- (owner_id, expires, now_iso, lock_id),
- )
- logger.info(
- "Lease %s stolen from %s (expired %s) by %s",
- lock_id, existing_owner, existing_expires, owner_id,
- )
- return True
-
- # Not expired — someone else holds it
- return False
-
- def release(self, lock_id: str, owner_id: str) -> bool:
- """Release a lease. Returns True if successfully released.
-
- Only the current owner can release. Returns False if:
- - Lock doesn't exist
- - Lock is held by a different owner
- """
- with self._tx() as conn:
- row = conn.execute(
- "SELECT owner_id FROM locks WHERE lock_id = ?",
- (lock_id,),
- ).fetchone()
-
- if not row or row["owner_id"] != owner_id:
- return False
-
- conn.execute("DELETE FROM locks WHERE lock_id = ?", (lock_id,))
- return True
-
- def renew(
- self,
- lock_id: str,
- owner_id: str,
- *,
- duration_seconds: Optional[int] = None,
- ) -> bool:
- """Extend a lease. Returns True if renewed.
-
- Only the current owner can renew. Returns False if:
- - Lock doesn't exist
- - Lock is held by a different owner
- - Lock has already expired (use acquire to re-take)
- """
- duration = duration_seconds or self.default_duration
- now = _utcnow()
- expires = (now + timedelta(seconds=duration)).isoformat()
- now_iso = now.isoformat()
-
- with self._tx() as conn:
- row = conn.execute(
- "SELECT owner_id, lease_expires_at FROM locks WHERE lock_id = ?",
- (lock_id,),
- ).fetchone()
-
- if not row or row["owner_id"] != owner_id:
- return False
-
- # Check not expired
- try:
- exp_dt = datetime.fromisoformat(
- row["lease_expires_at"].replace("Z", "+00:00")
- )
- except (ValueError, AttributeError):
- return False
-
- if now >= exp_dt:
- return False # expired — must re-acquire
-
- conn.execute(
- "UPDATE locks SET lease_expires_at = ?, updated_at = ? WHERE lock_id = ?",
- (expires, now_iso, lock_id),
- )
- return True
-
- def is_held(self, lock_id: str) -> bool:
- """Check if a lock is currently held (not expired)."""
- with self._lock:
- row = self._conn.execute(
- "SELECT lease_expires_at FROM locks WHERE lock_id = ?",
- (lock_id,),
- ).fetchone()
-
- if not row:
- return False
-
- try:
- exp_dt = datetime.fromisoformat(
- row["lease_expires_at"].replace("Z", "+00:00")
- )
- except (ValueError, AttributeError):
- return False
-
- return _utcnow() < exp_dt
-
- def get_holder(self, lock_id: str) -> Optional[str]:
- """Get the current holder of a lock, or None if unheld/expired."""
- with self._lock:
- row = self._conn.execute(
- "SELECT owner_id, lease_expires_at FROM locks WHERE lock_id = ?",
- (lock_id,),
- ).fetchone()
-
- if not row:
- return None
-
- try:
- exp_dt = datetime.fromisoformat(
- row["lease_expires_at"].replace("Z", "+00:00")
- )
- except (ValueError, AttributeError):
- return None
-
- if _utcnow() >= exp_dt:
- return None
-
- return row["owner_id"]
-
- def cleanup_expired(self) -> int:
- """Remove all expired lease rows. Returns count removed."""
- now_iso = _utcnow_iso()
- with self._tx() as conn:
- result = conn.execute(
- "DELETE FROM locks WHERE lease_expires_at < ?",
- (now_iso,),
- )
- return result.rowcount
diff --git a/dhee/core/metrics.py b/dhee/core/metrics.py
deleted file mode 100644
index 08a0f9f..0000000
--- a/dhee/core/metrics.py
+++ /dev/null
@@ -1,207 +0,0 @@
-"""Observability — structured logging and metrics for Dhee.
-
-Counters, histograms, and a @measure decorator for all critical paths.
-Prometheus-compatible /metrics endpoint output.
-"""
-
-import logging
-import time
-import threading
-from collections import defaultdict
-from dataclasses import dataclass, field
-from functools import wraps
-from typing import Any, Callable, Dict, List, Optional
-
-logger = logging.getLogger(__name__)
-
-
-class Counter:
- """Thread-safe monotonic counter."""
-
- def __init__(self, name: str, description: str = ""):
- self.name = name
- self.description = description
- self._value = 0
- self._lock = threading.Lock()
-
- def inc(self, amount: int = 1) -> None:
- with self._lock:
- self._value += amount
-
- @property
- def value(self) -> int:
- return self._value
-
- def to_prometheus(self) -> str:
- return (
- f"# HELP {self.name} {self.description}\n"
- f"# TYPE {self.name} counter\n"
- f"{self.name} {self._value}\n"
- )
-
-
-class Histogram:
- """Thread-safe latency histogram with percentile tracking."""
-
- def __init__(self, name: str, description: str = "", max_samples: int = 1000):
- self.name = name
- self.description = description
- self._samples: List[float] = []
- self._max_samples = max_samples
- self._lock = threading.Lock()
- self._sum = 0.0
- self._count = 0
-
- def observe(self, value: float) -> None:
- with self._lock:
- self._samples.append(value)
- self._sum += value
- self._count += 1
- if len(self._samples) > self._max_samples:
- self._samples = self._samples[-self._max_samples:]
-
- @property
- def count(self) -> int:
- return self._count
-
- @property
- def avg(self) -> float:
- return self._sum / self._count if self._count else 0.0
-
- def percentile(self, p: float) -> float:
- with self._lock:
- if not self._samples:
- return 0.0
- sorted_samples = sorted(self._samples)
- idx = int(len(sorted_samples) * p / 100)
- return sorted_samples[min(idx, len(sorted_samples) - 1)]
-
- def to_prometheus(self) -> str:
- p50 = self.percentile(50)
- p95 = self.percentile(95)
- p99 = self.percentile(99)
- return (
- f"# HELP {self.name} {self.description}\n"
- f"# TYPE {self.name} histogram\n"
- f'{self.name}{{quantile="0.5"}} {p50:.4f}\n'
- f'{self.name}{{quantile="0.95"}} {p95:.4f}\n'
- f'{self.name}{{quantile="0.99"}} {p99:.4f}\n'
- f"{self.name}_sum {self._sum:.4f}\n"
- f"{self.name}_count {self._count}\n"
- )
-
-
-class DheeMetrics:
- """Centralized metrics registry for Dhee."""
-
- _instance = None
- _lock = threading.Lock()
-
- def __new__(cls):
- with cls._lock:
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- cls._instance._initialized = False
- return cls._instance
-
- def __init__(self):
- if self._initialized:
- return
- self._initialized = True
-
- # Counters
- self.memory_add = Counter("dhee_memory_add_total", "Total memories added")
- self.memory_search = Counter("dhee_memory_search_total", "Total memory searches")
- self.deterministic_resolve = Counter(
- "dhee_deterministic_resolve_total", "Queries resolved deterministically via SQL"
- )
- self.vector_search = Counter("dhee_vector_search_total", "Vector search operations")
- self.cognitive_loop = Counter("dhee_cognitive_loop_total", "Cognitive decomposition runs")
- self.llm_calls = Counter("dhee_llm_calls_total", "Total LLM API calls")
- self.engram_extractions = Counter(
- "dhee_engram_extractions_total", "Structured engram extractions"
- )
- self.prospective_triggered = Counter(
- "dhee_prospective_triggered_total", "Prospective scenes triggered"
- )
- self.facts_stored = Counter("dhee_facts_stored_total", "Structured facts stored")
- self.rerank_calls = Counter("dhee_rerank_calls_total", "Reranker invocations")
-
- # Histograms
- self.add_latency = Histogram(
- "dhee_memory_add_seconds", "Memory add latency in seconds"
- )
- self.search_latency = Histogram(
- "dhee_memory_search_seconds", "Memory search latency in seconds"
- )
- self.resolve_latency = Histogram(
- "dhee_deterministic_resolve_seconds", "Deterministic resolution latency"
- )
- self.extraction_latency = Histogram(
- "dhee_engram_extraction_seconds", "Engram extraction latency"
- )
- self.cognitive_latency = Histogram(
- "dhee_cognitive_loop_seconds", "Cognitive loop latency"
- )
- self.llm_latency = Histogram("dhee_llm_call_seconds", "LLM call latency")
- self.rerank_latency = Histogram("dhee_rerank_seconds", "Reranker latency")
-
- def to_prometheus(self) -> str:
- """Render all metrics in Prometheus text format."""
- parts = []
- for attr_name in dir(self):
- attr = getattr(self, attr_name)
- if isinstance(attr, (Counter, Histogram)):
- parts.append(attr.to_prometheus())
- return "\n".join(parts)
-
- def to_dict(self) -> Dict[str, Any]:
- """Render metrics as a dict for JSON API."""
- result = {}
- for attr_name in dir(self):
- attr = getattr(self, attr_name)
- if isinstance(attr, Counter):
- result[attr.name] = attr.value
- elif isinstance(attr, Histogram):
- result[attr.name] = {
- "count": attr.count,
- "avg": round(attr.avg, 4),
- "p50": round(attr.percentile(50), 4),
- "p95": round(attr.percentile(95), 4),
- "p99": round(attr.percentile(99), 4),
- }
- return result
-
-
-def get_metrics() -> DheeMetrics:
- """Get the singleton metrics instance."""
- return DheeMetrics()
-
-
-def measure(counter_name: str = "", histogram_name: str = ""):
- """Decorator to measure function execution.
-
- Usage:
- @measure("memory_add", "add_latency")
- def add(self, ...): ...
- """
- def decorator(func: Callable) -> Callable:
- @wraps(func)
- def wrapper(*args, **kwargs):
- metrics = get_metrics()
- t0 = time.monotonic()
- try:
- result = func(*args, **kwargs)
- return result
- finally:
- elapsed = time.monotonic() - t0
- if counter_name:
- counter = getattr(metrics, counter_name, None)
- if counter:
- counter.inc()
- if histogram_name:
- histogram = getattr(metrics, histogram_name, None)
- if histogram:
- histogram.observe(elapsed)
- return wrapper
- return decorator
diff --git a/dhee/core/promotion.py b/dhee/core/promotion.py
deleted file mode 100644
index 99e0550..0000000
--- a/dhee/core/promotion.py
+++ /dev/null
@@ -1,316 +0,0 @@
-"""Dhee v3 — Promotion Pipeline: validate and promote distillation candidates.
-
-The promotion flow:
- 1. Select pending candidates from distillation_candidates
- 2. Validate (confidence threshold, conflict check)
- 3. Promote transactionally into the target derived store
- 4. Write lineage rows
- 5. Mark candidate as promoted with promoted_id
-
-Design contract:
- - Promotion is transactional: either fully committed or rolled back
- - Every promoted object gets lineage rows linking to source events
- - Idempotent: re-running on already-promoted candidates is a no-op
- - Zero LLM calls — pure storage operations
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-import uuid
-from datetime import datetime, timezone
-from typing import Any, Dict, List, Optional
-
-from dhee.core.derived_store import (
- BeliefStore,
- PolicyStore,
- InsightStore,
- HeuristicStore,
- DerivedLineageStore,
-)
-from dhee.core.distillation import DistillationStore
-
-logger = logging.getLogger(__name__)
-
-
-def _utcnow_iso() -> str:
- return datetime.now(timezone.utc).isoformat()
-
-
-# Minimum confidence to promote a candidate
-MIN_PROMOTION_CONFIDENCE = 0.3
-
-
-class PromotionResult:
- """Result of a promotion batch."""
-
- def __init__(self):
- self.promoted: List[str] = []
- self.rejected: List[str] = []
- self.quarantined: List[str] = []
- self.skipped: List[str] = []
- self.errors: List[Dict[str, str]] = []
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "promoted": len(self.promoted),
- "rejected": len(self.rejected),
- "quarantined": len(self.quarantined),
- "skipped": len(self.skipped),
- "errors": len(self.errors),
- "promoted_ids": self.promoted,
- }
-
-
-class PromotionEngine:
- """Validates and promotes distillation candidates into derived stores.
-
- Usage:
- engine = PromotionEngine(
- distillation=distillation_store,
- beliefs=belief_store,
- policies=policy_store,
- insights=insight_store,
- heuristics=heuristic_store,
- lineage=lineage_store,
- )
- result = engine.promote_pending(target_type="belief", limit=20)
- """
-
- def __init__(
- self,
- distillation: DistillationStore,
- beliefs: BeliefStore,
- policies: PolicyStore,
- insights: InsightStore,
- heuristics: HeuristicStore,
- lineage: DerivedLineageStore,
- *,
- min_confidence: float = MIN_PROMOTION_CONFIDENCE,
- ):
- self.distillation = distillation
- self.beliefs = beliefs
- self.policies = policies
- self.insights = insights
- self.heuristics = heuristics
- self.lineage = lineage
- self.min_confidence = min_confidence
-
- self._promoters = {
- "belief": self._promote_belief,
- "policy": self._promote_policy,
- "insight": self._promote_insight,
- "heuristic": self._promote_heuristic,
- }
-
- def promote_pending(
- self,
- target_type: Optional[str] = None,
- *,
- limit: int = 50,
- ) -> PromotionResult:
- """Promote all pending candidates of a given type.
-
- Args:
- target_type: Filter by type (belief, policy, etc.) or None for all
- limit: Max candidates to process
-
- Returns:
- PromotionResult with counts and IDs
- """
- result = PromotionResult()
- candidates = self.distillation.get_pending(target_type, limit=limit)
-
- for candidate in candidates:
- cid = candidate["candidate_id"]
- ctype = candidate["target_type"]
-
- try:
- # Validate
- validation = self._validate(candidate)
-
- if validation == "reject":
- self.distillation.set_status(cid, "rejected")
- result.rejected.append(cid)
- continue
-
- if validation == "quarantine":
- self.distillation.set_status(cid, "quarantined")
- result.quarantined.append(cid)
- continue
-
- # Promote
- promoter = self._promoters.get(ctype)
- if not promoter:
- logger.warning("No promoter for type: %s", ctype)
- result.skipped.append(cid)
- continue
-
- promoted_id = promoter(candidate)
- if promoted_id:
- # Write lineage
- self._write_lineage(
- ctype, promoted_id, candidate["source_event_ids"]
- )
- # Mark candidate as promoted
- self.distillation.set_status(
- cid, "promoted", promoted_id=promoted_id
- )
- result.promoted.append(promoted_id)
- else:
- result.skipped.append(cid)
-
- except Exception as e:
- logger.exception(
- "Failed to promote candidate %s: %s", cid, e
- )
- result.errors.append({
- "candidate_id": cid,
- "error": str(e),
- })
-
- return result
-
- def promote_single(
- self, candidate_id: str
- ) -> Dict[str, Any]:
- """Promote a single candidate by ID."""
- candidate = self.distillation.get(candidate_id)
- if not candidate:
- return {"status": "error", "error": f"Candidate not found: {candidate_id}"}
-
- if candidate["status"] != "pending_validation":
- return {
- "status": "skipped",
- "reason": f"Candidate status is '{candidate['status']}', not pending",
- }
-
- ctype = candidate["target_type"]
- validation = self._validate(candidate)
-
- if validation == "reject":
- self.distillation.set_status(candidate_id, "rejected")
- return {"status": "rejected", "reason": "validation_failed"}
-
- if validation == "quarantine":
- self.distillation.set_status(candidate_id, "quarantined")
- return {"status": "quarantined", "reason": "needs_review"}
-
- promoter = self._promoters.get(ctype)
- if not promoter:
- return {"status": "error", "error": f"No promoter for type: {ctype}"}
-
- promoted_id = promoter(candidate)
- if promoted_id:
- self._write_lineage(ctype, promoted_id, candidate["source_event_ids"])
- self.distillation.set_status(
- candidate_id, "promoted", promoted_id=promoted_id
- )
- return {"status": "promoted", "promoted_id": promoted_id}
-
- return {"status": "skipped", "reason": "promoter_returned_none"}
-
- # ------------------------------------------------------------------
- # Validation
- # ------------------------------------------------------------------
-
- def _validate(self, candidate: Dict[str, Any]) -> str:
- """Validate a candidate. Returns: 'accept', 'reject', or 'quarantine'."""
- confidence = candidate.get("confidence", 0.0)
- payload = candidate.get("payload", {})
-
- # Hard reject: below minimum confidence
- if confidence < self.min_confidence:
- return "reject"
-
- # Hard reject: empty payload
- if not payload:
- return "reject"
-
- # Type-specific validation
- target_type = candidate["target_type"]
-
- if target_type == "belief":
- claim = payload.get("claim", "")
- if not claim or len(claim.strip()) < 5:
- return "reject"
-
- elif target_type == "policy":
- if not payload.get("name") or not payload.get("condition"):
- return "reject"
-
- elif target_type in ("insight", "heuristic"):
- content = payload.get("content", "")
- if not content or len(content.strip()) < 10:
- return "reject"
-
- return "accept"
-
- # ------------------------------------------------------------------
- # Type-specific promoters
- # ------------------------------------------------------------------
-
- def _promote_belief(self, candidate: Dict[str, Any]) -> Optional[str]:
- payload = candidate["payload"]
- return self.beliefs.add(
- user_id=payload["user_id"],
- claim=payload["claim"],
- domain=payload.get("domain", "general"),
- confidence=candidate["confidence"],
- source_memory_ids=candidate["source_event_ids"],
- tags=payload.get("tags"),
- )
-
- def _promote_policy(self, candidate: Dict[str, Any]) -> Optional[str]:
- payload = candidate["payload"]
- return self.policies.add(
- user_id=payload["user_id"],
- name=payload["name"],
- condition=payload.get("condition", {}),
- action=payload.get("action", {}),
- granularity=payload.get("granularity", "task"),
- source_task_ids=candidate["source_event_ids"],
- tags=payload.get("tags"),
- )
-
- def _promote_insight(self, candidate: Dict[str, Any]) -> Optional[str]:
- payload = candidate["payload"]
- return self.insights.add(
- user_id=payload["user_id"],
- content=payload["content"],
- insight_type=payload.get("insight_type", "pattern"),
- confidence=candidate["confidence"],
- tags=payload.get("tags"),
- )
-
- def _promote_heuristic(self, candidate: Dict[str, Any]) -> Optional[str]:
- payload = candidate["payload"]
- return self.heuristics.add(
- user_id=payload["user_id"],
- content=payload["content"],
- abstraction_level=payload.get("abstraction_level", "specific"),
- confidence=candidate["confidence"],
- tags=payload.get("tags"),
- )
-
- # ------------------------------------------------------------------
- # Lineage
- # ------------------------------------------------------------------
-
- def _write_lineage(
- self,
- derived_type: str,
- derived_id: str,
- source_event_ids: List[str],
- ) -> None:
- """Write lineage rows linking the promoted object to source events."""
- if not source_event_ids:
- return
-
- # Equal weight distribution across sources
- weight = 1.0 / len(source_event_ids)
- self.lineage.add_batch(
- derived_type, derived_id, source_event_ids,
- weights=[weight] * len(source_event_ids),
- )
diff --git a/dhee/core/proposition_context.py b/dhee/core/proposition_context.py
deleted file mode 100644
index 7f7642f..0000000
--- a/dhee/core/proposition_context.py
+++ /dev/null
@@ -1,144 +0,0 @@
-"""Proposition-based context builder for memory QA.
-
-Instead of feeding 30K chars of raw conversation to an LLM, this module
-builds ~3-8K chars of focused, structured facts with source citations
-from episodic events and retrieval results.
-
-This improves answer accuracy for freeform questions by reducing noise
-and letting even weak models answer correctly with clean input.
-"""
-
-from __future__ import annotations
-
-import re
-from typing import Any, Dict, List, Optional, Sequence
-
-
-def build_proposition_context(
- *,
- events: Sequence[Dict[str, Any]],
- results: Sequence[Dict[str, Any]],
- question: str,
- max_chars: int = 8000,
-) -> str:
- """Build LLM context from propositions + evidence snippets.
-
- Instead of 30K chars of raw conversation, produces ~3-8K chars of
- focused, structured facts with source citations.
-
- Args:
- events: Episodic events matched to the query.
- results: Search results from the retrieval pipeline.
- question: The user's question (used for relevance hints).
- max_chars: Maximum output length.
-
- Returns:
- A compact context string ready for LLM consumption.
- """
- lines: List[str] = []
- remaining = max(1, int(max_chars))
-
- # Section 1: Structured facts from episodic events.
- if events:
- lines.append("Structured Facts:")
- remaining -= len(lines[-1]) + 1
- seen_keys: set = set()
- for event in events:
- value = str(event.get("value_text") or "").strip()
- if not value:
- continue
- # Deduplicate by canonical key.
- ckey = str(event.get("canonical_key") or "").strip().lower()
- if ckey and ckey in seen_keys:
- continue
- if ckey:
- seen_keys.add(ckey)
-
- session_id = str(event.get("session_id") or "").strip()
- event_time = str(event.get("event_time") or "").strip()
- event_type = str(event.get("event_type") or "fact").strip()
- actor = str(event.get("actor_role") or event.get("actor_id") or "").strip()
-
- source_parts: List[str] = []
- if session_id:
- source_parts.append(f"Session {session_id}")
- if event_time:
- # Show only date portion for readability.
- date_part = event_time[:10] if len(event_time) >= 10 else event_time
- source_parts.append(date_part)
- source = " | ".join(source_parts) if source_parts else "unknown"
-
- fact_parts: List[str] = []
- if actor:
- fact_parts.append(actor)
- fact_parts.append(f"({event_type})")
- fact_parts.append(value[:200])
- fact = " ".join(fact_parts)
-
- line = f"- [{source}] {fact}"
- if len(line) + 1 > remaining:
- break
- lines.append(line)
- remaining -= len(line) + 1
-
- # Section 2: Evidence snippets from retrieval results.
- evidence_results = list(results)[:5]
- if evidence_results and remaining > 100:
- lines.append("")
- lines.append("Evidence Snippets:")
- remaining -= 20 # header overhead
-
- for result in evidence_results:
- evidence = str(
- result.get("evidence_text") or result.get("memory") or ""
- ).strip()
- if not evidence:
- continue
- metadata = result.get("metadata") or {}
- session_id = str(metadata.get("session_id") or "").strip()
- session_date = str(metadata.get("session_date") or "").strip()
-
- header_parts: List[str] = []
- if session_id:
- header_parts.append(f"Session {session_id}")
- if session_date:
- header_parts.append(session_date)
- header = " | ".join(header_parts) if header_parts else "unknown"
-
- # Truncate evidence to fit budget.
- budget = min(1500, remaining - 20)
- if budget <= 50:
- break
- snippet = evidence[:budget]
-
- block = f"\n[{header}]\n{snippet}"
- if len(block) + 1 > remaining:
- break
- lines.append(block)
- remaining -= len(block) + 1
-
- context = "\n".join(lines).strip()
- return context[:max_chars]
-
-
-def build_proposition_answer_prompt(
- question: str,
- prop_context: str,
- question_date: str = "",
-) -> str:
- """Build a simplified answer prompt for proposition-based context.
-
- When the LLM receives clean, structured facts instead of 30K chars
- of raw conversation, the prompt can be dramatically simpler.
- """
- date_str = question_date or "Not specified"
- return (
- "Answer this question using ONLY the facts below.\n\n"
- f"Question: {question}\n"
- f"Date: {date_str}\n\n"
- f"Facts:\n{prop_context}\n\n"
- "Answer concisely using the exact values from the facts above. "
- "If the answer requires counting items, count the distinct items "
- "listed in the facts. If the facts do not contain enough "
- "information, say so."
- )
diff --git a/dhee/core/read_model.py b/dhee/core/read_model.py
deleted file mode 100644
index 9b97bc6..0000000
--- a/dhee/core/read_model.py
+++ /dev/null
@@ -1,327 +0,0 @@
-"""Dhee v3 — Read Model: materialized retrieval view + delta overlay.
-
-Writes are normalized across type-specific tables. Reads are fast via
-a precomputed retrieval_view table plus a delta overlay of recent changes
-not yet folded in.
-
-Design contract:
- - retrieval_view is a real table (materialized), not a SQL VIEW
- - Delta overlay covers raw events + derived objects created since last refresh
- - Hot-path retrieval queries the view + delta, fuses results
- - View refresh is a cold-path job (recompute_retrieval_view)
- - Zero LLM calls
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-import sqlite3
-import threading
-from contextlib import contextmanager
-from datetime import datetime, timezone
-from typing import Any, Dict, List, Optional
-
-logger = logging.getLogger(__name__)
-
-
-def _utcnow_iso() -> str:
- return datetime.now(timezone.utc).isoformat()
-
-
-# Schema for the materialized retrieval view
-RETRIEVAL_VIEW_SCHEMA = """
-CREATE TABLE IF NOT EXISTS retrieval_view (
- row_id TEXT PRIMARY KEY,
- source_kind TEXT NOT NULL CHECK (source_kind IN ('raw', 'distilled', 'episodic')),
- source_type TEXT NOT NULL,
- source_id TEXT NOT NULL,
- user_id TEXT NOT NULL,
- retrieval_text TEXT NOT NULL,
- summary TEXT,
- anchor_era TEXT,
- anchor_place TEXT,
- anchor_activity TEXT,
- confidence REAL DEFAULT 1.0,
- utility REAL DEFAULT 0.0,
- status TEXT DEFAULT 'active',
- created_at TEXT NOT NULL,
- refreshed_at TEXT NOT NULL
-);
-
-CREATE INDEX IF NOT EXISTS idx_rv_user_kind ON retrieval_view(user_id, source_kind);
-CREATE INDEX IF NOT EXISTS idx_rv_source ON retrieval_view(source_type, source_id);
-CREATE INDEX IF NOT EXISTS idx_rv_status ON retrieval_view(status) WHERE status != 'active';
-"""
-
-
-class ReadModel:
- """Materialized retrieval view with delta overlay.
-
- Usage:
- model = ReadModel(conn, lock)
- model.refresh(events, beliefs, policies, ...) # cold-path
- results = model.query(user_id, limit=20) # hot-path
- """
-
- def __init__(self, conn: sqlite3.Connection, lock: threading.RLock):
- self._conn = conn
- self._lock = lock
- self._ensure_schema()
- self._last_refresh: Optional[str] = None
-
- def _ensure_schema(self) -> None:
- with self._lock:
- self._conn.executescript(RETRIEVAL_VIEW_SCHEMA)
- self._conn.commit()
-
- @contextmanager
- def _tx(self):
- with self._lock:
- try:
- yield self._conn
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
-
- # ------------------------------------------------------------------
- # Cold-path: refresh the materialized view
- # ------------------------------------------------------------------
-
- def refresh(
- self,
- user_id: str,
- *,
- events_store: Optional[Any] = None,
- beliefs_store: Optional[Any] = None,
- policies_store: Optional[Any] = None,
- insights_store: Optional[Any] = None,
- heuristics_store: Optional[Any] = None,
- anchors_store: Optional[Any] = None,
- ) -> Dict[str, int]:
- """Rebuild the retrieval view for a user. Cold-path operation.
-
- Returns counts of rows refreshed per source type.
- """
- now = _utcnow_iso()
- counts: Dict[str, int] = {}
-
- with self._tx() as conn:
- # Clear existing rows for this user
- conn.execute(
- "DELETE FROM retrieval_view WHERE user_id = ?",
- (user_id,),
- )
-
- # Raw events
- if events_store:
- from dhee.core.events import EventStatus
- events = events_store.list_by_user(
- user_id, status=EventStatus.ACTIVE, limit=5000
- )
- for e in events:
- conn.execute(
- """INSERT INTO retrieval_view
- (row_id, source_kind, source_type, source_id,
- user_id, retrieval_text, confidence,
- status, created_at, refreshed_at)
- VALUES (?, 'raw', 'event', ?, ?, ?, 1.0, 'active', ?, ?)""",
- (
- f"raw:{e.event_id}", e.event_id, user_id,
- e.content, e.created_at or now, now,
- ),
- )
- counts["raw_events"] = len(events)
-
- # Beliefs
- if beliefs_store:
- beliefs = beliefs_store.list_by_user(user_id, limit=1000)
- for b in beliefs:
- if b["status"] in ("invalidated",):
- continue
- conn.execute(
- """INSERT INTO retrieval_view
- (row_id, source_kind, source_type, source_id,
- user_id, retrieval_text, summary, confidence,
- utility, status, created_at, refreshed_at)
- VALUES (?, 'distilled', 'belief', ?, ?, ?, ?, ?, 0.0, ?, ?, ?)""",
- (
- f"belief:{b['belief_id']}", b["belief_id"],
- user_id, b["claim"],
- f"[{b['domain']}] {b['claim']}",
- b["confidence"], b["status"],
- b["created_at"], now,
- ),
- )
- counts["beliefs"] = len(beliefs)
-
- # Policies
- if policies_store:
- policies = policies_store.list_by_user(user_id, limit=500)
- for p in policies:
- if p["status"] in ("invalidated",):
- continue
- text = f"{p['name']}: {json.dumps(p['action'])}"
- conn.execute(
- """INSERT INTO retrieval_view
- (row_id, source_kind, source_type, source_id,
- user_id, retrieval_text, summary, confidence,
- utility, status, created_at, refreshed_at)
- VALUES (?, 'distilled', 'policy', ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
- (
- f"policy:{p['policy_id']}", p["policy_id"],
- user_id, text, p["name"],
- 1.0, p["utility"], p["status"],
- p["created_at"], now,
- ),
- )
- counts["policies"] = len(policies)
-
- # Insights
- if insights_store:
- insights = insights_store.list_by_user(user_id, limit=500)
- for i in insights:
- if i["status"] in ("invalidated",):
- continue
- conn.execute(
- """INSERT INTO retrieval_view
- (row_id, source_kind, source_type, source_id,
- user_id, retrieval_text, confidence,
- utility, status, created_at, refreshed_at)
- VALUES (?, 'distilled', 'insight', ?, ?, ?, ?, ?, ?, ?, ?)""",
- (
- f"insight:{i['insight_id']}", i["insight_id"],
- user_id, i["content"], i["confidence"],
- i["utility"], i["status"],
- i["created_at"], now,
- ),
- )
- counts["insights"] = len(insights)
-
- # Heuristics
- if heuristics_store:
- heuristics = heuristics_store.list_by_user(user_id, limit=500)
- for h in heuristics:
- if h["status"] in ("invalidated",):
- continue
- conn.execute(
- """INSERT INTO retrieval_view
- (row_id, source_kind, source_type, source_id,
- user_id, retrieval_text, confidence,
- utility, status, created_at, refreshed_at)
- VALUES (?, 'distilled', 'heuristic', ?, ?, ?, ?, ?, ?, ?, ?)""",
- (
- f"heuristic:{h['heuristic_id']}", h["heuristic_id"],
- user_id, h["content"], h["confidence"],
- h["utility"], h["status"],
- h["created_at"], now,
- ),
- )
- counts["heuristics"] = len(heuristics)
-
- self._last_refresh = now
- return counts
-
- # ------------------------------------------------------------------
- # Hot-path: query the view
- # ------------------------------------------------------------------
-
- def query(
- self,
- user_id: str,
- *,
- source_kind: Optional[str] = None,
- source_type: Optional[str] = None,
- status_exclude: Optional[List[str]] = None,
- limit: int = 100,
- ) -> List[Dict[str, Any]]:
- """Query the retrieval view. Returns rows for downstream fusion."""
- query = "SELECT * FROM retrieval_view WHERE user_id = ?"
- params: list = [user_id]
-
- if source_kind:
- query += " AND source_kind = ?"
- params.append(source_kind)
- if source_type:
- query += " AND source_type = ?"
- params.append(source_type)
-
- excludes = status_exclude or ["invalidated"]
- for s in excludes:
- query += " AND status != ?"
- params.append(s)
-
- query += " ORDER BY created_at DESC LIMIT ?"
- params.append(limit)
-
- with self._lock:
- rows = self._conn.execute(query, params).fetchall()
-
- return [
- {
- "row_id": r["row_id"],
- "source_kind": r["source_kind"],
- "source_type": r["source_type"],
- "source_id": r["source_id"],
- "user_id": r["user_id"],
- "retrieval_text": r["retrieval_text"],
- "summary": r["summary"],
- "anchor_era": r["anchor_era"],
- "anchor_place": r["anchor_place"],
- "anchor_activity": r["anchor_activity"],
- "confidence": r["confidence"],
- "utility": r["utility"],
- "status": r["status"],
- "created_at": r["created_at"],
- }
- for r in rows
- ]
-
- def get_delta(
- self,
- user_id: str,
- since_iso: str,
- *,
- events_store: Optional[Any] = None,
- ) -> List[Dict[str, Any]]:
- """Get raw events created since the last refresh.
-
- These haven't been folded into the retrieval_view yet.
- Used by fusion to overlay recent changes on top of the materialized view.
- """
- if not events_store:
- return []
-
- from dhee.core.events import EventStatus
- recent = events_store.get_events_since(
- user_id, since_iso, status=EventStatus.ACTIVE
- )
- return [
- {
- "row_id": f"delta:{e.event_id}",
- "source_kind": "raw",
- "source_type": "event",
- "source_id": e.event_id,
- "user_id": e.user_id,
- "retrieval_text": e.content,
- "summary": None,
- "confidence": 1.0,
- "utility": 0.0,
- "status": "active",
- "created_at": e.created_at,
- }
- for e in recent
- ]
-
- @property
- def last_refresh(self) -> Optional[str]:
- return self._last_refresh
-
- def row_count(self, user_id: str) -> int:
- with self._lock:
- row = self._conn.execute(
- "SELECT COUNT(*) FROM retrieval_view WHERE user_id = ?",
- (user_id,),
- ).fetchone()
- return row[0] if row else 0
diff --git a/dhee/core/salience.py b/dhee/core/salience.py
deleted file mode 100644
index 2746196..0000000
--- a/dhee/core/salience.py
+++ /dev/null
@@ -1,113 +0,0 @@
-"""Salience tagging for memories.
-
-Computes emotional valence, arousal, and overall salience score.
-High-salience memories decay slower and rank higher in search.
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-import re
-from typing import Any, Dict, Optional
-
-logger = logging.getLogger(__name__)
-
-# Heuristic keyword lists for fast salience estimation
-_POSITIVE_WORDS = frozenset({
- "love", "great", "excellent", "amazing", "wonderful", "happy", "success",
- "perfect", "awesome", "fantastic", "brilliant", "enjoy", "solved", "fixed",
- "win", "achieved", "celebrate", "milestone", "breakthrough", "promoted",
-})
-_NEGATIVE_WORDS = frozenset({
- "hate", "terrible", "awful", "horrible", "fail", "failure", "crash",
- "error", "bug", "broken", "angry", "frustrated", "blocked", "lost",
- "dead", "killed", "disaster", "critical", "urgent", "emergency",
-})
-_HIGH_AROUSAL_WORDS = frozenset({
- "urgent", "critical", "emergency", "asap", "immediately", "deadline",
- "panic", "crash", "outage", "breaking", "alert", "important", "warning",
- "danger", "production", "incident", "blocker", "showstopper",
-})
-
-_SALIENCE_PROMPT = """Rate the emotional content of this text.
-
-Text: {content}
-
-Respond in JSON:
-{{"valence": 0.0, "arousal": 0.0, "reasoning": "..."}}
-
-valence: -1.0 (very negative) to 1.0 (very positive), 0.0 = neutral
-arousal: 0.0 (calm/routine) to 1.0 (intense/urgent)"""
-
-
-def compute_salience_heuristic(content: str) -> Dict[str, float]:
- """Fast heuristic salience computation from keyword matching."""
- words = set(re.findall(r'\b\w+\b', content.lower()))
-
- pos_count = len(words & _POSITIVE_WORDS)
- neg_count = len(words & _NEGATIVE_WORDS)
- arousal_count = len(words & _HIGH_AROUSAL_WORDS)
-
- total_emotional = pos_count + neg_count
- if total_emotional > 0:
- valence = (pos_count - neg_count) / total_emotional
- else:
- valence = 0.0
-
- arousal = min(1.0, arousal_count * 0.25)
-
- salience_score = min(1.0, (abs(valence) + arousal) / 2)
-
- return {
- "sal_valence": round(valence, 3),
- "sal_arousal": round(arousal, 3),
- "sal_salience_score": round(salience_score, 3),
- }
-
-
-def compute_salience_llm(content: str, llm: Any) -> Dict[str, float]:
- """LLM-based salience computation (slower, more accurate)."""
- formatted = _SALIENCE_PROMPT.format(content=content)
-
- try:
- response = llm.generate(formatted)
- text = response if isinstance(response, str) else str(response)
- start = text.find("{")
- if start >= 0:
- parsed, _ = json.JSONDecoder().raw_decode(text, start)
- valence = max(-1.0, min(1.0, float(parsed.get("valence", 0.0))))
- arousal = max(0.0, min(1.0, float(parsed.get("arousal", 0.0))))
- salience_score = min(1.0, (abs(valence) + arousal) / 2)
- return {
- "sal_valence": round(valence, 3),
- "sal_arousal": round(arousal, 3),
- "sal_salience_score": round(salience_score, 3),
- }
- except Exception as e:
- logger.warning("LLM salience computation failed: %s", e)
-
- return compute_salience_heuristic(content)
-
-
-def compute_salience(
- content: str,
- llm: Optional[Any] = None,
- use_llm: bool = False,
-) -> Dict[str, float]:
- """Compute salience for a memory's content.
-
- Returns dict with sal_valence, sal_arousal, sal_salience_score.
- """
- if use_llm and llm:
- return compute_salience_llm(content, llm)
- return compute_salience_heuristic(content)
-
-
-def salience_decay_modifier(salience_score: float) -> float:
- """Compute decay rate modifier based on salience.
-
- High-salience memories decay slower.
- Returns a multiplier for the decay lambda (< 1.0 means slower decay).
- """
- return 1.0 - (salience_score * 0.5)
diff --git a/dhee/core/storage.py b/dhee/core/storage.py
deleted file mode 100644
index 08e1e08..0000000
--- a/dhee/core/storage.py
+++ /dev/null
@@ -1,428 +0,0 @@
-"""Dhee v3 — Schema DDL for the event-sourced cognition substrate.
-
-All tables live in a single SQLite database (v3.db). Schema is organized as:
-
-Layer 1 — Raw truth:
- raw_memory_events Immutable source-of-truth memory events
-
-Layer 2 — Derived cognition (type-specific tables):
- beliefs Confidence-tracked claims with Bayesian updates
- policies Condition→action rules with utility tracking
- anchors Hierarchical context (era/place/time/activity)
- insights Synthesized causal hypotheses
- heuristics Transferable reasoning patterns
-
-Layer 3 — Infrastructure:
- derived_lineage Links derived objects → source raw events
- maintenance_jobs Cold-path job registry
- locks SQLite lease manager for job concurrency
- cognitive_conflicts Contradiction/disagreement queue
- anchor_candidates Per-field extraction candidates (Phase 2)
- distillation_candidates Consolidation promotion candidates (Phase 4)
-
-All tables use TEXT PRIMARY KEY (UUIDs), ISO timestamps, and JSON for
-nested structures. Follows existing Dhee conventions from dhee/db/sqlite.py.
-"""
-
-# ---------------------------------------------------------------------------
-# Layer 1: Raw truth
-# ---------------------------------------------------------------------------
-
-RAW_MEMORY_EVENTS = """
-CREATE TABLE IF NOT EXISTS raw_memory_events (
- event_id TEXT PRIMARY KEY,
- user_id TEXT NOT NULL,
- session_id TEXT,
- created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')),
- content TEXT NOT NULL,
- content_hash TEXT NOT NULL,
- source TEXT DEFAULT 'user',
- status TEXT NOT NULL DEFAULT 'active'
- CHECK (status IN ('active', 'corrected', 'deleted')),
- supersedes_event_id TEXT REFERENCES raw_memory_events(event_id),
- metadata_json TEXT DEFAULT '{}'
-);
-
-CREATE INDEX IF NOT EXISTS idx_rme_user_status
- ON raw_memory_events(user_id, status);
-CREATE INDEX IF NOT EXISTS idx_rme_content_hash
- ON raw_memory_events(content_hash, user_id);
-CREATE INDEX IF NOT EXISTS idx_rme_created
- ON raw_memory_events(created_at DESC);
-CREATE INDEX IF NOT EXISTS idx_rme_supersedes
- ON raw_memory_events(supersedes_event_id)
- WHERE supersedes_event_id IS NOT NULL;
-"""
-
-# ---------------------------------------------------------------------------
-# Layer 2: Derived cognition — type-specific tables
-# ---------------------------------------------------------------------------
-
-BELIEFS = """
-CREATE TABLE IF NOT EXISTS beliefs (
- belief_id TEXT PRIMARY KEY,
- user_id TEXT NOT NULL,
- claim TEXT NOT NULL,
- domain TEXT DEFAULT 'general',
- status TEXT NOT NULL DEFAULT 'proposed'
- CHECK (status IN (
- 'proposed', 'held', 'challenged',
- 'revised', 'retracted',
- 'stale', 'suspect', 'invalidated'
- )),
- confidence REAL NOT NULL DEFAULT 0.5,
- evidence_json TEXT DEFAULT '[]',
- revisions_json TEXT DEFAULT '[]',
- contradicts_ids TEXT DEFAULT '[]',
- source_memory_ids TEXT DEFAULT '[]',
- source_episode_ids TEXT DEFAULT '[]',
- derivation_version INTEGER NOT NULL DEFAULT 1,
- lineage_fingerprint TEXT,
- created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')),
- updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')),
- tags_json TEXT DEFAULT '[]'
-);
-
-CREATE INDEX IF NOT EXISTS idx_beliefs_user_domain_status
- ON beliefs(user_id, domain, status);
-CREATE INDEX IF NOT EXISTS idx_beliefs_user_confidence
- ON beliefs(user_id, confidence DESC);
-CREATE INDEX IF NOT EXISTS idx_beliefs_status
- ON beliefs(status)
- WHERE status IN ('stale', 'suspect', 'invalidated');
-"""
-
-POLICIES = """
-CREATE TABLE IF NOT EXISTS policies (
- policy_id TEXT PRIMARY KEY,
- user_id TEXT NOT NULL,
- name TEXT NOT NULL,
- granularity TEXT NOT NULL DEFAULT 'task'
- CHECK (granularity IN ('task', 'step')),
- status TEXT NOT NULL DEFAULT 'proposed'
- CHECK (status IN (
- 'proposed', 'active', 'validated', 'deprecated',
- 'stale', 'suspect', 'invalidated'
- )),
- condition_json TEXT NOT NULL DEFAULT '{}',
- action_json TEXT NOT NULL DEFAULT '{}',
- apply_count INTEGER NOT NULL DEFAULT 0,
- success_count INTEGER NOT NULL DEFAULT 0,
- failure_count INTEGER NOT NULL DEFAULT 0,
- utility REAL NOT NULL DEFAULT 0.0,
- last_delta REAL NOT NULL DEFAULT 0.0,
- cumulative_delta REAL NOT NULL DEFAULT 0.0,
- source_task_ids TEXT DEFAULT '[]',
- source_episode_ids TEXT DEFAULT '[]',
- derivation_version INTEGER NOT NULL DEFAULT 1,
- lineage_fingerprint TEXT,
- created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')),
- updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')),
- tags_json TEXT DEFAULT '[]'
-);
-
-CREATE INDEX IF NOT EXISTS idx_policies_user_gran_status
- ON policies(user_id, granularity, status);
-CREATE INDEX IF NOT EXISTS idx_policies_user_utility
- ON policies(user_id, utility DESC);
-CREATE INDEX IF NOT EXISTS idx_policies_status
- ON policies(status)
- WHERE status IN ('stale', 'suspect', 'invalidated');
-"""
-
-ANCHORS = """
-CREATE TABLE IF NOT EXISTS anchors (
- anchor_id TEXT PRIMARY KEY,
- user_id TEXT NOT NULL,
- memory_event_id TEXT REFERENCES raw_memory_events(event_id),
- era TEXT,
- place TEXT,
- place_type TEXT,
- place_detail TEXT,
- time_absolute TEXT,
- time_markers_json TEXT DEFAULT '[]',
- time_range_start TEXT,
- time_range_end TEXT,
- time_derivation TEXT,
- activity TEXT,
- session_id TEXT,
- session_position INTEGER DEFAULT 0,
- derivation_version INTEGER NOT NULL DEFAULT 1,
- lineage_fingerprint TEXT,
- created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now'))
-);
-
-CREATE INDEX IF NOT EXISTS idx_anchors_user_era_place
- ON anchors(user_id, era, place, activity);
-CREATE INDEX IF NOT EXISTS idx_anchors_user_time
- ON anchors(user_id, time_range_start, time_range_end);
-CREATE INDEX IF NOT EXISTS idx_anchors_event
- ON anchors(memory_event_id)
- WHERE memory_event_id IS NOT NULL;
-"""
-
-INSIGHTS = """
-CREATE TABLE IF NOT EXISTS insights (
- insight_id TEXT PRIMARY KEY,
- user_id TEXT NOT NULL,
- content TEXT NOT NULL,
- insight_type TEXT NOT NULL DEFAULT 'pattern'
- CHECK (insight_type IN (
- 'causal', 'warning', 'strategy', 'pattern'
- )),
- source_task_types_json TEXT DEFAULT '[]',
- confidence REAL NOT NULL DEFAULT 0.5,
- validation_count INTEGER NOT NULL DEFAULT 0,
- invalidation_count INTEGER NOT NULL DEFAULT 0,
- utility REAL NOT NULL DEFAULT 0.0,
- apply_count INTEGER NOT NULL DEFAULT 0,
- derivation_version INTEGER NOT NULL DEFAULT 1,
- lineage_fingerprint TEXT,
- created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')),
- last_validated TEXT,
- tags_json TEXT DEFAULT '[]',
- status TEXT NOT NULL DEFAULT 'active'
- CHECK (status IN (
- 'active', 'stale', 'suspect', 'invalidated'
- ))
-);
-
-CREATE INDEX IF NOT EXISTS idx_insights_user_type_conf
- ON insights(user_id, insight_type, confidence DESC);
-CREATE INDEX IF NOT EXISTS idx_insights_user_utility
- ON insights(user_id, utility DESC);
-CREATE INDEX IF NOT EXISTS idx_insights_status
- ON insights(status)
- WHERE status IN ('stale', 'suspect', 'invalidated');
-"""
-
-HEURISTICS = """
-CREATE TABLE IF NOT EXISTS heuristics (
- heuristic_id TEXT PRIMARY KEY,
- user_id TEXT NOT NULL,
- content TEXT NOT NULL,
- abstraction_level TEXT NOT NULL DEFAULT 'specific'
- CHECK (abstraction_level IN (
- 'specific', 'domain', 'universal'
- )),
- source_task_types_json TEXT DEFAULT '[]',
- confidence REAL NOT NULL DEFAULT 0.5,
- validation_count INTEGER NOT NULL DEFAULT 0,
- invalidation_count INTEGER NOT NULL DEFAULT 0,
- utility REAL NOT NULL DEFAULT 0.0,
- last_delta REAL NOT NULL DEFAULT 0.0,
- apply_count INTEGER NOT NULL DEFAULT 0,
- derivation_version INTEGER NOT NULL DEFAULT 1,
- lineage_fingerprint TEXT,
- created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')),
- tags_json TEXT DEFAULT '[]',
- status TEXT NOT NULL DEFAULT 'active'
- CHECK (status IN (
- 'active', 'stale', 'suspect', 'invalidated'
- ))
-);
-
-CREATE INDEX IF NOT EXISTS idx_heuristics_user_level_conf
- ON heuristics(user_id, abstraction_level, confidence DESC);
-CREATE INDEX IF NOT EXISTS idx_heuristics_user_utility
- ON heuristics(user_id, utility DESC);
-CREATE INDEX IF NOT EXISTS idx_heuristics_status
- ON heuristics(status)
- WHERE status IN ('stale', 'suspect', 'invalidated');
-"""
-
-# ---------------------------------------------------------------------------
-# Layer 3: Infrastructure
-# ---------------------------------------------------------------------------
-
-DERIVED_LINEAGE = """
-CREATE TABLE IF NOT EXISTS derived_lineage (
- lineage_id TEXT PRIMARY KEY,
- derived_type TEXT NOT NULL
- CHECK (derived_type IN (
- 'belief', 'policy', 'anchor',
- 'insight', 'heuristic'
- )),
- derived_id TEXT NOT NULL,
- source_event_id TEXT NOT NULL REFERENCES raw_memory_events(event_id),
- contribution_weight REAL NOT NULL DEFAULT 1.0,
- created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now'))
-);
-
-CREATE INDEX IF NOT EXISTS idx_lineage_derived
- ON derived_lineage(derived_type, derived_id);
-CREATE INDEX IF NOT EXISTS idx_lineage_source
- ON derived_lineage(source_event_id);
-"""
-
-MAINTENANCE_JOBS = """
-CREATE TABLE IF NOT EXISTS maintenance_jobs (
- job_id TEXT PRIMARY KEY,
- job_name TEXT NOT NULL,
- status TEXT NOT NULL DEFAULT 'pending'
- CHECK (status IN (
- 'pending', 'running', 'completed',
- 'failed', 'cancelled'
- )),
- payload_json TEXT DEFAULT '{}',
- result_json TEXT,
- error_message TEXT,
- created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')),
- started_at TEXT,
- completed_at TEXT,
- retry_count INTEGER NOT NULL DEFAULT 0,
- max_retries INTEGER NOT NULL DEFAULT 3,
- idempotency_key TEXT
-);
-
-CREATE INDEX IF NOT EXISTS idx_jobs_status_name
- ON maintenance_jobs(status, job_name);
-CREATE UNIQUE INDEX IF NOT EXISTS idx_jobs_idempotency
- ON maintenance_jobs(idempotency_key)
- WHERE idempotency_key IS NOT NULL;
-"""
-
-LOCKS = """
-CREATE TABLE IF NOT EXISTS locks (
- lock_id TEXT PRIMARY KEY,
- owner_id TEXT NOT NULL,
- lease_expires_at TEXT NOT NULL,
- updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now'))
-);
-"""
-
-COGNITIVE_CONFLICTS = """
-CREATE TABLE IF NOT EXISTS cognitive_conflicts (
- conflict_id TEXT PRIMARY KEY,
- conflict_type TEXT NOT NULL
- CHECK (conflict_type IN (
- 'belief_contradiction',
- 'anchor_disagreement',
- 'distillation_conflict',
- 'invalidation_dispute'
- )),
- side_a_type TEXT NOT NULL,
- side_a_id TEXT NOT NULL,
- side_b_type TEXT NOT NULL,
- side_b_id TEXT NOT NULL,
- detected_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')),
- resolution_status TEXT NOT NULL DEFAULT 'open'
- CHECK (resolution_status IN (
- 'open', 'auto_resolved',
- 'user_resolved', 'deferred'
- )),
- resolution_json TEXT,
- auto_resolution_confidence REAL
-);
-
-CREATE INDEX IF NOT EXISTS idx_conflicts_status
- ON cognitive_conflicts(resolution_status)
- WHERE resolution_status = 'open';
-CREATE INDEX IF NOT EXISTS idx_conflicts_sides
- ON cognitive_conflicts(side_a_type, side_a_id);
-"""
-
-ANCHOR_CANDIDATES = """
-CREATE TABLE IF NOT EXISTS anchor_candidates (
- candidate_id TEXT PRIMARY KEY,
- anchor_id TEXT NOT NULL REFERENCES anchors(anchor_id),
- field_name TEXT NOT NULL,
- field_value TEXT NOT NULL,
- confidence REAL NOT NULL DEFAULT 0.5,
- extractor_source TEXT NOT NULL DEFAULT 'default',
- source_event_ids TEXT DEFAULT '[]',
- derivation_version INTEGER NOT NULL DEFAULT 1,
- status TEXT NOT NULL DEFAULT 'pending'
- CHECK (status IN (
- 'pending', 'accepted', 'rejected', 'superseded'
- )),
- created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now'))
-);
-
-CREATE INDEX IF NOT EXISTS idx_anchor_cand_anchor_field
- ON anchor_candidates(anchor_id, field_name, status);
-"""
-
-DISTILLATION_CANDIDATES = """
-CREATE TABLE IF NOT EXISTS distillation_candidates (
- candidate_id TEXT PRIMARY KEY,
- source_event_ids TEXT NOT NULL DEFAULT '[]',
- derivation_version INTEGER NOT NULL DEFAULT 1,
- confidence REAL NOT NULL DEFAULT 0.5,
- canonical_key TEXT,
- idempotency_key TEXT,
- target_type TEXT NOT NULL
- CHECK (target_type IN (
- 'belief', 'policy', 'insight', 'heuristic'
- )),
- payload_json TEXT NOT NULL DEFAULT '{}',
- status TEXT NOT NULL DEFAULT 'pending_validation'
- CHECK (status IN (
- 'pending_validation', 'promoted',
- 'rejected', 'quarantined'
- )),
- promoted_id TEXT,
- created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now'))
-);
-
-CREATE UNIQUE INDEX IF NOT EXISTS idx_distill_idempotency
- ON distillation_candidates(idempotency_key)
- WHERE idempotency_key IS NOT NULL AND status != 'rejected';
-CREATE INDEX IF NOT EXISTS idx_distill_status
- ON distillation_candidates(status, target_type);
-"""
-
-SCHEMA_VERSION = """
-CREATE TABLE IF NOT EXISTS v3_schema_version (
- version INTEGER PRIMARY KEY,
- applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')),
- description TEXT
-);
-"""
-
-# ---------------------------------------------------------------------------
-# Ordered list of all DDL statements for initialization
-# ---------------------------------------------------------------------------
-
-ALL_SCHEMAS = [
- # Version tracking
- SCHEMA_VERSION,
- # Layer 1
- RAW_MEMORY_EVENTS,
- # Layer 2
- BELIEFS,
- POLICIES,
- ANCHORS,
- INSIGHTS,
- HEURISTICS,
- # Layer 3
- DERIVED_LINEAGE,
- MAINTENANCE_JOBS,
- LOCKS,
- COGNITIVE_CONFLICTS,
- ANCHOR_CANDIDATES,
- DISTILLATION_CANDIDATES,
-]
-
-CURRENT_VERSION = 1
-
-
-def initialize_schema(conn: "sqlite3.Connection") -> None:
- """Create all v3 tables if they don't exist.
-
- Idempotent — safe to call on every startup.
- """
- for ddl in ALL_SCHEMAS:
- conn.executescript(ddl)
-
- # Record schema version (idempotent)
- existing = conn.execute(
- "SELECT 1 FROM v3_schema_version WHERE version = ?",
- (CURRENT_VERSION,),
- ).fetchone()
- if not existing:
- conn.execute(
- "INSERT INTO v3_schema_version (version, description) VALUES (?, ?)",
- (CURRENT_VERSION, "Initial v3 event-sourced substrate"),
- )
- conn.commit()
diff --git a/dhee/core/v3_health.py b/dhee/core/v3_health.py
deleted file mode 100644
index 7f1379a..0000000
--- a/dhee/core/v3_health.py
+++ /dev/null
@@ -1,193 +0,0 @@
-"""Dhee v3 — Observability: expanded cognition_health with v3 metrics.
-
-Adds v3-specific metrics to the existing cognition_health() output:
-- Stale/suspect/invalidated derived counts per type
-- Pending conflict backlog
-- Lease contention (active locks)
-- Candidate promotion stats (promoted/rejected/quarantined)
-- Retrieval view freshness
-- Maintenance job health
-
-All metrics are pure SQL COUNT/aggregation queries. Zero LLM calls.
-"""
-
-from __future__ import annotations
-
-import logging
-from datetime import datetime, timezone
-from typing import Any, Dict, Optional
-
-logger = logging.getLogger(__name__)
-
-
-def _utcnow_iso() -> str:
- return datetime.now(timezone.utc).isoformat()
-
-
-def v3_health(
- conn: "sqlite3.Connection",
- lock: "threading.RLock",
- *,
- user_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """Compute v3 substrate health metrics.
-
- Returns a dict suitable for merging into existing cognition_health() output.
- """
- health: Dict[str, Any] = {}
-
- with lock:
- # --- Raw event counts ---
- if user_id:
- row = conn.execute(
- "SELECT COUNT(*) FROM raw_memory_events WHERE user_id = ? AND status = 'active'",
- (user_id,),
- ).fetchone()
- health["raw_events_active"] = row[0] if row else 0
-
- row = conn.execute(
- "SELECT COUNT(*) FROM raw_memory_events WHERE user_id = ? AND status = 'corrected'",
- (user_id,),
- ).fetchone()
- health["raw_events_corrected"] = row[0] if row else 0
- else:
- row = conn.execute(
- "SELECT COUNT(*) FROM raw_memory_events WHERE status = 'active'"
- ).fetchone()
- health["raw_events_active"] = row[0] if row else 0
-
- # --- Derived object counts by invalidation status ---
- derived_tables = {
- "beliefs": "belief_id",
- "policies": "policy_id",
- "insights": "insight_id",
- "heuristics": "heuristic_id",
- }
- invalidation_statuses = ("stale", "suspect", "invalidated")
- derived_health: Dict[str, Dict[str, int]] = {}
-
- for table, _pk in derived_tables.items():
- counts: Dict[str, int] = {}
- for status in invalidation_statuses:
- try:
- if user_id:
- row = conn.execute(
- f"SELECT COUNT(*) FROM {table} WHERE status = ? AND user_id = ?",
- (status, user_id),
- ).fetchone()
- else:
- row = conn.execute(
- f"SELECT COUNT(*) FROM {table} WHERE status = ?",
- (status,),
- ).fetchone()
- counts[status] = row[0] if row else 0
- except Exception:
- counts[status] = -1 # table might not exist yet
- derived_health[table] = counts
-
- health["derived_invalidation"] = derived_health
-
- # --- Conflict backlog ---
- try:
- row = conn.execute(
- "SELECT COUNT(*) FROM cognitive_conflicts WHERE resolution_status = 'open'"
- ).fetchone()
- health["open_conflicts"] = row[0] if row else 0
- except Exception:
- health["open_conflicts"] = -1
-
- # --- Lease contention ---
- try:
- now_iso = _utcnow_iso()
- row = conn.execute(
- "SELECT COUNT(*) FROM locks WHERE lease_expires_at > ?",
- (now_iso,),
- ).fetchone()
- health["active_leases"] = row[0] if row else 0
- except Exception:
- health["active_leases"] = -1
-
- # --- Candidate promotion stats ---
- try:
- promo_stats: Dict[str, int] = {}
- for status in ("pending_validation", "promoted", "rejected", "quarantined"):
- row = conn.execute(
- "SELECT COUNT(*) FROM distillation_candidates WHERE status = ?",
- (status,),
- ).fetchone()
- promo_stats[status] = row[0] if row else 0
- health["distillation_candidates"] = promo_stats
- except Exception:
- health["distillation_candidates"] = {}
-
- # --- Maintenance job health ---
- try:
- job_stats: Dict[str, int] = {}
- for status in ("pending", "running", "completed", "failed"):
- row = conn.execute(
- "SELECT COUNT(*) FROM maintenance_jobs WHERE status = ?",
- (status,),
- ).fetchone()
- job_stats[status] = row[0] if row else 0
- health["maintenance_jobs"] = job_stats
- except Exception:
- health["maintenance_jobs"] = {}
-
- # --- Retrieval view freshness ---
- try:
- row = conn.execute(
- "SELECT MAX(refreshed_at) FROM retrieval_view"
- ).fetchone()
- health["retrieval_view_last_refresh"] = row[0] if row and row[0] else None
-
- row = conn.execute(
- "SELECT COUNT(*) FROM retrieval_view"
- ).fetchone()
- health["retrieval_view_rows"] = row[0] if row else 0
- except Exception:
- health["retrieval_view_last_refresh"] = None
- health["retrieval_view_rows"] = 0
-
- # --- Lineage coverage ---
- try:
- row = conn.execute(
- "SELECT COUNT(DISTINCT derived_type || ':' || derived_id) FROM derived_lineage"
- ).fetchone()
- health["lineage_derived_objects"] = row[0] if row else 0
-
- row = conn.execute(
- "SELECT COUNT(DISTINCT source_event_id) FROM derived_lineage"
- ).fetchone()
- health["lineage_source_events"] = row[0] if row else 0
- except Exception:
- health["lineage_derived_objects"] = 0
- health["lineage_source_events"] = 0
-
- # --- Warnings ---
- warnings: list = []
-
- di = health.get("derived_invalidation", {})
- total_stale = sum(
- counts.get("stale", 0) for counts in di.values()
- if isinstance(counts, dict)
- )
- total_suspect = sum(
- counts.get("suspect", 0) for counts in di.values()
- if isinstance(counts, dict)
- )
- if total_stale > 10:
- warnings.append(f"{total_stale} stale derived objects awaiting repair")
- if total_suspect > 5:
- warnings.append(f"{total_suspect} suspect derived objects need verification")
-
- oc = health.get("open_conflicts", 0)
- if isinstance(oc, int) and oc > 5:
- warnings.append(f"{oc} unresolved cognitive conflicts")
-
- jobs = health.get("maintenance_jobs", {})
- if isinstance(jobs, dict) and jobs.get("failed", 0) > 3:
- warnings.append(f"{jobs['failed']} failed maintenance jobs")
-
- health["v3_warnings"] = warnings
-
- return health
diff --git a/dhee/core/v3_migration.py b/dhee/core/v3_migration.py
deleted file mode 100644
index 427770e..0000000
--- a/dhee/core/v3_migration.py
+++ /dev/null
@@ -1,238 +0,0 @@
-"""Dhee v3 — Migration Bridge: dual-write from v2 to v3.
-
-Phase 10 migration strategy:
-1. Add raw event store without changing external API
-2. Dual-write: old path + new raw event path
-3. Backfill old engrams into raw + derived form
-
-This module provides the dual-write bridge and backfill utilities.
-The external API (remember/recall/context/checkpoint) stays stable.
-
-Design contract:
- - Old path continues to work — no breakage
- - New path writes to v3 raw events in parallel
- - Feature flag controls whether recall reads from v3
- - Backfill is idempotent and resumable
- - Zero LLM calls
-"""
-
-from __future__ import annotations
-
-import hashlib
-import json
-import logging
-import os
-import uuid
-from datetime import datetime, timezone
-from typing import Any, Dict, List, Optional
-
-logger = logging.getLogger(__name__)
-
-
-def _utcnow_iso() -> str:
- return datetime.now(timezone.utc).isoformat()
-
-
-# Feature flags (env vars)
-def _flag(name: str, default: bool = False) -> bool:
- val = os.environ.get(name, "").lower()
- if val in ("1", "true", "yes"):
- return True
- if val in ("0", "false", "no"):
- return False
- return default
-
-
-class V3MigrationBridge:
- """Dual-write bridge between v2 (UniversalEngram) and v3 (event-sourced).
-
- Usage:
- bridge = V3MigrationBridge(v3_store)
- # In remember():
- bridge.on_remember(content, user_id, memory_id)
- # In recall():
- if bridge.should_use_v3_read():
- results = bridge.recall_from_v3(query, user_id)
-
- Feature flags:
- DHEE_V3_WRITE=1 → dual-write to v3 raw events (default: on)
- DHEE_V3_READ=1 → read from v3 retrieval view (default: off)
- """
-
- def __init__(
- self,
- v3_store: Optional["CognitionStore"] = None,
- ):
- self._store = v3_store
- self._write_enabled = _flag("DHEE_V3_WRITE", default=True)
- self._read_enabled = _flag("DHEE_V3_READ", default=False)
-
- @property
- def write_enabled(self) -> bool:
- return self._write_enabled and self._store is not None
-
- @property
- def read_enabled(self) -> bool:
- return self._read_enabled and self._store is not None
-
- # ------------------------------------------------------------------
- # Dual-write hooks
- # ------------------------------------------------------------------
-
- def on_remember(
- self,
- content: str,
- user_id: str,
- *,
- session_id: Optional[str] = None,
- source: str = "user",
- metadata: Optional[Dict[str, Any]] = None,
- v2_memory_id: Optional[str] = None,
- ) -> Optional[str]:
- """Dual-write: store raw event alongside v2 memory.
-
- Returns the v3 event_id, or None if v3 write is disabled.
- """
- if not self.write_enabled:
- return None
-
- try:
- meta = metadata or {}
- if v2_memory_id:
- meta["v2_memory_id"] = v2_memory_id
-
- event = self._store.events.add(
- content=content,
- user_id=user_id,
- session_id=session_id,
- source=source,
- metadata=meta,
- )
- return event.event_id
- except Exception as e:
- logger.warning("v3 dual-write failed (non-fatal): %s", e)
- return None
-
- def on_correction(
- self,
- original_content: str,
- new_content: str,
- user_id: str,
- ) -> Optional[str]:
- """Handle a memory correction in v3.
-
- Finds the original event by content hash and creates a correction.
- """
- if not self.write_enabled:
- return None
-
- try:
- content_hash = hashlib.sha256(
- original_content.encode("utf-8")
- ).hexdigest()
- original = self._store.events.get_by_hash(content_hash, user_id)
- if not original:
- # Original not in v3 yet — just add the new content
- event = self._store.events.add(
- content=new_content, user_id=user_id,
- source="user_correction",
- )
- return event.event_id
-
- correction = self._store.events.correct(
- original.event_id, new_content,
- source="user_correction",
- )
- return correction.event_id
- except Exception as e:
- logger.warning("v3 correction failed (non-fatal): %s", e)
- return None
-
- # ------------------------------------------------------------------
- # Backfill: v2 engrams → v3 raw events
- # ------------------------------------------------------------------
-
- def backfill_from_v2(
- self,
- memories: List[Dict[str, Any]],
- *,
- user_id: str = "default",
- batch_size: int = 100,
- ) -> Dict[str, int]:
- """Backfill v2 memories into v3 raw events.
-
- Idempotent: content-hash dedup prevents duplicates.
-
- Args:
- memories: List of v2 memory dicts with at least 'memory' key
- user_id: Default user ID if not in memory dict
- batch_size: Process in batches for progress reporting
-
- Returns:
- {"total": N, "created": M, "skipped_dedup": K, "errors": E}
- """
- if not self._store:
- return {"total": 0, "error": "v3 store not initialized"}
-
- stats = {"total": len(memories), "created": 0, "skipped_dedup": 0, "errors": 0}
-
- for i, mem in enumerate(memories):
- content = mem.get("memory", mem.get("content", ""))
- if not content:
- stats["errors"] += 1
- continue
-
- uid = mem.get("user_id", user_id)
- created_at = mem.get("created_at")
-
- try:
- content_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()
- existing = self._store.events.get_by_hash(content_hash, uid)
- if existing:
- stats["skipped_dedup"] += 1
- continue
-
- meta = {}
- v2_id = mem.get("id", mem.get("memory_id"))
- if v2_id:
- meta["v2_memory_id"] = v2_id
- if mem.get("layer"):
- meta["v2_layer"] = mem["layer"]
- if mem.get("strength"):
- meta["v2_strength"] = mem["strength"]
-
- self._store.events.add(
- content=content,
- user_id=uid,
- source="v2_backfill",
- metadata=meta,
- )
- stats["created"] += 1
-
- except Exception as e:
- logger.warning("Backfill error for memory %d: %s", i, e)
- stats["errors"] += 1
-
- return stats
-
- # ------------------------------------------------------------------
- # v3 read path (behind feature flag)
- # ------------------------------------------------------------------
-
- def should_use_v3_read(self) -> bool:
- return self.read_enabled
-
- def get_v3_stats(self) -> Dict[str, Any]:
- """Get basic stats about v3 state (for monitoring)."""
- if not self._store:
- return {"v3_available": False}
-
- try:
- return {
- "v3_available": True,
- "write_enabled": self._write_enabled,
- "read_enabled": self._read_enabled,
- "event_count": self._store.events.count("default"),
- }
- except Exception as e:
- return {"v3_available": True, "error": str(e)}
diff --git a/dhee/decay/__init__.py b/dhee/decay/__init__.py
deleted file mode 100644
index 0b7b506..0000000
--- a/dhee/decay/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""Decay helpers for Engram v2."""
-
-from dhee.decay.refcounts import RefCountManager
-
-__all__ = ["RefCountManager"]
diff --git a/dhee/edge/__init__.py b/dhee/edge/__init__.py
deleted file mode 100644
index 2e4fc3b..0000000
--- a/dhee/edge/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-"""Dhee Edge — minimal-footprint cognition for hardware/humanoid deployment."""
-
-from dhee.edge.edge_plugin import DheeEdge
-from dhee.edge.edge_trainer import EdgeTrainer
-
-__all__ = ["DheeEdge", "EdgeTrainer"]
diff --git a/dhee/edge/edge_plugin.py b/dhee/edge/edge_plugin.py
deleted file mode 100644
index 2777aa2..0000000
--- a/dhee/edge/edge_plugin.py
+++ /dev/null
@@ -1,322 +0,0 @@
-"""DheeEdge — minimal cognition plugin for edge/hardware deployment.
-
-Designed for humanoid robots, IoT devices, and AI hardware products.
-All computation runs locally — no cloud API calls, no internet required.
-
-Constraints:
- - LLM: DheeModel (GGUF Q4, ~1.5GB) or mock fallback
- - Embedder: ONNX MiniLM (22MB) or hash-based fallback
- - Vector store: sqlite_vec (local file)
- - RAM: <500MB working set
- - No external API calls ever
-
-Adds embodiment hooks for hardware integration:
- - on_sensor_input() — process sensor data into episodic memory
- - on_action_result() — record action outcomes for environment learning
- - predict_environment() — predict next state from memory patterns
-
-Usage:
- from dhee.edge import DheeEdge
-
- d = DheeEdge(data_dir="/data/dhee")
- d.remember("User prefers quiet mode after 10pm")
- d.on_sensor_input("microphone", {"volume_db": 85, "duration": 3.0})
- d.on_action_result("reduce_volume", success=True, env_state={"volume_db": 40})
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-import os
-import time
-from pathlib import Path
-from typing import Any, Dict, List, Optional, Union
-
-from dhee.adapters.base import DheePlugin
-
-logger = logging.getLogger(__name__)
-
-
-class DheeEdge(DheePlugin):
- """Minimal cognition plugin for edge deployment.
-
- Forces all-offline providers. No API calls, no internet.
- Extends DheePlugin with embodiment hooks for hardware.
-
- Args:
- data_dir: Storage directory (required for edge — no temp dirs).
- model_path: Path to GGUF model file for local LLM inference.
- user_id: Default user ID.
- """
-
- def __init__(
- self,
- data_dir: Union[str, Path],
- model_path: Optional[str] = None,
- user_id: str = "default",
- ):
- # Force offline — never make API calls.
- # Try persistent storage first; fall back to in-memory if
- # sqlite_vec extension isn't available on this platform.
- try:
- super().__init__(
- data_dir=data_dir,
- provider="mock",
- user_id=user_id,
- in_memory=False,
- offline=True,
- )
- except (AttributeError, OSError) as e:
- logger.debug("Persistent storage unavailable (%s), using in-memory", e)
- super().__init__(
- data_dir=data_dir,
- provider="mock",
- user_id=user_id,
- in_memory=True,
- offline=True,
- )
-
- # Embodiment state
- self._sensor_history: List[Dict[str, Any]] = []
- self._action_history: List[Dict[str, Any]] = []
- self._environment_model: Dict[str, Any] = {}
-
- # Try to upgrade to local GGUF model
- if model_path:
- self._try_load_local_model(model_path)
-
- def _try_load_local_model(self, model_path: str) -> None:
- """Attempt to load a local GGUF model for on-device LLM inference."""
- if not os.path.exists(model_path):
- logger.debug("GGUF model not found at %s, using mock LLM", model_path)
- return
- try:
- from dhee.llms.dhee import DheeLLM
- self._engram._memory.llm = DheeLLM(
- config={"model_path": model_path, "backend": "gguf"}
- )
- logger.info("Loaded local GGUF model: %s", model_path)
- except ImportError:
- logger.debug("llama-cpp-python not available, using mock LLM")
- except Exception as e:
- logger.debug("Failed to load GGUF model: %s", e)
-
- # ------------------------------------------------------------------
- # Embodiment hooks (from Self-evolving Embodied AI, arXiv:2602.04411)
- # ------------------------------------------------------------------
-
- def on_sensor_input(
- self,
- sensor_type: str,
- data: Dict[str, Any],
- user_id: Optional[str] = None,
- ) -> Dict[str, Any]:
- """Process sensor data into episodic memory.
-
- Converts raw sensor readings into natural language memories that
- can be recalled later. Tracks sensor patterns for environment
- prediction.
-
- Args:
- sensor_type: Type of sensor (e.g., "microphone", "camera", "imu")
- data: Sensor data dict with readings and metadata
- user_id: Override default user_id
-
- Returns:
- {"stored": bool, "id": str, "description": str}
- """
- uid = user_id or self._user_id
- timestamp = data.get("timestamp", time.time())
-
- # Build natural language description from sensor data
- description = self._describe_sensor_data(sensor_type, data)
-
- # Store as memory
- result = self.remember(
- content=description,
- user_id=uid,
- metadata={
- "source": "sensor",
- "sensor_type": sensor_type,
- "timestamp": timestamp,
- "raw_data": data,
- },
- )
-
- # Track in sensor history (bounded)
- record = {
- "sensor_type": sensor_type,
- "data": data,
- "timestamp": timestamp,
- "description": description,
- }
- self._sensor_history.append(record)
- if len(self._sensor_history) > 500:
- self._sensor_history = self._sensor_history[-500:]
-
- # Update environment model
- self._update_environment_model(sensor_type, data)
-
- result["description"] = description
- return result
-
- def on_action_result(
- self,
- action: str,
- success: bool,
- env_state: Optional[Dict[str, Any]] = None,
- user_id: Optional[str] = None,
- ) -> Dict[str, Any]:
- """Record action outcomes for environment self-prediction.
-
- Builds a causal model: action + context → outcome. Over time,
- the system learns which actions work in which states.
-
- Args:
- action: What the agent did (e.g., "reduce_volume", "move_forward")
- success: Whether the action achieved its goal
- env_state: Environment state after action
- user_id: Override default user_id
- """
- uid = user_id or self._user_id
-
- # Store as memory with outcome
- outcome_word = "succeeded" if success else "failed"
- content = f"Action '{action}' {outcome_word}"
- if env_state:
- state_summary = ", ".join(f"{k}={v}" for k, v in list(env_state.items())[:5])
- content += f". Environment state: {state_summary}"
-
- result = self.remember(
- content=content,
- user_id=uid,
- metadata={
- "source": "action_result",
- "action": action,
- "success": success,
- "env_state": env_state,
- },
- )
-
- # Track action history
- record = {
- "action": action,
- "success": success,
- "env_state": env_state,
- "timestamp": time.time(),
- }
- self._action_history.append(record)
- if len(self._action_history) > 500:
- self._action_history = self._action_history[-500:]
-
- # Record outcome for performance tracking
- task_type = f"action_{action}"
- self._buddhi.record_outcome(
- user_id=uid,
- task_type=task_type,
- score=1.0 if success else 0.0,
- )
-
- return result
-
- def predict_environment(
- self,
- current_state: Dict[str, Any],
- proposed_action: Optional[str] = None,
- ) -> Dict[str, Any]:
- """Predict next environment state from memory patterns.
-
- Uses action history to estimate what will happen if a given
- action is taken in the current state.
-
- Args:
- current_state: Current environment state dict
- proposed_action: Action being considered (optional)
-
- Returns:
- {"prediction": str, "confidence": float, "similar_outcomes": list}
- """
- # Build query from current state + proposed action
- state_desc = ", ".join(f"{k}={v}" for k, v in list(current_state.items())[:5])
- query = f"environment state: {state_desc}"
- if proposed_action:
- query += f", action: {proposed_action}"
-
- # Search for similar past situations
- similar = self.recall(query=query, limit=5)
-
- # Compute confidence from action history
- confidence = 0.0
- outcomes = []
- if proposed_action and self._action_history:
- matching = [
- a for a in self._action_history
- if a["action"] == proposed_action
- ]
- if matching:
- success_rate = sum(1 for a in matching if a["success"]) / len(matching)
- confidence = success_rate
- outcomes = matching[-3:] # last 3 similar actions
-
- # Simple prediction based on success rate
- prediction = "unknown"
- if confidence > 0.7:
- prediction = f"Action '{proposed_action}' is likely to succeed (confidence: {confidence:.0%})"
- elif confidence > 0.3:
- prediction = f"Action '{proposed_action}' has mixed results (confidence: {confidence:.0%})"
- elif confidence > 0 and proposed_action:
- prediction = f"Action '{proposed_action}' has often failed (confidence: {confidence:.0%})"
-
- return {
- "prediction": prediction,
- "confidence": round(confidence, 3),
- "similar_memories": similar[:3],
- "recent_outcomes": [
- {"action": o["action"], "success": o["success"]}
- for o in outcomes
- ],
- }
-
- def adapt_embodiment(
- self,
- capabilities: Dict[str, Any],
- user_id: Optional[str] = None,
- ) -> Dict[str, Any]:
- """Update self-model when hardware capabilities change.
-
- Call when sensors are added/removed, actuators change, or the
- physical form factor is modified.
-
- Args:
- capabilities: New capability dict (e.g., {"has_camera": True, "arm_reach_cm": 60})
- """
- uid = user_id or self._user_id
- cap_desc = ", ".join(f"{k}: {v}" for k, v in capabilities.items())
- content = f"Embodiment update: {cap_desc}"
- return self.remember(
- content=content,
- user_id=uid,
- metadata={"source": "embodiment_update", "capabilities": capabilities},
- )
-
- # ------------------------------------------------------------------
- # Internal helpers
- # ------------------------------------------------------------------
-
- def _describe_sensor_data(self, sensor_type: str, data: Dict[str, Any]) -> str:
- """Convert raw sensor data to natural language for memory storage."""
- readings = ", ".join(
- f"{k}={v}" for k, v in data.items()
- if k != "timestamp" and not isinstance(v, (dict, list))
- )
- return f"Sensor[{sensor_type}]: {readings}"
-
- def _update_environment_model(
- self, sensor_type: str, data: Dict[str, Any],
- ) -> None:
- """Update the running environment model with new sensor data."""
- self._environment_model[sensor_type] = {
- "last_reading": data,
- "last_updated": time.time(),
- }
diff --git a/dhee/edge/edge_trainer.py b/dhee/edge/edge_trainer.py
deleted file mode 100644
index 572bd9c..0000000
--- a/dhee/edge/edge_trainer.py
+++ /dev/null
@@ -1,508 +0,0 @@
-"""EdgeTrainer — on-device micro-training for edge deployments.
-
-Runs minimal LoRA fine-tuning directly on edge hardware (ARM CPU, low-RAM
-devices). Designed for DheeEdge scenarios where the model needs to adapt
-to its specific user/environment without cloud connectivity.
-
-Constraints:
- - CPU-only training (no CUDA required)
- - <2GB peak RAM during training
- - LoRA rank 4-8 (tiny adapter, ~2MB)
- - Micro-batches of 1-4 samples
- - 10-50 gradient steps per cycle (not epochs)
-
-Training data sources (all local):
- - Samskara signals from SamskaraCollector.get_training_data()
- - Action outcome pairs from DheeEdge._action_history
- - Sensor pattern correlations
-
-Architecture:
- EdgeTrainer does NOT require torch/transformers at init. It checks
- for their availability lazily. On devices without PyTorch, it logs
- a warning and becomes a no-op.
-
-Usage:
- from dhee.edge.edge_trainer import EdgeTrainer
-
- trainer = EdgeTrainer(
- model_path="/data/models/dhee-2b-q4.gguf",
- adapter_dir="/data/dhee/adapters",
- )
-
- # Check if training is possible on this device
- if trainer.can_train:
- result = trainer.micro_train(training_data)
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-import os
-import time
-from dataclasses import dataclass, field
-from typing import Any, Dict, List, Optional
-
-logger = logging.getLogger(__name__)
-
-
-@dataclass
-class MicroTrainResult:
- """Result of a micro-training cycle."""
-
- success: bool
- steps_completed: int = 0
- samples_used: int = 0
- loss_start: float = 0.0
- loss_end: float = 0.0
- adapter_path: Optional[str] = None
- duration_seconds: float = 0.0
- error: Optional[str] = None
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "success": self.success,
- "steps_completed": self.steps_completed,
- "samples_used": self.samples_used,
- "loss_start": round(self.loss_start, 4),
- "loss_end": round(self.loss_end, 4),
- "adapter_path": self.adapter_path,
- "duration_seconds": round(self.duration_seconds, 2),
- "error": self.error,
- }
-
-
-@dataclass
-class EdgeTrainingConfig:
- """Configuration for edge micro-training."""
-
- lora_rank: int = 4
- lora_alpha: int = 8
- learning_rate: float = 2e-4
- max_steps: int = 30
- micro_batch_size: int = 2
- max_seq_len: int = 256
- gradient_accumulation_steps: int = 2
- warmup_steps: int = 3
- weight_decay: float = 0.01
- max_samples: int = 100 # Limit training data
-
-
-class EdgeTrainer:
- """On-device micro-training for edge deployments.
-
- Performs minimal LoRA fine-tuning on CPU with tight resource budgets.
- Training is designed to be interruptible — partial progress is saved.
- """
-
- def __init__(
- self,
- model_path: Optional[str] = None,
- adapter_dir: Optional[str] = None,
- config: Optional[EdgeTrainingConfig] = None,
- ):
- """
- Args:
- model_path: Path to the base model (GGUF or safetensors).
- adapter_dir: Directory to save/load LoRA adapters.
- config: Training hyperparameters.
- """
- self._model_path = model_path
- self._adapter_dir = adapter_dir or os.path.join(
- os.path.expanduser("~"), ".dhee", "edge_adapters",
- )
- self.config = config or EdgeTrainingConfig()
- self._training_history: List[Dict[str, Any]] = []
- self._torch_available: Optional[bool] = None
-
- @property
- def can_train(self) -> bool:
- """Check if training is possible on this device."""
- if self._torch_available is None:
- try:
- import torch # noqa: F401
- self._torch_available = True
- except ImportError:
- self._torch_available = False
- logger.info(
- "PyTorch not available — edge training disabled. "
- "Install with: pip install torch --index-url https://download.pytorch.org/whl/cpu"
- )
- return self._torch_available and self._model_path is not None
-
- def micro_train(
- self,
- training_data: Dict[str, Any],
- samskara_signals: Optional[Dict[str, Any]] = None,
- ) -> MicroTrainResult:
- """Run a micro-training cycle.
-
- Args:
- training_data: Dict with keys:
- - sft_samples: List of {"input": str, "output": str} dicts
- - dpo_pairs: List of {"chosen": str, "rejected": str} dicts (optional)
- samskara_signals: Optional vasana report for sample weighting.
-
- Returns:
- MicroTrainResult with training metrics.
- """
- if not self.can_train:
- return MicroTrainResult(
- success=False,
- error="Training not available (missing PyTorch or model)",
- )
-
- start_time = time.time()
-
- # Prepare training samples
- sft_samples = training_data.get("sft_samples", [])
- if not sft_samples:
- return MicroTrainResult(
- success=False,
- error="No training samples provided",
- )
-
- # Limit to max_samples
- samples = sft_samples[:self.config.max_samples]
-
- # Weight samples by vasana if available
- if samskara_signals:
- samples = self._weight_samples(samples, samskara_signals)
-
- try:
- result = self._run_lora_training(samples)
- result.duration_seconds = time.time() - start_time
-
- # Record in history
- self._training_history.append({
- "timestamp": time.time(),
- "result": result.to_dict(),
- })
-
- # Persist training log
- self._save_training_log()
-
- return result
- except Exception as e:
- logger.warning("Micro-training failed: %s", e)
- return MicroTrainResult(
- success=False,
- duration_seconds=time.time() - start_time,
- error=str(e),
- )
-
- def _run_lora_training(self, samples: List[Dict]) -> MicroTrainResult:
- """Run LoRA fine-tuning with PyTorch.
-
- This is CPU-optimized:
- - float32 (no mixed precision on CPU)
- - Gradient checkpointing for memory efficiency
- - Small LoRA rank (4) = tiny trainable parameter count
- """
- import torch
- from torch.utils.data import DataLoader, Dataset
-
- class TextDataset(Dataset):
- def __init__(self, data: List[Dict]):
- self.data = data
-
- def __len__(self):
- return len(self.data)
-
- def __getitem__(self, idx):
- item = self.data[idx]
- return item.get("input", ""), item.get("output", "")
-
- # Check if we can load the model for training
- # For GGUF models, we need llama-cpp-python for inference but
- # can't fine-tune them directly. We look for a safetensors/HF model.
- model, tokenizer = self._load_model_for_training()
- if model is None:
- # Fallback: save training data for later batch processing
- return self._save_for_deferred_training(samples)
-
- dataset = TextDataset(samples)
- loader = DataLoader(
- dataset,
- batch_size=self.config.micro_batch_size,
- shuffle=True,
- )
-
- # Apply LoRA
- model = self._apply_lora(model)
- model.train()
-
- # Optimizer
- trainable_params = [p for p in model.parameters() if p.requires_grad]
- optimizer = torch.optim.AdamW(
- trainable_params,
- lr=self.config.learning_rate,
- weight_decay=self.config.weight_decay,
- )
-
- # Training loop
- total_steps = 0
- losses = []
- accum_loss = 0.0
-
- for step_idx in range(self.config.max_steps):
- for batch_inputs, batch_outputs in loader:
- # Tokenize
- combined = [
- f"{inp} {out}" for inp, out in zip(batch_inputs, batch_outputs)
- ]
- encodings = tokenizer(
- combined,
- return_tensors="pt",
- max_length=self.config.max_seq_len,
- truncation=True,
- padding=True,
- )
-
- # Forward pass
- outputs = model(
- input_ids=encodings["input_ids"],
- attention_mask=encodings["attention_mask"],
- labels=encodings["input_ids"],
- )
- loss = outputs.loss / self.config.gradient_accumulation_steps
- loss.backward()
- accum_loss += loss.item()
-
- total_steps += 1
-
- if total_steps % self.config.gradient_accumulation_steps == 0:
- torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
- optimizer.step()
- optimizer.zero_grad()
- losses.append(accum_loss)
- accum_loss = 0.0
-
- if total_steps >= self.config.max_steps:
- break
-
- if total_steps >= self.config.max_steps:
- break
-
- # Save adapter
- adapter_path = self._save_adapter(model)
-
- return MicroTrainResult(
- success=True,
- steps_completed=total_steps,
- samples_used=len(samples),
- loss_start=losses[0] if losses else 0.0,
- loss_end=losses[-1] if losses else 0.0,
- adapter_path=adapter_path,
- )
-
- def _load_model_for_training(self):
- """Load model + tokenizer for training.
-
- Returns (model, tokenizer) or (None, None) if not available.
- """
- if not self._model_path:
- return None, None
-
- # GGUF files can't be fine-tuned directly
- if self._model_path.endswith(".gguf"):
- # Check for a companion safetensors model
- base_dir = os.path.dirname(self._model_path)
- safetensors_path = os.path.join(base_dir, "model.safetensors")
- config_path = os.path.join(base_dir, "config.json")
- if not os.path.exists(config_path):
- logger.info(
- "GGUF model cannot be fine-tuned directly. "
- "Saving training data for deferred processing."
- )
- return None, None
- model_dir = base_dir
- else:
- model_dir = self._model_path
-
- try:
- from transformers import AutoModelForCausalLM, AutoTokenizer
- tokenizer = AutoTokenizer.from_pretrained(
- model_dir, trust_remote_code=True,
- )
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
-
- model = AutoModelForCausalLM.from_pretrained(
- model_dir,
- torch_dtype="auto",
- trust_remote_code=True,
- )
- return model, tokenizer
- except Exception as e:
- logger.info("Model loading failed: %s", e)
- return None, None
-
- def _apply_lora(self, model):
- """Apply LoRA adapters to the model."""
- try:
- from peft import LoraConfig, get_peft_model, TaskType
- lora_config = LoraConfig(
- task_type=TaskType.CAUSAL_LM,
- r=self.config.lora_rank,
- lora_alpha=self.config.lora_alpha,
- lora_dropout=0.0, # No dropout for micro-training
- target_modules=["q_proj", "v_proj"], # Minimal target
- )
- model = get_peft_model(model, lora_config)
- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
- total = sum(p.numel() for p in model.parameters())
- logger.info(
- "LoRA applied: %d trainable / %d total params (%.2f%%)",
- trainable, total, 100 * trainable / total,
- )
- return model
- except ImportError:
- logger.warning(
- "peft not available — training all parameters (not recommended "
- "for edge). Install: pip install peft"
- )
- return model
-
- def _save_adapter(self, model) -> str:
- """Save the LoRA adapter to disk."""
- os.makedirs(self._adapter_dir, exist_ok=True)
- adapter_name = f"adapter_{int(time.time())}"
- adapter_path = os.path.join(self._adapter_dir, adapter_name)
-
- try:
- if hasattr(model, "save_pretrained"):
- model.save_pretrained(adapter_path)
- else:
- # Fallback: save state dict
- import torch
- torch.save(
- {k: v for k, v in model.state_dict().items() if "lora" in k},
- os.path.join(adapter_path, "lora_weights.pt"),
- )
- except Exception as e:
- logger.warning("Adapter save failed: %s", e)
- adapter_path = ""
-
- return adapter_path
-
- def _save_for_deferred_training(
- self, samples: List[Dict],
- ) -> MicroTrainResult:
- """Save training data to disk for later batch processing.
-
- Used when the model format doesn't support direct fine-tuning
- (e.g., GGUF without a companion HF model).
- """
- os.makedirs(self._adapter_dir, exist_ok=True)
- deferred_path = os.path.join(
- self._adapter_dir, f"deferred_{int(time.time())}.jsonl",
- )
-
- with open(deferred_path, "w", encoding="utf-8") as f:
- for sample in samples:
- f.write(json.dumps(sample, ensure_ascii=False) + "\n")
-
- return MicroTrainResult(
- success=True,
- steps_completed=0,
- samples_used=len(samples),
- adapter_path=deferred_path,
- error="deferred: saved training data for batch processing",
- )
-
- def _weight_samples(
- self,
- samples: List[Dict],
- samskara_signals: Dict[str, Any],
- ) -> List[Dict]:
- """Weight training samples based on vasana degradation signals.
-
- Samples from degrading dimensions get 2x representation.
- """
- degrading = set(samskara_signals.get("degrading_dimensions", []))
- if not degrading:
- return samples
-
- weighted = []
- for sample in samples:
- weighted.append(sample)
- sample_type = sample.get("type", "")
- if sample_type in degrading:
- weighted.append(sample) # Duplicate for emphasis
-
- return weighted
-
- def _save_training_log(self) -> None:
- """Persist training history to disk."""
- os.makedirs(self._adapter_dir, exist_ok=True)
- log_path = os.path.join(self._adapter_dir, "training_log.jsonl")
- try:
- with open(log_path, "a", encoding="utf-8") as f:
- entry = self._training_history[-1]
- f.write(json.dumps(entry, ensure_ascii=False) + "\n")
- except OSError:
- pass
-
- # ------------------------------------------------------------------
- # Integration with DheeEdge
- # ------------------------------------------------------------------
-
- def train_from_edge(
- self,
- edge_plugin: Any,
- samskara: Optional[Any] = None,
- ) -> MicroTrainResult:
- """Convenience: collect training data from a DheeEdge instance and train.
-
- Gathers:
- - Action outcome pairs as SFT samples
- - Samskara signals for weighting
- """
- # Collect action-based SFT samples
- sft_samples = []
- action_history = getattr(edge_plugin, "_action_history", [])
- for record in action_history[-self.config.max_samples:]:
- action = record.get("action", "")
- success = record.get("success", False)
- env = record.get("env_state", {})
-
- env_desc = ", ".join(f"{k}={v}" for k, v in list(env.items())[:5]) if env else "unknown"
- sft_samples.append({
- "input": f"[ACTION] state: {env_desc}, action: {action}",
- "output": f"{'success' if success else 'failure'}",
- "type": "action_prediction",
- })
-
- # Add samskara training data if available
- samskara_data = {}
- if samskara:
- try:
- samskara_data = samskara.get_training_data()
- sft_samples.extend(samskara_data.get("sft_samples", []))
- except Exception:
- pass
-
- if not sft_samples:
- return MicroTrainResult(
- success=False, error="No training data from edge",
- )
-
- return self.micro_train(
- training_data={"sft_samples": sft_samples},
- samskara_signals=samskara_data,
- )
-
- def get_status(self) -> Dict[str, Any]:
- """Get trainer status."""
- return {
- "can_train": self.can_train,
- "model_path": self._model_path,
- "adapter_dir": self._adapter_dir,
- "training_cycles": len(self._training_history),
- "config": {
- "lora_rank": self.config.lora_rank,
- "max_steps": self.config.max_steps,
- "learning_rate": self.config.learning_rate,
- },
- }
diff --git a/dhee/embeddings/nvidia.py b/dhee/embeddings/nvidia.py
index eb3cad1..24c154e 100644
--- a/dhee/embeddings/nvidia.py
+++ b/dhee/embeddings/nvidia.py
@@ -20,12 +20,13 @@ def __init__(self, config: Optional[dict] = None):
api_key = (
self.config.get("api_key")
or os.getenv("NVIDIA_EMBEDDING_API_KEY")
+ or os.getenv("NVIDIA_EMBED_API_KEY")
or os.getenv("NVIDIA_API_KEY")
)
if not api_key:
raise ValueError(
"NVIDIA API key required. Set config['api_key'], "
- "NVIDIA_EMBEDDING_API_KEY, or NVIDIA_API_KEY env var."
+ "NVIDIA_EMBEDDING_API_KEY, NVIDIA_EMBED_API_KEY, or NVIDIA_API_KEY env var."
)
base_url = self.config.get("base_url", "https://integrate.api.nvidia.com/v1")
@@ -54,10 +55,15 @@ def _extra_body(self, memory_action: Optional[str] = None, count: int = 1) -> di
def _truncate_if_needed(self, text: str) -> str:
"""Truncate text to stay within model token limits.
- nv-embed-v1 has a 4096 token limit. Using ~3.5 chars/token as
- a conservative estimate, cap at 14000 characters.
+ Model-aware defaults (conservative ~3.5 chars/token):
+ - nv-embed-v1: 4096 tokens → 14000 chars
+ - nemotron-embed: 8192 tokens → 26000 chars (use 24000 for safety)
"""
- max_chars = int(self.config.get("max_input_chars", 14000))
+ if "nemotron-embed" in self.model:
+ default_max = 24000
+ else:
+ default_max = 14000
+ max_chars = int(self.config.get("max_input_chars", default_max))
if len(text) > max_chars:
logger.debug("Truncating input from %d to %d chars for embedding", len(text), max_chars)
return text[:max_chars]
diff --git a/dhee/hive/__init__.py b/dhee/hive/__init__.py
deleted file mode 100644
index a35f837..0000000
--- a/dhee/hive/__init__.py
+++ /dev/null
@@ -1,9 +0,0 @@
-"""Dhee Hive — multi-agent shared cognition layer.
-
-Built on top of engram-bus for real-time agent-to-agent communication,
-with CRDT-based sync for offline/edge scenarios.
-"""
-
-from dhee.hive.hive_memory import HiveMemory
-
-__all__ = ["HiveMemory"]
diff --git a/dhee/hive/hive_memory.py b/dhee/hive/hive_memory.py
deleted file mode 100644
index 26453e2..0000000
--- a/dhee/hive/hive_memory.py
+++ /dev/null
@@ -1,526 +0,0 @@
-"""HiveMemory — multi-agent shared cognition on top of engram-bus.
-
-Enables multiple DheePlugin instances (across agents, processes, or machines)
-to share and evolve a collective knowledge base:
-
- - **Shared Insights**: Cross-agent discoveries and patterns.
- - **Shared Heuristics**: Abstract reasoning rules mined from any agent's trajectories.
- - **Shared Skills**: Proven skills that any agent can adopt.
- - **Collective Signals**: Aggregated samskara-like signals for hive-level evolution.
-
-Architecture:
- Each agent runs a local DheePlugin. HiveMemory sits alongside it and
- periodically publishes local discoveries to the bus, and subscribes to
- discoveries from other agents. A quality gate ensures only validated
- knowledge propagates.
-
- ┌──────────┐ bus.publish() ┌─────────────┐
- │ Agent A │ ────────────────────▶ │ engram-bus │
- │ DheePlugin│ ◀──────────────────── │ (pub/sub + │
- │ + Hive │ bus.subscribe() │ KV store) │
- └──────────┘ └──────┬──────┘
- │
- ┌──────────┐ │
- │ Agent B │ ◀───────────────────────────┘
- │ DheePlugin│
- │ + Hive │
- └──────────┘
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-import time
-import threading
-from dataclasses import asdict, dataclass, field
-from datetime import datetime, timezone
-from typing import Any, Callable, Dict, List, Optional, Set
-
-logger = logging.getLogger(__name__)
-
-
-def _now_iso() -> str:
- return datetime.now(timezone.utc).isoformat()
-
-
-# ---------------------------------------------------------------------------
-# Shared knowledge item types
-# ---------------------------------------------------------------------------
-
-@dataclass
-class SharedItem:
- """A piece of knowledge shared on the hive."""
-
- id: str
- kind: str # "insight" | "heuristic" | "skill" | "signal"
- content: Dict[str, Any]
- source_agent: str
- timestamp: str = field(default_factory=_now_iso)
- confidence: float = 0.5
- votes_up: int = 0
- votes_down: int = 0
- adopted_by: List[str] = field(default_factory=list)
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "id": self.id,
- "kind": self.kind,
- "content": self.content,
- "source_agent": self.source_agent,
- "timestamp": self.timestamp,
- "confidence": self.confidence,
- "votes_up": self.votes_up,
- "votes_down": self.votes_down,
- "adopted_by": self.adopted_by,
- }
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "SharedItem":
- return cls(
- id=data["id"],
- kind=data["kind"],
- content=data.get("content", {}),
- source_agent=data.get("source_agent", "unknown"),
- timestamp=data.get("timestamp", _now_iso()),
- confidence=data.get("confidence", 0.5),
- votes_up=data.get("votes_up", 0),
- votes_down=data.get("votes_down", 0),
- adopted_by=data.get("adopted_by", []),
- )
-
- @property
- def quality_score(self) -> float:
- """Wilson score lower bound — conservative estimate of true quality."""
- n = self.votes_up + self.votes_down
- if n == 0:
- return self.confidence
- p = self.votes_up / n
- # Wilson score interval (simplified)
- z = 1.96 # 95% confidence
- denominator = 1 + z * z / n
- centre = p + z * z / (2 * n)
- spread = z * ((p * (1 - p) + z * z / (4 * n)) / n) ** 0.5
- return (centre - spread) / denominator
-
-
-# ---------------------------------------------------------------------------
-# Topics
-# ---------------------------------------------------------------------------
-
-_TOPIC_SHARE = "dhee.hive.share"
-_TOPIC_VOTE = "dhee.hive.vote"
-_TOPIC_ADOPT = "dhee.hive.adopt"
-_TOPIC_SYNC_REQUEST = "dhee.hive.sync.request"
-_TOPIC_SYNC_RESPONSE = "dhee.hive.sync.response"
-_NS_HIVE = "dhee_hive"
-
-
-# ---------------------------------------------------------------------------
-# HiveMemory
-# ---------------------------------------------------------------------------
-
-class HiveMemory:
- """Multi-agent shared cognition layer.
-
- Wraps an engram-bus instance to provide structured knowledge sharing
- with quality gating and adoption tracking.
- """
-
- def __init__(
- self,
- agent_id: str,
- bus: Any = None,
- min_confidence_to_share: float = 0.4,
- min_quality_to_adopt: float = 0.3,
- auto_subscribe: bool = True,
- ):
- """
- Args:
- agent_id: This agent's identifier on the hive.
- bus: An engram_bus.Bus instance. If None, creates an in-memory bus.
- min_confidence_to_share: Minimum confidence to publish to hive.
- min_quality_to_adopt: Minimum quality_score to auto-adopt shared items.
- auto_subscribe: Whether to subscribe to hive topics on init.
- """
- self.agent_id = agent_id
- self._min_share = min_confidence_to_share
- self._min_adopt = min_quality_to_adopt
-
- # Local store of hive items (id -> SharedItem)
- self._items: Dict[str, SharedItem] = {}
- self._lock = threading.RLock()
- self._on_receive_callbacks: List[Callable] = []
-
- # Bus connection
- if bus is None:
- try:
- from engram_bus import Bus
- bus = Bus()
- except ImportError:
- logger.warning("engram-bus not available, hive runs in local-only mode")
- bus = None
-
- self._bus = bus
-
- if bus and auto_subscribe:
- self._subscribe()
-
- def _subscribe(self) -> None:
- """Subscribe to hive topics on the bus."""
- if not self._bus:
- return
- self._bus.subscribe(_TOPIC_SHARE, self._on_share_received, agent=self.agent_id)
- self._bus.subscribe(_TOPIC_VOTE, self._on_vote_received, agent=self.agent_id)
- self._bus.subscribe(_TOPIC_ADOPT, self._on_adopt_received, agent=self.agent_id)
- self._bus.subscribe(
- _TOPIC_SYNC_REQUEST, self._on_sync_request, agent=self.agent_id,
- )
- self._bus.register(self.agent_id, metadata={"type": "dhee_hive_member"})
-
- # ------------------------------------------------------------------
- # Publishing
- # ------------------------------------------------------------------
-
- def share_insight(
- self,
- insight_id: str,
- content: Dict[str, Any],
- confidence: float = 0.5,
- ) -> Optional[SharedItem]:
- """Share an insight (from Buddhi reflect) with the hive."""
- return self._publish_item(
- item_id=f"insight:{self.agent_id}:{insight_id}",
- kind="insight",
- content=content,
- confidence=confidence,
- )
-
- def share_heuristic(
- self,
- heuristic_id: str,
- content: Dict[str, Any],
- confidence: float = 0.5,
- ) -> Optional[SharedItem]:
- """Share a distilled heuristic with the hive."""
- return self._publish_item(
- item_id=f"heuristic:{self.agent_id}:{heuristic_id}",
- kind="heuristic",
- content=content,
- confidence=confidence,
- )
-
- def share_skill(
- self,
- skill_id: str,
- content: Dict[str, Any],
- confidence: float = 0.5,
- ) -> Optional[SharedItem]:
- """Share a proven skill with the hive."""
- return self._publish_item(
- item_id=f"skill:{self.agent_id}:{skill_id}",
- kind="skill",
- content=content,
- confidence=confidence,
- )
-
- def share_signal(
- self,
- signal_type: str,
- data: Dict[str, Any],
- ) -> Optional[SharedItem]:
- """Share an aggregated signal (e.g., vasana shift) with the hive."""
- return self._publish_item(
- item_id=f"signal:{self.agent_id}:{signal_type}:{int(time.time())}",
- kind="signal",
- content={"signal_type": signal_type, **data},
- confidence=0.5,
- )
-
- def _publish_item(
- self,
- item_id: str,
- kind: str,
- content: Dict[str, Any],
- confidence: float,
- ) -> Optional[SharedItem]:
- if confidence < self._min_share:
- logger.debug(
- "Not sharing %s (confidence %.2f < %.2f)",
- item_id, confidence, self._min_share,
- )
- return None
-
- item = SharedItem(
- id=item_id,
- kind=kind,
- content=content,
- source_agent=self.agent_id,
- confidence=confidence,
- )
-
- with self._lock:
- self._items[item.id] = item
-
- if self._bus:
- self._bus.publish(_TOPIC_SHARE, item.to_dict(), agent=self.agent_id)
- # Also store in bus KV for late joiners
- self._bus.put(
- f"hive:{item.id}",
- json.dumps(item.to_dict()),
- agent=self.agent_id,
- namespace=_NS_HIVE,
- )
-
- return item
-
- # ------------------------------------------------------------------
- # Voting
- # ------------------------------------------------------------------
-
- def vote(self, item_id: str, upvote: bool = True) -> None:
- """Vote on a shared item's quality."""
- with self._lock:
- item = self._items.get(item_id)
- if item:
- if upvote:
- item.votes_up += 1
- else:
- item.votes_down += 1
-
- if self._bus:
- self._bus.publish(
- _TOPIC_VOTE,
- {"item_id": item_id, "upvote": upvote, "voter": self.agent_id},
- agent=self.agent_id,
- )
-
- def adopt(self, item_id: str) -> Optional[SharedItem]:
- """Mark a shared item as adopted by this agent."""
- with self._lock:
- item = self._items.get(item_id)
- if not item:
- return None
- if self.agent_id not in item.adopted_by:
- item.adopted_by.append(self.agent_id)
-
- if self._bus:
- self._bus.publish(
- _TOPIC_ADOPT,
- {"item_id": item_id, "adopter": self.agent_id},
- agent=self.agent_id,
- )
- return item
-
- # ------------------------------------------------------------------
- # Querying
- # ------------------------------------------------------------------
-
- def get_shared(
- self,
- kind: Optional[str] = None,
- min_quality: Optional[float] = None,
- limit: int = 20,
- ) -> List[SharedItem]:
- """Get shared items, optionally filtered by kind and quality."""
- min_q = min_quality if min_quality is not None else 0.0
-
- with self._lock:
- items = list(self._items.values())
-
- if kind:
- items = [i for i in items if i.kind == kind]
- items = [i for i in items if i.quality_score >= min_q]
- items.sort(key=lambda i: i.quality_score, reverse=True)
- return items[:limit]
-
- def get_adoptable(self, limit: int = 10) -> List[SharedItem]:
- """Get high-quality items not yet adopted by this agent."""
- with self._lock:
- candidates = [
- item for item in self._items.values()
- if self.agent_id not in item.adopted_by
- and item.source_agent != self.agent_id
- and item.quality_score >= self._min_adopt
- ]
- candidates.sort(key=lambda i: i.quality_score, reverse=True)
- return candidates[:limit]
-
- def get_hive_stats(self) -> Dict[str, Any]:
- """Get statistics about the hive."""
- with self._lock:
- items = list(self._items.values())
-
- by_kind: Dict[str, int] = {}
- by_agent: Dict[str, int] = {}
- total_votes = 0
-
- for item in items:
- by_kind[item.kind] = by_kind.get(item.kind, 0) + 1
- by_agent[item.source_agent] = by_agent.get(item.source_agent, 0) + 1
- total_votes += item.votes_up + item.votes_down
-
- return {
- "total_items": len(items),
- "by_kind": by_kind,
- "by_agent": by_agent,
- "total_votes": total_votes,
- "avg_quality": (
- sum(i.quality_score for i in items) / len(items)
- if items else 0.0
- ),
- }
-
- # ------------------------------------------------------------------
- # Bus callbacks
- # ------------------------------------------------------------------
-
- def _on_share_received(
- self, topic: str, data: Any, sender_agent: Optional[str],
- ) -> None:
- """Handle incoming shared item from another agent."""
- if sender_agent == self.agent_id:
- return # Ignore own messages
-
- try:
- item = SharedItem.from_dict(data)
- except (TypeError, KeyError) as e:
- logger.debug("Invalid shared item: %s", e)
- return
-
- with self._lock:
- if item.id not in self._items:
- self._items[item.id] = item
-
- for cb in self._on_receive_callbacks:
- try:
- cb(item)
- except Exception as e:
- logger.debug("Hive receive callback error: %s", e)
-
- def _on_vote_received(
- self, topic: str, data: Any, sender_agent: Optional[str],
- ) -> None:
- """Handle incoming vote from another agent."""
- if sender_agent == self.agent_id:
- return
-
- item_id = data.get("item_id")
- upvote = data.get("upvote", True)
-
- with self._lock:
- item = self._items.get(item_id)
- if item:
- if upvote:
- item.votes_up += 1
- else:
- item.votes_down += 1
-
- def _on_adopt_received(
- self, topic: str, data: Any, sender_agent: Optional[str],
- ) -> None:
- """Handle adoption notification from another agent."""
- item_id = data.get("item_id")
- adopter = data.get("adopter")
-
- with self._lock:
- item = self._items.get(item_id)
- if item and adopter and adopter not in item.adopted_by:
- item.adopted_by.append(adopter)
-
- def _on_sync_request(
- self, topic: str, data: Any, sender_agent: Optional[str],
- ) -> None:
- """Respond to sync request from another agent (e.g., edge coming online)."""
- if sender_agent == self.agent_id or not self._bus:
- return
-
- # Send all our items as a sync response
- with self._lock:
- payload = {item_id: item.to_dict() for item_id, item in self._items.items()}
-
- self._bus.publish(
- _TOPIC_SYNC_RESPONSE,
- {"items": payload, "responder": self.agent_id},
- agent=self.agent_id,
- )
-
- # ------------------------------------------------------------------
- # Sync (pull-based)
- # ------------------------------------------------------------------
-
- def request_sync(self) -> None:
- """Request a full sync from other hive members (e.g., after coming online)."""
- if not self._bus:
- return
- self._bus.subscribe(
- _TOPIC_SYNC_RESPONSE, self._on_sync_response, agent=self.agent_id,
- )
- self._bus.publish(
- _TOPIC_SYNC_REQUEST,
- {"requester": self.agent_id},
- agent=self.agent_id,
- )
-
- def _on_sync_response(
- self, topic: str, data: Any, sender_agent: Optional[str],
- ) -> None:
- """Handle sync response — merge received items."""
- items_data = data.get("items", {})
- with self._lock:
- for item_id, item_dict in items_data.items():
- if item_id not in self._items:
- try:
- self._items[item_id] = SharedItem.from_dict(item_dict)
- except (TypeError, KeyError):
- pass
-
- # ------------------------------------------------------------------
- # Callbacks
- # ------------------------------------------------------------------
-
- def on_receive(self, callback: Callable[[SharedItem], None]) -> None:
- """Register a callback for when new items arrive from the hive."""
- self._on_receive_callbacks.append(callback)
-
- # ------------------------------------------------------------------
- # Export for DheePlugin integration
- # ------------------------------------------------------------------
-
- def get_context_block(self, limit: int = 5) -> Dict[str, Any]:
- """Get hive knowledge formatted for HyperContext injection."""
- insights = self.get_shared(kind="insight", limit=limit)
- heuristics = self.get_shared(kind="heuristic", limit=limit)
- skills = self.get_shared(kind="skill", limit=limit)
-
- return {
- "hive_insights": [
- {
- "source": i.source_agent,
- "content": i.content,
- "quality": round(i.quality_score, 2),
- }
- for i in insights
- ],
- "hive_heuristics": [
- {
- "source": h.source_agent,
- "content": h.content,
- "quality": round(h.quality_score, 2),
- }
- for h in heuristics
- ],
- "hive_skills": [
- {
- "source": s.source_agent,
- "name": s.content.get("name", s.id),
- "quality": round(s.quality_score, 2),
- }
- for s in skills
- ],
- }
-
- def close(self) -> None:
- """Unsubscribe and clean up."""
- # Bus cleanup is handled by the bus owner
- self._on_receive_callbacks.clear()
diff --git a/dhee/hive/sync.py b/dhee/hive/sync.py
deleted file mode 100644
index 1a0e546..0000000
--- a/dhee/hive/sync.py
+++ /dev/null
@@ -1,436 +0,0 @@
-"""CRDT-based sync protocol for offline/edge Dhee nodes.
-
-When a DheeEdge instance operates offline (e.g., a humanoid robot in a
-warehouse with no connectivity), it accumulates local hive items. On
-reconnection, it needs to merge with the central hive without conflicts.
-
-This module implements:
- 1. **LWW-Register** (Last-Writer-Wins) for individual shared items.
- 2. **G-Counter** for vote counts (grow-only, merge = max per node).
- 3. **OR-Set** (Observed-Remove) for adoption lists.
- 4. **SyncEnvelope** — wire format for shipping CRDT state between nodes.
-
-Usage:
- # On edge device:
- state = CRDTState(node_id="edge-1")
- state.set_item(shared_item)
- state.increment_votes_up("item:123")
- envelope = state.export_envelope()
-
- # Ship envelope (HTTP, BLE, serial, file drop — whatever works)
-
- # On hub:
- hub_state = CRDTState(node_id="hub")
- hub_state.merge(envelope)
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-import os
-import time
-from dataclasses import dataclass, field
-from typing import Any, Dict, List, Optional, Set, Tuple
-
-logger = logging.getLogger(__name__)
-
-
-def _hlc_now(node_id: str) -> str:
- """Hybrid Logical Clock timestamp: |.
-
- Provides a globally-unique, monotonically-increasing timestamp even
- if wall clocks disagree between nodes. Uses '|' separator so node_ids
- can contain dashes.
- """
- return f"{int(time.time() * 1000)}|{node_id}"
-
-
-def _hlc_compare(a: str, b: str) -> int:
- """Compare two HLC timestamps. Returns -1, 0, or 1."""
- a_ms, a_node = a.split("|", 1)
- b_ms, b_node = b.split("|", 1)
- a_int, b_int = int(a_ms), int(b_ms)
- if a_int != b_int:
- return -1 if a_int < b_int else 1
- if a_node < b_node:
- return -1
- if a_node > b_node:
- return 1
- return 0
-
-
-# ---------------------------------------------------------------------------
-# LWW-Register: per-item state
-# ---------------------------------------------------------------------------
-
-@dataclass
-class LWWRegister:
- """Last-Writer-Wins Register for a shared item's content."""
-
- value: Dict[str, Any]
- timestamp: str # HLC timestamp
- node_id: str
-
- def merge(self, other: "LWWRegister") -> "LWWRegister":
- """Merge two registers — latest timestamp wins."""
- if _hlc_compare(self.timestamp, other.timestamp) >= 0:
- return self
- return other
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "value": self.value,
- "timestamp": self.timestamp,
- "node_id": self.node_id,
- }
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "LWWRegister":
- return cls(
- value=data["value"],
- timestamp=data["timestamp"],
- node_id=data["node_id"],
- )
-
-
-# ---------------------------------------------------------------------------
-# G-Counter: grow-only counter per node
-# ---------------------------------------------------------------------------
-
-@dataclass
-class GCounter:
- """Grow-only counter — each node has its own monotonic count."""
-
- counts: Dict[str, int] = field(default_factory=dict) # node_id -> count
-
- @property
- def value(self) -> int:
- return sum(self.counts.values())
-
- def increment(self, node_id: str, amount: int = 1) -> None:
- self.counts[node_id] = self.counts.get(node_id, 0) + amount
-
- def merge(self, other: "GCounter") -> "GCounter":
- """Merge = max of each node's count."""
- all_nodes = set(self.counts) | set(other.counts)
- merged = GCounter()
- for node in all_nodes:
- merged.counts[node] = max(
- self.counts.get(node, 0),
- other.counts.get(node, 0),
- )
- return merged
-
- def to_dict(self) -> Dict[str, int]:
- return dict(self.counts)
-
- @classmethod
- def from_dict(cls, data: Dict[str, int]) -> "GCounter":
- return cls(counts=dict(data))
-
-
-# ---------------------------------------------------------------------------
-# OR-Set: Observed-Remove Set (for adoption lists)
-# ---------------------------------------------------------------------------
-
-@dataclass
-class ORSet:
- """Observed-Remove Set — supports both add and remove with convergence.
-
- Each element is tagged with a unique (node_id, seq) pair. Removes
- only remove the tags that were observed, so concurrent adds win.
- """
-
- # element -> set of (node_id, seq) tags
- _elements: Dict[str, Set[Tuple[str, int]]] = field(default_factory=lambda: {})
- _tombstones: Set[Tuple[str, int]] = field(default_factory=set)
- _seq: int = 0
-
- def add(self, element: str, node_id: str) -> None:
- self._seq += 1
- tag = (node_id, self._seq)
- if element not in self._elements:
- self._elements[element] = set()
- self._elements[element].add(tag)
-
- def remove(self, element: str) -> None:
- tags = self._elements.pop(element, set())
- self._tombstones.update(tags)
-
- @property
- def elements(self) -> Set[str]:
- return {
- elem for elem, tags in self._elements.items()
- if tags - self._tombstones
- }
-
- def merge(self, other: "ORSet") -> "ORSet":
- """Merge two OR-Sets."""
- merged = ORSet()
- merged._seq = max(self._seq, other._seq)
- merged._tombstones = self._tombstones | other._tombstones
-
- all_elements = set(self._elements) | set(other._elements)
- for elem in all_elements:
- tags_a = self._elements.get(elem, set())
- tags_b = other._elements.get(elem, set())
- # Union of live tags minus all tombstones
- live_tags = (tags_a | tags_b) - merged._tombstones
- if live_tags:
- merged._elements[elem] = live_tags
-
- return merged
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "elements": {
- elem: [list(t) for t in tags]
- for elem, tags in self._elements.items()
- },
- "tombstones": [list(t) for t in self._tombstones],
- "seq": self._seq,
- }
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "ORSet":
- s = cls()
- s._seq = data.get("seq", 0)
- s._tombstones = {tuple(t) for t in data.get("tombstones", [])}
- s._elements = {
- elem: {tuple(t) for t in tags}
- for elem, tags in data.get("elements", {}).items()
- }
- return s
-
-
-# ---------------------------------------------------------------------------
-# Per-item CRDT state
-# ---------------------------------------------------------------------------
-
-@dataclass
-class ItemCRDT:
- """CRDT state for a single shared hive item."""
-
- item_id: str
- content: LWWRegister # The item payload
- votes_up: GCounter = field(default_factory=GCounter)
- votes_down: GCounter = field(default_factory=GCounter)
- adopted_by: ORSet = field(default_factory=ORSet)
-
- def merge(self, other: "ItemCRDT") -> "ItemCRDT":
- assert self.item_id == other.item_id
- return ItemCRDT(
- item_id=self.item_id,
- content=self.content.merge(other.content),
- votes_up=self.votes_up.merge(other.votes_up),
- votes_down=self.votes_down.merge(other.votes_down),
- adopted_by=self.adopted_by.merge(other.adopted_by),
- )
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "item_id": self.item_id,
- "content": self.content.to_dict(),
- "votes_up": self.votes_up.to_dict(),
- "votes_down": self.votes_down.to_dict(),
- "adopted_by": self.adopted_by.to_dict(),
- }
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "ItemCRDT":
- return cls(
- item_id=data["item_id"],
- content=LWWRegister.from_dict(data["content"]),
- votes_up=GCounter.from_dict(data.get("votes_up", {})),
- votes_down=GCounter.from_dict(data.get("votes_down", {})),
- adopted_by=ORSet.from_dict(data.get("adopted_by", {})),
- )
-
-
-# ---------------------------------------------------------------------------
-# SyncEnvelope — wire format
-# ---------------------------------------------------------------------------
-
-@dataclass
-class SyncEnvelope:
- """Wire format for CRDT state exchange between nodes."""
-
- source_node: str
- timestamp: str # HLC
- items: Dict[str, Dict[str, Any]] # item_id -> ItemCRDT.to_dict()
-
- def to_bytes(self) -> bytes:
- return json.dumps({
- "source_node": self.source_node,
- "timestamp": self.timestamp,
- "items": self.items,
- }, ensure_ascii=False).encode("utf-8")
-
- @classmethod
- def from_bytes(cls, data: bytes) -> "SyncEnvelope":
- d = json.loads(data.decode("utf-8"))
- return cls(
- source_node=d["source_node"],
- timestamp=d["timestamp"],
- items=d.get("items", {}),
- )
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "source_node": self.source_node,
- "timestamp": self.timestamp,
- "items": self.items,
- }
-
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "SyncEnvelope":
- return cls(
- source_node=data["source_node"],
- timestamp=data["timestamp"],
- items=data.get("items", {}),
- )
-
-
-# ---------------------------------------------------------------------------
-# CRDTState — per-node state manager
-# ---------------------------------------------------------------------------
-
-class CRDTState:
- """Manages CRDT state for a single node.
-
- Each node maintains its own CRDTState. On sync, nodes exchange
- SyncEnvelopes and merge them — convergence is guaranteed by the
- CRDT merge semantics (commutative, associative, idempotent).
- """
-
- def __init__(self, node_id: str, persist_path: Optional[str] = None):
- self.node_id = node_id
- self._items: Dict[str, ItemCRDT] = {}
- self._persist_path = persist_path
- if persist_path:
- self._load()
-
- def set_item(self, item_id: str, content: Dict[str, Any]) -> None:
- """Set or update an item's content (LWW)."""
- ts = _hlc_now(self.node_id)
- register = LWWRegister(value=content, timestamp=ts, node_id=self.node_id)
-
- if item_id in self._items:
- self._items[item_id].content = self._items[item_id].content.merge(register)
- else:
- self._items[item_id] = ItemCRDT(item_id=item_id, content=register)
-
- self._auto_persist()
-
- def increment_votes_up(self, item_id: str, amount: int = 1) -> None:
- if item_id in self._items:
- self._items[item_id].votes_up.increment(self.node_id, amount)
- self._auto_persist()
-
- def increment_votes_down(self, item_id: str, amount: int = 1) -> None:
- if item_id in self._items:
- self._items[item_id].votes_down.increment(self.node_id, amount)
- self._auto_persist()
-
- def add_adopter(self, item_id: str, adopter: str) -> None:
- if item_id in self._items:
- self._items[item_id].adopted_by.add(adopter, self.node_id)
- self._auto_persist()
-
- def export_envelope(self) -> SyncEnvelope:
- """Export current state as a sync envelope."""
- return SyncEnvelope(
- source_node=self.node_id,
- timestamp=_hlc_now(self.node_id),
- items={
- item_id: crdt.to_dict()
- for item_id, crdt in self._items.items()
- },
- )
-
- def merge(self, envelope: SyncEnvelope) -> int:
- """Merge a received envelope into local state.
-
- Returns number of items updated.
- """
- updated = 0
- for item_id, item_dict in envelope.items.items():
- try:
- remote = ItemCRDT.from_dict(item_dict)
- except (TypeError, KeyError) as e:
- logger.debug("Skipping malformed item %s: %s", item_id, e)
- continue
-
- if item_id in self._items:
- merged = self._items[item_id].merge(remote)
- # Check if anything actually changed
- if merged.to_dict() != self._items[item_id].to_dict():
- self._items[item_id] = merged
- updated += 1
- else:
- self._items[item_id] = remote
- updated += 1
-
- if updated:
- self._auto_persist()
- return updated
-
- def get_item(self, item_id: str) -> Optional[Dict[str, Any]]:
- """Get resolved state of an item (content + vote totals + adopters)."""
- crdt = self._items.get(item_id)
- if not crdt:
- return None
- return {
- "item_id": item_id,
- "content": crdt.content.value,
- "votes_up": crdt.votes_up.value,
- "votes_down": crdt.votes_down.value,
- "adopted_by": sorted(crdt.adopted_by.elements),
- "last_updated": crdt.content.timestamp,
- }
-
- def list_items(self) -> List[Dict[str, Any]]:
- """List all items with resolved state."""
- return [
- self.get_item(item_id)
- for item_id in sorted(self._items)
- ]
-
- @property
- def item_count(self) -> int:
- return len(self._items)
-
- # ── Persistence ──
-
- def _auto_persist(self) -> None:
- if self._persist_path:
- self.save()
-
- def save(self, path: Optional[str] = None) -> None:
- path = path or self._persist_path
- if not path:
- return
- os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
- tmp = path + ".tmp"
- data = {
- "node_id": self.node_id,
- "items": {
- item_id: crdt.to_dict()
- for item_id, crdt in self._items.items()
- },
- }
- with open(tmp, "w", encoding="utf-8") as f:
- json.dump(data, f, ensure_ascii=False)
- os.replace(tmp, path)
-
- def _load(self) -> None:
- if not self._persist_path or not os.path.exists(self._persist_path):
- return
- try:
- with open(self._persist_path, "r", encoding="utf-8") as f:
- data = json.load(f)
- for item_id, item_dict in data.get("items", {}).items():
- self._items[item_id] = ItemCRDT.from_dict(item_dict)
- except (OSError, json.JSONDecodeError, KeyError) as e:
- logger.warning("Failed to load CRDT state: %s", e)
diff --git a/dhee/integrations/__init__.py b/dhee/integrations/__init__.py
deleted file mode 100644
index 08bfb9a..0000000
--- a/dhee/integrations/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Third-party agent framework integrations for Engram."""
diff --git a/dhee/llms/nvidia.py b/dhee/llms/nvidia.py
index 234a45d..a5eb5d1 100644
--- a/dhee/llms/nvidia.py
+++ b/dhee/llms/nvidia.py
@@ -21,13 +21,14 @@ def __init__(self, config: Optional[dict] = None):
api_key = (
self.config.get("api_key")
or os.getenv("NVIDIA_QWEN_API_KEY")
+ or os.getenv("NVIDIA_LLAMA_4_MAV_API_KEY")
or os.getenv("LLAMA_API_KEY")
or os.getenv("NVIDIA_API_KEY")
)
if not api_key:
raise ValueError(
"NVIDIA API key required. Set config['api_key'], "
- "NVIDIA_QWEN_API_KEY, LLAMA_API_KEY, or NVIDIA_API_KEY env var."
+ "NVIDIA_QWEN_API_KEY, NVIDIA_LLAMA_4_MAV_API_KEY, LLAMA_API_KEY, or NVIDIA_API_KEY env var."
)
base_url = self.config.get("base_url", "https://integrate.api.nvidia.com/v1")
@@ -58,7 +59,11 @@ def generate(self, prompt: str) -> str:
extra_kwargs = {}
if self.enable_thinking:
extra_kwargs["extra_body"] = {
- "chat_template_kwargs": {"thinking": True}
+ "chat_template_kwargs": {"enable_thinking": True}
+ }
+ elif "gemma" in self.model.lower():
+ extra_kwargs["extra_body"] = {
+ "chat_template_kwargs": {"enable_thinking": False}
}
use_stream = self.enable_thinking or self.config.get("stream", False)
diff --git a/dhee/mcp_server.py b/dhee/mcp_server.py
index 9634b6c..21ae6d4 100644
--- a/dhee/mcp_server.py
+++ b/dhee/mcp_server.py
@@ -48,6 +48,30 @@
logger = logging.getLogger(__name__)
+def _default_user_id(args: Dict[str, Any]) -> str:
+ return str(args.get("user_id") or os.environ.get("DHEE_USER_ID") or "default")
+
+
+def _default_agent_id(args: Dict[str, Any]) -> str:
+ return str(args.get("agent_id") or os.environ.get("DHEE_AGENT_ID") or "mcp-server")
+
+
+def _default_requester_agent_id(args: Dict[str, Any]) -> str:
+ return str(
+ args.get("requester_agent_id")
+ or os.environ.get("DHEE_REQUESTER_AGENT_ID")
+ or _default_agent_id(args)
+ )
+
+
+def _default_source_app(args: Dict[str, Any]) -> str:
+ return str(
+ args.get("source_app")
+ or os.environ.get("DHEE_SOURCE_APP")
+ or _default_agent_id(args)
+ )
+
+
def _get_embedding_dims_for_model(model: str, provider: str) -> int:
EMBEDDING_DIMS = {
"models/text-embedding-005": 768,
@@ -77,6 +101,8 @@ def get_memory_instance() -> FullMemory:
os.environ.get("NVIDIA_API_KEY")
or os.environ.get("NVIDIA_QWEN_API_KEY")
or os.environ.get("NVIDIA_EMBEDDING_API_KEY")
+ or os.environ.get("NVIDIA_EMBED_API_KEY")
+ or os.environ.get("NVIDIA_LLAMA_4_MAV_API_KEY")
)
def _env(key: str, default: str = "") -> str:
@@ -172,6 +198,9 @@ def _env(key: str, default: str = "") -> str:
embedding_model_dims=embedding_dims,
fade=fade_config,
)
+ if hasattr(config, "enrichment"):
+ config.enrichment.defer_enrichment = True
+ config.enrichment.enable_unified = True
return FullMemory(config)
@@ -192,8 +221,9 @@ def get_buddhi():
"""Lazy singleton for the Buddhi cognition layer."""
global _buddhi
if _buddhi is None:
+ from dhee.configs.base import _dhee_data_dir
from dhee.core.buddhi import Buddhi
- _buddhi = Buddhi()
+ _buddhi = Buddhi(data_dir=os.path.join(_dhee_data_dir(), "buddhi"))
return _buddhi
@@ -205,11 +235,14 @@ def get_buddhi():
TOOLS = [
Tool(
name="remember",
- description="Quick-save a fact or preference to memory. Creates a staging proposal commit with source_app='claude-code' and infer=False by default.",
+ description="Quick-save a fact or preference to memory. Stores immediately with infer=False by default and uses the configured MCP agent/source identity when not provided.",
inputSchema={
"type": "object",
"properties": {
"content": {"type": "string", "description": "The fact or preference to remember"},
+ "user_id": {"type": "string", "description": "User identifier (defaults to DHEE_USER_ID or 'default')."},
+ "agent_id": {"type": "string", "description": "Agent identifier (defaults to DHEE_AGENT_ID or 'mcp-server')."},
+ "source_app": {"type": "string", "description": "Source application label (defaults to DHEE_SOURCE_APP or agent_id)."},
"categories": {"type": "array", "items": {"type": "string"}, "description": "Optional categories to tag this memory with (e.g., ['preferences', 'coding'])"},
"context": {
"type": "array",
@@ -574,10 +607,10 @@ def get_buddhi():
def _handle_remember(memory, args):
return memory.add(
messages=args.get("content", ""),
- user_id="default",
- agent_id="claude-code",
+ user_id=_default_user_id(args),
+ agent_id=_default_agent_id(args),
categories=args.get("categories"),
- source_app="claude-code",
+ source_app=_default_source_app(args),
infer=False,
context_messages=args.get("context"),
)
@@ -596,7 +629,7 @@ def _handle_search_memory(memory, args):
context_top_k = 10
result = memory.search_orchestrated(
query=args.get("query", ""),
- user_id=args.get("user_id", "default"),
+ user_id=_default_user_id(args),
agent_id=args.get("agent_id"),
categories=args.get("categories"),
limit=limit,
@@ -612,7 +645,7 @@ def _handle_search_memory(memory, args):
else:
result = memory.search(
query=args.get("query", ""),
- user_id=args.get("user_id", "default"),
+ user_id=_default_user_id(args),
agent_id=args.get("agent_id"),
limit=limit,
categories=args.get("categories"),
@@ -652,7 +685,7 @@ def _handle_get_all_memories(memory, args):
except (ValueError, TypeError):
limit = 50
result = memory.get_all(
- user_id=args.get("user_id", "default"),
+ user_id=_default_user_id(args),
agent_id=args.get("agent_id"),
limit=limit,
layer=args.get("layer"),
@@ -691,8 +724,10 @@ def _handle_dhee_context(memory, args):
def _handle_get_last_session(_memory, args):
from dhee.core.kernel import get_last_session
session = get_last_session(
- agent_id=args.get("agent_id", "mcp-server"),
+ agent_id=args.get("agent_id"),
repo=args.get("repo"),
+ user_id=_default_user_id(args),
+ requester_agent_id=_default_requester_agent_id(args),
fallback_log_recovery=args.get("fallback_log_recovery", True),
)
if session is None:
@@ -704,7 +739,8 @@ def _handle_save_session_digest(_memory, args):
from dhee.core.kernel import save_session_digest
return save_session_digest(
task_summary=args.get("task_summary", ""),
- agent_id=args.get("agent_id", "claude-code"),
+ agent_id=_default_agent_id(args),
+ requester_agent_id=_default_requester_agent_id(args),
repo=args.get("repo"),
status=args.get("status", "active"),
decisions_made=args.get("decisions_made"),
@@ -718,7 +754,7 @@ def _handle_save_session_digest(_memory, args):
def _handle_get_memory_stats(memory, args):
return memory.get_stats(
- user_id=args.get("user_id"),
+ user_id=args.get("user_id") or os.environ.get("DHEE_USER_ID"),
agent_id=args.get("agent_id"),
)
@@ -847,7 +883,7 @@ def _handle_think(memory, arguments: Dict[str, Any]) -> Dict[str, Any]:
if hasattr(memory, "think"):
result = memory.think(
question=question,
- user_id=user_id,
+ user_id=_default_user_id(arguments) if not arguments.get("user_id") else user_id,
max_depth=max_depth,
)
if hasattr(result, "to_dict"):
@@ -858,7 +894,7 @@ def _handle_think(memory, arguments: Dict[str, Any]) -> Dict[str, Any]:
def _handle_anticipate(memory, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Proactive intelligence — Buddhi checks intentions, insights, and scenes."""
- user_id = arguments.get("user_id", "default")
+ user_id = _default_user_id(arguments)
context = arguments.get("context")
buddhi = get_buddhi()
@@ -894,7 +930,7 @@ def _handle_record_outcome(_memory, arguments: Dict[str, Any]) -> Dict[str, Any]
"""Record task outcome for performance tracking."""
task_type = arguments.get("task_type", "")
score = float(arguments.get("score", 0.0))
- user_id = arguments.get("user_id", "default")
+ user_id = _default_user_id(arguments)
metadata = arguments.get("metadata")
if not task_type:
@@ -916,7 +952,7 @@ def _handle_record_outcome(_memory, arguments: Dict[str, Any]) -> Dict[str, Any]
def _handle_reflect(_memory, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Agent-triggered reflection — synthesize insights from experience."""
task_type = arguments.get("task_type", "")
- user_id = arguments.get("user_id", "default")
+ user_id = _default_user_id(arguments)
if not task_type:
return {"error": "task_type is required"}
@@ -938,7 +974,7 @@ def _handle_reflect(_memory, arguments: Dict[str, Any]) -> Dict[str, Any]:
def _handle_store_intention(_memory, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Store a future trigger — prospective memory."""
description = arguments.get("description", "")
- user_id = arguments.get("user_id", "default")
+ user_id = _default_user_id(arguments)
if not description:
return {"error": "description is required"}
diff --git a/dhee/mcp_slim.py b/dhee/mcp_slim.py
index 1013f99..607213b 100644
--- a/dhee/mcp_slim.py
+++ b/dhee/mcp_slim.py
@@ -38,7 +38,7 @@ def _get_plugin():
"""Create the DheePlugin singleton. Wraps Engram + Buddhi."""
global _plugin
if _plugin is None:
- from dhee.adapters.base import DheePlugin
+ from dhee.plugin import DheePlugin
_plugin = DheePlugin()
# Enable deferred enrichment on the underlying memory
memory = _plugin.memory
diff --git a/dhee/memory/main.py b/dhee/memory/main.py
index 69089b3..a5b741b 100644
--- a/dhee/memory/main.py
+++ b/dhee/memory/main.py
@@ -33,7 +33,6 @@
from dhee.core.graph import KnowledgeGraph
from dhee.core.scene import SceneProcessor
from dhee.core.profile import ProfileProcessor
-from dhee.core.answer_orchestration import extract_atomic_facts, reduce_atomic_facts
from dhee.db.sqlite import SQLiteManager
from dhee.exceptions import FadeMemValidationError
from dhee.memory.base import MemoryBase
@@ -430,8 +429,6 @@ def _orchestration_engine(self) -> OrchestrationEngine:
profile_processor_fn=lambda: self.profile_processor,
evolution_layer_fn=lambda: self.evolution_layer,
llm_fn=lambda: self.llm,
- extract_atomic_facts_fn=extract_atomic_facts,
- reduce_atomic_facts_fn=reduce_atomic_facts,
)
return self.__orchestration_engine
@@ -577,7 +574,7 @@ def reranker(self):
"""Lazy-initialized neural reranker (only if enabled in config)."""
rerank_cfg = getattr(self.config, "rerank", None)
if self._reranker is None and rerank_cfg and rerank_cfg.enable_rerank:
- from dhee.retrieval.reranker import create_reranker
+ from dhee.memory.reranker import create_reranker
self._reranker = create_reranker({
"provider": rerank_cfg.provider,
"model": rerank_cfg.model,
diff --git a/dhee/memory/orchestration.py b/dhee/memory/orchestration.py
index ff3295b..a6ab4e5 100644
--- a/dhee/memory/orchestration.py
+++ b/dhee/memory/orchestration.py
@@ -1,24 +1,17 @@
-"""Orchestration engine: map-reduce, episodic anchoring, hierarchical retrieval.
+"""Orchestration engine: episodic anchoring, hierarchical retrieval, context assembly.
-Extracted from memory/main.py — centralizes the orchestrated-search path so that
-FullMemory.search_orchestrated() becomes a thin delegation wrapper.
+Dhee's job: retrieve well, assemble context, return it. No answer synthesis.
"""
from __future__ import annotations
import logging
import re
-import time
from typing import Any, Callable, Dict, List, Optional, Tuple
-from dhee.memory.cost import stable_hash_text
from dhee.core.episodic_index import normalize_actor_id
from dhee.core.answer_orchestration import (
- build_map_candidates,
build_query_plan,
- deterministic_inconsistency_check,
- is_low_confidence_answer,
- render_fact_context,
)
logger = logging.getLogger(__name__)
@@ -45,8 +38,6 @@ def __init__(
profile_processor_fn: Callable,
evolution_layer_fn: Callable,
llm_fn: Callable,
- extract_atomic_facts_fn: Callable,
- reduce_atomic_facts_fn: Callable,
):
self._config = config
self._db = db
@@ -59,71 +50,9 @@ def __init__(
self._profile_processor_fn = profile_processor_fn
self._evolution_layer_fn = evolution_layer_fn
self._llm_fn = llm_fn
- self._extract_atomic_facts_fn = extract_atomic_facts_fn
- self._reduce_atomic_facts_fn = reduce_atomic_facts_fn
# Internal state
- self._reducer_cache: Dict[str, Dict[str, Any]] = {}
self._guardrail_auto_disabled: bool = False
- # -- Reducer cache helpers ------------------------------------------------
-
- def _build_reducer_cache_key(
- self,
- *,
- user_id: str,
- intent_value: str,
- query: str,
- results: List[Dict[str, Any]],
- ) -> str:
- evidence_fingerprint_parts: List[str] = []
- for row in results[:30]:
- mem_id = str(row.get("id") or "").strip()
- score = float(row.get("composite_score", row.get("score", 0.0)) or 0.0)
- evidence_fingerprint_parts.append(f"{mem_id}:{score:.4f}")
- evidence_fingerprint = "|".join(evidence_fingerprint_parts)
- base = "|".join(
- [
- str(user_id or ""),
- str(intent_value or ""),
- stable_hash_text(query),
- stable_hash_text(evidence_fingerprint),
- ]
- )
- return stable_hash_text(base)
-
- def _get_reducer_cache(self, cache_key: str) -> Optional[Dict[str, Any]]:
- orch_cfg = getattr(self._config, "orchestration", None)
- ttl_seconds = int(getattr(orch_cfg, "reducer_cache_ttl_seconds", 900) or 900)
- record = self._reducer_cache.get(cache_key)
- if not record:
- return None
- ts = float(record.get("ts", 0.0) or 0.0)
- if ts <= 0.0:
- return None
- if (time.time() - ts) > max(1, ttl_seconds):
- self._reducer_cache.pop(cache_key, None)
- return None
- return record
-
- def _put_reducer_cache(
- self,
- *,
- cache_key: str,
- reduced_answer: Optional[str],
- facts: List[Dict[str, Any]],
- ) -> None:
- orch_cfg = getattr(self._config, "orchestration", None)
- max_entries = int(getattr(orch_cfg, "reducer_cache_max_entries", 2048) or 2048)
- self._reducer_cache[cache_key] = {
- "ts": time.time(),
- "reduced_answer": reduced_answer,
- "facts": list(facts or []),
- }
- # Keep insertion-order bounded cache.
- while len(self._reducer_cache) > max(1, max_entries):
- oldest_key = next(iter(self._reducer_cache))
- self._reducer_cache.pop(oldest_key, None)
-
# -- Cost guardrail -------------------------------------------------------
def _enforce_write_cost_guardrail(self, *, user_id: Optional[str]) -> None:
@@ -507,49 +436,15 @@ def search_orchestrated(
limit=3,
)
- (
- reduced_answer,
- facts,
- map_reduce_used,
- reflection_hops,
- llm_calls_used,
- cache_hit,
- orchestration_reasons,
- results,
- ) = self._execute_map_reduce(
- query_plan=query_plan,
- orchestrator_llm=orchestrator_llm,
- results=results,
- event_hits=event_hits,
- coverage=coverage,
- query=query,
- question_type=question_type,
- question_date=question_date,
- mode=mode,
- search_cap_value=search_cap_value,
- map_max_candidates_value=map_max_candidates_value,
- map_max_chars_value=map_max_chars_value,
- reflection_max_hops=reflection_max_hops,
- search_query=search_query,
- search_limit=search_limit,
- rerank=rerank,
- keyword_search=keyword_search,
- hybrid_alpha=hybrid_alpha,
- include_evidence=include_evidence,
- evidence_strategy=evidence_strategy,
- evidence_max_chars=evidence_max_chars,
- evidence_context_lines=evidence_context_lines,
- user_id=user_id,
- filters=filters,
- categories=categories,
- agent_id=agent_id,
- run_id=run_id,
- app_id=app_id,
- )
- reason_codes.extend(orchestration_reasons)
+ # Dhee's job: retrieve and assemble context. Agent answers.
+ # No map-reduce, no triple extraction, no LLM calls at query time.
+ reduced_answer: Optional[str] = None
+ facts: List[Dict[str, Any]] = []
+ map_reduce_used = False
+ reflection_hops = 0
+ llm_calls_used = 0.0
+ cache_hit = False
- # Always use full retrieval context — proposition context (Phase 3)
- # is deferred until episodic event coverage is proven reliable.
context = self._build_orchestrated_context(
results=results,
event_hits=event_hits,
@@ -558,13 +453,6 @@ def search_orchestrated(
max_chars=max_context_chars,
per_result_max_chars=evidence_max_chars,
)
- if facts:
- fact_context = render_fact_context(facts, max_facts=20)
- if fact_context:
- if mode == "strict":
- context = "Canonical Facts:\n" + fact_context
- else:
- context = "Canonical Facts:\n" + fact_context + "\n\nRetrieved Context:\n" + context
self._record_cost_fn(
phase="query",
@@ -577,22 +465,6 @@ def search_orchestrated(
intent_coverage = float(coverage.get("intent_coverage", coverage.get("coverage_ratio", 0.0)) or 0.0)
- # Dhee: Self-evolution — record answer generation signal
- evolution_layer = self._evolution_layer_fn()
- if evolution_layer and reduced_answer:
- try:
- source_ids = [r.get("id", "") for r in results[:context_limit] if r.get("id")]
- source_texts = [r.get("memory", "") for r in results[:context_limit] if r.get("memory")]
- evolution_layer.on_answer_generated(
- query=query,
- answer=str(reduced_answer),
- source_memory_ids=source_ids,
- source_texts=source_texts,
- user_id=user_id or "default",
- )
- except Exception as e:
- logger.debug("Evolution answer hook skipped: %s", e)
-
return {
"results": results[: max(1, int(limit))],
"event_hits": event_hits,
@@ -618,247 +490,3 @@ def search_orchestrated(
"facts": facts,
}
- # -- Map-reduce execution -------------------------------------------------
-
- def _execute_map_reduce(
- self,
- *,
- query_plan: Any,
- orchestrator_llm: Optional[Any],
- results: List[Dict[str, Any]],
- event_hits: Optional[List[Dict[str, Any]]] = None,
- coverage: Optional[Dict[str, Any]],
- query: str,
- question_type: str,
- question_date: str,
- mode: str,
- search_cap_value: int,
- map_max_candidates_value: int,
- map_max_chars_value: int,
- reflection_max_hops: Optional[int],
- search_query: str,
- search_limit: int,
- rerank: bool,
- keyword_search: bool,
- hybrid_alpha: float,
- include_evidence: bool,
- evidence_strategy: str,
- evidence_max_chars: int,
- evidence_context_lines: int,
- user_id: str,
- filters: Optional[Dict[str, Any]],
- categories: Optional[List[str]],
- agent_id: Optional[str],
- run_id: Optional[str],
- app_id: Optional[str],
- ) -> Tuple[Optional[str], List[Dict[str, Any]], bool, int, float, bool, List[str], List[Dict[str, Any]]]:
- """Execute map-reduce orchestration with optional reflection.
-
- Tries event-first reduction (zero LLM cost) before falling back
- to LLM-based atomic fact extraction.
-
- Returns:
- (
- reduced_answer,
- facts,
- map_reduce_used,
- reflection_hops,
- llm_calls_used,
- cache_hit,
- reason_codes,
- updated_results,
- )
- """
- reduced_answer: Optional[str] = None
- facts: List[Dict[str, Any]] = []
- map_reduce_used = False
- reflection_hops = 0
- llm_calls_used = 0.0
- cache_hit = False
- reason_codes: List[str] = []
- active_orchestrator_llm = orchestrator_llm or self._llm_fn()
- orch_cfg = getattr(self._config, "orchestration", None)
- raw_max_query_llm_calls = getattr(orch_cfg, "max_query_llm_calls", 2)
- try:
- max_query_llm_calls = int(raw_max_query_llm_calls if raw_max_query_llm_calls is not None else 2)
- except (TypeError, ValueError):
- max_query_llm_calls = 2
-
- coverage_sufficient = bool((coverage or {}).get("sufficient"))
- if coverage_sufficient:
- reason_codes.append("coverage_sufficient")
- else:
- reason_codes.append("coverage_insufficient")
-
- inconsistency = deterministic_inconsistency_check(
- question=query,
- intent=query_plan.intent,
- results=results,
- coverage=coverage,
- )
- inconsistency_detected = bool(inconsistency.get("inconsistent"))
- if inconsistency_detected:
- reason_codes.extend(list(inconsistency.get("reasons") or []))
-
- # NOTE: Event-first reduction (Phase 2) disabled — episodic events
- # alone lack sufficient coverage for accurate multi-session counting.
- # The LLM-based map-reduce path below is more reliable.
- if mode == "strict":
- mode_requires_map_reduce = True
- else:
- mode_requires_map_reduce = (not coverage_sufficient) or inconsistency_detected
-
- should_run_map_reduce = bool(
- query_plan.should_map_reduce
- and active_orchestrator_llm is not None
- and results
- and mode_requires_map_reduce
- )
- if query_plan.should_map_reduce and active_orchestrator_llm is None:
- reason_codes.append("no_orchestrator_llm")
- if should_run_map_reduce and max_query_llm_calls <= 0:
- reason_codes.append("query_llm_budget_exhausted")
- should_run_map_reduce = False
-
- if should_run_map_reduce:
- cache_key = self._build_reducer_cache_key(
- user_id=user_id,
- intent_value=query_plan.intent.value,
- query=query,
- results=results,
- )
- cached = self._get_reducer_cache(cache_key)
- if cached and str(cached.get("reduced_answer") or "").strip():
- cached_answer = str(cached.get("reduced_answer") or "").strip()
- if not is_low_confidence_answer(cached_answer):
- reduced_answer = cached_answer
- facts = list(cached.get("facts") or [])
- cache_hit = True
- reason_codes.append("reducer_cache_hit")
-
- if not cache_hit:
- map_candidates = build_map_candidates(
- results,
- max_candidates=map_max_candidates_value,
- per_candidate_max_chars=map_max_chars_value,
- )
- if llm_calls_used < float(max_query_llm_calls):
- facts = self._extract_atomic_facts_fn(
- llm=active_orchestrator_llm,
- question=query,
- question_type=question_type,
- question_date=question_date,
- candidates=map_candidates,
- )
- reduced_answer, _ = self._reduce_atomic_facts_fn(
- question=query,
- intent=query_plan.intent,
- facts=facts,
- )
- llm_calls_used += 1.0
- map_reduce_used = True
- reason_codes.append("map_reduce_executed")
- if reduced_answer or facts:
- self._put_reducer_cache(
- cache_key=cache_key,
- reduced_answer=reduced_answer,
- facts=facts,
- )
- else:
- reason_codes.append("query_llm_budget_exhausted")
-
- max_hops = int(
- reflection_max_hops
- if reflection_max_hops is not None
- else getattr(self._config.orchestration, "reflection_max_hops", 1)
- )
- if (
- max_hops > 0
- and (not reduced_answer or is_low_confidence_answer(reduced_answer))
- and search_limit < search_cap_value
- and llm_calls_used < float(max_query_llm_calls)
- ):
- reflection_hops = 1
- reason_codes.append("reflection_executed")
- expanded_limit = min(search_cap_value, max(search_limit + 8, search_limit * 2))
- reflection_payload = self._search_fn(
- query=search_query,
- user_id=user_id,
- agent_id=agent_id,
- run_id=run_id,
- app_id=app_id,
- filters=filters,
- categories=categories,
- limit=expanded_limit,
- rerank=rerank,
- keyword_search=keyword_search,
- hybrid_alpha=hybrid_alpha,
- include_evidence=include_evidence,
- evidence_strategy=evidence_strategy,
- evidence_max_chars=evidence_max_chars,
- evidence_context_lines=evidence_context_lines,
- )
- reflected_results = list(reflection_payload.get("results", []))
- merged: Dict[str, Dict[str, Any]] = {}
- for row in results + reflected_results:
- memory_id = str(row.get("id") or "")
- existing = merged.get(memory_id)
- if not existing or float(row.get("composite_score", row.get("score", 0.0))) > float(
- existing.get("composite_score", existing.get("score", 0.0))
- ):
- merged[memory_id] = row
- results = sorted(
- merged.values(),
- key=lambda row: float(row.get("composite_score", row.get("score", 0.0))),
- reverse=True,
- )
- map_candidates = build_map_candidates(
- results,
- max_candidates=map_max_candidates_value,
- per_candidate_max_chars=map_max_chars_value,
- )
- if llm_calls_used < float(max_query_llm_calls):
- facts = self._extract_atomic_facts_fn(
- llm=active_orchestrator_llm,
- question=query,
- question_type=question_type,
- question_date=question_date,
- candidates=map_candidates,
- )
- reduced_answer, _ = self._reduce_atomic_facts_fn(
- question=query,
- intent=query_plan.intent,
- facts=facts,
- )
- llm_calls_used += 1.0
- map_reduce_used = True
- if reduced_answer or facts:
- self._put_reducer_cache(
- cache_key=self._build_reducer_cache_key(
- user_id=user_id,
- intent_value=query_plan.intent.value,
- query=query,
- results=results,
- ),
- reduced_answer=reduced_answer,
- facts=facts,
- )
- else:
- reason_codes.append("query_llm_budget_exhausted")
- elif (
- max_hops > 0
- and (not reduced_answer or is_low_confidence_answer(reduced_answer))
- and search_limit < search_cap_value
- ):
- reason_codes.append("reflection_skipped_budget")
-
- return (
- reduced_answer,
- facts,
- map_reduce_used,
- reflection_hops,
- llm_calls_used,
- cache_hit,
- list(dict.fromkeys(reason_codes)),
- results,
- )
diff --git a/dhee/retrieval/reranker.py b/dhee/memory/reranker.py
similarity index 100%
rename from dhee/retrieval/reranker.py
rename to dhee/memory/reranker.py
diff --git a/dhee/memory/search_pipeline.py b/dhee/memory/search_pipeline.py
index 4e4d1e6..5d76f45 100644
--- a/dhee/memory/search_pipeline.py
+++ b/dhee/memory/search_pipeline.py
@@ -453,7 +453,7 @@ def search(
else:
for mid in reecho_ids:
self._reecho_memory(mid)
- if agent_id:
+ if agent_id and hasattr(self._db, "add_memory_subscriber"):
for mid in subscriber_ids:
self._db.add_memory_subscriber(mid, f"agent:{agent_id}", ref_type="weak")
diff --git a/dhee/memory/write_pipeline.py b/dhee/memory/write_pipeline.py
index 602ce41..94fd173 100644
--- a/dhee/memory/write_pipeline.py
+++ b/dhee/memory/write_pipeline.py
@@ -467,7 +467,14 @@ def _add_llm_cost(input_tokens: float) -> None:
content = explicit_intent.content
blocked = detect_sensitive_categories(content)
- allow_sensitive = bool(mem_metadata.get("allow_sensitive"))
+ # allow_sensitive: explicit caller opt-in, or caller explicitly provided
+ # the content (infer=False / user_provided=True). PII detection is a
+ # guardrail for agent-inferred memories from raw conversation, not for
+ # bulk corpus ingestion where the caller owns the content decision.
+ allow_sensitive = (
+ bool(mem_metadata.get("allow_sensitive"))
+ or bool(mem_metadata.get("user_provided"))
+ )
if blocked and not allow_sensitive:
return {
"event": "BLOCKED",
@@ -477,7 +484,8 @@ def _add_llm_cost(input_tokens: float) -> None:
}
is_task_or_note = (mem_metadata or {}).get("memory_type") in ("task", "note")
- if not explicit_remember and not is_task_or_note and is_ephemeral(content):
+ is_user_provided = bool(mem_metadata.get("user_provided"))
+ if not explicit_remember and not is_task_or_note and not is_user_provided and is_ephemeral(content):
return {
"event": "SKIP",
"reason": "ephemeral",
diff --git a/dhee/mini/__init__.py b/dhee/mini/__init__.py
deleted file mode 100644
index df441ba..0000000
--- a/dhee/mini/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-"""Dhee Mini — small trainable model for self-evolving cognition."""
-
-from dhee.mini.buddhi_mini import BuddhiMini
-from dhee.mini.trace_segmenter import TraceSegmenter, TrainingSpan, SpanType
-
-__all__ = ["BuddhiMini", "TraceSegmenter", "TrainingSpan", "SpanType"]
diff --git a/dhee/mini/buddhi_mini.py b/dhee/mini/buddhi_mini.py
deleted file mode 100644
index 8fa8cd7..0000000
--- a/dhee/mini/buddhi_mini.py
+++ /dev/null
@@ -1,414 +0,0 @@
-"""BuddhiMini — small trainable model for self-evolving cognition.
-
-NOT a separate model from DheeModel. This is DheeModel with 3 new task
-heads + a trace-driven data pipeline that produces better training data.
-
-The self-evolution loop:
- 1. Agent uses Dhee (remember/recall/context/checkpoint)
- 2. Samskara collects 12 signal types per operation
- 3. TraceSegmenter splits trajectories into [REASON]/[ACT]/[MEMORY_OP]
- 4. When signals reach critical mass → Nididhyasana triggers
- 5. ProgressiveTrainer runs: SFT → DPO → RL
- 6. DheeModel updates weights (LoRA merge or GGUF export)
- 7. Hot-swapped without restart
-
-Research basis:
- - Structured Agent Distillation (arXiv:2505.13820): span-specific losses
- - AgeMem (arXiv:2601.01885): memory ops as RL-optimized tool calls
- - EvolveR (arXiv:2510.16079): offline distillation → online retrieval
-
-New task heads (added to DheeModel's existing 6):
- [MEMORY_OP] — predict optimal memory operation for context
- [HEURISTIC] — generate abstract heuristic from trajectory
- [RETRIEVAL_JUDGE] — predict whether retrieval results are sufficient
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-import os
-import time
-from dataclasses import dataclass, field
-from pathlib import Path
-from typing import Any, Callable, Dict, List, Optional
-
-from dhee.mini.trace_segmenter import TraceSegmenter, TrainingSpan, SpanType
-
-logger = logging.getLogger(__name__)
-
-# Training thresholds
-_MIN_SFT_SAMPLES = 50 # minimum samples to trigger SFT
-_MIN_DPO_PAIRS = 20 # minimum pairs to trigger DPO
-_ACCUMULATION_WINDOW = 3600 # seconds between training checks
-
-
-@dataclass
-class TrainingBuffer:
- """Accumulates training data between training cycles."""
- sft_samples: List[Dict[str, str]] = field(default_factory=list)
- dpo_pairs: List[Dict[str, Any]] = field(default_factory=list)
- trajectories_ingested: int = 0
- contrastive_pairs_ingested: int = 0
- last_train_time: float = 0.0
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "sft_samples": len(self.sft_samples),
- "dpo_pairs": len(self.dpo_pairs),
- "trajectories_ingested": self.trajectories_ingested,
- "contrastive_pairs_ingested": self.contrastive_pairs_ingested,
- "last_train_time": self.last_train_time,
- }
-
-
-class BuddhiMini:
- """Small trainable model for self-evolving cognition.
-
- Wraps the existing DheeModel (Qwen3.5-2B) and adds:
- 1. Trace ingestion pipeline (trajectories → training spans)
- 2. Training data accumulation with thresholds
- 3. Progressive training trigger (SFT → DPO → RL)
- 4. 3 new inference task heads
-
- The model trains itself from the agent's own interaction traces.
- No external training data needed. Pure self-evolution.
-
- Args:
- data_dir: Directory for training data and checkpoints
- model_size: Not used yet — reserved for future model variants
- device: Device for inference (auto-detected if None)
- """
-
- def __init__(
- self,
- data_dir: Optional[str] = None,
- model_size: str = "2B",
- device: Optional[str] = None,
- ):
- self._data_dir = data_dir or os.path.join(
- os.path.expanduser("~"), ".dhee", "mini"
- )
- os.makedirs(self._data_dir, exist_ok=True)
-
- self._segmenter = TraceSegmenter()
- self._buffer = TrainingBuffer()
- self._model = None # lazy-loaded DheeLLM
- self._device = device
-
- # Load persisted buffer if exists
- self._load_buffer()
-
- # ------------------------------------------------------------------
- # Trace ingestion
- # ------------------------------------------------------------------
-
- def ingest_trajectory(self, trajectory) -> Dict[str, Any]:
- """Ingest a trajectory and segment it into training spans.
-
- Called automatically by DheePlugin.end_trajectory() or manually.
-
- Returns:
- {"spans": int, "sft_added": int, "dpo_ready": bool}
- """
- spans = self._segmenter.segment(trajectory)
- if not spans:
- return {"spans": 0, "sft_added": 0, "dpo_ready": False}
-
- # Add successful spans to SFT buffer
- sft_examples = self._segmenter.format_for_sft(spans)
- self._buffer.sft_samples.extend(sft_examples)
- self._buffer.trajectories_ingested += 1
-
- # Store spans for DPO pairing later
- self._save_spans(spans)
- self._save_buffer()
-
- return {
- "spans": len(spans),
- "sft_added": len(sft_examples),
- "dpo_ready": len(self._buffer.dpo_pairs) >= _MIN_DPO_PAIRS,
- }
-
- def ingest_contrastive_pair(
- self,
- task_description: str,
- success_approach: str,
- failure_approach: str,
- task_type: str = "general",
- ) -> None:
- """Ingest a contrastive pair for DPO training.
-
- Called when checkpoint() receives both what_worked and what_failed.
- """
- self._buffer.dpo_pairs.append({
- "prompt": f"[TASK] {task_description}\n[TYPE] {task_type}",
- "chosen": success_approach,
- "rejected": failure_approach,
- "span_type": "reflect",
- })
- self._buffer.contrastive_pairs_ingested += 1
- self._save_buffer()
-
- # ------------------------------------------------------------------
- # Training control
- # ------------------------------------------------------------------
-
- def should_train(self) -> tuple:
- """Check if enough data has accumulated for a training cycle.
-
- Returns:
- (should_train: bool, reason: str)
- """
- now = time.time()
- if now - self._buffer.last_train_time < _ACCUMULATION_WINDOW:
- return False, "Too soon since last training cycle"
-
- sft_ready = len(self._buffer.sft_samples) >= _MIN_SFT_SAMPLES
- dpo_ready = len(self._buffer.dpo_pairs) >= _MIN_DPO_PAIRS
-
- if sft_ready and dpo_ready:
- return True, f"Ready: {len(self._buffer.sft_samples)} SFT + {len(self._buffer.dpo_pairs)} DPO"
- if sft_ready:
- return True, f"SFT ready: {len(self._buffer.sft_samples)} samples"
- if dpo_ready:
- return True, f"DPO ready: {len(self._buffer.dpo_pairs)} pairs"
-
- return False, (
- f"Accumulating: {len(self._buffer.sft_samples)}/{_MIN_SFT_SAMPLES} SFT, "
- f"{len(self._buffer.dpo_pairs)}/{_MIN_DPO_PAIRS} DPO"
- )
-
- def train_cycle(self, stage: str = "auto") -> Dict[str, Any]:
- """Run one training cycle.
-
- Delegates to ProgressiveTrainer or Nididhyasana depending on
- what's available. Returns training results.
-
- Args:
- stage: "sft", "dpo", "progressive", or "auto"
- """
- result: Dict[str, Any] = {"stage": stage, "status": "skipped"}
-
- try:
- # Try to use Nididhyasana (existing auto-evolution loop)
- from dheeModel.training.nididhyasana import NididhyasanaLoop
- loop = NididhyasanaLoop(data_dir=self._data_dir)
-
- # Export training data in Nididhyasana format
- training_data = self._export_training_data()
- if not training_data:
- result["status"] = "no_data"
- return result
-
- # Run cycle
- cycle_result = loop.run_cycle(
- sft_data=training_data.get("sft", []),
- dpo_data=training_data.get("dpo", []),
- )
- result["status"] = "completed"
- result["cycle"] = cycle_result
- except ImportError:
- logger.debug("Nididhyasana not available — storing data for manual training")
- self._save_training_export()
- result["status"] = "data_saved"
- result["path"] = os.path.join(self._data_dir, "training_export.jsonl")
- except Exception as e:
- logger.debug("Training cycle failed: %s", e)
- result["status"] = "error"
- result["error"] = str(e)
-
- # Update buffer
- self._buffer.last_train_time = time.time()
- self._buffer.sft_samples = [] # Clear used samples
- self._buffer.dpo_pairs = []
- self._save_buffer()
-
- return result
-
- # ------------------------------------------------------------------
- # Inference (edge-optimized task heads)
- # ------------------------------------------------------------------
-
- def classify_memory_op(self, context: str) -> str:
- """Predict optimal memory operation for current context.
-
- Task head: [MEMORY_OP]
- Returns: "store" | "retrieve" | "update" | "summarize" | "discard" | "none"
- """
- model = self._get_model()
- if model is None:
- return self._heuristic_classify_memory_op(context)
-
- try:
- response = model.generate_with_task(
- task="MEMORY_OP",
- prompt=context[:1000],
- )
- op = response.strip().lower()
- valid_ops = {"store", "retrieve", "update", "summarize", "discard", "none"}
- return op if op in valid_ops else "none"
- except Exception:
- return self._heuristic_classify_memory_op(context)
-
- def generate_heuristic(self, trajectory_summary: str) -> str:
- """Generate an abstract heuristic from a trajectory summary.
-
- Task head: [HEURISTIC]
- Returns: A transferable reasoning pattern as natural language.
- """
- model = self._get_model()
- if model is None:
- return f"From experience: {trajectory_summary[:200]}"
-
- try:
- return model.generate_with_task(
- task="HEURISTIC",
- prompt=trajectory_summary[:2000],
- )
- except Exception:
- return f"From experience: {trajectory_summary[:200]}"
-
- def predict_retrieval_quality(
- self, query: str, results: List[Dict[str, Any]],
- ) -> float:
- """Predict whether retrieval results are sufficient.
-
- Task head: [RETRIEVAL_JUDGE]
- Returns: 0.0 (insufficient) to 1.0 (fully sufficient)
- """
- model = self._get_model()
- if model is None:
- return self._heuristic_retrieval_quality(query, results)
-
- try:
- results_text = "\n".join(
- f"- {r.get('memory', '')[:100]} (score={r.get('score', 0):.2f})"
- for r in results[:5]
- )
- prompt = f"Query: {query}\nResults:\n{results_text}"
- response = model.generate_with_task(
- task="RETRIEVAL_JUDGE",
- prompt=prompt,
- )
- return max(0.0, min(1.0, float(response.strip())))
- except Exception:
- return self._heuristic_retrieval_quality(query, results)
-
- # ------------------------------------------------------------------
- # Stats
- # ------------------------------------------------------------------
-
- def get_stats(self) -> Dict[str, Any]:
- """Get BuddhiMini status."""
- should, reason = self.should_train()
- return {
- "buffer": self._buffer.to_dict(),
- "should_train": should,
- "train_reason": reason,
- "model_loaded": self._model is not None,
- "data_dir": self._data_dir,
- }
-
- # ------------------------------------------------------------------
- # Internals
- # ------------------------------------------------------------------
-
- def _get_model(self):
- """Lazy-load the DheeModel."""
- if self._model is not None:
- return self._model
- try:
- from dhee.llms.dhee import DheeLLM
- self._model = DheeLLM(config={"device": self._device} if self._device else {})
- return self._model
- except Exception:
- return None
-
- def _heuristic_classify_memory_op(self, context: str) -> str:
- """Rule-based fallback for memory op classification."""
- cl = context.lower()
- if any(w in cl for w in ["remember", "store", "save", "note"]):
- return "store"
- if any(w in cl for w in ["recall", "search", "find", "what did"]):
- return "retrieve"
- if any(w in cl for w in ["update", "change", "correct"]):
- return "update"
- if any(w in cl for w in ["forget", "delete", "remove"]):
- return "discard"
- if any(w in cl for w in ["summarize", "consolidate", "compress"]):
- return "summarize"
- return "none"
-
- def _heuristic_retrieval_quality(
- self, query: str, results: List[Dict[str, Any]],
- ) -> float:
- """Rule-based fallback for retrieval quality prediction."""
- if not results:
- return 0.0
- top_score = results[0].get("score", 0) if results else 0
- count = len(results)
- # Simple heuristic: score * coverage
- coverage = min(count / 3.0, 1.0)
- return round(min(top_score * coverage, 1.0), 3)
-
- def _export_training_data(self) -> Optional[Dict[str, List]]:
- """Export accumulated buffer as training data."""
- if not self._buffer.sft_samples and not self._buffer.dpo_pairs:
- return None
- return {
- "sft": list(self._buffer.sft_samples),
- "dpo": list(self._buffer.dpo_pairs),
- }
-
- def _save_training_export(self) -> None:
- """Save training data to JSONL for manual training."""
- path = os.path.join(self._data_dir, "training_export.jsonl")
- try:
- with open(path, "a", encoding="utf-8") as f:
- for sample in self._buffer.sft_samples:
- f.write(json.dumps({"type": "sft", **sample}) + "\n")
- for pair in self._buffer.dpo_pairs:
- f.write(json.dumps({"type": "dpo", **pair}) + "\n")
- except OSError as e:
- logger.debug("Failed to save training export: %s", e)
-
- def _save_spans(self, spans: List[TrainingSpan]) -> None:
- """Persist spans for later DPO pairing."""
- path = os.path.join(self._data_dir, "spans.jsonl")
- try:
- with open(path, "a", encoding="utf-8") as f:
- for span in spans:
- f.write(json.dumps(span.to_dict()) + "\n")
- except OSError as e:
- logger.debug("Failed to save spans: %s", e)
-
- def _save_buffer(self) -> None:
- """Persist buffer metadata."""
- path = os.path.join(self._data_dir, "buffer.json")
- try:
- data = {
- "sft_count": len(self._buffer.sft_samples),
- "dpo_count": len(self._buffer.dpo_pairs),
- "trajectories_ingested": self._buffer.trajectories_ingested,
- "contrastive_pairs_ingested": self._buffer.contrastive_pairs_ingested,
- "last_train_time": self._buffer.last_train_time,
- }
- with open(path, "w", encoding="utf-8") as f:
- json.dump(data, f)
- except OSError:
- pass
-
- def _load_buffer(self) -> None:
- """Load persisted buffer metadata."""
- path = os.path.join(self._data_dir, "buffer.json")
- if not os.path.exists(path):
- return
- try:
- with open(path, "r", encoding="utf-8") as f:
- data = json.load(f)
- self._buffer.trajectories_ingested = data.get("trajectories_ingested", 0)
- self._buffer.contrastive_pairs_ingested = data.get("contrastive_pairs_ingested", 0)
- self._buffer.last_train_time = data.get("last_train_time", 0.0)
- except (OSError, json.JSONDecodeError):
- pass
diff --git a/dhee/mini/progressive_trainer.py b/dhee/mini/progressive_trainer.py
deleted file mode 100644
index b6f082e..0000000
--- a/dhee/mini/progressive_trainer.py
+++ /dev/null
@@ -1,411 +0,0 @@
-"""Progressive Trainer — 3-stage training for BuddhiMini.
-
-Based on AgeMem (arXiv:2601.01885): memory ops as RL-optimized tool calls
-with 3-stage progressive training for optimal learning.
-
-Stage 1 — SFT (Supervised Fine-Tuning):
- Train on high-quality trajectory spans from TraceSegmenter.
- Each span is a (task_context → [SPAN_TYPE] output) example.
- Span-specific losses per Structured Agent Distillation (arXiv:2505.13820).
-
-Stage 2 — DPO (Direct Preference Optimization):
- Train on contrastive pairs from ContrastiveStore.
- Each pair: (task_context, success_approach, failure_approach).
- The model learns to prefer successful reasoning patterns.
-
-Stage 3 — RL (Retrieval-Quality Reward):
- Use retrieval quality as reward signal.
- After the model updates, measure whether recall@K improves.
- If yes, keep the update. If no, rollback.
-
-The trainer does NOT run training itself — it curates data and delegates
-to the existing Nididhyasana/train.py pipeline. Its job is to decide
-WHAT to train on and in WHAT order.
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-import os
-import time
-from dataclasses import dataclass, field
-from typing import Any, Dict, List, Optional
-
-logger = logging.getLogger(__name__)
-
-
-@dataclass
-class TrainingStageResult:
- """Result from a single training stage."""
-
- stage: str # sft | dpo | rl
- status: str # completed | skipped | error
- samples_used: int = 0
- metrics: Dict[str, float] = field(default_factory=dict)
- duration_seconds: float = 0.0
- error: Optional[str] = None
-
- def to_dict(self) -> Dict[str, Any]:
- d = {
- "stage": self.stage,
- "status": self.status,
- "samples_used": self.samples_used,
- "duration_seconds": round(self.duration_seconds, 1),
- }
- if self.metrics:
- d["metrics"] = self.metrics
- if self.error:
- d["error"] = self.error
- return d
-
-
-@dataclass
-class ProgressiveTrainingResult:
- """Result from a full progressive training cycle."""
-
- cycle_id: str
- stages: List[TrainingStageResult]
- total_duration: float = 0.0
- model_improved: bool = False
- data_exported_path: Optional[str] = None
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "cycle_id": self.cycle_id,
- "stages": [s.to_dict() for s in self.stages],
- "total_duration": round(self.total_duration, 1),
- "model_improved": self.model_improved,
- "data_exported_path": self.data_exported_path,
- }
-
-
-class ProgressiveTrainer:
- """Curates and orders training data for the 3-stage progressive pipeline.
-
- This is the brain that decides what data to train on. The actual
- training execution is delegated to Nididhyasana (which calls train.py).
-
- Usage:
- trainer = ProgressiveTrainer(data_dir="/path/to/training")
- result = trainer.run_cycle(
- sft_data=[...], # from TraceSegmenter
- dpo_data=[...], # from ContrastiveStore
- samskara_data={...} # from Samskara.get_training_data()
- )
- """
-
- # Minimum data thresholds
- MIN_SFT = 20
- MIN_DPO = 10
-
- def __init__(
- self,
- data_dir: Optional[str] = None,
- train_fn=None,
- ):
- self._dir = data_dir or os.path.join(
- os.path.expanduser("~"), ".dhee", "progressive_training"
- )
- os.makedirs(self._dir, exist_ok=True)
- self._train_fn = train_fn # injected training function
- self._cycle_count = 0
- self._history: List[Dict[str, Any]] = []
- self._load_history()
-
- def run_cycle(
- self,
- sft_data: Optional[List[Dict[str, str]]] = None,
- dpo_data: Optional[List[Dict[str, str]]] = None,
- samskara_data: Optional[Dict[str, Any]] = None,
- ) -> ProgressiveTrainingResult:
- """Run a full progressive training cycle: SFT → DPO → RL.
-
- Each stage is optional — skipped if insufficient data.
- """
- cycle_id = f"prog_{self._cycle_count:04d}_{int(time.time())}"
- self._cycle_count += 1
- cycle_dir = os.path.join(self._dir, cycle_id)
- os.makedirs(cycle_dir, exist_ok=True)
-
- start = time.time()
- stages: List[TrainingStageResult] = []
-
- # Merge samskara SFT samples with explicit SFT data
- all_sft = list(sft_data or [])
- if samskara_data:
- all_sft.extend(samskara_data.get("sft_samples", []))
-
- # Merge samskara DPO pairs with explicit DPO data
- all_dpo = list(dpo_data or [])
- if samskara_data:
- all_dpo.extend(samskara_data.get("dpo_pairs", []))
-
- # Weight data by vasana degradation (focus on weak areas)
- if samskara_data:
- all_sft = self._weight_by_vasana(all_sft, samskara_data)
-
- # --- Stage 1: SFT ---
- sft_result = self._run_sft(all_sft, cycle_dir)
- stages.append(sft_result)
-
- # --- Stage 2: DPO ---
- dpo_result = self._run_dpo(all_dpo, cycle_dir)
- stages.append(dpo_result)
-
- # --- Stage 3: RL (evaluation-based) ---
- rl_result = self._run_rl_eval(cycle_dir, sft_result, dpo_result)
- stages.append(rl_result)
-
- total_duration = time.time() - start
- completed_stages = [s for s in stages if s.status == "completed"]
-
- result = ProgressiveTrainingResult(
- cycle_id=cycle_id,
- stages=stages,
- total_duration=total_duration,
- model_improved=len(completed_stages) > 0,
- data_exported_path=cycle_dir,
- )
-
- self._record_history(result)
- return result
-
- # ------------------------------------------------------------------
- # Stage 1: SFT
- # ------------------------------------------------------------------
-
- def _run_sft(
- self, data: List[Dict[str, str]], cycle_dir: str,
- ) -> TrainingStageResult:
- """Stage 1: Supervised Fine-Tuning on trajectory spans."""
- if len(data) < self.MIN_SFT:
- return TrainingStageResult(
- stage="sft", status="skipped",
- metrics={"reason": f"insufficient data ({len(data)}/{self.MIN_SFT})"},
- )
-
- start = time.time()
-
- # Curate: prioritize diverse span types
- curated = self._curate_sft(data)
-
- # Export as train.jsonl
- train_path = os.path.join(cycle_dir, "sft_train.jsonl")
- self._write_jsonl(train_path, curated)
-
- # Try to run actual training
- metrics = {}
- if self._train_fn:
- try:
- train_result = self._train_fn(
- data_dir=cycle_dir,
- output_dir=os.path.join(cycle_dir, "sft_output"),
- )
- metrics = train_result if isinstance(train_result, dict) else {}
- except Exception as e:
- return TrainingStageResult(
- stage="sft", status="error",
- samples_used=len(curated),
- duration_seconds=time.time() - start,
- error=str(e),
- )
- else:
- metrics["data_exported"] = train_path
-
- return TrainingStageResult(
- stage="sft", status="completed",
- samples_used=len(curated),
- metrics=metrics,
- duration_seconds=time.time() - start,
- )
-
- def _curate_sft(self, data: List[Dict[str, str]]) -> List[Dict[str, str]]:
- """Curate SFT data: balance span types, cap per type."""
- by_type: Dict[str, List] = {}
- for sample in data:
- t = sample.get("type", "general")
- by_type.setdefault(t, []).append(sample)
-
- # Take up to 50 per type for balance
- curated = []
- for samples in by_type.values():
- curated.extend(samples[:50])
-
- return curated
-
- # ------------------------------------------------------------------
- # Stage 2: DPO
- # ------------------------------------------------------------------
-
- def _run_dpo(
- self, data: List[Dict[str, str]], cycle_dir: str,
- ) -> TrainingStageResult:
- """Stage 2: Direct Preference Optimization on contrastive pairs."""
- if len(data) < self.MIN_DPO:
- return TrainingStageResult(
- stage="dpo", status="skipped",
- metrics={"reason": f"insufficient data ({len(data)}/{self.MIN_DPO})"},
- )
-
- start = time.time()
-
- # Export as dpo_pairs.jsonl
- dpo_path = os.path.join(cycle_dir, "dpo_pairs.jsonl")
- self._write_jsonl(dpo_path, data)
-
- metrics = {}
- if self._train_fn:
- try:
- train_result = self._train_fn(
- data_dir=cycle_dir,
- output_dir=os.path.join(cycle_dir, "dpo_output"),
- dpo_mode=True,
- )
- metrics = train_result if isinstance(train_result, dict) else {}
- except Exception as e:
- return TrainingStageResult(
- stage="dpo", status="error",
- samples_used=len(data),
- duration_seconds=time.time() - start,
- error=str(e),
- )
- else:
- metrics["data_exported"] = dpo_path
-
- return TrainingStageResult(
- stage="dpo", status="completed",
- samples_used=len(data),
- metrics=metrics,
- duration_seconds=time.time() - start,
- )
-
- # ------------------------------------------------------------------
- # Stage 3: RL (reward = retrieval quality)
- # ------------------------------------------------------------------
-
- def _run_rl_eval(
- self,
- cycle_dir: str,
- sft_result: TrainingStageResult,
- dpo_result: TrainingStageResult,
- ) -> TrainingStageResult:
- """Stage 3: RL evaluation — decide whether to keep updates.
-
- In practice, this stage verifies that the SFT/DPO changes
- didn't degrade retrieval quality. It's a gate, not a trainer.
- """
- start = time.time()
-
- # If neither SFT nor DPO actually trained, skip
- if sft_result.status != "completed" and dpo_result.status != "completed":
- return TrainingStageResult(
- stage="rl", status="skipped",
- metrics={"reason": "no prior stages completed"},
- )
-
- # Compute aggregate quality from what we have
- sft_samples = sft_result.samples_used
- dpo_samples = dpo_result.samples_used
- total_samples = sft_samples + dpo_samples
-
- # Simple quality heuristic: more diverse data = more likely to help
- quality_estimate = min(1.0, total_samples / 100.0)
-
- metrics = {
- "quality_estimate": round(quality_estimate, 3),
- "sft_contribution": sft_samples,
- "dpo_contribution": dpo_samples,
- "verdict": "keep" if quality_estimate > 0.3 else "rollback",
- }
-
- return TrainingStageResult(
- stage="rl", status="completed",
- metrics=metrics,
- duration_seconds=time.time() - start,
- )
-
- # ------------------------------------------------------------------
- # Vasana-weighted data emphasis
- # ------------------------------------------------------------------
-
- def _weight_by_vasana(
- self,
- data: List[Dict[str, str]],
- samskara_data: Dict[str, Any],
- ) -> List[Dict[str, str]]:
- """Duplicate samples from degrading dimensions for emphasis.
-
- If retrieval_recall is degrading, duplicate RETRIEVAL_HIT samples.
- Same weighting logic as Nididhyasana._curate_dataset().
- """
- degrading = set(samskara_data.get("degrading_dimensions", []))
- if not degrading:
- return data
-
- # Map degrading dimensions to sample types
- dim_to_type = {
- "retrieval_recall": {"retrieval_hit", "retrieval_miss"},
- "retrieval_precision": {"retrieval_hit"},
- "answer_quality": {"answer_accepted", "answer_corrected"},
- "fact_extraction": {"extraction"},
- }
-
- emphasized_types = set()
- for dim in degrading:
- emphasized_types |= dim_to_type.get(dim, set())
-
- # Duplicate matching samples (2x weight)
- weighted = []
- for sample in data:
- weighted.append(sample)
- sample_type = sample.get("type", "").lower()
- valence = sample.get("valence", "")
- if sample_type in emphasized_types or valence == "klishta":
- weighted.append(sample) # duplicate = 2x emphasis
-
- return weighted
-
- # ------------------------------------------------------------------
- # Helpers
- # ------------------------------------------------------------------
-
- def _write_jsonl(self, path: str, data: List[Dict]) -> None:
- try:
- with open(path, "w", encoding="utf-8") as f:
- for item in data:
- f.write(json.dumps(item, ensure_ascii=False) + "\n")
- except OSError as e:
- logger.debug("Failed to write %s: %s", path, e)
-
- def _record_history(self, result: ProgressiveTrainingResult) -> None:
- self._history.append(result.to_dict())
- path = os.path.join(self._dir, "history.jsonl")
- try:
- with open(path, "a", encoding="utf-8") as f:
- f.write(json.dumps(result.to_dict(), ensure_ascii=False) + "\n")
- except OSError:
- pass
-
- def _load_history(self) -> None:
- path = os.path.join(self._dir, "history.jsonl")
- if not os.path.exists(path):
- return
- try:
- with open(path, "r", encoding="utf-8") as f:
- for line in f:
- line = line.strip()
- if line:
- try:
- self._history.append(json.loads(line))
- self._cycle_count += 1
- except json.JSONDecodeError:
- continue
- except OSError:
- pass
-
- def get_stats(self) -> Dict[str, Any]:
- return {
- "cycles_completed": self._cycle_count,
- "last_cycle": self._history[-1] if self._history else None,
- }
diff --git a/dhee/mini/trace_segmenter.py b/dhee/mini/trace_segmenter.py
deleted file mode 100644
index 5522981..0000000
--- a/dhee/mini/trace_segmenter.py
+++ /dev/null
@@ -1,246 +0,0 @@
-"""Trace segmenter — converts agent trajectories into training spans.
-
-Based on Structured Agent Distillation (Liu et al., arXiv:2505.13820):
-segments agent interaction traces into [REASON], [ACT], and [MEMORY_OP]
-spans with span-specific training losses for more efficient learning.
-
-The key insight: different types of agent behavior (reasoning vs action
-vs memory management) benefit from different training objectives.
-Token-level distillation treats them uniformly and is less effective.
-"""
-
-from __future__ import annotations
-
-import re
-import uuid
-from dataclasses import dataclass, field
-from enum import Enum
-from typing import Any, Dict, List, Optional, Tuple
-
-
-class SpanType(str, Enum):
- """Type of training span — each gets its own loss weight."""
-
- REASON = "reason" # Internal reasoning, planning, analysis
- ACT = "act" # Tool calls, commands, actions taken
- MEMORY_OP = "memory_op" # Memory operations (store, retrieve, update, summarize, discard)
- REFLECT = "reflect" # Self-reflection, insight synthesis
- OBSERVE = "observe" # Observation, reading results, understanding state
-
-
-@dataclass
-class TrainingSpan:
- """A single segment of an agent trajectory for training.
-
- Each span has a type, the text content, and metadata about the
- trajectory it came from. Spans from successful trajectories are
- used for SFT; paired success/failure spans for DPO.
- """
- id: str
- span_type: SpanType
- content: str # the text of this span
- context_before: str # preceding context (for input)
- trajectory_id: str
- step_index: int
- task_description: str
- success: bool # was the overall trajectory successful?
- metadata: Dict[str, Any] = field(default_factory=dict)
-
- def to_sft_example(self) -> Dict[str, str]:
- """Format as SFT training example (input → output)."""
- return {
- "input": f"[TASK] {self.task_description}\n[CONTEXT] {self.context_before}",
- "output": f"[{self.span_type.value.upper()}] {self.content}",
- "type": self.span_type.value,
- }
-
- def to_dict(self) -> Dict[str, Any]:
- return {
- "id": self.id,
- "span_type": self.span_type.value,
- "content": self.content,
- "context_before": self.context_before[:500],
- "trajectory_id": self.trajectory_id,
- "step_index": self.step_index,
- "task_description": self.task_description,
- "success": self.success,
- }
-
-
-# Patterns for classifying steps into span types
-_MEMORY_PATTERNS = re.compile(
- r"(?:remember|recall|search|store|forget|update.*memor|delete.*memor|checkpoint)",
- re.IGNORECASE,
-)
-_REASON_PATTERNS = re.compile(
- r"(?:think|plan|analyze|consider|reason|decide|evaluate|assess|compare)",
- re.IGNORECASE,
-)
-_REFLECT_PATTERNS = re.compile(
- r"(?:reflect|insight|learned|worked|failed|improve|heuristic|takeaway)",
- re.IGNORECASE,
-)
-
-
-class TraceSegmenter:
- """Segments agent trajectories into typed training spans.
-
- Takes a Trajectory (from dhee.skills.trajectory) and produces
- a list of TrainingSpan objects suitable for:
- - SFT: train on successful spans
- - DPO: pair successful/failed spans for preference learning
- - RL: use retrieval quality as reward signal
-
- Usage:
- from dhee.mini.trace_segmenter import TraceSegmenter
- from dhee.skills.schema import Trajectory
-
- segmenter = TraceSegmenter()
- spans = segmenter.segment(trajectory)
-
- # For SFT training
- sft_data = segmenter.format_for_sft(spans)
-
- # For DPO training
- dpo_data = segmenter.format_for_dpo(success_spans, failure_spans)
- """
-
- def segment(self, trajectory) -> List[TrainingSpan]:
- """Segment a trajectory into typed training spans.
-
- Args:
- trajectory: A Trajectory object from dhee.skills.schema
-
- Returns:
- List of TrainingSpan objects
- """
- spans: List[TrainingSpan] = []
- context_parts: List[str] = []
-
- for i, step in enumerate(trajectory.steps):
- # Classify the step
- span_type = self._classify_step(step)
-
- # Build content from step
- content = self._extract_content(step)
- if not content:
- continue
-
- # Context is everything before this step
- context_before = "\n".join(context_parts[-3:]) # last 3 steps
-
- span = TrainingSpan(
- id=str(uuid.uuid4()),
- span_type=span_type,
- content=content,
- context_before=context_before,
- trajectory_id=trajectory.id,
- step_index=i,
- task_description=trajectory.task_description,
- success=trajectory.success,
- metadata={
- "tool": getattr(step, "tool", ""),
- "error": getattr(step, "error", None),
- "duration_ms": getattr(step, "duration_ms", None),
- },
- )
- spans.append(span)
-
- # Update rolling context
- context_parts.append(f"[{span_type.value}] {content[:200]}")
-
- return spans
-
- def format_for_sft(self, spans: List[TrainingSpan]) -> List[Dict[str, str]]:
- """Format successful spans as SFT training examples."""
- return [
- span.to_sft_example()
- for span in spans
- if span.success
- ]
-
- def format_for_dpo(
- self,
- success_spans: List[TrainingSpan],
- failure_spans: List[TrainingSpan],
- ) -> List[Dict[str, Any]]:
- """Create DPO training pairs from success/failure spans.
-
- Pairs are created by matching spans with the same span_type
- and similar step_index from successful and failed trajectories.
- """
- pairs = []
-
- # Group by span type
- success_by_type: Dict[str, List[TrainingSpan]] = {}
- failure_by_type: Dict[str, List[TrainingSpan]] = {}
-
- for s in success_spans:
- success_by_type.setdefault(s.span_type.value, []).append(s)
- for f in failure_spans:
- failure_by_type.setdefault(f.span_type.value, []).append(f)
-
- # Create pairs for each shared type
- for span_type in set(success_by_type) & set(failure_by_type):
- chosen_list = success_by_type[span_type]
- rejected_list = failure_by_type[span_type]
-
- # Pair by position (zip truncates to shorter)
- for chosen, rejected in zip(chosen_list, rejected_list):
- pairs.append({
- "prompt": f"[TASK] {chosen.task_description}\n"
- f"[CONTEXT] {chosen.context_before}",
- "chosen": f"[{span_type.upper()}] {chosen.content}",
- "rejected": f"[{span_type.upper()}] {rejected.content}",
- "span_type": span_type,
- })
-
- return pairs
-
- def _classify_step(self, step) -> SpanType:
- """Classify a trajectory step into a span type."""
- action = getattr(step, "action", "")
- tool = getattr(step, "tool", "")
- result_summary = getattr(step, "result_summary", "")
- combined = f"{action} {tool} {result_summary}"
-
- # Memory operations
- if _MEMORY_PATTERNS.search(combined):
- return SpanType.MEMORY_OP
-
- # Reflection
- if _REFLECT_PATTERNS.search(combined):
- return SpanType.REFLECT
-
- # Reasoning (no tool call, just thinking)
- if not tool and _REASON_PATTERNS.search(combined):
- return SpanType.REASON
-
- # Tool call = action
- if tool:
- return SpanType.ACT
-
- # Observation (reading results)
- if result_summary and not tool:
- return SpanType.OBSERVE
-
- # Default to reasoning
- return SpanType.REASON
-
- def _extract_content(self, step) -> str:
- """Extract the text content from a trajectory step."""
- parts = []
- action = getattr(step, "action", "")
- tool = getattr(step, "tool", "")
- result_summary = getattr(step, "result_summary", "")
-
- if action:
- parts.append(action)
- if tool:
- args = getattr(step, "args", {})
- args_str = ", ".join(f"{k}={v}" for k, v in list(args.items())[:3]) if args else ""
- parts.append(f"tool={tool}({args_str})")
- if result_summary:
- parts.append(f"→ {result_summary[:300]}")
-
- return " | ".join(parts)
diff --git a/dhee/adapters/base.py b/dhee/plugin.py
similarity index 100%
rename from dhee/adapters/base.py
rename to dhee/plugin.py
diff --git a/dhee/retrieval/__init__.py b/dhee/retrieval/__init__.py
deleted file mode 100644
index 49ed407..0000000
--- a/dhee/retrieval/__init__.py
+++ /dev/null
@@ -1,10 +0,0 @@
-"""Engram v2 retrieval components."""
-
-try:
- from dhee.retrieval.dual_search import DualSearchEngine
-except ImportError:
- DualSearchEngine = None
-
-from dhee.retrieval.reranker import NvidiaReranker, create_reranker
-
-__all__ = ["DualSearchEngine", "NvidiaReranker", "create_reranker"]
diff --git a/dhee/simple.py b/dhee/simple.py
index 376f9b0..cd3184d 100644
--- a/dhee/simple.py
+++ b/dhee/simple.py
@@ -41,17 +41,31 @@
def _detect_provider() -> str:
"""Detect which LLM/embedder provider to use based on environment."""
- if os.environ.get("GEMINI_API_KEY"):
- return "gemini"
if os.environ.get("OPENAI_API_KEY"):
return "openai"
- # Default to gemini if no key found (will fail later with clear error)
- return "gemini"
+ if os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY"):
+ return "gemini"
+ if (
+ os.environ.get("NVIDIA_API_KEY")
+ or os.environ.get("NVIDIA_QWEN_API_KEY")
+ or os.environ.get("NVIDIA_EMBEDDING_API_KEY")
+ or os.environ.get("NVIDIA_EMBED_API_KEY")
+ or os.environ.get("NVIDIA_LLAMA_4_MAV_API_KEY")
+ ):
+ return "nvidia"
+ # Zero-config fallback: local simple embedder + mock LLM.
+ return "mock"
def _get_embedding_dims(provider: str) -> int:
"""Get embedding dimensions for provider."""
- return 3072 if provider == "gemini" else 1536
+ if provider == "gemini":
+ return 3072
+ if provider == "openai":
+ return 1536
+ if provider == "nvidia":
+ return 2048
+ return 384
def _has_api_key() -> bool:
diff --git a/dhee/teaching/__init__.py b/dhee/teaching/__init__.py
deleted file mode 100644
index 80309bf..0000000
--- a/dhee/teaching/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-"""Teaching primitives for Engram — concepts, student models, and teaching memory.
-
-Domain objects stored as Engram memories with ``memory_type`` + metadata dict,
-following the same pattern as ``engram.memory.tasks``.
-"""
-
-from dhee.teaching.config import TeachingConfig
-from dhee.teaching.concepts import ConceptStore
-from dhee.teaching.student_model import StudentModel
-from dhee.teaching.teaching_memory import TeachingMemory
-
-__all__ = [
- "TeachingConfig",
- "ConceptStore",
- "StudentModel",
- "TeachingMemory",
-]
diff --git a/dhee/teaching/concepts.py b/dhee/teaching/concepts.py
deleted file mode 100644
index 0f3ad5a..0000000
--- a/dhee/teaching/concepts.py
+++ /dev/null
@@ -1,307 +0,0 @@
-"""ConceptStore — curriculum concepts stored as Engram memories.
-
-Concepts are stored with ``memory_type="concept"`` under a shared
-``user_id`` (the curriculum namespace). Prerequisites and cross-subject
-links are represented via the knowledge graph (``RelationType.REQUIRES``
-and ``RelationType.RELATED_TO``).
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-import uuid
-from typing import Any, Dict, List, Optional
-
-from dhee.teaching.config import TeachingConfig
-
-logger = logging.getLogger(__name__)
-
-
-class ConceptStore:
- """CRUD and graph operations over curriculum concepts."""
-
- def __init__(self, memory: "CoreMemory", config: TeachingConfig | None = None): # noqa: F821
- self.memory = memory
- self.config = config or TeachingConfig()
- self._namespace = self.config.concept_namespace
-
- # ------------------------------------------------------------------
- # Create / update
- # ------------------------------------------------------------------
-
- def add_concept(
- self,
- concept_id: str,
- name: str,
- subject: str,
- *,
- difficulty: float = 0.5,
- prerequisites: List[str] | None = None,
- cross_subject_links: List[str] | None = None,
- keywords: List[str] | None = None,
- description: str = "",
- ) -> Dict[str, Any]:
- """Add a concept to the store (idempotent by ``concept_id``)."""
-
- # Check for existing concept with same concept_id
- existing = self._find_by_concept_id(concept_id)
- if existing:
- return existing
-
- meta = {
- "memory_type": "concept",
- "concept_id": concept_id,
- "subject": subject,
- "difficulty": difficulty,
- "prerequisites": prerequisites or [],
- "cross_subject_links": cross_subject_links or [],
- "keywords": keywords or [],
- }
-
- content = f"{name}: {description}" if description else name
- result = self.memory.add(
- content=content,
- user_id=self._namespace,
- metadata=meta,
- categories=[f"concept/{subject}"],
- )
-
- mem_id = self._extract_id(result)
- if not mem_id:
- logger.warning("Failed to store concept %s", concept_id)
- return {"error": "Failed to store concept"}
-
- # Create prerequisite edges via knowledge graph
- if prerequisites and hasattr(self.memory, "knowledge_graph"):
- graph = self.memory.knowledge_graph
- for prereq_id in prerequisites:
- prereq = self._find_by_concept_id(prereq_id)
- prereq_mem_id = prereq.get("memory_id") if prereq else None
- if prereq_mem_id:
- from dhee.core.graph import RelationType
-
- graph.add_relationship(
- source_id=mem_id,
- target_id=prereq_mem_id,
- relation_type=RelationType.REQUIRES,
- metadata={"concept_link": True},
- )
-
- # Create cross-subject edges
- if cross_subject_links and hasattr(self.memory, "knowledge_graph"):
- graph = self.memory.knowledge_graph
- for linked_id in cross_subject_links:
- linked = self._find_by_concept_id(linked_id)
- linked_mem_id = linked.get("memory_id") if linked else None
- if linked_mem_id:
- from dhee.core.graph import RelationType
-
- graph.add_relationship(
- source_id=mem_id,
- target_id=linked_mem_id,
- relation_type=RelationType.RELATED_TO,
- metadata={"cross_subject": True},
- )
-
- return {
- "memory_id": mem_id,
- "concept_id": concept_id,
- "name": name,
- "subject": subject,
- }
-
- # ------------------------------------------------------------------
- # Read
- # ------------------------------------------------------------------
-
- def get_concept(self, concept_id: str) -> Optional[Dict[str, Any]]:
- """Retrieve a concept by its concept_id."""
- return self._find_by_concept_id(concept_id)
-
- def search_concepts(
- self,
- query: str,
- subject: str | None = None,
- limit: int = 10,
- ) -> List[Dict[str, Any]]:
- """Semantic search over concepts, optionally filtered by subject."""
- results = self.memory.search(
- query=query,
- user_id=self._namespace,
- limit=limit * 2,
- )
-
- concepts = []
- for mem in results:
- md = self._parse_metadata(mem)
- if md.get("memory_type") != "concept":
- continue
- if subject and md.get("subject") != subject:
- continue
- concepts.append(self._format_concept(mem))
- if len(concepts) >= limit:
- break
-
- return concepts
-
- def get_prerequisites(
- self,
- concept_id: str,
- depth: int = 2,
- ) -> List[Dict[str, Any]]:
- """Traverse prerequisite graph to the given depth."""
- root = self._find_by_concept_id(concept_id)
- if not root:
- return []
-
- root_mem_id = root.get("memory_id")
- if not root_mem_id:
- return []
-
- # Try graph traversal first
- if hasattr(self.memory, "knowledge_graph"):
- from dhee.core.graph import RelationType
-
- graph = self.memory.knowledge_graph
- related = graph.get_related_memories(
- root_mem_id,
- relation_types=[RelationType.REQUIRES],
- max_depth=depth,
- )
- prereqs = []
- for rel in related:
- target_id = rel.get("target_id") or rel.get("memory_id")
- if target_id:
- mem = self.memory.get(target_id)
- if mem:
- prereqs.append(self._format_concept(mem))
- return prereqs
-
- # Fallback: use metadata-based prerequisites
- md = self._parse_metadata(root)
- prereq_ids = md.get("prerequisites", [])
- prereqs = []
- for pid in prereq_ids:
- c = self._find_by_concept_id(pid)
- if c:
- prereqs.append(c)
- return prereqs
-
- def get_cross_subject_links(
- self,
- concept_id: str,
- ) -> List[Dict[str, Any]]:
- """Find concepts linked across subjects."""
- root = self._find_by_concept_id(concept_id)
- if not root:
- return []
-
- root_mem_id = root.get("memory_id")
- if not root_mem_id:
- return []
-
- if hasattr(self.memory, "knowledge_graph"):
- from dhee.core.graph import RelationType
-
- graph = self.memory.knowledge_graph
- related = graph.get_related_memories(
- root_mem_id,
- relation_types=[RelationType.RELATED_TO],
- max_depth=1,
- )
- links = []
- for rel in related:
- target_id = rel.get("target_id") or rel.get("memory_id")
- if target_id:
- mem = self.memory.get(target_id)
- if mem:
- links.append(self._format_concept(mem))
- return links
-
- # Fallback
- md = self._parse_metadata(root)
- linked_ids = md.get("cross_subject_links", [])
- return [c for pid in linked_ids if (c := self._find_by_concept_id(pid))]
-
- def link_concepts_cross_subject(
- self,
- concept_a_id: str,
- concept_b_id: str,
- ) -> bool:
- """Create a cross-subject edge between two concepts."""
- a = self._find_by_concept_id(concept_a_id)
- b = self._find_by_concept_id(concept_b_id)
- if not a or not b:
- return False
-
- a_mem_id = a.get("memory_id")
- b_mem_id = b.get("memory_id")
- if not a_mem_id or not b_mem_id:
- return False
-
- if hasattr(self.memory, "knowledge_graph"):
- from dhee.core.graph import RelationType
-
- self.memory.knowledge_graph.add_relationship(
- source_id=a_mem_id,
- target_id=b_mem_id,
- relation_type=RelationType.RELATED_TO,
- metadata={"cross_subject": True},
- )
- return True
-
- return False
-
- # ------------------------------------------------------------------
- # Helpers
- # ------------------------------------------------------------------
-
- def _find_by_concept_id(self, concept_id: str) -> Optional[Dict[str, Any]]:
- """Find concept memory by concept_id in metadata."""
- if hasattr(self.memory, "db") and hasattr(self.memory.db, "get_all_memories"):
- memories = self.memory.db.get_all_memories(
- user_id=self._namespace,
- memory_type="concept",
- limit=500,
- )
- for mem in memories:
- md = self._parse_metadata(mem)
- if md.get("concept_id") == concept_id:
- return self._format_concept(mem)
- return None
-
- @staticmethod
- def _extract_id(result: Dict[str, Any]) -> Optional[str]:
- if isinstance(result, dict):
- results = result.get("results", [])
- if results and isinstance(results, list):
- first = results[0]
- return first.get("id") or first.get("memory_id")
- return result.get("id") or result.get("memory_id")
- return None
-
- @staticmethod
- def _parse_metadata(mem: Dict[str, Any]) -> Dict[str, Any]:
- md = mem.get("metadata", {})
- if isinstance(md, str):
- try:
- md = json.loads(md)
- except (json.JSONDecodeError, TypeError):
- md = {}
- return md
-
- @classmethod
- def _format_concept(cls, mem: Dict[str, Any]) -> Dict[str, Any]:
- md = cls._parse_metadata(mem)
- return {
- "memory_id": mem.get("id"),
- "concept_id": md.get("concept_id", ""),
- "name": mem.get("content", "").split(":")[0].strip(),
- "subject": md.get("subject", ""),
- "difficulty": md.get("difficulty", 0.5),
- "prerequisites": md.get("prerequisites", []),
- "cross_subject_links": md.get("cross_subject_links", []),
- "keywords": md.get("keywords", []),
- "strength": mem.get("strength", 1.0),
- }
diff --git a/dhee/teaching/config.py b/dhee/teaching/config.py
deleted file mode 100644
index 152c8c6..0000000
--- a/dhee/teaching/config.py
+++ /dev/null
@@ -1,27 +0,0 @@
-"""Configuration for teaching primitives."""
-
-from __future__ import annotations
-
-from pydantic import BaseModel, field_validator
-
-
-class TeachingConfig(BaseModel):
- """Configuration for the teaching memory subsystem."""
-
- enable_teaching: bool = False # Off by default (backward compat)
- concept_namespace: str = "sensai_curriculum"
- mastery_initial_score: float = 0.35
- mastery_increment: float = 0.08
- mastery_decrement_on_misconception: float = -0.15
- weak_concept_threshold: float = 0.45
- mastered_concept_threshold: float = 0.70
-
- @field_validator(
- "mastery_initial_score",
- "mastery_increment",
- "weak_concept_threshold",
- "mastered_concept_threshold",
- )
- @classmethod
- def _clamp_unit_float(cls, v: float) -> float:
- return min(1.0, max(0.0, float(v)))
diff --git a/dhee/teaching/student_model.py b/dhee/teaching/student_model.py
deleted file mode 100644
index 3c5346c..0000000
--- a/dhee/teaching/student_model.py
+++ /dev/null
@@ -1,372 +0,0 @@
-"""StudentModel — per-student profile and concept mastery via Engram.
-
-Profiles are stored as ``memory_type="student_profile"`` and per-concept
-mastery as ``memory_type="concept_mastery"``. FadeMem handles mastery
-decay naturally: accessing a concept mastery memory boosts its strength
-(spaced repetition).
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-from datetime import datetime, timezone
-from typing import Any, Dict, List, Optional
-
-from dhee.teaching.config import TeachingConfig
-
-logger = logging.getLogger(__name__)
-
-
-class StudentModel:
- """Per-student learning profile and concept mastery tracker."""
-
- def __init__(self, memory: "CoreMemory", config: TeachingConfig | None = None): # noqa: F821
- self.memory = memory
- self.config = config or TeachingConfig()
-
- # ------------------------------------------------------------------
- # Profile
- # ------------------------------------------------------------------
-
- def get_or_create_profile(self, student_id: str) -> Dict[str, Any]:
- """Get or create a student learning profile."""
- existing = self._find_profile(student_id)
- if existing:
- return existing
-
- now = datetime.now(timezone.utc).isoformat()
- meta = {
- "memory_type": "student_profile",
- "student_id": student_id,
- "learning_style": "unknown",
- "interests": [],
- "goals": [],
- "effective_analogies": [],
- "created_at": now,
- "updated_at": now,
- }
-
- content = f"Student profile for {student_id}"
- result = self.memory.add(
- content=content,
- user_id=student_id,
- metadata=meta,
- categories=["student/profile"],
- )
-
- mem_id = self._extract_id(result)
- return {
- "memory_id": mem_id,
- "student_id": student_id,
- "learning_style": "unknown",
- "interests": [],
- "goals": [],
- "effective_analogies": [],
- }
-
- def update_profile(
- self,
- student_id: str,
- updates: Dict[str, Any],
- ) -> Dict[str, Any]:
- """Merge updates into the student profile."""
- profile_mem = self._find_profile_raw(student_id)
- if not profile_mem:
- profile = self.get_or_create_profile(student_id)
- profile_mem = self._find_profile_raw(student_id)
- if not profile_mem:
- return profile
-
- md = self._parse_metadata(profile_mem)
- mem_id = profile_mem.get("id")
-
- # Merge updates
- for key, value in updates.items():
- if key in ("interests", "goals", "effective_analogies") and isinstance(value, list):
- existing = md.get(key, [])
- merged = list(dict.fromkeys(existing + value)) # dedup, preserve order
- md[key] = merged
- else:
- md[key] = value
-
- md["updated_at"] = datetime.now(timezone.utc).isoformat()
-
- if hasattr(self.memory, "db") and hasattr(self.memory.db, "update_memory"):
- self.memory.db.update_memory(mem_id, {"metadata": json.dumps(md)})
-
- return self._format_profile({"id": mem_id, "metadata": md})
-
- # ------------------------------------------------------------------
- # Mastery
- # ------------------------------------------------------------------
-
- def get_mastery(self, student_id: str, concept_id: str) -> float:
- """Get current mastery score. Accessing boosts strength (spaced repetition)."""
- mastery_mem = self._find_mastery_raw(student_id, concept_id)
- if not mastery_mem:
- return self.config.mastery_initial_score
-
- md = self._parse_metadata(mastery_mem)
- score = md.get("mastery_score", self.config.mastery_initial_score)
-
- # Access the memory to trigger FadeMem strength boost
- mem_id = mastery_mem.get("id")
- if mem_id and hasattr(self.memory, "db"):
- self.memory.db.increment_access(mem_id)
-
- return score
-
- def update_mastery(
- self,
- student_id: str,
- concept_id: str,
- score_delta: float,
- ) -> float:
- """Update mastery score by delta. Returns new score."""
- mastery_mem = self._find_mastery_raw(student_id, concept_id)
-
- if not mastery_mem:
- # Create new mastery record
- new_score = max(0.0, min(1.0, self.config.mastery_initial_score + score_delta))
- now = datetime.now(timezone.utc).isoformat()
- meta = {
- "memory_type": "concept_mastery",
- "student_id": student_id,
- "concept_id": concept_id,
- "mastery_score": new_score,
- "history": [{"delta": score_delta, "score": new_score, "at": now}],
- "created_at": now,
- "updated_at": now,
- }
- self.memory.add(
- content=f"Mastery: {student_id} / {concept_id} = {new_score:.2f}",
- user_id=student_id,
- metadata=meta,
- categories=["student/mastery"],
- )
- return new_score
-
- md = self._parse_metadata(mastery_mem)
- old_score = md.get("mastery_score", self.config.mastery_initial_score)
- new_score = max(0.0, min(1.0, old_score + score_delta))
-
- now = datetime.now(timezone.utc).isoformat()
- history = md.get("history", [])
- history.append({"delta": score_delta, "score": new_score, "at": now})
- md["mastery_score"] = new_score
- md["history"] = history[-50:] # keep last 50 entries
- md["updated_at"] = now
-
- mem_id = mastery_mem.get("id")
- if mem_id and hasattr(self.memory, "db"):
- self.memory.db.update_memory(mem_id, {"metadata": json.dumps(md)})
- self.memory.db.increment_access(mem_id)
-
- return new_score
-
- def get_weak_concepts(
- self,
- student_id: str,
- threshold: float | None = None,
- ) -> List[Dict[str, Any]]:
- """Get concepts below the weak threshold."""
- threshold = threshold or self.config.weak_concept_threshold
- all_mastery = self._get_all_mastery(student_id)
- return [
- m for m in all_mastery
- if m["mastery_score"] < threshold
- ]
-
- def get_decay_risk_concepts(self, student_id: str) -> List[Dict[str, Any]]:
- """Get concepts whose mastery is decaying below the promotion threshold."""
- all_mastery = self._get_all_mastery(student_id)
- results = []
- for m in all_mastery:
- mem = self._find_mastery_raw(student_id, m["concept_id"])
- if not mem:
- continue
- # Check if Engram strength is decaying
- strength = mem.get("strength", 1.0)
- if strength < 0.5 and m["mastery_score"] >= self.config.weak_concept_threshold:
- results.append({**m, "decay_strength": strength})
- return results
-
- def record_misconception(
- self,
- student_id: str,
- concept_id: str,
- misconception: str,
- correction: str,
- ) -> None:
- """Record a misconception and apply mastery penalty."""
- self.update_mastery(
- student_id, concept_id, self.config.mastery_decrement_on_misconception
- )
-
- meta = {
- "memory_type": "misconception",
- "student_id": student_id,
- "concept_id": concept_id,
- "misconception": misconception,
- "correction": correction,
- "recorded_at": datetime.now(timezone.utc).isoformat(),
- }
- self.memory.add(
- content=f"Misconception ({concept_id}): {misconception} -> Correction: {correction}",
- user_id=student_id,
- metadata=meta,
- categories=["student/misconception"],
- )
-
- def record_effective_analogy(
- self,
- student_id: str,
- concept_id: str,
- analogy: str,
- success: bool = True,
- ) -> None:
- """Record an analogy that worked (or didn't) for a student."""
- meta = {
- "memory_type": "effective_analogy",
- "student_id": student_id,
- "concept_id": concept_id,
- "analogy": analogy,
- "success": success,
- "recorded_at": datetime.now(timezone.utc).isoformat(),
- }
- self.memory.add(
- content=f"Analogy ({concept_id}): {analogy} [{'success' if success else 'failed'}]",
- user_id=student_id,
- metadata=meta,
- categories=["student/analogy"],
- )
-
- # Also update profile with effective analogies
- if success:
- self.update_profile(student_id, {
- "effective_analogies": [{"concept_id": concept_id, "analogy": analogy}],
- })
-
- def get_effective_analogies(
- self,
- student_id: str,
- concept_id: str | None = None,
- ) -> List[Dict[str, Any]]:
- """Retrieve effective analogies for a student, optionally filtered by concept."""
- if not hasattr(self.memory, "db"):
- return []
-
- memories = self.memory.db.get_all_memories(
- user_id=student_id,
- memory_type="effective_analogy",
- limit=100,
- )
-
- results = []
- for mem in memories:
- md = self._parse_metadata(mem)
- if md.get("memory_type") != "effective_analogy":
- continue
- if not md.get("success", True):
- continue
- if concept_id and md.get("concept_id") != concept_id:
- continue
- results.append({
- "concept_id": md.get("concept_id", ""),
- "analogy": md.get("analogy", ""),
- "success": md.get("success", True),
- })
- return results
-
- # ------------------------------------------------------------------
- # Helpers
- # ------------------------------------------------------------------
-
- def _find_profile(self, student_id: str) -> Optional[Dict[str, Any]]:
- raw = self._find_profile_raw(student_id)
- return self._format_profile(raw) if raw else None
-
- def _find_profile_raw(self, student_id: str) -> Optional[Dict[str, Any]]:
- if not hasattr(self.memory, "db"):
- return None
- memories = self.memory.db.get_all_memories(
- user_id=student_id,
- memory_type="student_profile",
- limit=5,
- )
- for mem in memories:
- md = self._parse_metadata(mem)
- if md.get("memory_type") == "student_profile":
- return mem
- return None
-
- def _find_mastery_raw(
- self, student_id: str, concept_id: str
- ) -> Optional[Dict[str, Any]]:
- if not hasattr(self.memory, "db"):
- return None
- memories = self.memory.db.get_all_memories(
- user_id=student_id,
- memory_type="concept_mastery",
- limit=500,
- )
- for mem in memories:
- md = self._parse_metadata(mem)
- if (
- md.get("memory_type") == "concept_mastery"
- and md.get("concept_id") == concept_id
- ):
- return mem
- return None
-
- def _get_all_mastery(self, student_id: str) -> List[Dict[str, Any]]:
- if not hasattr(self.memory, "db"):
- return []
- memories = self.memory.db.get_all_memories(
- user_id=student_id,
- memory_type="concept_mastery",
- limit=500,
- )
- results = []
- for mem in memories:
- md = self._parse_metadata(mem)
- if md.get("memory_type") == "concept_mastery":
- results.append({
- "concept_id": md.get("concept_id", ""),
- "mastery_score": md.get("mastery_score", self.config.mastery_initial_score),
- "memory_id": mem.get("id"),
- })
- return results
-
- @staticmethod
- def _extract_id(result: Dict[str, Any]) -> Optional[str]:
- if isinstance(result, dict):
- results = result.get("results", [])
- if results and isinstance(results, list):
- first = results[0]
- return first.get("id") or first.get("memory_id")
- return result.get("id") or result.get("memory_id")
- return None
-
- @staticmethod
- def _parse_metadata(mem: Dict[str, Any]) -> Dict[str, Any]:
- md = mem.get("metadata", {})
- if isinstance(md, str):
- try:
- md = json.loads(md)
- except (json.JSONDecodeError, TypeError):
- md = {}
- return md
-
- @classmethod
- def _format_profile(cls, mem: Dict[str, Any]) -> Dict[str, Any]:
- md = cls._parse_metadata(mem)
- return {
- "memory_id": mem.get("id"),
- "student_id": md.get("student_id", ""),
- "learning_style": md.get("learning_style", "unknown"),
- "interests": md.get("interests", []),
- "goals": md.get("goals", []),
- "effective_analogies": md.get("effective_analogies", []),
- }
diff --git a/dhee/teaching/teaching_memory.py b/dhee/teaching/teaching_memory.py
deleted file mode 100644
index 33b9e21..0000000
--- a/dhee/teaching/teaching_memory.py
+++ /dev/null
@@ -1,255 +0,0 @@
-"""TeachingMemory — specialized memory types for the teaching domain.
-
-Stores lesson episodes, concept explanations, comprehension checks,
-misconceptions, and effective analogies as Engram memories.
-"""
-
-from __future__ import annotations
-
-import json
-import logging
-from datetime import datetime, timezone
-from typing import Any, Dict, List, Optional
-
-from dhee.teaching.config import TeachingConfig
-
-logger = logging.getLogger(__name__)
-
-
-class TeachingMemory:
- """Domain-specific memory storage for teaching interactions."""
-
- def __init__(self, memory: "CoreMemory", config: TeachingConfig | None = None): # noqa: F821
- self.memory = memory
- self.config = config or TeachingConfig()
-
- # ------------------------------------------------------------------
- # Lesson episodes
- # ------------------------------------------------------------------
-
- def store_lesson_episode(
- self,
- student_id: str,
- concept_id: str,
- lesson_summary: str,
- *,
- segments_completed: int = 0,
- session_id: str = "",
- topic_name: str = "",
- subject: str = "",
- ) -> Dict[str, Any]:
- """Store a completed (or partial) lesson episode."""
- now = datetime.now(timezone.utc).isoformat()
- meta = {
- "memory_type": "lesson_episode",
- "student_id": student_id,
- "concept_id": concept_id,
- "session_id": session_id,
- "segments_completed": segments_completed,
- "topic_name": topic_name,
- "subject": subject,
- "recorded_at": now,
- }
-
- content = f"Lesson ({topic_name or concept_id}): {lesson_summary}"
- result = self.memory.add(
- content=content,
- user_id=student_id,
- metadata=meta,
- categories=[f"lesson/{subject}" if subject else "lesson"],
- )
-
- mem_id = self._extract_id(result)
- return {"memory_id": mem_id, **meta}
-
- # ------------------------------------------------------------------
- # Concept explanations
- # ------------------------------------------------------------------
-
- def store_concept_explanation(
- self,
- student_id: str,
- concept_id: str,
- approach: str,
- explanation_text: str,
- success: bool = True,
- ) -> Dict[str, Any]:
- """Store an explanation attempt for a concept."""
- now = datetime.now(timezone.utc).isoformat()
- meta = {
- "memory_type": "concept_explanation",
- "student_id": student_id,
- "concept_id": concept_id,
- "approach": approach,
- "success": success,
- "recorded_at": now,
- }
-
- content = f"Explanation ({concept_id}, {approach}): {explanation_text[:500]}"
- result = self.memory.add(
- content=content,
- user_id=student_id,
- metadata=meta,
- categories=["teaching/explanation"],
- )
-
- mem_id = self._extract_id(result)
- return {"memory_id": mem_id, **meta}
-
- # ------------------------------------------------------------------
- # Comprehension checks
- # ------------------------------------------------------------------
-
- def store_comprehension_check(
- self,
- student_id: str,
- concept_id: str,
- question: str,
- student_answer: str,
- evaluation: Dict[str, Any],
- ) -> Dict[str, Any]:
- """Store a comprehension check result."""
- now = datetime.now(timezone.utc).isoformat()
- level = evaluation.get("level", "unknown")
- meta = {
- "memory_type": "comprehension_check",
- "student_id": student_id,
- "concept_id": concept_id,
- "question": question,
- "student_answer": student_answer,
- "level": level,
- "confidence": evaluation.get("confidence", 0.5),
- "misconception": evaluation.get("misconception"),
- "recorded_at": now,
- }
-
- content = (
- f"Check ({concept_id}): Q: {question[:200]} | "
- f"A: {student_answer[:200]} | Level: {level}"
- )
- result = self.memory.add(
- content=content,
- user_id=student_id,
- metadata=meta,
- categories=["teaching/check"],
- )
-
- mem_id = self._extract_id(result)
- return {"memory_id": mem_id, **meta}
-
- # ------------------------------------------------------------------
- # Search
- # ------------------------------------------------------------------
-
- def search_past_explanations(
- self,
- student_id: str,
- concept_id: str,
- limit: int = 5,
- ) -> List[Dict[str, Any]]:
- """Find past explanations for a concept with a specific student."""
- results = self.memory.search(
- query=f"explanation {concept_id}",
- user_id=student_id,
- limit=limit * 2,
- )
-
- explanations = []
- for mem in results:
- md = self._parse_metadata(mem)
- if (
- md.get("memory_type") == "concept_explanation"
- and md.get("concept_id") == concept_id
- ):
- explanations.append({
- "memory_id": mem.get("id"),
- "approach": md.get("approach", ""),
- "success": md.get("success", True),
- "content": mem.get("content", ""),
- "strength": mem.get("strength", 1.0),
- })
- if len(explanations) >= limit:
- break
-
- return explanations
-
- def search_student_history(
- self,
- student_id: str,
- query: str,
- limit: int = 10,
- ) -> List[Dict[str, Any]]:
- """Search all teaching memories for a student (grounding context)."""
- results = self.memory.search(
- query=query,
- user_id=student_id,
- limit=limit,
- )
-
- return [
- {
- "memory_id": mem.get("id"),
- "content": mem.get("content", ""),
- "memory_type": self._parse_metadata(mem).get("memory_type", ""),
- "strength": mem.get("strength", 1.0),
- }
- for mem in results
- ]
-
- def get_student_misconceptions(
- self,
- student_id: str,
- concept_id: str | None = None,
- limit: int = 20,
- ) -> List[Dict[str, Any]]:
- """Retrieve recorded misconceptions for a student."""
- if not hasattr(self.memory, "db"):
- return []
-
- memories = self.memory.db.get_all_memories(
- user_id=student_id,
- memory_type="misconception",
- limit=limit * 2,
- )
-
- results = []
- for mem in memories:
- md = self._parse_metadata(mem)
- if md.get("memory_type") != "misconception":
- continue
- if concept_id and md.get("concept_id") != concept_id:
- continue
- results.append({
- "concept_id": md.get("concept_id", ""),
- "misconception": md.get("misconception", ""),
- "correction": md.get("correction", ""),
- "recorded_at": md.get("recorded_at", ""),
- })
- if len(results) >= limit:
- break
-
- return results
-
- # ------------------------------------------------------------------
- # Helpers
- # ------------------------------------------------------------------
-
- @staticmethod
- def _extract_id(result: Dict[str, Any]) -> Optional[str]:
- if isinstance(result, dict):
- results = result.get("results", [])
- if results and isinstance(results, list):
- first = results[0]
- return first.get("id") or first.get("memory_id")
- return result.get("id") or result.get("memory_id")
- return None
-
- @staticmethod
- def _parse_metadata(mem: Dict[str, Any]) -> Dict[str, Any]:
- md = mem.get("metadata", {})
- if isinstance(md, str):
- try:
- md = json.loads(md)
- except (json.JSONDecodeError, TypeError):
- md = {}
- return md
diff --git a/dhee/utils/factory.py b/dhee/utils/factory.py
index 24102db..89fa196 100644
--- a/dhee/utils/factory.py
+++ b/dhee/utils/factory.py
@@ -7,6 +7,22 @@
logger = logging.getLogger(__name__)
+def _normalize_sqlite_vec_config(config: Dict[str, Any]) -> Dict[str, Any]:
+ """Coerce directory-style vector paths into sqlite-vec DB files."""
+ normalized = dict(config or {})
+ path = normalized.get("path")
+ if not path:
+ return normalized
+
+ path = str(path)
+ root, ext = os.path.splitext(path)
+ if ext:
+ return normalized
+
+ normalized["path"] = os.path.join(path, "sqlite_vec.db")
+ return normalized
+
+
def _dhee_model_available() -> bool:
"""Check if local DheeModel GGUF is available."""
try:
@@ -152,7 +168,7 @@ def create(cls, provider: str, config: Dict[str, Any]):
if provider == "sqlite_vec":
from dhee.vector_stores.sqlite_vec import SqliteVecStore
- return SqliteVecStore(config)
+ return SqliteVecStore(_normalize_sqlite_vec_config(config))
if provider == "zvec":
try:
from dhee.vector_stores.zvec_store import ZvecStore
@@ -161,7 +177,9 @@ def create(cls, provider: str, config: Dict[str, Any]):
logger.warning("zvec not installed, falling back to sqlite_vec")
try:
from dhee.vector_stores.sqlite_vec import SqliteVecStore
- return SqliteVecStore(config)
+ return SqliteVecStore(
+ _normalize_sqlite_vec_config(config)
+ )
except ImportError:
logger.warning("sqlite_vec not installed, falling back to in-memory")
from dhee.vector_stores.memory import InMemoryVectorStore
diff --git a/pyproject.toml b/pyproject.toml
index 09b62b6..0c301ea 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "dhee"
-version = "3.0.1"
+version = "3.1.0"
description = "Cognition layer for AI agents — persistent memory, performance tracking, and insight synthesis"
readme = "README.md"
requires-python = ">=3.9"
@@ -44,6 +44,7 @@ local = ["llama-cpp-python>=0.3", "sentence-transformers>=3.0"]
mcp = ["mcp>=1.0.0"]
api = ["fastapi>=0.100.0", "uvicorn>=0.20.0"]
bus = ["engram-bus>=0.1.0"]
+benchmarks = ["huggingface_hub>=0.24.0"]
# Edge/hardware deployment (offline, ONNX embedder)
edge = ["onnxruntime>=1.16"]
# Training (QLoRA fine-tuning)
@@ -56,6 +57,7 @@ all = [
"fastapi>=0.100.0",
"uvicorn>=0.20.0",
"engram-bus>=0.1.0",
+ "huggingface_hub>=0.24.0",
"llama-cpp-python>=0.3",
"sentence-transformers>=3.0",
]
@@ -63,6 +65,7 @@ dev = [
"pytest>=7.0.0",
"pytest-asyncio>=0.23.0",
"openai>=1.0.0",
+ "huggingface_hub>=0.24.0",
"build>=1.0.0",
"twine>=5.0.0",
]
diff --git a/tests/test_auto_lifecycle.py b/tests/test_auto_lifecycle.py
index 7411e6f..14f1be2 100644
--- a/tests/test_auto_lifecycle.py
+++ b/tests/test_auto_lifecycle.py
@@ -8,7 +8,7 @@
import pytest
from dhee.memory.main import FullMemory
-from dhee.adapters.base import DheePlugin
+from dhee.plugin import DheePlugin
from dhee.simple import Dhee, Engram
diff --git a/tests/test_cognition_v3.py b/tests/test_cognition_v3.py
index 1173400..bbca8bf 100644
--- a/tests/test_cognition_v3.py
+++ b/tests/test_cognition_v3.py
@@ -277,47 +277,7 @@ def test_rollback_on_poor_performance(self, tmpdir):
# ═══════════════════════════════════════════════════════════════════════════
-# 5. Progressive Training — data flow
-# ═══════════════════════════════════════════════════════════════════════════
-
-class TestProgressiveTraining:
- def test_training_cycle_with_data(self, tmpdir):
- from dhee.mini.progressive_trainer import ProgressiveTrainer
-
- trainer = ProgressiveTrainer(data_dir=os.path.join(tmpdir, "training"))
-
- # Generate enough SFT data
- sft_data = [
- {"input": f"[MEMORY_OP] Query {i}", "output": "store", "type": "memory_op"}
- for i in range(25)
- ]
- dpo_data = [
- {"prompt": f"Task {i}", "chosen": "good approach", "rejected": "bad approach"}
- for i in range(15)
- ]
-
- result = trainer.run_cycle(
- sft_data=sft_data,
- dpo_data=dpo_data,
- samskara_data={},
- )
- assert result.cycle_id
- stage_names = [s.stage for s in result.stages]
- assert "sft" in stage_names
- assert "dpo" in stage_names
-
- def test_skips_with_insufficient_data(self, tmpdir):
- from dhee.mini.progressive_trainer import ProgressiveTrainer
-
- trainer = ProgressiveTrainer(data_dir=os.path.join(tmpdir, "training"))
-
- result = trainer.run_cycle(sft_data=[], dpo_data=[], samskara_data={})
- for stage_result in result.stages:
- assert stage_result.status in ("skipped", "completed")
-
-
-# ═══════════════════════════════════════════════════════════════════════════
-# 6. Episode — lifecycle + selective forgetting
+# 5. Episode — lifecycle + selective forgetting
# ═══════════════════════════════════════════════════════════════════════════
class TestEpisode:
diff --git a/tests/test_hippocamp_benchmark.py b/tests/test_hippocamp_benchmark.py
new file mode 100644
index 0000000..3320f61
--- /dev/null
+++ b/tests/test_hippocamp_benchmark.py
@@ -0,0 +1,202 @@
+from __future__ import annotations
+
+import sqlite3
+import zipfile
+import json
+from pathlib import Path
+
+from dhee.benchmarks.hippocamp import (
+ CONFIG_SPECS,
+ _emit_progress,
+ _make_gold_repo_path,
+ _preview_text,
+ _relative_path_from_environment,
+ exact_match,
+ file_retrieval_metrics,
+ gold_document_to_items,
+ token_f1,
+)
+from dhee.benchmarks.raw_extractors import raw_file_to_items
+
+
+def test_relative_path_and_gold_path_mapping() -> None:
+ spec = CONFIG_SPECS["adam_subset"]
+ repo_path = "Adam/Subset/Adam_Subset/contractnli/Tazza-CAFFE-Confidentiality-Agreement.pdf"
+
+ relative = _relative_path_from_environment(spec, repo_path)
+ assert relative == "contractnli/Tazza-CAFFE-Confidentiality-Agreement.pdf"
+
+ gold_path = _make_gold_repo_path(spec.profile, relative)
+ assert gold_path == "HippoCamp_Gold/Adam/contractnli/Tazza-CAFFE-Confidentiality-Agreement.json"
+
+
+def test_gold_document_chunking_keeps_metadata() -> None:
+ document = {
+ "file_info": {
+ "file_name": "Diary Entry.pdf",
+ "file_type": "pdf",
+ "file_modality": "document",
+ "creation_date": "2025-10-20 09:00:00",
+ "location": "Home",
+ },
+ "summary": "This file summarizes a weekly routine.",
+ "segments": [
+ {"page": 1, "content": "First page content " * 40},
+ {"page": 2, "content": "Second page content " * 40},
+ ],
+ }
+
+ items = gold_document_to_items(
+ doc=document,
+ profile="Adam",
+ config_name="adam_subset",
+ relative_path="docs/Diary Entry.pdf",
+ chunk_chars=500,
+ )
+
+ assert len(items) >= 2
+ assert items[0]["metadata"]["file_path"] == "docs/Diary Entry.pdf"
+ assert items[0]["metadata"]["profile"] == "Adam"
+ assert "Summary: This file summarizes a weekly routine." in items[0]["content"]
+ assert "Relative Path: docs/Diary Entry.pdf" in items[0]["content"]
+
+
+def test_file_retrieval_metrics_deduplicate_predictions() -> None:
+ metrics = file_retrieval_metrics(
+ predicted_paths=["a.txt", "a.txt", "b.txt"],
+ gold_paths=["a.txt", "c.txt"],
+ )
+ assert round(metrics["precision"], 4) == 0.5
+ assert round(metrics["recall"], 4) == 0.5
+ assert round(metrics["f1"], 4) == 0.5
+
+
+def test_answer_matching_helpers() -> None:
+ assert exact_match("Cursor", "cursor")
+ assert not exact_match("Cursor editor", "VS Code")
+ assert token_f1("Cursor editor", "Cursor") > 0.0
+ assert token_f1("", "") == 1.0
+
+
+def test_progress_jsonl_event_emission(tmp_path: Path) -> None:
+ path = tmp_path / "adam_subset.progress.jsonl"
+ _emit_progress(
+ path,
+ "judge_done",
+ config="adam_subset",
+ question_index=3,
+ judge_correct=True,
+ judge_score_0_to_5=4.0,
+ judge_rationale_preview=_preview_text("This answer matches the gold evidence and resolves the question cleanly."),
+ )
+
+ lines = path.read_text(encoding="utf-8").splitlines()
+ assert len(lines) == 1
+ payload = json.loads(lines[0])
+ assert payload["event"] == "judge_done"
+ assert payload["config"] == "adam_subset"
+ assert payload["question_index"] == 3
+ assert payload["judge_correct"] is True
+ assert payload["judge_score_0_to_5"] == 4.0
+ assert "gold evidence" in payload["judge_rationale_preview"]
+
+
+def test_raw_text_file_extraction(tmp_path: Path) -> None:
+ path = tmp_path / "note.txt"
+ path.write_text("alpha\nbeta\ngamma", encoding="utf-8")
+
+ result = raw_file_to_items(
+ local_path=path,
+ relative_path="docs/note.txt",
+ profile="Adam",
+ config_name="adam_subset",
+ chunk_chars=1000,
+ )
+
+ assert result.mode == "text"
+ assert len(result.items) == 1
+ assert "alpha" in result.items[0]["content"]
+ assert result.items[0]["metadata"]["exposure_mode"] == "raw_files_only"
+
+
+def test_raw_sqlite_extraction(tmp_path: Path) -> None:
+ db_path = tmp_path / "sample.sqlite"
+ conn = sqlite3.connect(db_path)
+ try:
+ conn.execute("CREATE TABLE visits (url TEXT, title TEXT)")
+ conn.execute("INSERT INTO visits VALUES (?, ?)", ("https://example.com", "Example"))
+ conn.commit()
+ finally:
+ conn.close()
+
+ result = raw_file_to_items(
+ local_path=db_path,
+ relative_path="Browser_History.sqlite",
+ profile="Adam",
+ config_name="adam_subset",
+ chunk_chars=2000,
+ )
+
+ assert result.mode == "text"
+ assert "table=visits" in result.items[0]["content"]
+ assert "https://example.com" in result.items[0]["content"]
+
+
+def test_raw_xlsx_extraction_from_minimal_ooxml(tmp_path: Path) -> None:
+ xlsx_path = tmp_path / "plan.xlsx"
+ with zipfile.ZipFile(xlsx_path, "w") as archive:
+ archive.writestr(
+ "xl/workbook.xml",
+ """
+
+
+
+
+ """,
+ )
+ archive.writestr(
+ "xl/_rels/workbook.xml.rels",
+ """
+
+
+ """,
+ )
+ archive.writestr(
+ "xl/sharedStrings.xml",
+ """
+
+ Run
+ Tempo
+ """,
+ )
+ archive.writestr(
+ "xl/worksheets/sheet1.xml",
+ """
+
+
+
+ 0
+ 1
+
+
+ 5
+ 42
+
+
+ """,
+ )
+
+ result = raw_file_to_items(
+ local_path=xlsx_path,
+ relative_path="plan.xlsx",
+ profile="Adam",
+ config_name="adam_subset",
+ chunk_chars=2000,
+ )
+
+ assert result.mode == "text"
+ content = result.items[0]["content"]
+ assert "sheet=Schedule" in content
+ assert "Run\tTempo" in content
+ assert "5\t42" in content
diff --git a/tests/test_orchestration_core.py b/tests/test_orchestration_core.py
index b1afbce..bd28f5f 100644
--- a/tests/test_orchestration_core.py
+++ b/tests/test_orchestration_core.py
@@ -79,201 +79,3 @@ def memory_instance(tmp_path):
mem.close()
-def test_search_orchestrated_skips_map_reduce_when_coverage_sufficient(memory_instance, monkeypatch) -> None:
- mem = memory_instance
-
- def fake_search(**kwargs: Any) -> Dict[str, Any]:
- return {"results": [{"id": "m1", "memory": "foo", "composite_score": 0.9}]}
-
- def fake_search_episodes(**kwargs: Any) -> Dict[str, Any]:
- return {
- "results": [{"memory_id": "m1", "value_text": "foo", "event_type": "utterance"}],
- "coverage": {"sufficient": True, "coverage_ratio": 1.0, "event_hit_count": 1, "unique_canonical_keys": 1},
- }
-
- monkeypatch.setattr(mem, "search", fake_search)
- monkeypatch.setattr(mem, "search_episodes", fake_search_episodes)
-
- def fail_extract_atomic_facts(**kwargs: Any) -> List[Dict[str, Any]]:
- raise AssertionError("map stage should be skipped when episodic coverage is sufficient")
-
- monkeypatch.setattr("dhee.memory.main.extract_atomic_facts", fail_extract_atomic_facts)
-
- payload = mem.search_orchestrated(
- query="How many projects?",
- user_id="u1",
- question_type="multi-session",
- orchestration_mode="hybrid",
- orchestrator_llm=object(),
- rerank=False,
- )
-
- assert payload["orchestration"]["map_reduce_used"] is False
- assert payload["orchestration"]["reflection_hops"] == 0
-
-
-def test_search_orchestrated_reflection_hard_caps_to_one_hop(memory_instance, monkeypatch) -> None:
- mem = memory_instance
- search_calls: List[int] = []
- reduce_calls: List[int] = []
-
- def fake_search(**kwargs: Any) -> Dict[str, Any]:
- search_calls.append(1)
- if len(search_calls) == 1:
- return {"results": [{"id": "m1", "memory": "first", "composite_score": 0.8}]}
- return {"results": [{"id": "m1", "memory": "first", "composite_score": 0.8}, {"id": "m2", "memory": "second", "composite_score": 0.7}]}
-
- def fake_search_episodes(**kwargs: Any) -> Dict[str, Any]:
- return {
- "results": [{"memory_id": "m1", "value_text": "first", "event_type": "utterance"}],
- "coverage": {"sufficient": False, "coverage_ratio": 0.2, "event_hit_count": 1, "unique_canonical_keys": 1},
- }
-
- def fake_extract_atomic_facts(**kwargs: Any) -> List[Dict[str, Any]]:
- return [{"value": "4", "relevant": True, "canonical_key": "k1"}]
-
- def fake_reduce_atomic_facts(**kwargs: Any):
- reduce_calls.append(1)
- if len(reduce_calls) == 1:
- return "I don't know", {}
- return "4", {}
-
- monkeypatch.setattr(mem, "search", fake_search)
- monkeypatch.setattr(mem, "search_episodes", fake_search_episodes)
- monkeypatch.setattr("dhee.memory.main.extract_atomic_facts", fake_extract_atomic_facts)
- monkeypatch.setattr("dhee.memory.main.reduce_atomic_facts", fake_reduce_atomic_facts)
-
- payload = mem.search_orchestrated(
- query="How many projects?",
- user_id="u1",
- question_type="multi-session",
- orchestration_mode="hybrid",
- orchestrator_llm=object(),
- reflection_max_hops=1,
- base_search_limit=4,
- search_cap=20,
- rerank=False,
- )
-
- assert payload["orchestration"]["map_reduce_used"] is True
- assert payload["orchestration"]["reflection_hops"] == 1
- assert payload["reduced_answer"] == "4"
- # Initial retrieval + exactly one reflection retrieval.
- assert len(search_calls) == 2
-
-
-def test_search_orchestrated_inconsistency_can_trigger_map_reduce(memory_instance, monkeypatch) -> None:
- mem = memory_instance
-
- def fake_search(**kwargs: Any) -> Dict[str, Any]:
- return {
- "results": [
- {"id": "m1", "memory": "I led 2 projects.", "evidence_text": "I led 2 projects.", "composite_score": 0.9},
- {"id": "m2", "memory": "I led 5 projects.", "evidence_text": "I led 5 projects.", "composite_score": 0.8},
- ]
- }
-
- def fake_search_episodes(**kwargs: Any) -> Dict[str, Any]:
- return {
- "results": [
- {"memory_id": "m1", "value_text": "I led 2 projects", "event_type": "utterance"},
- {"memory_id": "m2", "value_text": "I led 5 projects", "event_type": "utterance"},
- ],
- "coverage": {
- "sufficient": True,
- "coverage_ratio": 1.0,
- "intent_coverage": 1.0,
- "event_hit_count": 2,
- "unique_canonical_keys": 2,
- },
- }
-
- monkeypatch.setattr(mem, "search", fake_search)
- monkeypatch.setattr(mem, "search_episodes", fake_search_episodes)
- monkeypatch.setattr("dhee.memory.main.extract_atomic_facts", lambda **kwargs: [{"value": "5", "relevant": True}])
- monkeypatch.setattr("dhee.memory.main.reduce_atomic_facts", lambda **kwargs: ("5", {}))
-
- payload = mem.search_orchestrated(
- query="How many projects have I led?",
- user_id="u1",
- question_type="multi-session",
- orchestration_mode="hybrid",
- orchestrator_llm=object(),
- rerank=False,
- )
-
- assert payload["orchestration"]["map_reduce_used"] is True
- assert "count_numeric_conflict" in payload["orchestration"]["reason_codes"]
-
-
-def test_search_orchestrated_respects_query_llm_budget(memory_instance, monkeypatch) -> None:
- mem = memory_instance
- mem.config.orchestration.max_query_llm_calls = 0
-
- def fake_search(**kwargs: Any) -> Dict[str, Any]:
- return {"results": [{"id": "m1", "memory": "I led 4 projects.", "composite_score": 0.9}]}
-
- def fake_search_episodes(**kwargs: Any) -> Dict[str, Any]:
- return {
- "results": [{"memory_id": "m1", "value_text": "I led 4 projects", "event_type": "utterance"}],
- "coverage": {"sufficient": False, "coverage_ratio": 0.2, "intent_coverage": 0.2, "event_hit_count": 1, "unique_canonical_keys": 1},
- }
-
- monkeypatch.setattr(mem, "search", fake_search)
- monkeypatch.setattr(mem, "search_episodes", fake_search_episodes)
-
- payload = mem.search_orchestrated(
- query="How many projects have I led?",
- user_id="u1",
- question_type="multi-session",
- orchestration_mode="hybrid",
- orchestrator_llm=object(),
- rerank=False,
- )
-
- assert payload["orchestration"]["map_reduce_used"] is False
- assert "query_llm_budget_exhausted" in payload["orchestration"]["reason_codes"]
-
-
-def test_search_orchestrated_reducer_cache_hit(memory_instance, monkeypatch) -> None:
- mem = memory_instance
- extract_calls = {"count": 0}
-
- def fake_search(**kwargs: Any) -> Dict[str, Any]:
- return {"results": [{"id": "m1", "memory": "I led 4 projects.", "evidence_text": "I led 4 projects.", "composite_score": 0.9}]}
-
- def fake_search_episodes(**kwargs: Any) -> Dict[str, Any]:
- return {
- "results": [{"memory_id": "m1", "value_text": "I led 4 projects", "event_type": "utterance"}],
- "coverage": {"sufficient": False, "coverage_ratio": 0.2, "intent_coverage": 0.2, "event_hit_count": 1, "unique_canonical_keys": 1},
- }
-
- def fake_extract_atomic_facts(**kwargs: Any) -> List[Dict[str, Any]]:
- extract_calls["count"] += 1
- return [{"value": "4", "relevant": True, "canonical_key": "projects"}]
-
- monkeypatch.setattr(mem, "search", fake_search)
- monkeypatch.setattr(mem, "search_episodes", fake_search_episodes)
- monkeypatch.setattr("dhee.memory.main.extract_atomic_facts", fake_extract_atomic_facts)
- monkeypatch.setattr("dhee.memory.main.reduce_atomic_facts", lambda **kwargs: ("4", {}))
-
- first = mem.search_orchestrated(
- query="How many projects have I led?",
- user_id="u1",
- question_type="multi-session",
- orchestration_mode="hybrid",
- orchestrator_llm=object(),
- rerank=False,
- )
- second = mem.search_orchestrated(
- query="How many projects have I led?",
- user_id="u1",
- question_type="multi-session",
- orchestration_mode="hybrid",
- orchestrator_llm=object(),
- rerank=False,
- )
-
- assert first["orchestration"]["cache_hit"] is False
- assert second["orchestration"]["cache_hit"] is True
- assert extract_calls["count"] == 1
diff --git a/tests/test_simple_zero_config.py b/tests/test_simple_zero_config.py
new file mode 100644
index 0000000..39e137d
--- /dev/null
+++ b/tests/test_simple_zero_config.py
@@ -0,0 +1,41 @@
+import os
+
+from dhee.simple import _detect_provider, _get_embedding_dims
+
+
+def test_detect_provider_defaults_to_mock_without_keys(monkeypatch):
+ for key in (
+ "OPENAI_API_KEY",
+ "GOOGLE_API_KEY",
+ "GEMINI_API_KEY",
+ "NVIDIA_API_KEY",
+ "NVIDIA_QWEN_API_KEY",
+ "NVIDIA_EMBEDDING_API_KEY",
+ "NVIDIA_EMBED_API_KEY",
+ "NVIDIA_LLAMA_4_MAV_API_KEY",
+ ):
+ monkeypatch.delenv(key, raising=False)
+
+ assert _detect_provider() == "mock"
+
+
+def test_embedding_dims_cover_mock_and_nvidia():
+ assert _get_embedding_dims("mock") == 384
+ assert _get_embedding_dims("nvidia") == 2048
+
+
+def test_detect_provider_accepts_nvidia_alias_keys(monkeypatch):
+ for key in (
+ "OPENAI_API_KEY",
+ "GOOGLE_API_KEY",
+ "GEMINI_API_KEY",
+ "NVIDIA_API_KEY",
+ "NVIDIA_QWEN_API_KEY",
+ "NVIDIA_EMBEDDING_API_KEY",
+ "NVIDIA_EMBED_API_KEY",
+ "NVIDIA_LLAMA_4_MAV_API_KEY",
+ ):
+ monkeypatch.delenv(key, raising=False)
+
+ monkeypatch.setenv("NVIDIA_EMBED_API_KEY", "test-key")
+ assert _detect_provider() == "nvidia"
diff --git a/tests/test_v3_all_phases.py b/tests/test_v3_all_phases.py
deleted file mode 100644
index 1ac808a..0000000
--- a/tests/test_v3_all_phases.py
+++ /dev/null
@@ -1,820 +0,0 @@
-"""Tests for Dhee v3 Phases 2-10.
-
-Phase 2: Anchor resolver (per-field candidates, re-anchoring)
-Phase 4: Distillation + Promotion pipeline
-Phase 5: Lease manager + Job registry
-Phase 7: RRF Fusion (5-stage pipeline)
-Phase 8: Three-tier invalidation + Conflicts
-Phase 6: Read model + delta overlay
-Phase 9: Observability (v3_health)
-Phase 10: Migration bridge (dual-write, backfill)
-Phase 3: Sparse to_dict on UniversalEngram
-"""
-
-import json
-import os
-import sqlite3
-import threading
-import time
-
-import pytest
-
-from dhee.core.storage import initialize_schema
-from dhee.core.events import RawEventStore, EventStatus
-from dhee.core.derived_store import (
- BeliefStore, PolicyStore, AnchorStore, InsightStore,
- HeuristicStore, DerivedLineageStore, CognitionStore,
-)
-
-
-# ---------------------------------------------------------------------------
-# Fixtures
-# ---------------------------------------------------------------------------
-
-@pytest.fixture
-def db_path(tmp_path):
- return str(tmp_path / "test_v3_phases.db")
-
-
-@pytest.fixture
-def store(db_path):
- s = CognitionStore(db_path=db_path)
- yield s
- s.close()
-
-
-@pytest.fixture
-def conn_lock(db_path):
- """Shared connection + lock for lower-level store tests."""
- conn = sqlite3.connect(db_path, check_same_thread=False)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.row_factory = sqlite3.Row
- initialize_schema(conn)
- lock = threading.RLock()
- yield conn, lock
- conn.close()
-
-
-# =========================================================================
-# Phase 2: Anchor Resolver
-# =========================================================================
-
-class TestAnchorResolver:
-
- def test_submit_and_resolve(self, store):
- from dhee.core.anchor_resolver import AnchorCandidateStore, AnchorResolver
-
- anchor_id = store.anchors.add(user_id="u1")
- cand_store = AnchorCandidateStore(store.events._conn, store.events._lock)
- resolver = AnchorResolver(cand_store, store.anchors)
-
- # Submit competing candidates for 'place'
- cand_store.submit(anchor_id, "place", "Ghazipur", confidence=0.6)
- cand_store.submit(anchor_id, "place", "Bengaluru", confidence=0.9)
-
- result = resolver.resolve(anchor_id)
- assert result["resolved_fields"]["place"] == "Bengaluru"
- assert result["details"]["place"]["confidence"] == 0.9
-
- # Verify anchor was updated
- anchor = store.anchors.get(anchor_id)
- assert anchor["place"] == "Bengaluru"
-
- def test_re_anchor_correction(self, store):
- from dhee.core.anchor_resolver import AnchorCandidateStore, AnchorResolver
-
- anchor_id = store.anchors.add(user_id="u1", place="Ghazipur")
- cand_store = AnchorCandidateStore(store.events._conn, store.events._lock)
- resolver = AnchorResolver(cand_store, store.anchors)
-
- # Initial candidate
- cand_store.submit(anchor_id, "place", "Ghazipur", confidence=0.6)
- resolver.resolve(anchor_id)
- assert store.anchors.get(anchor_id)["place"] == "Ghazipur"
-
- # User corrects
- result = resolver.re_anchor(
- anchor_id, "place", "Delhi", confidence=0.95
- )
- assert result["resolved_fields"]["place"] == "Delhi"
- assert store.anchors.get(anchor_id)["place"] == "Delhi"
-
- def test_extract_and_submit(self, store):
- from dhee.core.anchor_resolver import AnchorCandidateStore, AnchorResolver
-
- anchor_id = store.anchors.add(user_id="u1")
- cand_store = AnchorCandidateStore(store.events._conn, store.events._lock)
- resolver = AnchorResolver(cand_store, store.anchors)
-
- cids = resolver.extract_and_submit(
- anchor_id, "I was coding at the office today"
- )
- assert len(cids) >= 1 # should detect 'coding' activity and 'office' place_type
-
- candidates = cand_store.get_candidates(anchor_id)
- field_names = {c["field_name"] for c in candidates}
- assert "activity" in field_names
-
- def test_invalid_field_rejected(self, store):
- from dhee.core.anchor_resolver import AnchorCandidateStore
-
- cand_store = AnchorCandidateStore(store.events._conn, store.events._lock)
- with pytest.raises(ValueError, match="Invalid anchor field"):
- cand_store.submit("a1", "invalid_field", "value")
-
-
-# =========================================================================
-# Phase 4: Distillation + Promotion
-# =========================================================================
-
-class TestDistillationPromotion:
-
- def test_submit_and_promote_belief(self, store):
- from dhee.core.distillation import (
- DistillationStore, DistillationCandidate, distill_belief_from_events,
- )
- from dhee.core.promotion import PromotionEngine
-
- conn = store.events._conn
- lock = store.events._lock
-
- # Create source events
- e1 = store.events.add(content="Python uses GIL", user_id="u1")
- e2 = store.events.add(content="Python GIL limits threading", user_id="u1")
-
- # Distill a belief candidate
- candidate = distill_belief_from_events(
- [e1.to_dict(), e2.to_dict()],
- user_id="u1", domain="programming",
- )
- assert candidate is not None
-
- # Submit to distillation store
- dist_store = DistillationStore(conn, lock)
- cid = dist_store.submit(candidate)
- assert cid is not None
-
- # Promote
- engine = PromotionEngine(
- distillation=dist_store,
- beliefs=store.beliefs,
- policies=store.policies,
- insights=store.insights,
- heuristics=store.heuristics,
- lineage=store.lineage,
- )
- result = engine.promote_pending(target_type="belief")
- assert result.promoted # at least one promoted
-
- # Verify lineage was created
- promoted_id = result.promoted[0]
- sources = store.lineage.get_sources("belief", promoted_id)
- assert len(sources) >= 1
-
- def test_idempotent_dedup(self, store):
- from dhee.core.distillation import DistillationStore, DistillationCandidate
-
- conn = store.events._conn
- lock = store.events._lock
- dist_store = DistillationStore(conn, lock)
-
- candidate = DistillationCandidate(
- candidate_id="c1",
- source_event_ids=["e1", "e2"],
- target_type="belief",
- canonical_key="test_dedup",
- payload={"user_id": "u1", "claim": "test"},
- )
-
- cid1 = dist_store.submit(candidate)
- assert cid1 == "c1"
-
- # Same idempotency key — should be deduped
- candidate2 = DistillationCandidate(
- candidate_id="c2",
- source_event_ids=["e1", "e2"],
- target_type="belief",
- canonical_key="test_dedup",
- payload={"user_id": "u1", "claim": "test"},
- )
- cid2 = dist_store.submit(candidate2)
- assert cid2 is None # deduped
-
- def test_low_confidence_rejected(self, store):
- from dhee.core.distillation import DistillationStore, DistillationCandidate
- from dhee.core.promotion import PromotionEngine
-
- conn = store.events._conn
- lock = store.events._lock
- dist_store = DistillationStore(conn, lock)
-
- candidate = DistillationCandidate(
- candidate_id="c-low",
- source_event_ids=["e1"],
- target_type="belief",
- canonical_key="low_conf",
- confidence=0.1, # below MIN_PROMOTION_CONFIDENCE (0.3)
- payload={"user_id": "u1", "claim": "uncertain thing"},
- )
- dist_store.submit(candidate)
-
- engine = PromotionEngine(
- distillation=dist_store,
- beliefs=store.beliefs,
- policies=store.policies,
- insights=store.insights,
- heuristics=store.heuristics,
- lineage=store.lineage,
- )
- result = engine.promote_pending()
- assert "c-low" in result.rejected
-
-
-# =========================================================================
-# Phase 5: Lease Manager + Job Registry
-# =========================================================================
-
-class TestLeaseManager:
-
- def test_acquire_release(self, conn_lock):
- from dhee.core.lease_manager import LeaseManager
-
- conn, lock = conn_lock
- lm = LeaseManager(conn, lock)
-
- assert lm.acquire("job-1", "worker-a") is True
- assert lm.is_held("job-1") is True
- assert lm.get_holder("job-1") == "worker-a"
-
- # Different worker can't acquire
- assert lm.acquire("job-1", "worker-b") is False
-
- # Release
- assert lm.release("job-1", "worker-a") is True
- assert lm.is_held("job-1") is False
-
- def test_same_owner_renew(self, conn_lock):
- from dhee.core.lease_manager import LeaseManager
-
- conn, lock = conn_lock
- lm = LeaseManager(conn, lock)
-
- lm.acquire("job-1", "worker-a")
- assert lm.renew("job-1", "worker-a") is True
-
- def test_wrong_owner_cant_release(self, conn_lock):
- from dhee.core.lease_manager import LeaseManager
-
- conn, lock = conn_lock
- lm = LeaseManager(conn, lock)
-
- lm.acquire("job-1", "worker-a")
- assert lm.release("job-1", "worker-b") is False
-
-
-class TestJobRegistry:
-
- def test_register_and_run(self, conn_lock):
- from dhee.core.lease_manager import LeaseManager
- from dhee.core.jobs import JobRegistry, Job
-
- conn, lock = conn_lock
- lm = LeaseManager(conn, lock)
- registry = JobRegistry(conn, lock, lm)
-
- class TestJob(Job):
- name = "test_job"
- def execute(self, payload):
- return {"sum": payload.get("a", 0) + payload.get("b", 0)}
-
- registry.register(TestJob)
- result = registry.run("test_job", payload={"a": 3, "b": 7})
- assert result["status"] == "completed"
- assert result["result"]["sum"] == 10
-
- def test_run_unknown_job(self, conn_lock):
- from dhee.core.lease_manager import LeaseManager
- from dhee.core.jobs import JobRegistry
-
- conn, lock = conn_lock
- lm = LeaseManager(conn, lock)
- registry = JobRegistry(conn, lock, lm)
-
- result = registry.run("nonexistent")
- assert result["status"] == "error"
-
- def test_health_check(self, conn_lock):
- from dhee.core.lease_manager import LeaseManager
- from dhee.core.jobs import JobRegistry, Job
-
- conn, lock = conn_lock
- lm = LeaseManager(conn, lock)
- registry = JobRegistry(conn, lock, lm)
-
- class NopJob(Job):
- name = "nop"
- def execute(self, payload):
- return {}
-
- registry.register(NopJob)
- registry.run("nop")
- health = registry.get_health()
- assert "nop" in health["job_status"]
- assert health["job_status"]["nop"]["last_status"] == "completed"
-
-
-# =========================================================================
-# Phase 7: RRF Fusion
-# =========================================================================
-
-class TestRRFFusion:
-
- def test_basic_fusion(self):
- from dhee.core.fusion_v3 import RRFFusion, FusionCandidate, FusionConfig
-
- raw = [
- FusionCandidate(
- row_id="r1", source_kind="raw", source_type="event",
- source_id="e1", retrieval_text="raw fact", raw_score=0.9,
- ),
- FusionCandidate(
- row_id="r2", source_kind="raw", source_type="event",
- source_id="e2", retrieval_text="raw fact 2", raw_score=0.7,
- ),
- ]
- distilled = [
- FusionCandidate(
- row_id="d1", source_kind="distilled", source_type="belief",
- source_id="b1", retrieval_text="distilled belief", raw_score=0.85,
- confidence=0.9,
- ),
- ]
-
- fusion = RRFFusion(FusionConfig(final_top_n=5))
- results, breakdown = fusion.fuse(raw, distilled, query="test")
-
- assert len(results) >= 1
- assert breakdown.final_count >= 1
- # Distilled should rank high due to higher weight
- top = results[0]
- assert top.source_kind == "distilled"
-
- def test_staleness_penalty(self):
- from dhee.core.fusion_v3 import RRFFusion, FusionCandidate
-
- fresh = FusionCandidate(
- row_id="f1", source_kind="distilled", source_type="belief",
- source_id="b1", retrieval_text="fresh", raw_score=0.8,
- status="active",
- )
- stale = FusionCandidate(
- row_id="s1", source_kind="distilled", source_type="belief",
- source_id="b2", retrieval_text="stale", raw_score=0.85,
- status="stale",
- )
-
- fusion = RRFFusion()
- results, _ = fusion.fuse([], [fresh, stale])
-
- # Fresh should beat stale despite lower raw score
- assert results[0].row_id == "f1"
-
- def test_invalidated_excluded(self):
- from dhee.core.fusion_v3 import RRFFusion, FusionCandidate
-
- valid = FusionCandidate(
- row_id="v1", source_kind="distilled", source_type="belief",
- source_id="b1", retrieval_text="valid", raw_score=0.5,
- )
- invalid = FusionCandidate(
- row_id="i1", source_kind="distilled", source_type="belief",
- source_id="b2", retrieval_text="invalidated", raw_score=0.9,
- status="invalidated",
- )
-
- fusion = RRFFusion()
- results, _ = fusion.fuse([], [valid, invalid])
-
- # Invalidated should have score=0 and sort last
- ids = [r.row_id for r in results if r.adjusted_score > 0]
- assert "i1" not in ids
-
- def test_contradiction_penalty(self):
- from dhee.core.fusion_v3 import RRFFusion, FusionCandidate
-
- clean = FusionCandidate(
- row_id="c1", source_kind="distilled", source_type="belief",
- source_id="b1", retrieval_text="clean", raw_score=0.8,
- )
- conflicted = FusionCandidate(
- row_id="c2", source_kind="distilled", source_type="belief",
- source_id="b2", retrieval_text="conflicted", raw_score=0.85,
- )
-
- def checker(t, i):
- return i == "b2" # b2 has conflicts
-
- fusion = RRFFusion()
- results, _ = fusion.fuse([], [clean, conflicted], conflict_checker=checker)
- assert results[0].row_id == "c1" # clean beats conflicted
-
- def test_breakdown_logged(self):
- from dhee.core.fusion_v3 import RRFFusion, FusionCandidate
-
- raw = [FusionCandidate(
- row_id="r1", source_kind="raw", source_type="event",
- source_id="e1", retrieval_text="t", raw_score=0.5,
- )]
- dist = [FusionCandidate(
- row_id="d1", source_kind="distilled", source_type="belief",
- source_id="b1", retrieval_text="t", raw_score=0.5,
- )]
-
- fusion = RRFFusion()
- _, breakdown = fusion.fuse(raw, dist, query="test query")
-
- d = breakdown.to_dict()
- assert "per_index_counts" in d
- assert d["per_index_counts"]["raw"] == 1
- assert d["per_index_counts"]["distilled"] == 1
-
-
-# =========================================================================
-# Phase 8: Three-Tier Invalidation
-# =========================================================================
-
-class TestInvalidation:
-
- def test_hard_invalidation_sole_source(self, store):
- from dhee.core.invalidation import InvalidationEngine
-
- e = store.events.add(content="false memory", user_id="u1")
- bid = store.beliefs.add(user_id="u1", claim="false claim")
- store.lineage.add("belief", bid, e.event_id, contribution_weight=1.0)
-
- engine = InvalidationEngine(
- lineage=store.lineage,
- stores={"belief": store.beliefs},
- conn=store.events._conn,
- lock=store.events._lock,
- )
-
- # Delete the source → hard invalidation
- store.events.delete(e.event_id)
- result = engine.on_event_deleted(e.event_id)
-
- assert len(result["hard_invalidated"]) == 1
- belief = store.beliefs.get(bid)
- assert belief["status"] == "invalidated"
-
- def test_soft_invalidation_sole_source(self, store):
- from dhee.core.invalidation import InvalidationEngine
-
- e = store.events.add(content="old fact", user_id="u1")
- bid = store.beliefs.add(user_id="u1", claim="old fact", confidence=0.8)
- store.lineage.add("belief", bid, e.event_id, contribution_weight=1.0)
-
- engine = InvalidationEngine(
- lineage=store.lineage,
- stores={"belief": store.beliefs},
- conn=store.events._conn,
- lock=store.events._lock,
- )
-
- # Correct the source → soft invalidation
- store.events.correct(e.event_id, "new fact")
- result = engine.on_event_corrected(e.event_id)
-
- assert len(result["soft_invalidated"]) == 1
- assert len(result["jobs_enqueued"]) >= 1
- belief = store.beliefs.get(bid)
- assert belief["status"] == "stale"
-
- def test_partial_invalidation_minor_source(self, store):
- from dhee.core.invalidation import InvalidationEngine
-
- e1 = store.events.add(content="main fact", user_id="u1")
- e2 = store.events.add(content="supporting detail", user_id="u1")
-
- bid = store.beliefs.add(user_id="u1", claim="combined claim", confidence=0.8)
- store.lineage.add("belief", bid, e1.event_id, contribution_weight=0.8)
- store.lineage.add("belief", bid, e2.event_id, contribution_weight=0.2)
-
- engine = InvalidationEngine(
- lineage=store.lineage,
- stores={"belief": store.beliefs},
- conn=store.events._conn,
- lock=store.events._lock,
- )
-
- # Correct the minor source (weight=0.2 < 0.3 threshold)
- store.events.correct(e2.event_id, "updated detail")
- result = engine.on_event_corrected(e2.event_id)
-
- assert len(result["partial_invalidated"]) == 1
- belief = store.beliefs.get(bid)
- assert belief["status"] == "suspect"
-
- def test_partial_escalates_on_high_weight(self, store):
- from dhee.core.invalidation import InvalidationEngine
-
- e1 = store.events.add(content="main", user_id="u1")
- e2 = store.events.add(content="secondary", user_id="u1")
-
- bid = store.beliefs.add(user_id="u1", claim="test", confidence=0.8)
- store.lineage.add("belief", bid, e1.event_id, contribution_weight=0.6)
- store.lineage.add("belief", bid, e2.event_id, contribution_weight=0.4)
-
- engine = InvalidationEngine(
- lineage=store.lineage,
- stores={"belief": store.beliefs},
- conn=store.events._conn,
- lock=store.events._lock,
- )
-
- # Correct the secondary source (weight=0.4 >= 0.3 threshold) → soft
- store.events.correct(e2.event_id, "updated secondary")
- result = engine.on_event_corrected(e2.event_id)
-
- assert len(result["soft_invalidated"]) == 1 # escalated to soft
- assert len(result["partial_invalidated"]) == 0
-
-
-# =========================================================================
-# Phase 8: Conflicts
-# =========================================================================
-
-class TestConflicts:
-
- def test_create_and_auto_resolve(self, conn_lock):
- from dhee.core.conflicts import ConflictStore
-
- conn, lock = conn_lock
- cs = ConflictStore(conn, lock)
-
- # Clear confidence gap → auto-resolve
- result = cs.create(
- "belief_contradiction",
- "belief", "b1", "belief", "b2",
- side_a_confidence=0.95,
- side_b_confidence=0.1,
- )
- assert result["resolution_status"] == "auto_resolved"
- assert result["auto_resolution"]["winner"] == "side_a"
-
- def test_no_auto_resolve_when_close(self, conn_lock):
- from dhee.core.conflicts import ConflictStore
-
- conn, lock = conn_lock
- cs = ConflictStore(conn, lock)
-
- result = cs.create(
- "belief_contradiction",
- "belief", "b1", "belief", "b2",
- side_a_confidence=0.6,
- side_b_confidence=0.5,
- )
- assert result["resolution_status"] == "open"
-
- def test_manual_resolve(self, conn_lock):
- from dhee.core.conflicts import ConflictStore
-
- conn, lock = conn_lock
- cs = ConflictStore(conn, lock)
-
- result = cs.create(
- "anchor_disagreement",
- "anchor", "a1", "anchor", "a2",
- )
- cid = result["conflict_id"]
- assert cs.resolve(cid, {"winner": "a1", "reason": "user chose"})
-
- conflict = cs.get(cid)
- assert conflict["resolution_status"] == "user_resolved"
-
- def test_has_open_conflicts(self, conn_lock):
- from dhee.core.conflicts import ConflictStore
-
- conn, lock = conn_lock
- cs = ConflictStore(conn, lock)
-
- cs.create("belief_contradiction", "belief", "b1", "belief", "b2")
- assert cs.has_open_conflicts("belief", "b1") is True
- assert cs.has_open_conflicts("belief", "b999") is False
-
- def test_count_open(self, conn_lock):
- from dhee.core.conflicts import ConflictStore
-
- conn, lock = conn_lock
- cs = ConflictStore(conn, lock)
-
- cs.create("belief_contradiction", "belief", "x1", "belief", "x2")
- cs.create("distillation_conflict", "insight", "i1", "insight", "i2")
- assert cs.count_open() == 2
-
-
-# =========================================================================
-# Phase 6: Read Model
-# =========================================================================
-
-class TestReadModel:
-
- def test_refresh_and_query(self, store):
- from dhee.core.read_model import ReadModel
-
- conn = store.events._conn
- lock = store.events._lock
- rm = ReadModel(conn, lock)
-
- # Populate
- store.events.add(content="raw fact 1", user_id="u1")
- store.events.add(content="raw fact 2", user_id="u1")
- store.beliefs.add(user_id="u1", claim="belief 1", confidence=0.8)
-
- counts = rm.refresh(
- "u1",
- events_store=store.events,
- beliefs_store=store.beliefs,
- )
- assert counts["raw_events"] == 2
- assert counts["beliefs"] == 1
-
- results = rm.query("u1")
- assert len(results) == 3
-
- # Filter by kind
- raw_only = rm.query("u1", source_kind="raw")
- assert len(raw_only) == 2
-
- def test_delta_overlay(self, store):
- from dhee.core.read_model import ReadModel
-
- conn = store.events._conn
- lock = store.events._lock
- rm = ReadModel(conn, lock)
-
- # Add event before refresh
- store.events.add(content="before refresh", user_id="u1")
- rm.refresh("u1", events_store=store.events)
-
- # Add event after refresh
- since = rm.last_refresh
- store.events.add(content="after refresh", user_id="u1")
-
- delta = rm.get_delta("u1", since, events_store=store.events)
- assert len(delta) == 1
- assert delta[0]["retrieval_text"] == "after refresh"
-
- def test_invalidated_excluded(self, store):
- from dhee.core.read_model import ReadModel
-
- conn = store.events._conn
- lock = store.events._lock
- rm = ReadModel(conn, lock)
-
- bid = store.beliefs.add(user_id="u1", claim="will invalidate")
- store.beliefs.set_status(bid, "invalidated")
-
- rm.refresh("u1", beliefs_store=store.beliefs)
- results = rm.query("u1")
- # Invalidated beliefs are skipped during refresh
- assert all(r["source_id"] != bid for r in results)
-
-
-# =========================================================================
-# Phase 9: Observability
-# =========================================================================
-
-class TestV3Health:
-
- def test_health_metrics(self, store):
- from dhee.core.v3_health import v3_health
-
- conn = store.events._conn
- lock = store.events._lock
-
- store.events.add(content="fact", user_id="u1")
- bid = store.beliefs.add(user_id="u1", claim="test")
- store.beliefs.set_status(bid, "stale")
-
- health = v3_health(conn, lock, user_id="u1")
-
- assert health["raw_events_active"] == 1
- assert health["derived_invalidation"]["beliefs"]["stale"] == 1
- assert "v3_warnings" in health
-
- def test_health_no_user_filter(self, store):
- from dhee.core.v3_health import v3_health
-
- conn = store.events._conn
- lock = store.events._lock
-
- store.events.add(content="a", user_id="u1")
- store.events.add(content="b", user_id="u2")
-
- health = v3_health(conn, lock)
- assert health["raw_events_active"] == 2
-
-
-# =========================================================================
-# Phase 10: Migration
-# =========================================================================
-
-class TestMigration:
-
- def test_dual_write(self, db_path):
- from dhee.core.v3_migration import V3MigrationBridge
-
- cs = CognitionStore(db_path=db_path)
- bridge = V3MigrationBridge(v3_store=cs)
-
- eid = bridge.on_remember("test fact", "u1", v2_memory_id="v2-123")
- assert eid is not None
-
- event = cs.events.get(eid)
- assert event.content == "test fact"
- assert event.metadata.get("v2_memory_id") == "v2-123"
- cs.close()
-
- def test_backfill(self, db_path):
- from dhee.core.v3_migration import V3MigrationBridge
-
- cs = CognitionStore(db_path=db_path)
- bridge = V3MigrationBridge(v3_store=cs)
-
- v2_memories = [
- {"memory": "fact 1", "id": "m1", "layer": "sml"},
- {"memory": "fact 2", "id": "m2", "layer": "lml"},
- {"memory": "fact 1", "id": "m3"}, # duplicate content
- ]
-
- stats = bridge.backfill_from_v2(v2_memories, user_id="u1")
- assert stats["created"] == 2
- assert stats["skipped_dedup"] == 1
- assert stats["total"] == 3
-
- # Idempotent — running again should skip all
- stats2 = bridge.backfill_from_v2(v2_memories, user_id="u1")
- assert stats2["created"] == 0
- # All 3 are deduped: 2 unique already exist + 1 duplicate content
- assert stats2["skipped_dedup"] == 3
- cs.close()
-
- def test_correction_bridge(self, db_path):
- from dhee.core.v3_migration import V3MigrationBridge
-
- cs = CognitionStore(db_path=db_path)
- bridge = V3MigrationBridge(v3_store=cs)
-
- # First add the original
- bridge.on_remember("I live in Ghazipur", "u1")
-
- # Then correct
- eid = bridge.on_correction("I live in Ghazipur", "I live in Bengaluru", "u1")
- assert eid is not None
-
- event = cs.events.get(eid)
- assert event.content == "I live in Bengaluru"
- assert event.supersedes_event_id is not None
- cs.close()
-
- def test_disabled_bridge(self):
- from dhee.core.v3_migration import V3MigrationBridge
- bridge = V3MigrationBridge(v3_store=None)
- assert bridge.on_remember("test", "u1") is None
- assert bridge.should_use_v3_read() is False
-
-
-# =========================================================================
-# Phase 3: Sparse to_dict
-# =========================================================================
-
-class TestSparseDict:
-
- def test_sparse_omits_empty(self):
- from dhee.core.engram import UniversalEngram
-
- e = UniversalEngram(
- id="test-1",
- raw_content="hello",
- strength=1.0,
- user_id="u1",
- )
- full = e.to_dict()
- sparse = e.to_dict(sparse=True)
-
- # Sparse should be smaller
- assert len(sparse) < len(full)
- # Should keep non-empty values
- assert sparse["id"] == "test-1"
- assert sparse["raw_content"] == "hello"
- # Should omit empty lists, None, empty strings
- assert "echo" not in sparse or sparse.get("echo") != []
-
- def test_full_preserves_all(self):
- from dhee.core.engram import UniversalEngram
-
- e = UniversalEngram(id="test-2", raw_content="x")
- full = e.to_dict()
- assert "echo" in full # even empty values
- assert "entities" in full
diff --git a/tests/test_v3_jobs.py b/tests/test_v3_jobs.py
deleted file mode 100644
index 6eff31a..0000000
--- a/tests/test_v3_jobs.py
+++ /dev/null
@@ -1,706 +0,0 @@
-"""Tests for Dhee v3 Sprint 2: Lease Manager, Job Registry, Distillation, Promotion.
-
-Covers:
-- LeaseManager: acquire, release, renew, expiry steal, cleanup
-- JobRegistry: registration, execution, idempotency, history, health
-- DistillationStore: submit, dedup, status transitions
-- PromotionEngine: validation, type-specific promotion, lineage, batch
-- ConsolidationEngine: feedback loop prevention
-"""
-
-import json
-import sqlite3
-import threading
-import time
-
-import pytest
-
-from dhee.core.storage import initialize_schema
-from dhee.core.lease_manager import LeaseManager
-from dhee.core.jobs import Job, JobRegistry, ApplyForgettingJob
-from dhee.core.distillation import (
- DistillationCandidate,
- DistillationStore,
- compute_idempotency_key,
- distill_belief_from_events,
- DERIVATION_VERSION,
-)
-from dhee.core.promotion import PromotionEngine, PromotionResult
-from dhee.core.derived_store import (
- BeliefStore,
- PolicyStore,
- InsightStore,
- HeuristicStore,
- DerivedLineageStore,
- CognitionStore,
-)
-
-
-# ---------------------------------------------------------------------------
-# Fixtures
-# ---------------------------------------------------------------------------
-
-@pytest.fixture
-def db_conn(tmp_path):
- """Shared connection + lock for all Sprint 2 tests."""
- db_path = str(tmp_path / "test_v3_sprint2.db")
- conn = sqlite3.connect(db_path, check_same_thread=False)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.execute("PRAGMA busy_timeout=5000")
- conn.row_factory = sqlite3.Row
- initialize_schema(conn)
- lock = threading.RLock()
- yield conn, lock
- conn.close()
-
-
-@pytest.fixture
-def lease_manager(db_conn):
- conn, lock = db_conn
- return LeaseManager(conn, lock, default_duration_seconds=10)
-
-
-@pytest.fixture
-def stores(db_conn):
- """All derived stores sharing one connection."""
- conn, lock = db_conn
- return {
- "beliefs": BeliefStore(conn, lock),
- "policies": PolicyStore(conn, lock),
- "insights": InsightStore(conn, lock),
- "heuristics": HeuristicStore(conn, lock),
- "lineage": DerivedLineageStore(conn, lock),
- "distillation": DistillationStore(conn, lock),
- }
-
-
-@pytest.fixture
-def promotion_engine(stores):
- return PromotionEngine(
- distillation=stores["distillation"],
- beliefs=stores["beliefs"],
- policies=stores["policies"],
- insights=stores["insights"],
- heuristics=stores["heuristics"],
- lineage=stores["lineage"],
- min_confidence=0.3,
- )
-
-
-# =========================================================================
-# LeaseManager Tests
-# =========================================================================
-
-class TestLeaseManager:
-
- def test_acquire_new(self, lease_manager):
- assert lease_manager.acquire("job:decay", "worker-1") is True
- assert lease_manager.is_held("job:decay") is True
- assert lease_manager.get_holder("job:decay") == "worker-1"
-
- def test_acquire_same_owner_renews(self, lease_manager):
- assert lease_manager.acquire("job:decay", "worker-1") is True
- assert lease_manager.acquire("job:decay", "worker-1") is True # renew
- assert lease_manager.get_holder("job:decay") == "worker-1"
-
- def test_acquire_different_owner_blocked(self, lease_manager):
- assert lease_manager.acquire("job:decay", "worker-1") is True
- assert lease_manager.acquire("job:decay", "worker-2") is False
- assert lease_manager.get_holder("job:decay") == "worker-1"
-
- def test_release(self, lease_manager):
- lease_manager.acquire("job:decay", "worker-1")
- assert lease_manager.release("job:decay", "worker-1") is True
- assert lease_manager.is_held("job:decay") is False
-
- def test_release_wrong_owner(self, lease_manager):
- lease_manager.acquire("job:decay", "worker-1")
- assert lease_manager.release("job:decay", "worker-2") is False
- assert lease_manager.is_held("job:decay") is True
-
- def test_release_nonexistent(self, lease_manager):
- assert lease_manager.release("nonexistent", "w1") is False
-
- def test_renew(self, lease_manager):
- lease_manager.acquire("job:decay", "worker-1")
- assert lease_manager.renew("job:decay", "worker-1") is True
-
- def test_renew_wrong_owner(self, lease_manager):
- lease_manager.acquire("job:decay", "worker-1")
- assert lease_manager.renew("job:decay", "worker-2") is False
-
- def test_acquire_expired_lease(self, lease_manager):
- """Expired lease can be stolen by another worker."""
- # Acquire with very short duration
- assert lease_manager.acquire("job:decay", "worker-1", duration_seconds=1) is True
- # Wait for expiry
- time.sleep(1.1)
- # Another worker can steal it
- assert lease_manager.acquire("job:decay", "worker-2") is True
- assert lease_manager.get_holder("job:decay") == "worker-2"
-
- def test_is_held_expired(self, lease_manager):
- lease_manager.acquire("job:x", "w1", duration_seconds=1)
- time.sleep(1.1)
- assert lease_manager.is_held("job:x") is False
- assert lease_manager.get_holder("job:x") is None
-
- def test_cleanup_expired(self, lease_manager):
- lease_manager.acquire("job:a", "w1", duration_seconds=1)
- lease_manager.acquire("job:b", "w1", duration_seconds=1)
- lease_manager.acquire("job:c", "w1", duration_seconds=300)
- time.sleep(1.1)
- cleaned = lease_manager.cleanup_expired()
- assert cleaned == 2 # a and b expired, c still held
-
- def test_multiple_locks_independent(self, lease_manager):
- lease_manager.acquire("job:a", "w1")
- lease_manager.acquire("job:b", "w2")
- assert lease_manager.get_holder("job:a") == "w1"
- assert lease_manager.get_holder("job:b") == "w2"
-
-
-# =========================================================================
-# JobRegistry Tests
-# =========================================================================
-
-class _TestJob(Job):
- name = "test_job"
-
- def execute(self, payload):
- return {"echo": payload.get("value", "none")}
-
-
-class _FailingJob(Job):
- name = "failing_job"
-
- def execute(self, payload):
- raise RuntimeError("intentional failure")
-
-
-class _IdempotentJob(Job):
- name = "idempotent_job"
-
- def execute(self, payload):
- return {"processed": True}
-
- def make_idempotency_key(self, payload):
- return f"idem:{payload.get('batch_id', '')}"
-
-
-class TestJobRegistry:
-
- def test_register_and_list(self, db_conn, lease_manager):
- conn, lock = db_conn
- registry = JobRegistry(conn, lock, lease_manager)
- registry.register(_TestJob)
- assert "test_job" in registry.list_registered()
-
- def test_run_success(self, db_conn, lease_manager):
- conn, lock = db_conn
- registry = JobRegistry(conn, lock, lease_manager)
- registry.register(_TestJob)
-
- result = registry.run("test_job", payload={"value": "hello"})
- assert result["status"] == "completed"
- assert result["result"]["echo"] == "hello"
- assert result["job_id"]
-
- def test_run_failure(self, db_conn, lease_manager):
- conn, lock = db_conn
- registry = JobRegistry(conn, lock, lease_manager)
- registry.register(_FailingJob)
-
- result = registry.run("failing_job")
- assert result["status"] == "failed"
- assert "intentional failure" in result["error"]
-
- def test_run_unknown_job(self, db_conn, lease_manager):
- conn, lock = db_conn
- registry = JobRegistry(conn, lock, lease_manager)
-
- result = registry.run("nonexistent")
- assert result["status"] == "error"
-
- def test_idempotency(self, db_conn, lease_manager):
- conn, lock = db_conn
- registry = JobRegistry(conn, lock, lease_manager)
- registry.register(_IdempotentJob)
-
- r1 = registry.run("idempotent_job", payload={"batch_id": "b1"})
- assert r1["status"] == "completed"
-
- r2 = registry.run("idempotent_job", payload={"batch_id": "b1"})
- assert r2["status"] == "skipped_idempotent"
-
- # Different batch_id should run
- r3 = registry.run("idempotent_job", payload={"batch_id": "b2"})
- assert r3["status"] == "completed"
-
- def test_job_history(self, db_conn, lease_manager):
- conn, lock = db_conn
- registry = JobRegistry(conn, lock, lease_manager)
- registry.register(_TestJob)
-
- registry.run("test_job", payload={"n": 1})
- registry.run("test_job", payload={"n": 2})
-
- history = registry.get_job_history("test_job", limit=5)
- assert len(history) == 2
- assert history[0]["status"] == "completed"
-
- def test_health(self, db_conn, lease_manager):
- conn, lock = db_conn
- registry = JobRegistry(conn, lock, lease_manager)
- registry.register(_TestJob)
-
- health = registry.get_health()
- assert health["total_registered"] == 1
- assert "test_job" in health["job_status"]
- assert health["job_status"]["test_job"]["last_status"] == "never_run"
-
- registry.run("test_job")
- health = registry.get_health()
- assert health["job_status"]["test_job"]["last_status"] == "completed"
-
- def test_lease_prevents_concurrent(self, db_conn, lease_manager):
- """Two sequential runs of same job: first completes, second runs too
- (lease released). But if lease held, second is blocked."""
- conn, lock = db_conn
- registry = JobRegistry(conn, lock, lease_manager)
- registry.register(_TestJob)
-
- # First run — should succeed
- r1 = registry.run("test_job", owner_id="w1")
- assert r1["status"] == "completed"
-
- # Second run — lease was released, should succeed
- r2 = registry.run("test_job", owner_id="w2")
- assert r2["status"] == "completed"
-
- def test_run_all(self, db_conn, lease_manager):
- conn, lock = db_conn
- registry = JobRegistry(conn, lock, lease_manager)
- registry.register(_TestJob)
- registry.register(_IdempotentJob)
-
- results = registry.run_all()
- assert len(results) == 2
- completed = [r for r in results if r["status"] == "completed"]
- assert len(completed) == 2
-
-
-# =========================================================================
-# Distillation Tests
-# =========================================================================
-
-class TestDistillation:
-
- def test_idempotency_key(self):
- k1 = compute_idempotency_key(["e1", "e2"], 1, "belief:u1:test")
- k2 = compute_idempotency_key(["e2", "e1"], 1, "belief:u1:test") # sorted
- assert k1 == k2 # Same regardless of order
-
- k3 = compute_idempotency_key(["e1", "e2"], 2, "belief:u1:test") # diff version
- assert k1 != k3
-
- def test_candidate_auto_key(self):
- c = DistillationCandidate(
- candidate_id="c1",
- source_event_ids=["e1", "e2"],
- target_type="belief",
- canonical_key="belief:u1:test",
- payload={"claim": "test"},
- )
- assert c.idempotency_key
- assert len(c.idempotency_key) == 24
-
- def test_submit_and_get(self, stores):
- ds = stores["distillation"]
- candidate = DistillationCandidate(
- candidate_id="c1",
- source_event_ids=["e1"],
- target_type="belief",
- canonical_key="belief:u1:test",
- payload={"user_id": "u1", "claim": "test claim"},
- confidence=0.6,
- )
- result = ds.submit(candidate)
- assert result == "c1"
-
- fetched = ds.get("c1")
- assert fetched is not None
- assert fetched["target_type"] == "belief"
- assert fetched["status"] == "pending_validation"
-
- def test_submit_dedup(self, stores):
- ds = stores["distillation"]
-
- c1 = DistillationCandidate(
- candidate_id="c1",
- source_event_ids=["e1"],
- target_type="belief",
- canonical_key="belief:u1:test",
- payload={"claim": "test"},
- )
- c2 = DistillationCandidate(
- candidate_id="c2",
- source_event_ids=["e1"], # same source + same key
- target_type="belief",
- canonical_key="belief:u1:test",
- payload={"claim": "test"},
- )
-
- assert ds.submit(c1) == "c1"
- assert ds.submit(c2) is None # dedup
-
- def test_submit_after_reject(self, stores):
- ds = stores["distillation"]
-
- c1 = DistillationCandidate(
- candidate_id="c1",
- source_event_ids=["e1"],
- target_type="belief",
- canonical_key="belief:u1:test",
- payload={"claim": "test"},
- )
- ds.submit(c1)
- ds.set_status("c1", "rejected")
-
- # Same idempotency key, but rejected — should allow resubmit
- c2 = DistillationCandidate(
- candidate_id="c2",
- source_event_ids=["e1"],
- target_type="belief",
- canonical_key="belief:u1:test",
- payload={"claim": "test v2"},
- )
- assert ds.submit(c2) == "c2"
-
- def test_get_pending(self, stores):
- ds = stores["distillation"]
-
- ds.submit(DistillationCandidate(
- candidate_id="c1", source_event_ids=["e1"],
- target_type="belief", canonical_key="k1",
- payload={"claim": "a"}, confidence=0.8,
- ))
- ds.submit(DistillationCandidate(
- candidate_id="c2", source_event_ids=["e2"],
- target_type="policy", canonical_key="k2",
- payload={"name": "b"}, confidence=0.5,
- ))
-
- pending = ds.get_pending()
- assert len(pending) == 2
-
- beliefs_only = ds.get_pending("belief")
- assert len(beliefs_only) == 1
-
- def test_distill_belief_from_events(self):
- events = [
- {"event_id": "e1", "content": "User prefers dark mode"},
- {"event_id": "e2", "content": "User prefers dark mode"},
- ]
- candidate = distill_belief_from_events(events, user_id="u1")
- assert candidate is not None
- assert candidate.target_type == "belief"
- assert candidate.confidence == 0.5 # 0.3 + 0.1 * 2
- assert len(candidate.source_event_ids) == 2
-
- def test_distill_empty_events(self):
- assert distill_belief_from_events([], user_id="u1") is None
-
-
-# =========================================================================
-# Promotion Tests
-# =========================================================================
-
-class TestPromotion:
-
- def test_promote_belief(self, stores, promotion_engine):
- ds = stores["distillation"]
-
- ds.submit(DistillationCandidate(
- candidate_id="c1",
- source_event_ids=["e1", "e2"],
- target_type="belief",
- canonical_key="belief:u1:test",
- payload={"user_id": "u1", "claim": "Python is great", "domain": "tech"},
- confidence=0.7,
- ))
-
- result = promotion_engine.promote_pending("belief")
- assert result.to_dict()["promoted"] == 1
-
- # Verify belief was created
- beliefs = stores["beliefs"].list_by_user("u1")
- assert len(beliefs) == 1
- assert beliefs[0]["claim"] == "Python is great"
-
- # Verify lineage was written
- lineage = stores["lineage"].get_sources("belief", beliefs[0]["belief_id"])
- assert len(lineage) == 2 # from e1 and e2
-
- # Verify candidate was marked promoted
- candidate = ds.get("c1")
- assert candidate["status"] == "promoted"
- assert candidate["promoted_id"] == beliefs[0]["belief_id"]
-
- def test_promote_policy(self, stores, promotion_engine):
- ds = stores["distillation"]
-
- ds.submit(DistillationCandidate(
- candidate_id="c1",
- source_event_ids=["e1"],
- target_type="policy",
- canonical_key="policy:u1:blame",
- payload={
- "user_id": "u1",
- "name": "git blame first",
- "condition": {"task_types": ["bug_fix"]},
- "action": {"approach": "Run git blame"},
- },
- confidence=0.6,
- ))
-
- result = promotion_engine.promote_pending("policy")
- assert result.to_dict()["promoted"] == 1
-
- policies = stores["policies"].list_by_user("u1")
- assert len(policies) == 1
- assert policies[0]["name"] == "git blame first"
-
- def test_promote_insight(self, stores, promotion_engine):
- ds = stores["distillation"]
-
- ds.submit(DistillationCandidate(
- candidate_id="c1",
- source_event_ids=["e1"],
- target_type="insight",
- canonical_key="insight:u1:tokens",
- payload={
- "user_id": "u1",
- "content": "Token expiry causes outages in production",
- "insight_type": "causal",
- },
- confidence=0.5,
- ))
-
- result = promotion_engine.promote_pending("insight")
- assert result.to_dict()["promoted"] == 1
-
- def test_promote_heuristic(self, stores, promotion_engine):
- ds = stores["distillation"]
-
- ds.submit(DistillationCandidate(
- candidate_id="c1",
- source_event_ids=["e1"],
- target_type="heuristic",
- canonical_key="heuristic:u1:constrained",
- payload={
- "user_id": "u1",
- "content": "Start with the most constrained component first",
- "abstraction_level": "universal",
- },
- confidence=0.6,
- ))
-
- result = promotion_engine.promote_pending("heuristic")
- assert result.to_dict()["promoted"] == 1
-
- def test_reject_low_confidence(self, stores, promotion_engine):
- ds = stores["distillation"]
-
- ds.submit(DistillationCandidate(
- candidate_id="c1",
- source_event_ids=["e1"],
- target_type="belief",
- canonical_key="belief:u1:weak",
- payload={"user_id": "u1", "claim": "maybe"},
- confidence=0.1, # Below min_confidence of 0.3
- ))
-
- result = promotion_engine.promote_pending("belief")
- assert result.to_dict()["rejected"] == 1
- assert result.to_dict()["promoted"] == 0
-
- def test_reject_empty_payload(self, stores, promotion_engine):
- ds = stores["distillation"]
-
- ds.submit(DistillationCandidate(
- candidate_id="c1",
- source_event_ids=["e1"],
- target_type="belief",
- canonical_key="belief:u1:empty",
- payload={},
- confidence=0.8,
- ))
-
- result = promotion_engine.promote_pending("belief")
- assert result.to_dict()["rejected"] == 1
-
- def test_reject_short_claim(self, stores, promotion_engine):
- ds = stores["distillation"]
-
- ds.submit(DistillationCandidate(
- candidate_id="c1",
- source_event_ids=["e1"],
- target_type="belief",
- canonical_key="belief:u1:short",
- payload={"user_id": "u1", "claim": "hi"}, # < 5 chars
- confidence=0.8,
- ))
-
- result = promotion_engine.promote_pending("belief")
- assert result.to_dict()["rejected"] == 1
-
- def test_promote_single(self, stores, promotion_engine):
- ds = stores["distillation"]
-
- ds.submit(DistillationCandidate(
- candidate_id="c1",
- source_event_ids=["e1"],
- target_type="belief",
- canonical_key="belief:u1:single",
- payload={"user_id": "u1", "claim": "Test single promotion"},
- confidence=0.7,
- ))
-
- result = promotion_engine.promote_single("c1")
- assert result["status"] == "promoted"
- assert result["promoted_id"]
-
- def test_promote_single_nonexistent(self, promotion_engine):
- result = promotion_engine.promote_single("nonexistent")
- assert result["status"] == "error"
-
- def test_promote_single_already_promoted(self, stores, promotion_engine):
- ds = stores["distillation"]
-
- ds.submit(DistillationCandidate(
- candidate_id="c1",
- source_event_ids=["e1"],
- target_type="belief",
- canonical_key="belief:u1:double",
- payload={"user_id": "u1", "claim": "Test double promotion"},
- confidence=0.7,
- ))
-
- r1 = promotion_engine.promote_single("c1")
- assert r1["status"] == "promoted"
-
- r2 = promotion_engine.promote_single("c1")
- assert r2["status"] == "skipped"
-
- def test_batch_promotion(self, stores, promotion_engine):
- ds = stores["distillation"]
-
- for i in range(5):
- ds.submit(DistillationCandidate(
- candidate_id=f"c{i}",
- source_event_ids=[f"e{i}"],
- target_type="belief",
- canonical_key=f"belief:u1:batch{i}",
- payload={"user_id": "u1", "claim": f"Batch claim number {i}"},
- confidence=0.6,
- ))
-
- result = promotion_engine.promote_pending("belief", limit=10)
- assert result.to_dict()["promoted"] == 5
-
- beliefs = stores["beliefs"].list_by_user("u1")
- assert len(beliefs) == 5
-
-
-# =========================================================================
-# Consolidation Feedback Loop Tests
-# =========================================================================
-
-class TestConsolidationSafety:
-
- def test_should_promote_rejects_consolidated(self):
- """Signals with source='consolidated' must be rejected."""
- from dhee.core.consolidation import ConsolidationEngine
-
- # We can't easily instantiate ConsolidationEngine without full deps,
- # so test the logic directly by checking the source code contract.
- import inspect
- src = inspect.getsource(ConsolidationEngine._should_promote)
-
- # Verify the feedback loop guard is present
- assert "consolidated" in src, (
- "ConsolidationEngine._should_promote must check for "
- "'consolidated' source to prevent feedback loops"
- )
- assert "consolidated_from" in src, (
- "ConsolidationEngine._should_promote must check for "
- "'consolidated_from' metadata to prevent re-consolidation"
- )
-
- def test_promote_uses_infer_false(self):
- """_promote_to_passive must use infer=False to skip enrichment."""
- from dhee.core.consolidation import ConsolidationEngine
- import inspect
- src = inspect.getsource(ConsolidationEngine._promote_to_passive)
-
- assert "infer=False" in src, (
- "ConsolidationEngine._promote_to_passive must use infer=False "
- "to skip the LLM enrichment pipeline"
- )
-
- def test_promote_tags_provenance(self):
- """_promote_to_passive must tag consolidated provenance."""
- from dhee.core.consolidation import ConsolidationEngine
- import inspect
- src = inspect.getsource(ConsolidationEngine._promote_to_passive)
-
- assert '"source": "consolidated"' in src or "'source': 'consolidated'" in src, (
- "ConsolidationEngine._promote_to_passive must tag "
- "promoted memories with source='consolidated'"
- )
-
-
-# =========================================================================
-# AGI Loop Cleanup Tests
-# =========================================================================
-
-class TestAgiLoopCleanup:
-
- def test_no_phantom_imports(self):
- """agi_loop.py must not import non-existent engram_* packages."""
- import inspect
- from dhee.core import agi_loop
- src = inspect.getsource(agi_loop)
-
- phantom_packages = [
- "engram_reconsolidation",
- "engram_procedural",
- "engram_metamemory",
- "engram_prospective",
- "engram_working",
- "engram_failure",
- "engram_router",
- "engram_identity",
- "engram_heartbeat",
- "engram_policy",
- "engram_skills",
- "engram_spawn",
- "engram_resilience",
- ]
-
- for pkg in phantom_packages:
- assert pkg not in src, (
- f"agi_loop.py still references phantom package '{pkg}'. "
- f"All engram_* phantom imports must be removed."
- )
-
- def test_run_agi_cycle_api_preserved(self):
- """run_agi_cycle function must still exist (backward compat)."""
- from dhee.core.agi_loop import run_agi_cycle
- assert callable(run_agi_cycle)
-
- def test_get_system_health_api_preserved(self):
- """get_system_health function must still exist (backward compat)."""
- from dhee.core.agi_loop import get_system_health
- assert callable(get_system_health)
diff --git a/tests/test_v3_storage.py b/tests/test_v3_storage.py
deleted file mode 100644
index 344a4a8..0000000
--- a/tests/test_v3_storage.py
+++ /dev/null
@@ -1,738 +0,0 @@
-"""Tests for Dhee v3 event-sourced storage layer.
-
-Covers:
-- RawEventStore: add, dedup, correct, delete, supersedes chain
-- BeliefStore: CRUD, confidence updates, contradiction tracking
-- PolicyStore: CRUD, outcome recording, status transitions
-- AnchorStore: CRUD, field updates, filtering
-- InsightStore: CRUD, outcome recording
-- HeuristicStore: CRUD, outcome recording
-- DerivedLineageStore: source/dependent queries, contribution weights
-- CognitionStore: coordinator integration
-"""
-
-import os
-import sqlite3
-import tempfile
-
-import pytest
-
-from dhee.core.events import RawEventStore, RawMemoryEvent, EventStatus
-from dhee.core.derived_store import (
- BeliefStore,
- PolicyStore,
- AnchorStore,
- InsightStore,
- HeuristicStore,
- DerivedLineageStore,
- CognitionStore,
-)
-from dhee.core.storage import initialize_schema
-
-
-# ---------------------------------------------------------------------------
-# Fixtures
-# ---------------------------------------------------------------------------
-
-@pytest.fixture
-def db_path(tmp_path):
- return str(tmp_path / "test_v3.db")
-
-
-@pytest.fixture
-def event_store(db_path):
- store = RawEventStore(db_path=db_path)
- yield store
- store.close()
-
-
-@pytest.fixture
-def cognition_store(db_path):
- store = CognitionStore(db_path=db_path)
- yield store
- store.close()
-
-
-@pytest.fixture
-def shared_conn(db_path):
- """Shared connection + lock for derived store tests."""
- import threading
- conn = sqlite3.connect(db_path, check_same_thread=False)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.row_factory = sqlite3.Row
- initialize_schema(conn)
- lock = threading.RLock()
- yield conn, lock
- conn.close()
-
-
-# =========================================================================
-# RawEventStore Tests
-# =========================================================================
-
-class TestRawEventStore:
-
- def test_add_basic(self, event_store):
- event = event_store.add(content="User prefers dark mode", user_id="u1")
- assert event.event_id
- assert event.content == "User prefers dark mode"
- assert event.user_id == "u1"
- assert event.status == EventStatus.ACTIVE
- assert event.content_hash == RawMemoryEvent.compute_hash("User prefers dark mode")
-
- def test_add_with_metadata(self, event_store):
- event = event_store.add(
- content="test",
- user_id="u1",
- session_id="s1",
- source="mcp",
- metadata={"key": "value"},
- )
- assert event.session_id == "s1"
- assert event.source == "mcp"
- assert event.metadata == {"key": "value"}
-
- def test_dedup_same_content(self, event_store):
- e1 = event_store.add(content="same fact", user_id="u1")
- e2 = event_store.add(content="same fact", user_id="u1")
- assert e1.event_id == e2.event_id # dedup returns existing
-
- def test_dedup_different_users(self, event_store):
- e1 = event_store.add(content="shared fact", user_id="u1")
- e2 = event_store.add(content="shared fact", user_id="u2")
- assert e1.event_id != e2.event_id # different users = no dedup
-
- def test_get(self, event_store):
- e = event_store.add(content="test get", user_id="u1")
- fetched = event_store.get(e.event_id)
- assert fetched is not None
- assert fetched.content == "test get"
-
- def test_get_nonexistent(self, event_store):
- assert event_store.get("nonexistent") is None
-
- def test_get_by_hash(self, event_store):
- e = event_store.add(content="hash lookup", user_id="u1")
- found = event_store.get_by_hash(e.content_hash, "u1")
- assert found is not None
- assert found.event_id == e.event_id
-
- def test_correct(self, event_store):
- original = event_store.add(content="I live in Ghazipur", user_id="u1")
- correction = event_store.correct(
- original.event_id, "I live in Bengaluru"
- )
-
- assert correction.supersedes_event_id == original.event_id
- assert correction.content == "I live in Bengaluru"
- assert correction.status == EventStatus.ACTIVE
-
- # Original should now be 'corrected'
- old = event_store.get(original.event_id)
- assert old.status == EventStatus.CORRECTED
-
- def test_correct_nonexistent(self, event_store):
- with pytest.raises(ValueError, match="not found"):
- event_store.correct("nonexistent", "new content")
-
- def test_correct_already_corrected(self, event_store):
- e = event_store.add(content="old", user_id="u1")
- event_store.correct(e.event_id, "new")
- with pytest.raises(ValueError, match="Cannot correct"):
- event_store.correct(e.event_id, "newer")
-
- def test_delete(self, event_store):
- e = event_store.add(content="to delete", user_id="u1")
- assert event_store.delete(e.event_id) is True
- deleted = event_store.get(e.event_id)
- assert deleted.status == EventStatus.DELETED
-
- def test_delete_idempotent(self, event_store):
- e = event_store.add(content="to delete", user_id="u1")
- assert event_store.delete(e.event_id) is True
- assert event_store.delete(e.event_id) is False # already deleted
-
- def test_delete_nonexistent(self, event_store):
- with pytest.raises(ValueError, match="not found"):
- event_store.delete("nonexistent")
-
- def test_list_by_user(self, event_store):
- event_store.add(content="fact1", user_id="u1")
- event_store.add(content="fact2", user_id="u1")
- event_store.add(content="fact3", user_id="u2")
-
- u1_events = event_store.list_by_user("u1")
- assert len(u1_events) == 2
-
- u2_events = event_store.list_by_user("u2")
- assert len(u2_events) == 1
-
- def test_list_by_user_with_status(self, event_store):
- e1 = event_store.add(content="active", user_id="u1")
- e2 = event_store.add(content="will delete", user_id="u1")
- event_store.delete(e2.event_id)
-
- active = event_store.list_by_user("u1", status=EventStatus.ACTIVE)
- assert len(active) == 1
- assert active[0].content == "active"
-
- def test_supersedes_chain(self, event_store):
- e1 = event_store.add(content="v1", user_id="u1")
- e2 = event_store.correct(e1.event_id, "v2")
- # Can't correct e1 again (it's already corrected), but we can
- # correct e2 to get a 3-step chain
- e3 = event_store.correct(e2.event_id, "v3")
-
- chain = event_store.get_supersedes_chain(e3.event_id)
- assert len(chain) == 3
- assert chain[0].content == "v3"
- assert chain[1].content == "v2"
- assert chain[2].content == "v1"
-
- def test_count(self, event_store):
- event_store.add(content="a", user_id="u1")
- event_store.add(content="b", user_id="u1")
- event_store.add(content="c", user_id="u1")
-
- assert event_store.count("u1") == 3
- assert event_store.count("u1", status=EventStatus.ACTIVE) == 3
- assert event_store.count("u2") == 0
-
- def test_dedup_after_delete(self, event_store):
- """Deleted content should not block new addition of same content."""
- e1 = event_store.add(content="ephemeral", user_id="u1")
- event_store.delete(e1.event_id)
- # Adding same content again should create new event (old is deleted, not active)
- e2 = event_store.add(content="ephemeral", user_id="u1")
- assert e2.event_id != e1.event_id
-
-
-# =========================================================================
-# BeliefStore Tests
-# =========================================================================
-
-class TestBeliefStore:
-
- def test_add_and_get(self, shared_conn):
- conn, lock = shared_conn
- store = BeliefStore(conn, lock)
-
- bid = store.add(
- user_id="u1",
- claim="Python is dynamically typed",
- domain="programming",
- confidence=0.8,
- )
- belief = store.get(bid)
- assert belief is not None
- assert belief["claim"] == "Python is dynamically typed"
- assert belief["domain"] == "programming"
- assert belief["confidence"] == 0.8
- assert belief["status"] == "proposed"
-
- def test_update_confidence_auto_status(self, shared_conn):
- conn, lock = shared_conn
- store = BeliefStore(conn, lock)
-
- bid = store.add(user_id="u1", claim="test", confidence=0.5)
-
- # High confidence → held
- store.update_confidence(bid, 0.9)
- b = store.get(bid)
- assert b["status"] == "held"
- assert b["confidence"] == 0.9
- assert len(b["revisions"]) == 1
-
- # Low confidence → retracted
- store.update_confidence(bid, 0.05)
- b = store.get(bid)
- assert b["status"] == "retracted"
-
- def test_update_confidence_with_evidence(self, shared_conn):
- conn, lock = shared_conn
- store = BeliefStore(conn, lock)
-
- bid = store.add(user_id="u1", claim="test", confidence=0.5)
- store.update_confidence(
- bid, 0.7,
- evidence={"content": "saw it in docs", "supports": True},
- revision_reason="documentation found",
- )
-
- b = store.get(bid)
- assert len(b["evidence"]) == 1
- assert b["evidence"][0]["content"] == "saw it in docs"
-
- def test_contradiction(self, shared_conn):
- conn, lock = shared_conn
- store = BeliefStore(conn, lock)
-
- b1 = store.add(user_id="u1", claim="Earth is round", confidence=0.9)
- b2 = store.add(user_id="u1", claim="Earth is flat", confidence=0.3)
-
- store.add_contradiction(b1, b2)
-
- belief1 = store.get(b1)
- belief2 = store.get(b2)
- assert b2 in belief1["contradicts_ids"]
- assert b1 in belief2["contradicts_ids"]
- assert belief1["status"] == "challenged"
- assert belief2["status"] == "challenged"
-
- def test_list_by_user(self, shared_conn):
- conn, lock = shared_conn
- store = BeliefStore(conn, lock)
-
- store.add(user_id="u1", claim="a", domain="sci", confidence=0.9)
- store.add(user_id="u1", claim="b", domain="sci", confidence=0.3)
- store.add(user_id="u1", claim="c", domain="eng", confidence=0.7)
-
- sci = store.list_by_user("u1", domain="sci")
- assert len(sci) == 2
-
- high = store.list_by_user("u1", min_confidence=0.5)
- assert len(high) == 2
-
- def test_set_status(self, shared_conn):
- conn, lock = shared_conn
- store = BeliefStore(conn, lock)
-
- bid = store.add(user_id="u1", claim="test")
- assert store.set_status(bid, "stale") is True
- assert store.get(bid)["status"] == "stale"
-
- def test_get_by_invalidation_status(self, shared_conn):
- conn, lock = shared_conn
- store = BeliefStore(conn, lock)
-
- b1 = store.add(user_id="u1", claim="stale one")
- store.set_status(b1, "stale")
- b2 = store.add(user_id="u1", claim="active one")
-
- stale = store.get_by_invalidation_status("stale")
- assert len(stale) == 1
- assert stale[0]["belief_id"] == b1
-
-
-# =========================================================================
-# PolicyStore Tests
-# =========================================================================
-
-class TestPolicyStore:
-
- def test_add_and_get(self, shared_conn):
- conn, lock = shared_conn
- store = PolicyStore(conn, lock)
-
- pid = store.add(
- user_id="u1",
- name="Use git blame first",
- condition={"task_types": ["bug_fix"]},
- action={"approach": "Run git blame on failing file"},
- granularity="task",
- )
- policy = store.get(pid)
- assert policy is not None
- assert policy["name"] == "Use git blame first"
- assert policy["granularity"] == "task"
- assert policy["status"] == "proposed"
- assert policy["condition"]["task_types"] == ["bug_fix"]
-
- def test_record_outcome_success(self, shared_conn):
- conn, lock = shared_conn
- store = PolicyStore(conn, lock)
-
- pid = store.add(
- user_id="u1", name="test",
- condition={}, action={},
- )
-
- for _ in range(5):
- store.record_outcome(pid, success=True, baseline_score=0.5, actual_score=0.8)
-
- p = store.get(pid)
- assert p["apply_count"] == 5
- assert p["success_count"] == 5
- assert p["failure_count"] == 0
- assert p["status"] == "validated" # win_rate >= 0.6 after 5+
- assert p["utility"] > 0
-
- def test_record_outcome_deprecated(self, shared_conn):
- conn, lock = shared_conn
- store = PolicyStore(conn, lock)
-
- pid = store.add(
- user_id="u1", name="bad policy",
- condition={}, action={},
- )
-
- for _ in range(6):
- store.record_outcome(pid, success=False)
-
- p = store.get(pid)
- assert p["status"] == "deprecated"
-
- def test_record_outcome_utility_ema(self, shared_conn):
- conn, lock = shared_conn
- store = PolicyStore(conn, lock)
-
- pid = store.add(
- user_id="u1", name="ema test",
- condition={}, action={},
- )
-
- store.record_outcome(pid, success=True, baseline_score=0.5, actual_score=0.9)
- p = store.get(pid)
- # First delta = 0.4, utility = 0.3 * 0.4 + 0.7 * 0.0 = 0.12
- assert abs(p["utility"] - 0.12) < 0.01
-
- def test_list_by_user(self, shared_conn):
- conn, lock = shared_conn
- store = PolicyStore(conn, lock)
-
- store.add(user_id="u1", name="p1", condition={}, action={}, granularity="task")
- store.add(user_id="u1", name="p2", condition={}, action={}, granularity="step")
-
- task_policies = store.list_by_user("u1", granularity="task")
- assert len(task_policies) == 1
-
-
-# =========================================================================
-# AnchorStore Tests
-# =========================================================================
-
-class TestAnchorStore:
-
- def test_add_and_get(self, shared_conn):
- conn, lock = shared_conn
- store = AnchorStore(conn, lock)
-
- aid = store.add(
- user_id="u1",
- era="bengaluru_work",
- place="Bengaluru",
- place_type="city",
- activity="coding",
- )
- anchor = store.get(aid)
- assert anchor is not None
- assert anchor["era"] == "bengaluru_work"
- assert anchor["place"] == "Bengaluru"
- assert anchor["activity"] == "coding"
-
- def test_get_by_event(self, shared_conn):
- conn, lock = shared_conn
- store = AnchorStore(conn, lock)
-
- aid = store.add(user_id="u1", memory_event_id="evt-123")
- found = store.get_by_event("evt-123")
- assert found is not None
- assert found["anchor_id"] == aid
-
- def test_list_filtered(self, shared_conn):
- conn, lock = shared_conn
- store = AnchorStore(conn, lock)
-
- store.add(user_id="u1", era="school", place="Ghazipur")
- store.add(user_id="u1", era="work", place="Bengaluru")
- store.add(user_id="u1", era="school", place="Delhi")
-
- school = store.list_by_user("u1", era="school")
- assert len(school) == 2
-
- blr = store.list_by_user("u1", place="Bengaluru")
- assert len(blr) == 1
-
- def test_update_fields(self, shared_conn):
- conn, lock = shared_conn
- store = AnchorStore(conn, lock)
-
- aid = store.add(user_id="u1", era="old_era")
- assert store.update_fields(aid, era="new_era") is True
- assert store.get(aid)["era"] == "new_era"
-
- def test_update_fields_rejects_invalid(self, shared_conn):
- conn, lock = shared_conn
- store = AnchorStore(conn, lock)
-
- aid = store.add(user_id="u1")
- # user_id is not in the allowed update set
- assert store.update_fields(aid, user_id="hacker") is False
-
-
-# =========================================================================
-# InsightStore Tests
-# =========================================================================
-
-class TestInsightStore:
-
- def test_add_and_get(self, shared_conn):
- conn, lock = shared_conn
- store = InsightStore(conn, lock)
-
- iid = store.add(
- user_id="u1",
- content="Strict criteria + balanced scoring works best",
- insight_type="strategy",
- confidence=0.7,
- )
- insight = store.get(iid)
- assert insight["content"] == "Strict criteria + balanced scoring works best"
- assert insight["insight_type"] == "strategy"
-
- def test_record_outcome_success(self, shared_conn):
- conn, lock = shared_conn
- store = InsightStore(conn, lock)
-
- iid = store.add(user_id="u1", content="test", confidence=0.5)
- store.record_outcome(iid, success=True)
- i = store.get(iid)
- assert i["validation_count"] == 1
- assert i["confidence"] == 0.55 # 0.5 + 0.05
-
- def test_record_outcome_failure(self, shared_conn):
- conn, lock = shared_conn
- store = InsightStore(conn, lock)
-
- iid = store.add(user_id="u1", content="test", confidence=0.5)
- store.record_outcome(iid, success=False)
- i = store.get(iid)
- assert i["invalidation_count"] == 1
- assert i["confidence"] == 0.4 # 0.5 - 0.1
-
- def test_record_outcome_with_scores(self, shared_conn):
- conn, lock = shared_conn
- store = InsightStore(conn, lock)
-
- iid = store.add(user_id="u1", content="test", confidence=0.5)
- store.record_outcome(iid, success=True, baseline_score=0.5, actual_score=0.8)
- i = store.get(iid)
- # utility = 0.3 * 0.3 + 0.7 * 0.0 = 0.09
- assert abs(i["utility"] - 0.09) < 0.01
-
- def test_list_by_type(self, shared_conn):
- conn, lock = shared_conn
- store = InsightStore(conn, lock)
-
- store.add(user_id="u1", content="a", insight_type="warning")
- store.add(user_id="u1", content="b", insight_type="strategy")
- store.add(user_id="u1", content="c", insight_type="warning")
-
- warnings = store.list_by_user("u1", insight_type="warning")
- assert len(warnings) == 2
-
-
-# =========================================================================
-# HeuristicStore Tests
-# =========================================================================
-
-class TestHeuristicStore:
-
- def test_add_and_get(self, shared_conn):
- conn, lock = shared_conn
- store = HeuristicStore(conn, lock)
-
- hid = store.add(
- user_id="u1",
- content="Start with the most constrained component",
- abstraction_level="universal",
- )
- h = store.get(hid)
- assert h["content"] == "Start with the most constrained component"
- assert h["abstraction_level"] == "universal"
-
- def test_record_outcome(self, shared_conn):
- conn, lock = shared_conn
- store = HeuristicStore(conn, lock)
-
- hid = store.add(user_id="u1", content="test", confidence=0.5)
- store.record_outcome(hid, success=True, baseline_score=0.4, actual_score=0.7)
-
- h = store.get(hid)
- assert h["validation_count"] == 1
- assert h["confidence"] == 0.55
- assert abs(h["utility"] - 0.09) < 0.01 # 0.3 * 0.3 + 0.7 * 0
-
- def test_list_by_level(self, shared_conn):
- conn, lock = shared_conn
- store = HeuristicStore(conn, lock)
-
- store.add(user_id="u1", content="a", abstraction_level="specific")
- store.add(user_id="u1", content="b", abstraction_level="domain")
- store.add(user_id="u1", content="c", abstraction_level="universal")
-
- universal = store.list_by_user("u1", abstraction_level="universal")
- assert len(universal) == 1
-
-
-# =========================================================================
-# DerivedLineageStore Tests
-# =========================================================================
-
-class TestDerivedLineageStore:
-
- def test_add_and_get_sources(self, shared_conn):
- conn, lock = shared_conn
- store = DerivedLineageStore(conn, lock)
-
- store.add("belief", "b1", "evt1", contribution_weight=0.6)
- store.add("belief", "b1", "evt2", contribution_weight=0.4)
-
- sources = store.get_sources("belief", "b1")
- assert len(sources) == 2
- weights = {s["source_event_id"]: s["contribution_weight"] for s in sources}
- assert weights["evt1"] == 0.6
- assert weights["evt2"] == 0.4
-
- def test_get_dependents(self, shared_conn):
- conn, lock = shared_conn
- store = DerivedLineageStore(conn, lock)
-
- store.add("belief", "b1", "evt1")
- store.add("policy", "p1", "evt1")
- store.add("belief", "b2", "evt2")
-
- deps = store.get_dependents("evt1")
- assert len(deps) == 2
- types = {d["derived_type"] for d in deps}
- assert types == {"belief", "policy"}
-
- def test_add_batch(self, shared_conn):
- conn, lock = shared_conn
- store = DerivedLineageStore(conn, lock)
-
- ids = store.add_batch(
- "insight", "i1",
- ["evt1", "evt2", "evt3"],
- weights=[0.5, 0.3, 0.2],
- )
- assert len(ids) == 3
- assert store.get_source_count("insight", "i1") == 3
-
- def test_contribution_weight(self, shared_conn):
- conn, lock = shared_conn
- store = DerivedLineageStore(conn, lock)
-
- store.add("belief", "b1", "evt1", contribution_weight=0.7)
- w = store.get_contribution_weight("belief", "b1", "evt1")
- assert w == 0.7
-
- # Nonexistent
- assert store.get_contribution_weight("belief", "b1", "evt999") is None
-
- def test_delete_for_derived(self, shared_conn):
- conn, lock = shared_conn
- store = DerivedLineageStore(conn, lock)
-
- store.add("belief", "b1", "evt1")
- store.add("belief", "b1", "evt2")
- assert store.get_source_count("belief", "b1") == 2
-
- deleted = store.delete_for_derived("belief", "b1")
- assert deleted == 2
- assert store.get_source_count("belief", "b1") == 0
-
-
-# =========================================================================
-# CognitionStore Integration Tests
-# =========================================================================
-
-class TestCognitionStore:
-
- def test_full_lifecycle(self, cognition_store):
- """End-to-end: raw event → derived belief → lineage → invalidation status."""
- cs = cognition_store
-
- # 1. Store raw event
- event = cs.events.add(content="User prefers dark mode", user_id="u1")
- assert event.status == EventStatus.ACTIVE
-
- # 2. Derive a belief from it
- bid = cs.beliefs.add(
- user_id="u1",
- claim="User prefers dark mode",
- domain="preferences",
- confidence=0.8,
- source_memory_ids=[event.event_id],
- )
-
- # 3. Record lineage
- cs.lineage.add("belief", bid, event.event_id, contribution_weight=1.0)
-
- # 4. Verify lineage
- sources = cs.lineage.get_sources("belief", bid)
- assert len(sources) == 1
- assert sources[0]["source_event_id"] == event.event_id
-
- deps = cs.lineage.get_dependents(event.event_id)
- assert len(deps) == 1
- assert deps[0]["derived_id"] == bid
-
- # 5. Correct the raw event
- correction = cs.events.correct(
- event.event_id, "User prefers light mode"
- )
- assert correction.supersedes_event_id == event.event_id
-
- # 6. Mark belief as stale (soft invalidation would do this)
- cs.beliefs.set_status(bid, "stale")
- stale = cs.beliefs.get_by_invalidation_status("stale")
- assert len(stale) == 1
-
- def test_policy_lifecycle(self, cognition_store):
- cs = cognition_store
-
- # Create event + policy + lineage
- e = cs.events.add(content="git blame helped find bug", user_id="u1")
- pid = cs.policies.add(
- user_id="u1",
- name="git blame first",
- condition={"task_types": ["bug_fix"]},
- action={"approach": "Run git blame"},
- )
- cs.lineage.add("policy", pid, e.event_id)
-
- # Record outcomes → validated
- for _ in range(6):
- cs.policies.record_outcome(pid, success=True, baseline_score=0.4, actual_score=0.8)
-
- p = cs.policies.get(pid)
- assert p["status"] == "validated"
- assert p["utility"] > 0
-
- def test_multi_type_lineage(self, cognition_store):
- """One raw event feeds multiple derived types."""
- cs = cognition_store
-
- e = cs.events.add(content="auth tokens expire in prod", user_id="u1")
-
- bid = cs.beliefs.add(user_id="u1", claim="Tokens expire in prod")
- pid = cs.policies.add(
- user_id="u1", name="check token expiry",
- condition={}, action={},
- )
- iid = cs.insights.add(user_id="u1", content="Token expiry causes outages")
-
- cs.lineage.add("belief", bid, e.event_id)
- cs.lineage.add("policy", pid, e.event_id)
- cs.lineage.add("insight", iid, e.event_id)
-
- deps = cs.lineage.get_dependents(e.event_id)
- assert len(deps) == 3
- types = {d["derived_type"] for d in deps}
- assert types == {"belief", "policy", "insight"}
-
- def test_shared_connection(self, cognition_store):
- """All stores share the same connection — verify cross-store visibility."""
- cs = cognition_store
-
- # Write through events store
- e = cs.events.add(content="test visibility", user_id="u1")
-
- # Write through beliefs store
- bid = cs.beliefs.add(user_id="u1", claim="test")
-
- # Lineage can see both (same connection)
- cs.lineage.add("belief", bid, e.event_id)
- sources = cs.lineage.get_sources("belief", bid)
- assert len(sources) == 1
diff --git a/tests/test_vector_store_factory.py b/tests/test_vector_store_factory.py
new file mode 100644
index 0000000..2c2238d
--- /dev/null
+++ b/tests/test_vector_store_factory.py
@@ -0,0 +1,14 @@
+from dhee.utils.factory import _normalize_sqlite_vec_config
+
+
+def test_normalize_sqlite_vec_config_keeps_file_paths():
+ cfg = {"path": "/tmp/dhee/sqlite_vec.db", "collection_name": "x"}
+ normalized = _normalize_sqlite_vec_config(cfg)
+ assert normalized["path"] == "/tmp/dhee/sqlite_vec.db"
+ assert normalized["collection_name"] == "x"
+
+
+def test_normalize_sqlite_vec_config_converts_directory_paths():
+ cfg = {"path": "/tmp/dhee/zvec"}
+ normalized = _normalize_sqlite_vec_config(cfg)
+ assert normalized["path"] == "/tmp/dhee/zvec/sqlite_vec.db"