diff --git a/dhee/__init__.py b/dhee/__init__.py index 3172da0..f912e24 100644 --- a/dhee/__init__.py +++ b/dhee/__init__.py @@ -24,7 +24,7 @@ from dhee.memory.smart import SmartMemory from dhee.memory.main import FullMemory from dhee.simple import Dhee, Engram -from dhee.adapters.base import DheePlugin +from dhee.plugin import DheePlugin from dhee.core.category import CategoryProcessor, Category, CategoryType, CategoryMatch from dhee.core.echo import EchoProcessor, EchoDepth, EchoResult from dhee.configs.base import MemoryConfig, FadeMemConfig, EchoMemConfig, CategoryMemConfig, ScopeConfig @@ -32,7 +32,7 @@ # Default: CoreMemory (lightest, zero-config) Memory = CoreMemory -__version__ = "3.0.1" +__version__ = "3.1.0" __all__ = [ # Memory classes "Engram", diff --git a/dhee/adapters/__init__.py b/dhee/adapters/__init__.py deleted file mode 100644 index 85a7879..0000000 --- a/dhee/adapters/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Dhee adapters — universal plugin interface for any agent framework. - -Available adapters: - - DheePlugin: Base universal plugin (remember/recall/context/checkpoint) - - OpenAIToolAdapter: OpenAI function calling (tools= parameter) - - get_dhee_tools: LangChain BaseTool wrappers - - get_autogen_functions: AutoGen v0.2 callables + schemas - - generate_snapshot: Frozen system prompt for non-tool-calling agents -""" - -from dhee.adapters.base import DheePlugin - -__all__ = ["DheePlugin"] diff --git a/dhee/adapters/autogen.py b/dhee/adapters/autogen.py deleted file mode 100644 index 8d4b4e1..0000000 --- a/dhee/adapters/autogen.py +++ /dev/null @@ -1,268 +0,0 @@ -"""AutoGen adapter — wraps DheePlugin tools as AutoGen-callable functions. - -Supports both AutoGen v0.2 (register_for_llm/register_for_execution) and -the newer AG2 / AutoGen 0.4+ patterns. - -Usage with AutoGen v0.2: - from dhee import DheePlugin - from dhee.adapters.autogen import get_autogen_functions, register_dhee_tools - - plugin = DheePlugin() - - # Option 1: Get callables + schemas for manual registration - functions = get_autogen_functions(plugin) - - # Option 2: Auto-register on an assistant + executor pair - register_dhee_tools(plugin, assistant=assistant, executor=user_proxy) - -Usage with AG2 / AutoGen 0.4+: - from dhee.adapters.autogen import get_autogen_tool_specs - - specs = get_autogen_tool_specs(plugin) - # Pass to ConversableAgent(tools=specs) -""" - -from __future__ import annotations - -import json -import logging -from typing import Any, Callable, Dict, List, Optional, Tuple - -logger = logging.getLogger(__name__) - - -# --------------------------------------------------------------------------- -# Tool callables -# --------------------------------------------------------------------------- - -def _make_callables(plugin: Any) -> Dict[str, Callable]: - """Create plain callables wrapping DheePlugin methods.""" - - def remember(content: str, user_id: str = "default") -> str: - """Store a fact, preference, or observation to memory.""" - result = plugin.remember(content=content, user_id=user_id) - return json.dumps(result, default=str) - - def recall(query: str, user_id: str = "default", limit: int = 5) -> str: - """Search memory for relevant facts.""" - results = plugin.recall(query=query, user_id=user_id, limit=limit) - return json.dumps(results, default=str) - - def context( - task_description: str = "", user_id: str = "default", - ) -> str: - """HyperAgent session bootstrap. Returns full cognition context.""" - result = plugin.context( - task_description=task_description or None, user_id=user_id, - ) - return json.dumps(result, default=str) - - def checkpoint( - summary: str, - task_type: str = "", - outcome_score: float = -1.0, - what_worked: str = "", - what_failed: str = "", - remember_to: str = "", - ) -> str: - """Save session state and learnings.""" - kwargs: Dict[str, Any] = {"summary": summary} - if task_type: - kwargs["task_type"] = task_type - if outcome_score >= 0: - kwargs["outcome_score"] = outcome_score - if what_worked: - kwargs["what_worked"] = what_worked - if what_failed: - kwargs["what_failed"] = what_failed - if remember_to: - kwargs["remember_to"] = remember_to - result = plugin.checkpoint(**kwargs) - return json.dumps(result, default=str) - - return { - "dhee_remember": remember, - "dhee_recall": recall, - "dhee_context": context, - "dhee_checkpoint": checkpoint, - } - - -# --------------------------------------------------------------------------- -# AutoGen v0.2 schemas -# --------------------------------------------------------------------------- - -_AUTOGEN_SCHEMAS: List[Dict[str, Any]] = [ - { - "name": "dhee_remember", - "description": ( - "Store a fact, preference, or observation to memory. " - "Zero LLM calls, one embedding call." - ), - "parameters": { - "type": "object", - "properties": { - "content": { - "type": "string", - "description": "The fact to remember", - }, - "user_id": { - "type": "string", - "description": "User identifier (default: 'default')", - "default": "default", - }, - }, - "required": ["content"], - }, - }, - { - "name": "dhee_recall", - "description": ( - "Search memory for relevant facts. Returns top-K ranked by relevance." - ), - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "What you're trying to remember", - }, - "user_id": { - "type": "string", - "description": "User identifier", - "default": "default", - }, - "limit": { - "type": "integer", - "description": "Max results (default: 5)", - "default": 5, - }, - }, - "required": ["query"], - }, - }, - { - "name": "dhee_context", - "description": ( - "HyperAgent session bootstrap. Returns performance, insights, " - "intentions, warnings, heuristics, and memories." - ), - "parameters": { - "type": "object", - "properties": { - "task_description": { - "type": "string", - "description": "What you're about to work on", - "default": "", - }, - "user_id": { - "type": "string", - "description": "User identifier", - "default": "default", - }, - }, - }, - }, - { - "name": "dhee_checkpoint", - "description": ( - "Save session state and learnings. Records outcomes, synthesizes " - "insights, stores intentions." - ), - "parameters": { - "type": "object", - "properties": { - "summary": { - "type": "string", - "description": "What you were working on", - }, - "task_type": { - "type": "string", - "description": "Task category (e.g., 'bug_fix')", - "default": "", - }, - "outcome_score": { - "type": "number", - "description": "0.0-1.0 outcome score (-1 to skip)", - "default": -1.0, - }, - "what_worked": { - "type": "string", - "description": "Approach that worked", - "default": "", - }, - "what_failed": { - "type": "string", - "description": "Approach that failed", - "default": "", - }, - "remember_to": { - "type": "string", - "description": "Future intention: 'remember to X when Y'", - "default": "", - }, - }, - "required": ["summary"], - }, - }, -] - - -# --------------------------------------------------------------------------- -# Public API -# --------------------------------------------------------------------------- - -def get_autogen_functions( - plugin: Any, -) -> List[Tuple[Callable, Dict[str, Any]]]: - """Get (callable, schema) pairs for AutoGen v0.2 registration. - - Returns: - List of (function, schema_dict) tuples ready for - register_for_llm / register_for_execution. - """ - callables = _make_callables(plugin) - return [ - (callables[schema["name"]], schema) - for schema in _AUTOGEN_SCHEMAS - ] - - -def register_dhee_tools( - plugin: Any, - assistant: Any, - executor: Any, -) -> None: - """Register Dhee tools on an AutoGen v0.2 assistant + executor pair. - - Args: - plugin: A DheePlugin instance. - assistant: An AssistantAgent (or ConversableAgent) for LLM. - executor: A UserProxyAgent (or ConversableAgent) for execution. - """ - callables = _make_callables(plugin) - - for schema in _AUTOGEN_SCHEMAS: - name = schema["name"] - fn = callables[name] - - # Register for LLM (tool definition) - assistant.register_for_llm( - name=name, - description=schema["description"], - )(fn) - - # Register for execution - executor.register_for_execution(name=name)(fn) - - -def get_autogen_tool_specs(plugin: Any) -> List[Dict[str, Any]]: - """Get tool specs for AG2 / AutoGen 0.4+ ConversableAgent(tools=...). - - Returns a list of dicts with 'function' and 'schema' keys. - """ - callables = _make_callables(plugin) - return [ - {"function": callables[schema["name"]], "schema": schema} - for schema in _AUTOGEN_SCHEMAS - ] diff --git a/dhee/adapters/langchain.py b/dhee/adapters/langchain.py deleted file mode 100644 index a13a4a6..0000000 --- a/dhee/adapters/langchain.py +++ /dev/null @@ -1,236 +0,0 @@ -"""LangChain adapter — wraps DheePlugin tools as LangChain BaseTool instances. - -Usage: - from dhee import DheePlugin - from dhee.adapters.langchain import get_dhee_tools - - plugin = DheePlugin() - tools = get_dhee_tools(plugin) - - # Use with any LangChain agent: - agent = create_react_agent(llm, tools) - - # Or pick individual tools: - remember_tool, recall_tool, context_tool, checkpoint_tool = tools -""" - -from __future__ import annotations - -import json -import logging -from typing import Any, Dict, List, Optional, Type - -logger = logging.getLogger(__name__) - -# Lazy import — LangChain is optional -_HAS_LANGCHAIN = None - - -def _check_langchain() -> bool: - global _HAS_LANGCHAIN - if _HAS_LANGCHAIN is None: - try: - from langchain_core.tools import BaseTool # noqa: F401 - _HAS_LANGCHAIN = True - except ImportError: - _HAS_LANGCHAIN = False - return _HAS_LANGCHAIN - - -def _get_base_classes(): - """Import LangChain base classes (raises ImportError if not installed).""" - from langchain_core.tools import BaseTool - from langchain_core.callbacks import CallbackManagerForToolRun - try: - from pydantic import BaseModel, Field - except ImportError: - from langchain_core.pydantic_v1 import BaseModel, Field - return BaseTool, CallbackManagerForToolRun, BaseModel, Field - - -# --------------------------------------------------------------------------- -# Tool implementations -# --------------------------------------------------------------------------- - -def _make_remember_tool(plugin: Any): - BaseTool, CallbackManagerForToolRun, BaseModel, Field = _get_base_classes() - - class RememberInput(BaseModel): - content: str = Field(description="The fact, preference, or observation to remember") - user_id: Optional[str] = Field(default=None, description="User identifier") - - class DheeRemember(BaseTool): - name: str = "dhee_remember" - description: str = ( - "Store a fact, preference, or observation to memory. " - "Zero LLM calls, one embedding call. Fast." - ) - args_schema: Type[BaseModel] = RememberInput - _plugin: Any = None - - def __init__(self, plugin: Any, **kwargs): - super().__init__(**kwargs) - self._plugin = plugin - - def _run( - self, - content: str, - user_id: Optional[str] = None, - run_manager: Optional[CallbackManagerForToolRun] = None, - ) -> str: - result = self._plugin.remember(content=content, user_id=user_id) - return json.dumps(result, default=str) - - return DheeRemember(plugin=plugin) - - -def _make_recall_tool(plugin: Any): - BaseTool, CallbackManagerForToolRun, BaseModel, Field = _get_base_classes() - - class RecallInput(BaseModel): - query: str = Field(description="What you're trying to remember") - user_id: Optional[str] = Field(default=None, description="User identifier") - limit: int = Field(default=5, description="Maximum results to return") - - class DheeRecall(BaseTool): - name: str = "dhee_recall" - description: str = ( - "Search memory for relevant facts. Returns top-K results " - "ranked by relevance. Zero LLM calls, one embedding." - ) - args_schema: Type[BaseModel] = RecallInput - _plugin: Any = None - - def __init__(self, plugin: Any, **kwargs): - super().__init__(**kwargs) - self._plugin = plugin - - def _run( - self, - query: str, - user_id: Optional[str] = None, - limit: int = 5, - run_manager: Optional[CallbackManagerForToolRun] = None, - ) -> str: - results = self._plugin.recall(query=query, user_id=user_id, limit=limit) - return json.dumps(results, default=str) - - return DheeRecall(plugin=plugin) - - -def _make_context_tool(plugin: Any): - BaseTool, CallbackManagerForToolRun, BaseModel, Field = _get_base_classes() - - class ContextInput(BaseModel): - task_description: Optional[str] = Field( - default=None, description="What you're about to work on", - ) - user_id: Optional[str] = Field(default=None, description="User identifier") - - class DheeContext(BaseTool): - name: str = "dhee_context" - description: str = ( - "HyperAgent session bootstrap. Returns performance snapshots, " - "insights, intentions, warnings, heuristics, and relevant memories. " - "Call once at the start of a task." - ) - args_schema: Type[BaseModel] = ContextInput - _plugin: Any = None - - def __init__(self, plugin: Any, **kwargs): - super().__init__(**kwargs) - self._plugin = plugin - - def _run( - self, - task_description: Optional[str] = None, - user_id: Optional[str] = None, - run_manager: Optional[CallbackManagerForToolRun] = None, - ) -> str: - result = self._plugin.context( - task_description=task_description, user_id=user_id, - ) - return json.dumps(result, default=str) - - return DheeContext(plugin=plugin) - - -def _make_checkpoint_tool(plugin: Any): - BaseTool, CallbackManagerForToolRun, BaseModel, Field = _get_base_classes() - - class CheckpointInput(BaseModel): - summary: str = Field(description="What you were working on") - task_type: Optional[str] = Field( - default=None, description="Task category (e.g., 'bug_fix')", - ) - outcome_score: Optional[float] = Field( - default=None, description="0.0-1.0 outcome score", - ) - what_worked: Optional[str] = Field( - default=None, description="Approach that worked", - ) - what_failed: Optional[str] = Field( - default=None, description="Approach that failed", - ) - remember_to: Optional[str] = Field( - default=None, description="Future intention: 'remember to X when Y'", - ) - - class DheeCheckpoint(BaseTool): - name: str = "dhee_checkpoint" - description: str = ( - "Save session state and learnings. Records outcomes, " - "synthesizes insights from what worked/failed, stores intentions." - ) - args_schema: Type[BaseModel] = CheckpointInput - _plugin: Any = None - - def __init__(self, plugin: Any, **kwargs): - super().__init__(**kwargs) - self._plugin = plugin - - def _run( - self, - summary: str, - task_type: Optional[str] = None, - outcome_score: Optional[float] = None, - what_worked: Optional[str] = None, - what_failed: Optional[str] = None, - remember_to: Optional[str] = None, - run_manager: Optional[CallbackManagerForToolRun] = None, - ) -> str: - result = self._plugin.checkpoint( - summary=summary, task_type=task_type, - outcome_score=outcome_score, what_worked=what_worked, - what_failed=what_failed, remember_to=remember_to, - ) - return json.dumps(result, default=str) - - return DheeCheckpoint(plugin=plugin) - - -# --------------------------------------------------------------------------- -# Public API -# --------------------------------------------------------------------------- - -def get_dhee_tools(plugin: Any) -> List[Any]: - """Create LangChain tool instances from a DheePlugin. - - Returns: - [DheeRemember, DheeRecall, DheeContext, DheeCheckpoint] - - Raises: - ImportError: If langchain-core is not installed. - """ - if not _check_langchain(): - raise ImportError( - "langchain-core is required for LangChain integration. " - "Install it with: pip install langchain-core" - ) - - return [ - _make_remember_tool(plugin), - _make_recall_tool(plugin), - _make_context_tool(plugin), - _make_checkpoint_tool(plugin), - ] diff --git a/dhee/adapters/openai_funcs.py b/dhee/adapters/openai_funcs.py deleted file mode 100644 index 0702dfa..0000000 --- a/dhee/adapters/openai_funcs.py +++ /dev/null @@ -1,183 +0,0 @@ -"""OpenAI function calling adapter for DheePlugin. - -Generates tool definitions compatible with: - - OpenAI Chat Completions API (tools parameter) - - Any OpenAI-compatible API (Ollama, vLLM, LiteLLM, etc.) - -Usage with the OpenAI SDK: - from dhee.adapters.openai_funcs import OpenAIToolAdapter - - adapter = OpenAIToolAdapter(plugin) - response = client.chat.completions.create( - model="gpt-4", - messages=messages, - tools=adapter.tool_definitions(), - ) - - # Execute the function call - for call in response.choices[0].message.tool_calls: - result = adapter.execute(call.function.name, json.loads(call.function.arguments)) -""" - -from __future__ import annotations - -import json -import logging -from typing import Any, Callable, Dict, List, Optional - -logger = logging.getLogger(__name__) - - -class OpenAIToolAdapter: - """Wraps DheePlugin as OpenAI-compatible function calling tools. - - Provides tool_definitions() for the API request and execute() for - dispatching tool calls from the response. - """ - - def __init__(self, plugin: Any): - """ - Args: - plugin: A DheePlugin instance. - """ - self._plugin = plugin - self._dispatchers: Dict[str, Callable] = { - "remember": self._exec_remember, - "recall": self._exec_recall, - "context": self._exec_context, - "checkpoint": self._exec_checkpoint, - "session_start": self._exec_session_start, - "session_end": self._exec_session_end, - } - - def tool_definitions(self, include_session: bool = False) -> List[Dict[str, Any]]: - """Return OpenAI-format tool definitions. - - Args: - include_session: If True, also includes session_start and - session_end as callable tools. - """ - tools = self._plugin.as_openai_functions() - - if include_session: - tools.extend([ - { - "type": "function", - "function": { - "name": "session_start", - "description": ( - "Start a Dhee cognition session. Returns a frozen context " - "block. Call once at the beginning of a task." - ), - "parameters": { - "type": "object", - "properties": { - "task_description": { - "type": "string", - "description": "What you're about to work on", - }, - }, - }, - }, - }, - { - "type": "function", - "function": { - "name": "session_end", - "description": ( - "End the current Dhee session and save learnings." - ), - "parameters": { - "type": "object", - "properties": { - "summary": { - "type": "string", - "description": "What you accomplished", - }, - "outcome_score": { - "type": "number", - "description": "0.0-1.0 outcome score", - }, - "what_worked": { - "type": "string", - "description": "Approach that worked", - }, - "what_failed": { - "type": "string", - "description": "Approach that failed", - }, - }, - "required": ["summary"], - }, - }, - }, - ]) - - return tools - - def execute(self, function_name: str, arguments: Dict[str, Any]) -> str: - """Execute a tool call and return the JSON-encoded result. - - This is the glue between OpenAI's tool_call response and DheePlugin. - - Args: - function_name: The function name from the tool call. - arguments: Parsed JSON arguments from the tool call. - - Returns: - JSON string suitable for a tool message. - """ - dispatcher = self._dispatchers.get(function_name) - if not dispatcher: - return json.dumps({"error": f"Unknown function: {function_name}"}) - - try: - result = dispatcher(arguments) - return json.dumps(result, default=str, ensure_ascii=False) - except Exception as e: - logger.warning("Tool execution failed for %s: %s", function_name, e) - return json.dumps({"error": str(e)}) - - def _exec_remember(self, args: Dict[str, Any]) -> Any: - return self._plugin.remember( - content=args["content"], - user_id=args.get("user_id"), - ) - - def _exec_recall(self, args: Dict[str, Any]) -> Any: - return self._plugin.recall( - query=args["query"], - user_id=args.get("user_id"), - limit=args.get("limit", 5), - ) - - def _exec_context(self, args: Dict[str, Any]) -> Any: - return self._plugin.context( - task_description=args.get("task_description"), - user_id=args.get("user_id"), - ) - - def _exec_checkpoint(self, args: Dict[str, Any]) -> Any: - return self._plugin.checkpoint( - summary=args["summary"], - task_type=args.get("task_type"), - outcome_score=args.get("outcome_score"), - what_worked=args.get("what_worked"), - what_failed=args.get("what_failed"), - remember_to=args.get("remember_to"), - trigger_keywords=args.get("trigger_keywords"), - ) - - def _exec_session_start(self, args: Dict[str, Any]) -> Any: - prompt = self._plugin.session_start( - task_description=args.get("task_description"), - ) - return {"system_prompt": prompt} - - def _exec_session_end(self, args: Dict[str, Any]) -> Any: - return self._plugin.session_end( - summary=args["summary"], - outcome_score=args.get("outcome_score"), - what_worked=args.get("what_worked"), - what_failed=args.get("what_failed"), - ) diff --git a/dhee/adapters/system_prompt.py b/dhee/adapters/system_prompt.py deleted file mode 100644 index 60b39f9..0000000 --- a/dhee/adapters/system_prompt.py +++ /dev/null @@ -1,245 +0,0 @@ -"""Frozen snapshot generator — renders DheePlugin context as a system prompt. - -For agents that don't support tool calling (e.g., simple prompt → completion -workflows, humanoid robot controllers, voice assistants), this module generates -a self-contained system prompt block that includes the full HyperContext. - -The "frozen snapshot" pattern (from NousResearch Hermes Agent architecture): - 1. At session start, load HyperContext into the system prompt. - 2. During the session, the system prompt is NEVER mutated — this preserves - LLM KV-cache / prefix caches for fast inference. - 3. At session end, new knowledge is written to storage for next time. - -Usage: - from dhee import DheePlugin - from dhee.adapters.system_prompt import generate_snapshot, SnapshotConfig - - plugin = DheePlugin() - prompt = generate_snapshot(plugin, task="fixing auth bug") - - # Or with custom config: - config = SnapshotConfig( - include_memories=True, - include_heuristics=True, - max_memories=10, - include_tool_instructions=True, - ) - prompt = generate_snapshot(plugin, task="fixing auth bug", config=config) -""" - -from __future__ import annotations - -import logging -import textwrap -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional - -logger = logging.getLogger(__name__) - - -@dataclass -class SnapshotConfig: - """Controls what goes into the frozen snapshot.""" - - include_performance: bool = True - include_warnings: bool = True - include_insights: bool = True - include_intentions: bool = True - include_contrasts: bool = True - include_heuristics: bool = True - include_memories: bool = True - include_tool_instructions: bool = False - include_hive: bool = False - - max_performance: int = 5 - max_warnings: int = 5 - max_insights: int = 5 - max_intentions: int = 5 - max_contrasts: int = 3 - max_heuristics: int = 3 - max_memories: int = 10 - - # Prefix/suffix for wrapping the snapshot - header: str = "## Dhee Cognition Context (Frozen Snapshot)" - footer: str = "" - - -# Tool usage instructions (for agents that CAN call tools after loading snapshot) -_TOOL_INSTRUCTIONS = """\ -### Available Memory Tools -- **remember(content)** — store a new fact/observation -- **recall(query)** — search memory for relevant facts -- **context(task)** — load full HyperContext (already loaded above) -- **checkpoint(summary, ...)** — save session state and learnings - -Use `remember` proactively when you learn new facts. Use `recall` before -answering questions that may depend on stored knowledge. Use `checkpoint` -at natural breakpoints and at session end. -""" - - -def generate_snapshot( - plugin: Any, - task: Optional[str] = None, - user_id: Optional[str] = None, - config: Optional[SnapshotConfig] = None, - hive: Optional[Any] = None, -) -> str: - """Generate a frozen system prompt snapshot from DheePlugin. - - Args: - plugin: A DheePlugin instance. - task: Current task description. - user_id: User identifier. - config: Snapshot configuration. Defaults to include everything. - hive: Optional HiveMemory instance for multi-agent context. - - Returns: - A complete system prompt block as a string. - """ - cfg = config or SnapshotConfig() - ctx = plugin.context(task_description=task, user_id=user_id) - - parts: List[str] = [cfg.header] - - if task: - parts.append(f"\n**Current task:** {task}") - - # Performance - if cfg.include_performance: - perf = ctx.get("performance", [])[:cfg.max_performance] - if perf: - parts.append("\n### Performance History") - for p in perf: - trend = p.get("trend", 0) - direction = "improving" if trend > 0 else "declining" if trend < 0 else "stable" - parts.append( - f"- **{p['task_type']}**: avg={p['avg_score']:.2f}, " - f"trend={p['trend']:+.3f} ({direction}), " - f"attempts={p['total_attempts']}" - ) - - # Warnings - if cfg.include_warnings: - warnings = ctx.get("warnings", [])[:cfg.max_warnings] - if warnings: - parts.append("\n### Warnings") - for w in warnings: - parts.append(f"- {w}") - - # Insights - if cfg.include_insights: - insights = ctx.get("insights", [])[:cfg.max_insights] - if insights: - parts.append("\n### Insights from Past Work") - for i in insights: - parts.append(f"- [{i.get('type', 'general')}] {i['content']}") - - # Intentions (triggered reminders) - if cfg.include_intentions: - intentions = ctx.get("intentions", [])[:cfg.max_intentions] - if intentions: - parts.append("\n### Triggered Reminders") - for i in intentions: - parts.append(f"- {i['description']}") - - # Contrastive evidence - if cfg.include_contrasts: - contrasts = ctx.get("contrasts", [])[:cfg.max_contrasts] - if contrasts: - parts.append("\n### Contrastive Evidence (Do / Avoid)") - for c in contrasts: - do_text = c.get("do", "")[:200] - avoid_text = c.get("avoid", "")[:200] - parts.append(f"- **Do:** {do_text}") - parts.append(f" **Avoid:** {avoid_text}") - confidence = c.get("confidence") - if confidence is not None: - parts.append(f" *confidence: {confidence:.1%}*") - - # Heuristics - if cfg.include_heuristics: - heuristics = ctx.get("heuristics", [])[:cfg.max_heuristics] - if heuristics: - parts.append("\n### Learned Heuristics") - for h in heuristics: - level = h.get("level", "domain") - text = h.get("heuristic", "")[:250] - parts.append(f"- [{level}] {text}") - - # Memories - if cfg.include_memories: - memories = ctx.get("memories", [])[:cfg.max_memories] - if memories: - parts.append("\n### Relevant Memories") - for m in memories: - mem_text = m.get("memory", "")[:250] - score = m.get("score", 0) - if mem_text: - parts.append(f"- {mem_text}") - if score > 0: - parts.append(f" *(relevance: {score:.2f})*") - - # Hive context - if cfg.include_hive and hive: - try: - hive_block = hive.get_context_block(limit=3) - hive_insights = hive_block.get("hive_insights", []) - hive_heuristics = hive_block.get("hive_heuristics", []) - - if hive_insights or hive_heuristics: - parts.append("\n### Hive Knowledge (from other agents)") - for hi in hive_insights: - parts.append( - f"- [insight from {hi['source']}] " - f"{hi['content'].get('content', '')[:150]}" - ) - for hh in hive_heuristics: - parts.append( - f"- [heuristic from {hh['source']}] " - f"{hh['content'].get('heuristic', '')[:150]}" - ) - except Exception as e: - logger.debug("Hive context failed: %s", e) - - # Tool instructions - if cfg.include_tool_instructions: - parts.append("\n" + _TOOL_INSTRUCTIONS.strip()) - - # Meta - meta = ctx.get("meta", {}) - if meta: - meta_parts = [] - for key in ["insight_count", "intention_count", "contrast_count", "heuristic_count"]: - val = meta.get(key, 0) - if val > 0: - meta_parts.append(f"{key.replace('_count', '')}s: {val}") - if meta_parts: - parts.append(f"\n*Loaded: {', '.join(meta_parts)}*") - - if cfg.footer: - parts.append(f"\n{cfg.footer}") - - return "\n".join(parts) - - -def generate_minimal_snapshot( - plugin: Any, - task: Optional[str] = None, - user_id: Optional[str] = None, -) -> str: - """Generate a minimal snapshot — just warnings, intentions, and top memories. - - Suitable for edge/embedded agents with tight context budgets. - """ - cfg = SnapshotConfig( - include_performance=False, - include_insights=False, - include_contrasts=False, - include_heuristics=False, - max_warnings=3, - max_intentions=3, - max_memories=3, - header="## Dhee Context (Minimal)", - ) - return generate_snapshot(plugin, task=task, user_id=user_id, config=cfg) diff --git a/dhee/api/__init__.py b/dhee/api/__init__.py deleted file mode 100644 index 017dcb4..0000000 --- a/dhee/api/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Engram REST API module.""" - -from dhee.api.app import app -from dhee.api.server import run - -__all__ = ["app", "run"] diff --git a/dhee/api/app.py b/dhee/api/app.py deleted file mode 100644 index f49f382..0000000 --- a/dhee/api/app.py +++ /dev/null @@ -1,198 +0,0 @@ -"""Engram core REST API — lightweight handoff endpoints (no auth required). - -These endpoints mirror the enterprise ``/v1/handoff/*`` routes but delegate -directly to ``engram.core.kernel`` without session/token enforcement. -They are intended for local development and for the ``prompt_context.py`` hook -which fires as a subprocess with no auth context. -""" - -from __future__ import annotations - -import logging -import os -from typing import Any, Dict, List, Optional - -from fastapi import FastAPI, Query -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel, Field - -logger = logging.getLogger(__name__) - -app = FastAPI( - title="Engram Core API", - version="0.1.0", - description="Lightweight handoff + memory endpoints.", -) - -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_methods=["*"], - allow_headers=["*"], -) - - -# --------------------------------------------------------------------------- -# Health -# --------------------------------------------------------------------------- - -@app.get("/health") -async def health(): - return {"status": "ok"} - - -# --------------------------------------------------------------------------- -# Request / response schemas -# --------------------------------------------------------------------------- - -class CheckpointRequest(BaseModel): - task_summary: Optional[str] = None - event_type: str = "hook_checkpoint" - agent_id: str = "claude-code" - context_snapshot: Optional[str] = None - repo_path: Optional[str] = None - status: Optional[str] = None - decisions_made: Optional[List[str]] = None - files_touched: Optional[List[str]] = None - todos_remaining: Optional[List[str]] = None - blockers: Optional[List[str]] = None - key_commands: Optional[List[str]] = None - test_results: Optional[str] = None - - -class RecoverRequest(BaseModel): - repo_path: str - agent_id: str = "claude-code" - - -class SessionDigestRequest(BaseModel): - task_summary: str - repo: Optional[str] = None - status: str = "active" - agent_id: str = "claude-code" - decisions_made: Optional[List[str]] = None - files_touched: Optional[List[str]] = None - todos_remaining: Optional[List[str]] = None - blockers: Optional[List[str]] = None - key_commands: Optional[List[str]] = None - test_results: Optional[str] = None - - -# --------------------------------------------------------------------------- -# Handoff endpoints -# --------------------------------------------------------------------------- - -@app.post("/v1/handoff/checkpoint") -async def handoff_checkpoint(request: CheckpointRequest): - """Receive a lightweight checkpoint from the hook or an agent. - - Creates a dhee-bus session (if needed) and writes a checkpoint snapshot. - """ - from dhee.core.kernel import _get_bus - - bus = None - try: - bus = _get_bus() - - # Find or create a session for this agent - session = bus.get_session( - agent_id=request.agent_id, - repo=request.repo_path, - ) - if session is None: - sid = bus.save_session( - agent_id=request.agent_id, - repo=request.repo_path, - status=request.status or "active", - task_summary=request.task_summary or "", - ) - else: - sid = session["id"] - # Update task_summary if provided - updates: Dict[str, Any] = {} - if request.task_summary: - updates["task_summary"] = request.task_summary - if request.status: - updates["status"] = request.status - if updates: - bus.update_session(sid, **updates) - - snapshot = { - "event_type": request.event_type, - "task_summary": request.task_summary, - "context_snapshot": request.context_snapshot, - "files_touched": request.files_touched or [], - "key_commands": request.key_commands or [], - "decisions_made": request.decisions_made or [], - "todos_remaining": request.todos_remaining or [], - "blockers": request.blockers or [], - "test_results": request.test_results, - } - cid = bus.checkpoint(sid, request.agent_id, snapshot) - - return {"status": "ok", "session_id": sid, "checkpoint_id": cid} - - except Exception as exc: - logger.exception("Checkpoint failed") - return {"status": "error", "detail": str(exc)} - finally: - if bus is not None: - try: - bus.close() - except Exception: - pass - - -@app.get("/v1/handoff/sessions/last") -async def handoff_last_session( - agent_id: Optional[str] = Query(default=None), - repo: Optional[str] = Query(default=None), - fallback_log_recovery: bool = Query(default=True), -): - """Get the last session, falling back to JSONL log parsing.""" - from dhee.core.kernel import get_last_session - - session = get_last_session( - agent_id=agent_id or "mcp-server", - repo=repo, - fallback_log_recovery=fallback_log_recovery, - ) - if session is None: - return {"status": "no_session", "message": "No previous session found."} - return session - - -@app.post("/v1/handoff/recover") -async def handoff_recover(request: RecoverRequest): - """Direct log recovery — parse JSONL logs without checking bus first.""" - from dhee.core.log_parser import find_latest_log, parse_conversation_log - - log_path = find_latest_log(request.repo_path) - if log_path is None: - return {"status": "no_logs", "message": "No conversation logs found."} - - digest = parse_conversation_log(log_path) - if digest.get("message_count", 0) == 0: - return {"status": "empty_log", "message": "Log file was empty."} - - return digest - - -@app.post("/v1/handoff/sessions/digest") -async def save_handoff_digest(request: SessionDigestRequest): - """Save a session digest (lightweight, no auth).""" - from dhee.core.kernel import save_session_digest - - result = save_session_digest( - task_summary=request.task_summary, - agent_id=request.agent_id, - repo=request.repo, - status=request.status, - decisions_made=request.decisions_made, - files_touched=request.files_touched, - todos_remaining=request.todos_remaining, - blockers=request.blockers, - key_commands=request.key_commands, - test_results=request.test_results, - ) - return result diff --git a/dhee/api/server.py b/dhee/api/server.py deleted file mode 100644 index adc2e34..0000000 --- a/dhee/api/server.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Engram core API server runner. - -Starts the lightweight FastAPI app from ``engram.api.app`` on the configured -host/port. This is the standalone server — the enterprise version adds auth, -governance, and more endpoints on top. -""" - -from __future__ import annotations - - -def run(): - """Run the Engram core API server.""" - import argparse - import uvicorn - - parser = argparse.ArgumentParser(description="Engram Core API Server") - parser.add_argument("--host", default="127.0.0.1", help="Host to bind to") - parser.add_argument("--port", type=int, default=8100, help="Port to listen on") - parser.add_argument("--reload", action="store_true", help="Enable auto-reload") - args = parser.parse_args() - - print(f"Starting Engram Core API on http://{args.host}:{args.port}") - print(f"Docs at http://{args.host}:{args.port}/docs") - - uvicorn.run( - "engram.api.app:app", - host=args.host, - port=args.port, - reload=args.reload, - ) - - -if __name__ == "__main__": - run() diff --git a/dhee/api/static/dashboard.html b/dhee/api/static/dashboard.html deleted file mode 100644 index fc0bc8b..0000000 --- a/dhee/api/static/dashboard.html +++ /dev/null @@ -1,736 +0,0 @@ - - - - - -Engram Memory Visualizer - - - - - - - - -
-
- - -
- -
- -
-
-
-

Memory Layers

-

Top Categories

-

Decay History

-
-
- - -
-
-
- -
- -
-
- - -
-
-
- - -
-
-
- - -
-
-
- - -
-
-
-
-
-
- - - - - - diff --git a/dhee/benchmarks/hippocamp.py b/dhee/benchmarks/hippocamp.py new file mode 100644 index 0000000..ba0d2f5 --- /dev/null +++ b/dhee/benchmarks/hippocamp.py @@ -0,0 +1,1272 @@ +"""HippoCamp benchmark runner for Dhee. + +This runner targets the official HippoCamp release artifacts on Hugging Face. +It supports: + +* ``gold``: the released ``HippoCamp_Gold`` parsed-text setting. Useful for + ablations, but easier than the paper's default raw-file benchmark. +* ``raw``: raw file-tree ingestion using Dhee's own deterministic parsers and + metadata extraction. This follows the raw-file exposure boundary, but remains + only partially multimodal unless augmented with OCR / ASR / vision. + +Usage: + python -m dhee.benchmarks.hippocamp \ + --config adam_subset \ + --mode raw \ + --embedder-provider simple \ + --llm-provider mock \ + --answer-strategy extractive +""" + +from __future__ import annotations + +import argparse +import json +import logging +import math +import os +import re +import time +from dataclasses import dataclass +from pathlib import Path, PurePosixPath +from statistics import mean +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +from dhee.benchmarks.longmemeval import build_memory +from dhee.benchmarks.raw_extractors import raw_file_to_items +from dhee.memory.utils import strip_code_fences +from dhee.utils.factory import LLMFactory + +logger = logging.getLogger("dhee.benchmarks.hippocamp") + +DEFAULT_REPO_ID = "MMMem-org/HippoCamp" +NO_INFO = "No information available" + +_JSON_BLOCK_RE = re.compile(r"\{.*\}", re.S) +_TOKEN_RE = re.compile(r"[a-z0-9]+") +_WHITESPACE_RE = re.compile(r"\s+") + + +def _require_huggingface_hub(): + try: + from huggingface_hub import hf_hub_download, list_repo_files + except ImportError as exc: + raise ImportError( + "HippoCamp benchmark support requires 'huggingface_hub'. " + "Install it with `pip install 'dhee[benchmarks]'` or `pip install huggingface_hub`." + ) from exc + return hf_hub_download, list_repo_files + + +@dataclass(frozen=True) +class HippoCampConfigSpec: + name: str + profile: str + manifest_path: str + environment_prefix: str + + +CONFIG_SPECS: Dict[str, HippoCampConfigSpec] = { + "adam_fullset": HippoCampConfigSpec( + name="adam_fullset", + profile="Adam", + manifest_path="Adam/Fullset/Adam.json", + environment_prefix="Adam/Fullset/Adam/", + ), + "adam_subset": HippoCampConfigSpec( + name="adam_subset", + profile="Adam", + manifest_path="Adam/Subset/Adam_Subset.json", + environment_prefix="Adam/Subset/Adam_Subset/", + ), + "bei_fullset": HippoCampConfigSpec( + name="bei_fullset", + profile="Bei", + manifest_path="Bei/Fullset/Bei.json", + environment_prefix="Bei/Fullset/Bei/", + ), + "bei_subset": HippoCampConfigSpec( + name="bei_subset", + profile="Bei", + manifest_path="Bei/Subset/Bei_Subset.json", + environment_prefix="Bei/Subset/Bei_Subset/", + ), + "victoria_fullset": HippoCampConfigSpec( + name="victoria_fullset", + profile="Victoria", + manifest_path="Victoria/Fullset/Victoria.json", + environment_prefix="Victoria/Fullset/Victoria/", + ), + "victoria_subset": HippoCampConfigSpec( + name="victoria_subset", + profile="Victoria", + manifest_path="Victoria/Subset/Victoria_Subset.json", + environment_prefix="Victoria/Subset/Victoria_Subset/", + ), +} + + +def _configure_logging(level: str) -> None: + logging.basicConfig( + level=getattr(logging, level.upper(), logging.INFO), + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + ) + + +def _persist_json(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, ensure_ascii=False, indent=2) + + +def _append_jsonl(path: Path, payload: Dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(payload, ensure_ascii=False) + "\n") + + +def _load_checkpoint(checkpoint_path: Path) -> Tuple[List[Dict[str, Any]], set]: + """Load completed records from a checkpoint JSONL file. + + Returns (records_list, set_of_completed_question_ids). + """ + records: List[Dict[str, Any]] = [] + completed_ids: set = set() + if not checkpoint_path.exists(): + return records, completed_ids + for line in checkpoint_path.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line: + continue + try: + record = json.loads(line) + except json.JSONDecodeError: + continue + if isinstance(record, dict) and record.get("id"): + records.append(record) + completed_ids.add(str(record["id"])) + logger.info( + "Resumed from checkpoint: %d completed questions loaded from %s", + len(completed_ids), + checkpoint_path, + ) + return records, completed_ids + + +def _memory_has_data(history_db_path: Path, user_id: str) -> bool: + """Check if the SQLite memory DB already has indexed data for the user.""" + if not history_db_path.exists(): + return False + try: + import sqlite3 as _sqlite3 + + conn = _sqlite3.connect(str(history_db_path)) + cursor = conn.execute( + "SELECT count(*) FROM memories WHERE user_id = ?", (user_id,) + ) + count = cursor.fetchone()[0] + conn.close() + return count > 0 + except Exception: + return False + + +def _load_env_file(env_path: Path) -> int: + if not env_path.exists() or not env_path.is_file(): + return 0 + loaded = 0 + for raw_line in env_path.read_text(encoding="utf-8", errors="ignore").splitlines(): + line = raw_line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, value = line.split("=", 1) + key = key.strip() + value = value.strip().strip('"').strip("'") + if not key or key in os.environ: + continue + os.environ[key] = value + loaded += 1 + return loaded + + +def _resolve_config(name: str) -> HippoCampConfigSpec: + try: + return CONFIG_SPECS[name] + except KeyError as exc: + supported = ", ".join(sorted(CONFIG_SPECS)) + raise ValueError(f"Unsupported HippoCamp config '{name}'. Supported: {supported}") from exc + + +def _normalize_answer(text: str) -> str: + cleaned = strip_code_fences(str(text or "")).strip() + cleaned = cleaned.replace("Answer:", "").replace("Final answer:", "").strip() + cleaned = _WHITESPACE_RE.sub(" ", cleaned) + return cleaned + + +def _normalize_for_match(text: str) -> str: + lowered = _normalize_answer(text).lower() + lowered = re.sub(r"[^a-z0-9\s]", " ", lowered) + return _WHITESPACE_RE.sub(" ", lowered).strip() + + +def _tokenize(text: str) -> List[str]: + return _TOKEN_RE.findall(_normalize_for_match(text)) + + +def _preview_text(text: str, *, limit: int = 160) -> str: + cleaned = _WHITESPACE_RE.sub(" ", str(text or "")).strip() + if len(cleaned) <= limit: + return cleaned + return cleaned[: max(0, limit - 3)].rstrip() + "..." + + +def token_f1(prediction: str, gold: str) -> float: + pred_tokens = _tokenize(prediction) + gold_tokens = _tokenize(gold) + if not pred_tokens and not gold_tokens: + return 1.0 + if not pred_tokens or not gold_tokens: + return 0.0 + + pred_counts: Dict[str, int] = {} + gold_counts: Dict[str, int] = {} + for token in pred_tokens: + pred_counts[token] = pred_counts.get(token, 0) + 1 + for token in gold_tokens: + gold_counts[token] = gold_counts.get(token, 0) + 1 + + overlap = 0 + for token, pred_count in pred_counts.items(): + overlap += min(pred_count, gold_counts.get(token, 0)) + if overlap == 0: + return 0.0 + + precision = overlap / len(pred_tokens) + recall = overlap / len(gold_tokens) + return (2 * precision * recall) / (precision + recall) + + +def exact_match(prediction: str, gold: str) -> bool: + return _normalize_for_match(prediction) == _normalize_for_match(gold) + + +def file_retrieval_metrics(predicted_paths: Sequence[str], gold_paths: Sequence[str]) -> Dict[str, float]: + predicted = {str(path).strip() for path in predicted_paths if str(path).strip()} + gold = {str(path).strip() for path in gold_paths if str(path).strip()} + + if not gold and not predicted: + return {"precision": 1.0, "recall": 1.0, "f1": 1.0} + if not predicted: + return {"precision": 0.0, "recall": 0.0, "f1": 0.0} + if not gold: + return {"precision": 0.0, "recall": 0.0, "f1": 0.0} + + true_positives = len(predicted & gold) + precision = true_positives / len(predicted) + recall = true_positives / len(gold) + if precision + recall == 0: + f1 = 0.0 + else: + f1 = (2 * precision * recall) / (precision + recall) + return {"precision": precision, "recall": recall, "f1": f1} + + +def _extract_json_block(text: str) -> Optional[Dict[str, Any]]: + cleaned = strip_code_fences(str(text or "")).strip() + match = _JSON_BLOCK_RE.search(cleaned) + if not match: + return None + try: + payload = json.loads(match.group(0)) + except json.JSONDecodeError: + return None + return payload if isinstance(payload, dict) else None + + +def _make_gold_repo_path(profile: str, relative_path: str) -> str: + rel = PurePosixPath(str(relative_path)) + return f"HippoCamp_Gold/{profile}/{rel.with_suffix('.json').as_posix()}" + + +def _relative_path_from_environment(spec: HippoCampConfigSpec, repo_path: str) -> str: + if not str(repo_path).startswith(spec.environment_prefix): + raise ValueError(f"Path '{repo_path}' is outside '{spec.environment_prefix}'") + return str(repo_path)[len(spec.environment_prefix):] + + +def _render_file_header( + *, + profile: str, + config_name: str, + relative_path: str, + file_info: Dict[str, Any], + summary: str, +) -> str: + header_lines = [ + "[HippoCamp Gold File]", + f"Profile: {profile}", + f"Config: {config_name}", + f"Relative Path: {relative_path}", + ] + for key, label in ( + ("file_modality", "Modality"), + ("file_type", "File Type"), + ("creation_date", "Created"), + ("modification_date", "Modified"), + ("location", "Location"), + ("latitude", "Latitude"), + ("longitude", "Longitude"), + ): + value = str(file_info.get(key) or "").strip() + if value: + header_lines.append(f"{label}: {value}") + summary_text = str(summary or "").strip() + if summary_text: + header_lines.append(f"Summary: {summary_text[:1200]}") + return "\n".join(header_lines) + + +def _segment_label(segment: Dict[str, Any], index: int) -> str: + labels = [f"segment={index}"] + for key in ("page", "timestamp", "frame", "sheet", "row"): + value = segment.get(key) + if value not in (None, "", []): + labels.append(f"{key}={value}") + return "[" + ", ".join(labels) + "]" + + +def gold_document_to_items( + *, + doc: Dict[str, Any], + profile: str, + config_name: str, + relative_path: str, + chunk_chars: int, +) -> List[Dict[str, Any]]: + file_info = dict(doc.get("file_info") or {}) + summary = str(doc.get("summary") or "").strip() + header = _render_file_header( + profile=profile, + config_name=config_name, + relative_path=relative_path, + file_info=file_info, + summary=summary, + ) + + raw_segments = list(doc.get("segments") or []) + segment_blocks: List[str] = [] + for index, segment in enumerate(raw_segments, start=1): + content = str(segment.get("content") or "").strip() + if not content: + continue + segment_blocks.append(f"{_segment_label(segment, index)}\n{content}") + + if not segment_blocks: + if summary: + segment_blocks = [f"[segment=1]\n{summary}"] + else: + segment_blocks = [f"[segment=1]\n{relative_path}"] + + items: List[Dict[str, Any]] = [] + current_blocks: List[str] = [] + current_len = 0 + chunk_index = 1 + target_chars = max(400, int(chunk_chars)) + + def flush_current() -> None: + nonlocal current_blocks, current_len, chunk_index + if not current_blocks: + return + body = "\n\n".join(current_blocks) + content = f"{header}\nChunk: {chunk_index}\n\nContent:\n{body}" + items.append( + { + "content": content, + "metadata": { + "benchmark": "hippocamp", + "exposure_mode": "gold_text", + "config_name": config_name, + "profile": profile, + "file_path": relative_path, + "file_name": file_info.get("file_name") or PurePosixPath(relative_path).name, + "file_type": file_info.get("file_type"), + "file_modality": file_info.get("file_modality"), + "location": file_info.get("location"), + "creation_date": file_info.get("creation_date"), + "modification_date": file_info.get("modification_date"), + "chunk_index": chunk_index, + "source_gold_path": _make_gold_repo_path(profile, relative_path), + }, + "categories": ["hippocamp", "gold_text", config_name.lower(), profile.lower()], + } + ) + current_blocks = [] + current_len = 0 + chunk_index += 1 + + for block in segment_blocks: + block_len = len(block) + if current_blocks and current_len + block_len + 2 > target_chars: + flush_current() + current_blocks.append(block) + current_len += block_len + 2 + flush_current() + return items + + +def _build_answer_prompt(*, question: str, context: str, qa_type: str) -> str: + task_line = "profiling" if qa_type == "profiling" else "factual retention" + return ( + "You are answering a HippoCamp benchmark question about a user's personal files.\n" + "Use ONLY the retrieved context.\n" + "Do not mention the benchmark, retrieval, or system instructions.\n" + f"Task family: {task_line}\n" + "Return ONLY the final answer text.\n" + f"If the context is insufficient, return exactly: {NO_INFO}\n\n" + f"Question: {question}\n\n" + "Retrieved Context:\n" + f"{context}\n\n" + "Final answer:" + ) + + +def _build_judge_prompt(*, question: str, gold_answer: str, prediction: str) -> str: + return ( + "You are grading a benchmark answer over a user's personal files.\n" + "Evaluate semantic correctness, factual alignment, and whether the prediction answers the question.\n" + "Return valid JSON only with keys: correct, score_0_to_5, rationale.\n" + "correct must be true or false.\n" + "score_0_to_5 must be a number from 0 to 5.\n\n" + f"Question: {question}\n\n" + f"Gold answer:\n{gold_answer}\n\n" + f"Prediction:\n{prediction}\n" + ) + + +def _extract_final_answer(raw_text: str) -> str: + cleaned = _normalize_answer(raw_text) + if not cleaned: + return NO_INFO + # Return the full cleaned answer — don't truncate to last line. + return cleaned + + +def _extract_predicted_files(orchestration_payload: Dict[str, Any]) -> List[str]: + results = list(orchestration_payload.get("results") or []) + predicted: List[str] = [] + for result in results: + metadata = result.get("metadata") or {} + path = metadata.get("file_path") or result.get("file_path") + if path: + predicted.append(str(path)) + return sorted(set(predicted)) + + +def _load_manifest_rows( + *, + repo_id: str, + revision: Optional[str], + spec: HippoCampConfigSpec, + qa_type: str, + max_samples: int, +) -> List[Dict[str, Any]]: + hf_hub_download, _ = _require_huggingface_hub() + manifest_path = hf_hub_download( + repo_id=repo_id, + repo_type="dataset", + filename=spec.manifest_path, + revision=revision, + ) + with open(manifest_path, "r", encoding="utf-8") as handle: + rows = json.load(handle) + if not isinstance(rows, list): + raise ValueError(f"HippoCamp manifest is not a list: {spec.manifest_path}") + + filtered: List[Dict[str, Any]] = [] + for row in rows: + if not isinstance(row, dict): + continue + row_qa_type = str(row.get("QA_type") or "").strip() + if qa_type != "all" and row_qa_type != qa_type: + continue + filtered.append(row) + if max_samples > 0: + filtered = filtered[: max_samples] + return filtered + + +def _list_environment_repo_files( + *, + repo_id: str, + revision: Optional[str], + spec: HippoCampConfigSpec, +) -> List[str]: + _, hf_list_repo_files = _require_huggingface_hub() + files = hf_list_repo_files(repo_id=repo_id, repo_type="dataset", revision=revision) + return sorted(path for path in files if str(path).startswith(spec.environment_prefix)) + + +def _build_environment_items( + *, + repo_id: str, + revision: Optional[str], + spec: HippoCampConfigSpec, + max_environment_files: int, + chunk_chars: int, +) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + hf_hub_download, _ = _require_huggingface_hub() + repo_files = _list_environment_repo_files(repo_id=repo_id, revision=revision, spec=spec) + if max_environment_files > 0: + repo_files = repo_files[: max_environment_files] + + items: List[Dict[str, Any]] = [] + missing_gold: List[str] = [] + + for repo_path in repo_files: + relative_path = _relative_path_from_environment(spec, repo_path) + gold_repo_path = _make_gold_repo_path(spec.profile, relative_path) + try: + local_path = hf_hub_download( + repo_id=repo_id, + repo_type="dataset", + filename=gold_repo_path, + revision=revision, + ) + except Exception: + missing_gold.append(relative_path) + continue + + with open(local_path, "r", encoding="utf-8") as handle: + document = json.load(handle) + if not isinstance(document, dict): + logger.warning("Skipping malformed gold file: %s", gold_repo_path) + continue + items.extend( + gold_document_to_items( + doc=document, + profile=spec.profile, + config_name=spec.name, + relative_path=relative_path, + chunk_chars=chunk_chars, + ) + ) + + metadata = { + "environment_repo_files": len(repo_files), + "environment_index_items": len(items), + "missing_gold_files": missing_gold, + } + return items, metadata + + +def _build_raw_environment_items( + *, + repo_id: str, + revision: Optional[str], + spec: HippoCampConfigSpec, + max_environment_files: int, + chunk_chars: int, +) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + hf_hub_download, _ = _require_huggingface_hub() + repo_files = _list_environment_repo_files(repo_id=repo_id, revision=revision, spec=spec) + if max_environment_files > 0: + repo_files = repo_files[: max_environment_files] + + items: List[Dict[str, Any]] = [] + mode_counts: Dict[str, int] = {} + metadata_only_files: List[str] = [] + + for repo_path in repo_files: + relative_path = _relative_path_from_environment(spec, repo_path) + local_path = Path( + hf_hub_download( + repo_id=repo_id, + repo_type="dataset", + filename=repo_path, + revision=revision, + ) + ) + result = raw_file_to_items( + local_path=local_path, + relative_path=relative_path, + profile=spec.profile, + config_name=spec.name, + chunk_chars=chunk_chars, + ) + mode_counts[result.mode] = mode_counts.get(result.mode, 0) + 1 + if result.mode == "metadata_only": + metadata_only_files.append(relative_path) + items.extend(result.items) + + metadata = { + "environment_repo_files": len(repo_files), + "environment_index_items": len(items), + "raw_extraction_modes": mode_counts, + "metadata_only_files": metadata_only_files, + } + return items, metadata + + +def _make_llm( + *, + provider: str, + model: Optional[str], + max_tokens: int, + timeout: int, + temperature: float, + top_p: float, + enable_thinking: bool = False, +) -> Any: + config: Dict[str, Any] = { + "max_tokens": max(32, int(max_tokens)), + "timeout": max(1, int(timeout)), + "temperature": temperature, + "top_p": top_p, + "enable_thinking": bool(enable_thinking), + } + if model: + config["model"] = model + return LLMFactory.create(provider, config) + + +def _maybe_judge( + *, + judge_llm: Optional[Any], + question: str, + gold_answer: str, + prediction: str, +) -> Optional[Dict[str, Any]]: + if judge_llm is None: + return None + prompt = _build_judge_prompt(question=question, gold_answer=gold_answer, prediction=prediction) + try: + raw = str(judge_llm.generate(prompt)).strip() + except Exception as exc: + logger.warning("Judge generation failed: %s", exc) + return None + payload = _extract_json_block(raw) + if not payload: + return None + score = payload.get("score_0_to_5") + try: + score_value = max(0.0, min(5.0, float(score))) + except (TypeError, ValueError): + score_value = None + return { + "correct": bool(payload.get("correct")), + "score_0_to_5": score_value, + "rationale": str(payload.get("rationale") or "").strip(), + "raw": raw, + } + + +def _generate_prediction( + *, + answer_llm: Optional[Any], + answer_strategy: str, + question: str, + qa_type: str, + context: str, + reduced_answer: str, +) -> str: + strategy = answer_strategy + if strategy == "auto": + strategy = "llm" if answer_llm is not None else "extractive" + + if strategy == "extractive": + return _normalize_answer(reduced_answer) or NO_INFO + + if answer_llm is None: + return _normalize_answer(reduced_answer) or NO_INFO + + prompt = _build_answer_prompt(question=question, context=context, qa_type=qa_type) + try: + raw = str(answer_llm.generate(prompt)).strip() + except Exception as exc: + logger.warning("Answer generation failed: %s", exc) + return _normalize_answer(reduced_answer) or NO_INFO + + prediction = _extract_final_answer(raw) + if prediction == NO_INFO and reduced_answer: + fallback = _normalize_answer(reduced_answer) + if fallback: + return fallback + return prediction + + +def _score_record( + *, + row: Dict[str, Any], + prediction: str, + predicted_files: Sequence[str], + judge_payload: Optional[Dict[str, Any]], +) -> Dict[str, Any]: + gold_answer = str(row.get("answer") or "").strip() + gold_files = [str(path) for path in row.get("file_path") or []] + retrieval = file_retrieval_metrics(predicted_files, gold_files) + record: Dict[str, Any] = { + "id": str(row.get("id") or ""), + "qa_type": str(row.get("QA_type") or "").strip(), + "profiling_type": str(row.get("profiling_type") or "").strip(), + "question": str(row.get("question") or "").strip(), + "gold_answer": gold_answer, + "prediction": prediction, + "exact_match": exact_match(prediction, gold_answer), + "token_f1": round(token_f1(prediction, gold_answer), 4), + "gold_files": gold_files, + "predicted_files": list(predicted_files), + "file_precision": round(retrieval["precision"], 4), + "file_recall": round(retrieval["recall"], 4), + "file_f1": round(retrieval["f1"], 4), + "agent_cap": row.get("agent_cap") or {}, + } + if judge_payload is not None: + record["judge_correct"] = bool(judge_payload.get("correct")) + record["judge_score_0_to_5"] = judge_payload.get("score_0_to_5") + record["judge_rationale"] = judge_payload.get("rationale") + return record + + +def _family_accuracy(records: Sequence[Dict[str, Any]], family: str) -> Optional[float]: + by_subcategory: Dict[str, List[bool]] = {} + for record in records: + if "judge_correct" not in record: + continue + labels = ((record.get("agent_cap") or {}).get(family) or []) + for label in labels: + by_subcategory.setdefault(str(label), []).append(bool(record["judge_correct"])) + if not by_subcategory: + return None + sub_scores = [sum(values) / len(values) for values in by_subcategory.values() if values] + return mean(sub_scores) if sub_scores else None + + +def _summarize_records(records: Sequence[Dict[str, Any]]) -> Dict[str, Any]: + if not records: + return {"count": 0} + + summary: Dict[str, Any] = { + "count": len(records), + "exact_match": round(mean(1.0 if record["exact_match"] else 0.0 for record in records), 4), + "token_f1": round(mean(float(record["token_f1"]) for record in records), 4), + "file_precision": round(mean(float(record["file_precision"]) for record in records), 4), + "file_recall": round(mean(float(record["file_recall"]) for record in records), 4), + "file_f1": round(mean(float(record["file_f1"]) for record in records), 4), + } + if any("judge_correct" in record for record in records): + judged = [record for record in records if "judge_correct" in record] + summary["judge_accuracy"] = round( + mean(1.0 if record["judge_correct"] else 0.0 for record in judged), + 4, + ) + judge_scores = [ + float(record["judge_score_0_to_5"]) + for record in judged + if record.get("judge_score_0_to_5") is not None + ] + if judge_scores: + summary["judge_avg_score_0_to_5"] = round(mean(judge_scores), 4) + summary["judge_avg_score_0_to_10"] = round(mean(score * 2.0 for score in judge_scores), 4) + + capability_breakdown = {} + for family in ("search", "evidence_perception", "reasoning"): + family_score = _family_accuracy(judged, family) + if family_score is not None: + capability_breakdown[family] = round(family_score, 4) + if capability_breakdown: + summary["capability_accuracy"] = capability_breakdown + return summary + + +def _emit_progress(progress_path: Optional[Path], event: str, **payload: Any) -> None: + if progress_path is None: + return + row = { + "ts": round(time.time(), 3), + "event": event, + **payload, + } + try: + _append_jsonl(progress_path, row) + except Exception as exc: + logger.warning("Failed to append progress event '%s' to %s: %s", event, progress_path, exc) + + +def _run_config(args: argparse.Namespace, spec: HippoCampConfigSpec) -> Dict[str, Any]: + output_root = Path(args.output_dir) + history_db_path = output_root / f"{spec.name}.sqlite" + progress_path = output_root / f"{spec.name}.progress.jsonl" + checkpoint_path = output_root / f"{spec.name}.checkpoint.jsonl" + user_id = f"hippocamp_{spec.name}" + # Don't nuke progress — append to it for resumed runs. + # Load checkpoint of previously completed questions. + resumed_records, completed_ids = _load_checkpoint(checkpoint_path) + + memory = build_memory( + llm_provider=args.llm_provider, + embedder_provider=args.embedder_provider, + vector_store_provider=args.vector_store_provider, + embedding_dims=args.embedding_dims, + history_db_path=str(history_db_path), + llm_model=args.llm_model, + llm_timeout=args.llm_timeout, + embedder_model=args.embedder_model, + full_potential=not args.minimal, + defer_enrichment=args.defer_enrichment, + enable_rerank=args.enable_rerank, + rerank_model=args.rerank_model, + rerank_config=None, + enable_episodic_index=not args.disable_episodic_index, + enable_hierarchical_retrieval=not args.disable_hierarchical_retrieval, + enable_orchestrated_search=not args.disable_orchestrated_search, + cost_guardrail_strict=not args.no_cost_guardrail_strict, + ) + + answer_llm = None + if args.answer_provider not in {"", "none", "mock"}: + answer_llm = _make_llm( + provider=args.answer_provider, + model=args.answer_model, + max_tokens=args.answer_max_tokens, + timeout=args.answer_timeout, + temperature=args.answer_temperature, + top_p=args.answer_top_p, + enable_thinking=args.answer_enable_thinking, + ) + + judge_llm = None + if args.judge_provider: + judge_llm = _make_llm( + provider=args.judge_provider, + model=args.judge_model, + max_tokens=args.judge_max_tokens, + timeout=args.judge_timeout, + temperature=args.judge_temperature, + top_p=args.judge_top_p, + enable_thinking=args.judge_enable_thinking, + ) + + rows = _load_manifest_rows( + repo_id=args.repo_id, + revision=args.revision, + spec=spec, + qa_type=args.qa_type, + max_samples=args.max_samples, + ) + + # When resuming with an existing index, skip the expensive environment download/parse. + can_skip_env = bool(completed_ids) and _memory_has_data(history_db_path, user_id) + if can_skip_env: + environment_items: List[Dict[str, Any]] = [] + environment_meta: Dict[str, Any] = {"environment_repo_files": 0, "environment_index_items": 0, "resumed": True} + logger.info("RESUME: Skipping environment download — memory DB already populated for %s", spec.name) + elif args.mode == "gold": + environment_items, environment_meta = _build_environment_items( + repo_id=args.repo_id, + revision=args.revision, + spec=spec, + max_environment_files=args.max_environment_files, + chunk_chars=args.chunk_chars, + ) + elif args.mode == "raw": + environment_items, environment_meta = _build_raw_environment_items( + repo_id=args.repo_id, + revision=args.revision, + spec=spec, + max_environment_files=args.max_environment_files, + chunk_chars=args.chunk_chars, + ) + else: + raise ValueError(f"Unsupported mode: {args.mode}") + + logger.info("Live progress %s -> %s", spec.name, progress_path) + _emit_progress( + progress_path, + "config_start", + config=spec.name, + profile=spec.profile, + mode=args.mode, + qa_type=args.qa_type, + question_count=len(rows), + environment_repo_files=environment_meta["environment_repo_files"], + environment_index_items=environment_meta["environment_index_items"], + ) + # Skip re-indexing if the memory DB already has data for this user (resume mode). + skip_indexing = bool(completed_ids) and _memory_has_data(history_db_path, user_id) + if skip_indexing: + logger.info( + "RESUME: Skipping indexing for %s — memory DB exists with data, %d questions already completed", + spec.name, + len(completed_ids), + ) + index_seconds = 0.0 + else: + logger.info( + "Indexing %s: %d files -> %d memory items", + spec.name, + environment_meta["environment_repo_files"], + environment_meta["environment_index_items"], + ) + _emit_progress( + progress_path, + "index_start", + config=spec.name, + environment_repo_files=environment_meta["environment_repo_files"], + environment_index_items=environment_meta["environment_index_items"], + ) + t0 = time.time() + memory.delete_all(user_id=user_id) + memory.add_batch(items=environment_items, user_id=user_id) + index_seconds = time.time() - t0 + _emit_progress( + progress_path, + "index_done", + config=spec.name, + index_seconds=round(index_seconds, 4), + environment=environment_meta, + ) + + if args.enrich_after_ingest: + try: + memory.enrich_pending(user_id=user_id, batch_size=args.enrich_batch_size, max_batches=args.enrich_max_batches) + _emit_progress( + progress_path, + "enrichment_done", + config=spec.name, + batch_size=args.enrich_batch_size, + max_batches=args.enrich_max_batches, + ) + except Exception as exc: + logger.warning("Post-ingest enrichment failed for %s: %s", spec.name, exc) + _emit_progress( + progress_path, + "enrichment_failed", + config=spec.name, + error=str(exc), + ) + + records: List[Dict[str, Any]] = list(resumed_records) + skipped_count = 0 + total_rows = len(rows) + for index, row in enumerate(rows, start=1): + row_id = str(row.get("id") or "") + if row_id and row_id in completed_ids: + skipped_count += 1 + continue + question = str(row.get("question") or "").strip() + qa_type = str(row.get("QA_type") or "").strip() or "factual_retention" + _emit_progress( + progress_path, + "question_start", + config=spec.name, + question_index=index, + question_count=total_rows, + qa_type=qa_type, + question=_preview_text(question, limit=240), + ) + payload = memory.search_orchestrated( + query=question, + user_id=user_id, + question_type=f"hippocamp-{qa_type}", + question_date="", + limit=args.top_k, + orchestration_mode=args.answer_orchestration_mode, + base_search_limit=args.top_k, + base_context_limit=args.answer_context_top_k, + search_cap=args.search_cap, + context_cap=args.context_cap, + map_max_candidates=args.map_max_candidates, + map_max_chars=args.map_max_chars, + keyword_search=True, + hybrid_alpha=0.7, + include_evidence=True, + evidence_strategy=args.evidence_strategy, + evidence_max_chars=args.evidence_max_chars, + evidence_context_lines=args.evidence_context_lines, + max_context_chars=args.max_context_chars, + rerank=args.enable_rerank, + orchestrator_llm=memory.llm if args.answer_orchestration_mode != "off" else None, + reflection_max_hops=1, + ) + context = str(payload.get("context") or "").strip() + reduced_answer = str(payload.get("reduced_answer") or "").strip() + predicted_files = _extract_predicted_files(payload) + _emit_progress( + progress_path, + "search_done", + config=spec.name, + question_index=index, + question_count=total_rows, + search_result_count=len(payload.get("results") or []), + predicted_files=predicted_files, + reduced_answer_preview=_preview_text(reduced_answer), + ) + _emit_progress( + progress_path, + "answer_start", + config=spec.name, + question_index=index, + question_count=total_rows, + strategy=args.answer_strategy, + provider=args.answer_provider or None, + model=args.answer_model or None, + ) + prediction = _generate_prediction( + answer_llm=answer_llm, + answer_strategy=args.answer_strategy, + question=question, + qa_type=qa_type, + context=context, + reduced_answer=reduced_answer, + ) + _emit_progress( + progress_path, + "answer_done", + config=spec.name, + question_index=index, + question_count=total_rows, + prediction_preview=_preview_text(prediction), + ) + if judge_llm is not None: + _emit_progress( + progress_path, + "judge_start", + config=spec.name, + question_index=index, + question_count=total_rows, + provider=args.judge_provider or None, + model=args.judge_model or None, + ) + judge_payload = _maybe_judge( + judge_llm=judge_llm, + question=question, + gold_answer=str(row.get("answer") or ""), + prediction=prediction, + ) + if judge_llm is not None: + _emit_progress( + progress_path, + "judge_done", + config=spec.name, + question_index=index, + question_count=total_rows, + judge_correct=None if judge_payload is None else bool(judge_payload.get("correct")), + judge_score_0_to_5=None if judge_payload is None else judge_payload.get("score_0_to_5"), + judge_rationale_preview="" if judge_payload is None else _preview_text(judge_payload.get("rationale") or ""), + ) + record = _score_record( + row=row, + prediction=prediction, + predicted_files=predicted_files, + judge_payload=judge_payload, + ) + record["search_result_count"] = len(payload.get("results") or []) + record["question_index"] = index + records.append(record) + # Checkpoint: persist each completed record so we can resume. + _append_jsonl(checkpoint_path, record) + _emit_progress( + progress_path, + "question_done", + config=spec.name, + question_index=index, + question_count=total_rows, + exact_match=bool(record["exact_match"]), + token_f1=float(record["token_f1"]), + file_f1=float(record["file_f1"]), + judge_correct=record.get("judge_correct"), + judge_score_0_to_5=record.get("judge_score_0_to_5"), + ) + + if args.print_every > 0 and index % args.print_every == 0: + logger.info("Progress %s: %d/%d", spec.name, index, len(rows)) + + if skipped_count > 0: + logger.info( + "RESUME: Skipped %d already-completed questions, processed %d new for %s", + skipped_count, + len(records) - len(resumed_records), + spec.name, + ) + overall = _summarize_records(records) + by_type: Dict[str, Dict[str, Any]] = {} + for qa_type in ("profiling", "factual_retention"): + subset = [record for record in records if record.get("qa_type") == qa_type] + if subset: + by_type[qa_type] = _summarize_records(subset) + + _emit_progress( + progress_path, + "config_done", + config=spec.name, + overall=overall, + by_type=by_type, + progress_jsonl=str(progress_path), + ) + + return { + "config": spec.name, + "profile": spec.profile, + "mode": args.mode, + "repo_id": args.repo_id, + "revision": args.revision, + "qa_type_filter": args.qa_type, + "answer_strategy": args.answer_strategy, + "llm_provider": args.llm_provider, + "llm_model": args.llm_model, + "answer_provider": args.answer_provider or None, + "answer_model": args.answer_model or None, + "judge_provider": args.judge_provider or None, + "judge_model": args.judge_model or None, + "index_seconds": round(index_seconds, 4), + "environment": environment_meta, + "progress_jsonl": str(progress_path), + "overall": overall, + "by_type": by_type, + "records": records, + } + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run Dhee on the HippoCamp benchmark release.") + parser.add_argument("--repo-id", default=DEFAULT_REPO_ID) + parser.add_argument("--revision", default=None) + parser.add_argument( + "--config", + action="append", + dest="configs", + choices=sorted(CONFIG_SPECS), + help="HippoCamp config(s) to run. May be passed multiple times.", + ) + parser.add_argument("--mode", choices=["gold", "raw"], default="raw") + parser.add_argument("--qa-type", choices=["all", "profiling", "factual_retention"], default="all") + parser.add_argument("--max-samples", type=int, default=-1) + parser.add_argument("--max-environment-files", type=int, default=-1) + parser.add_argument("--chunk-chars", type=int, default=2600) + parser.add_argument("--output-dir", default="runs/hippocamp") + parser.add_argument("--print-every", type=int, default=10) + + # --- Full-power Dhee defaults: NVIDIA embedder + reranker, all features ON --- + parser.add_argument("--llm-provider", default="nvidia") + parser.add_argument("--llm-model", default="meta/llama-3.3-70b-instruct") + parser.add_argument("--llm-timeout", type=int, default=240) + parser.add_argument("--embedder-provider", default="nvidia") + parser.add_argument("--embedder-model", default="nvidia/llama-nemotron-embed-vl-1b-v2") + parser.add_argument("--embedding-dims", type=int, default=2048) + parser.add_argument("--vector-store-provider", choices=["memory", "sqlite_vec"], default="memory") + + parser.add_argument("--minimal", action="store_true", default=False) + # defer_enrichment=True: fast 0-LLM ingestion; batch enrichment runs after ingest via enrich_after_ingest + parser.add_argument("--defer-enrichment", dest="defer_enrichment", action="store_true", default=True) + parser.add_argument("--no-defer-enrichment", dest="defer_enrichment", action="store_false") + parser.add_argument("--enrich-after-ingest", dest="enrich_after_ingest", action="store_true", default=True) + parser.add_argument("--no-enrich-after-ingest", dest="enrich_after_ingest", action="store_false") + parser.add_argument("--enrich-batch-size", type=int, default=10) + parser.add_argument("--enrich-max-batches", type=int, default=200) + parser.add_argument("--enable-rerank", dest="enable_rerank", action="store_true", default=True) + parser.add_argument("--disable-rerank", dest="enable_rerank", action="store_false") + parser.add_argument("--rerank-model", default="nvidia/llama-3.2-nv-rerankqa-1b-v2") + parser.add_argument("--disable-episodic-index", action="store_true", default=False) + parser.add_argument("--disable-hierarchical-retrieval", action="store_true", default=False) + parser.add_argument("--disable-orchestrated-search", action="store_true", default=False) + parser.add_argument("--no-cost-guardrail-strict", action="store_true", default=False) + + parser.add_argument("--top-k", type=int, default=12) + parser.add_argument("--answer-context-top-k", type=int, default=8) + parser.add_argument("--search-cap", type=int, default=30) + parser.add_argument("--context-cap", type=int, default=20) + parser.add_argument("--map-max-candidates", type=int, default=8) + parser.add_argument("--map-max-chars", type=int, default=1400) + parser.add_argument("--max-context-chars", type=int, default=28000) + parser.add_argument("--evidence-strategy", choices=["full", "vector_text", "snippet"], default="snippet") + parser.add_argument("--evidence-max-chars", type=int, default=3500) + parser.add_argument("--evidence-context-lines", type=int, default=1) + parser.add_argument("--answer-orchestration-mode", choices=["off", "hybrid", "strict"], default="hybrid") + + parser.add_argument("--answer-strategy", choices=["auto", "llm", "extractive"], default="auto") + parser.add_argument("--answer-provider", default="nvidia") + parser.add_argument("--answer-model", default="meta/llama-3.3-70b-instruct") + parser.add_argument("--answer-timeout", type=int, default=240) + parser.add_argument("--answer-max-tokens", type=int, default=1024) + parser.add_argument("--answer-temperature", type=float, default=0.2) + parser.add_argument("--answer-top-p", type=float, default=0.7) + parser.add_argument("--answer-enable-thinking", dest="answer_enable_thinking", action="store_true", default=False) + parser.add_argument("--answer-disable-thinking", dest="answer_enable_thinking", action="store_false") + + parser.add_argument("--judge-provider", default="nvidia") + parser.add_argument("--judge-model", default="deepseek-ai/deepseek-v3.1-terminus") + parser.add_argument("--judge-timeout", type=int, default=60) + parser.add_argument("--judge-max-tokens", type=int, default=2048) + parser.add_argument("--judge-temperature", type=float, default=0.2) + parser.add_argument("--judge-top-p", type=float, default=0.7) + parser.add_argument("--judge-enable-thinking", dest="judge_enable_thinking", action="store_true", default=False) + parser.add_argument("--judge-disable-thinking", dest="judge_enable_thinking", action="store_false") + + parser.add_argument("--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"]) + args = parser.parse_args() + if not args.configs: + args.configs = ["adam_subset"] + return args + + +def main() -> None: + args = parse_args() + _configure_logging(args.log_level) + _load_env_file(Path.cwd() / ".env") + _load_env_file(Path(__file__).resolve().parents[2] / ".env") + + output_root = Path(args.output_dir) + output_root.mkdir(parents=True, exist_ok=True) + + started = time.time() + config_results = [] + for config_name in args.configs: + spec = _resolve_config(config_name) + logger.info("Running HippoCamp config=%s mode=%s qa_type=%s", spec.name, args.mode, args.qa_type) + config_results.append(_run_config(args, spec)) + + summary = { + "runner": "dhee.benchmarks.hippocamp", + "repo_id": args.repo_id, + "revision": args.revision, + "configs": args.configs, + "mode": args.mode, + "qa_type": args.qa_type, + "elapsed_seconds": round(time.time() - started, 4), + "config_results": config_results, + } + + summary_path = output_root / "summary.json" + predictions_path = output_root / "predictions.json" + flat_records = [] + for config_result in config_results: + for record in config_result.get("records", []): + flat_records.append( + { + "config": config_result["config"], + "profile": config_result["profile"], + **record, + } + ) + + _persist_json(summary_path, summary) + _persist_json(predictions_path, flat_records) + + print( + json.dumps( + { + "summary_json": str(summary_path), + "predictions_json": str(predictions_path), + "configs": args.configs, + "mode": args.mode, + "qa_type": args.qa_type, + "records": len(flat_records), + }, + indent=2, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/dhee/benchmarks/longmemeval.py b/dhee/benchmarks/longmemeval.py index e4a7984..35947d4 100644 --- a/dhee/benchmarks/longmemeval.py +++ b/dhee/benchmarks/longmemeval.py @@ -38,7 +38,6 @@ from dhee.core.answer_orchestration import ( is_low_confidence_answer, ) -from dhee.core.code_exec_counter import refine_count_with_code_exec logger = logging.getLogger(__name__) @@ -1771,25 +1770,7 @@ def run_longmemeval(args: argparse.Namespace) -> Dict[str, Any]: except Exception as reg_exc: logger.debug("Entity registry lookup failed for %s: %s", question_id, reg_exc) - # 2) Code-exec counting (1 LLM call + deterministic exec) code_exec_answer = None - try: - ce_llm = _code_exec_llm or _answer_llm or memory.llm - code_exec_answer = refine_count_with_code_exec( - llm=ce_llm, - question=query, - question_type=question_type, - retrieved_context=context, - draft_answer=hypothesis, - question_date=question_date, - ) - if code_exec_answer: - logger.info( - "Code-exec answer for %s: %r", - question_id, code_exec_answer, - ) - except Exception as ce_exc: - logger.debug("Code-exec counting failed for %s: %s", question_id, ce_exc) # 3) Pick best: code_exec > registry > existing refinement if code_exec_answer and not _is_refusal(code_exec_answer): diff --git a/dhee/benchmarks/raw_extractors.py b/dhee/benchmarks/raw_extractors.py new file mode 100644 index 0000000..2435dd4 --- /dev/null +++ b/dhee/benchmarks/raw_extractors.py @@ -0,0 +1,505 @@ +"""Raw-file extraction helpers for HippoCamp-style benchmarks. + +These extractors operate only on released raw files. They intentionally avoid +benchmark gold text and QA annotations. +""" + +from __future__ import annotations + +import csv +import json +import logging +import sqlite3 +import subprocess +import zipfile +from dataclasses import dataclass +from email import policy +from email.parser import BytesParser +from pathlib import Path, PurePosixPath +from typing import Any, Dict, List, Optional, Sequence +from xml.etree import ElementTree as ET + +logger = logging.getLogger(__name__) + + +TEXT_EXTENSIONS = {"txt", "md", "py", "log", "ics", "json", "csv"} +IMAGE_EXTENSIONS = {"png", "jpg", "jpeg", "gif"} +VIDEO_EXTENSIONS = {"mp4", "mkv"} +AUDIO_EXTENSIONS = {"mp3", "wav", "m4a", "aac"} + + +@dataclass +class RawExtractionResult: + items: List[Dict[str, Any]] + mode: str + notes: List[str] + + +def _run_command(args: Sequence[str]) -> str: + proc = subprocess.run( + list(args), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=False, + ) + if proc.returncode != 0: + raise RuntimeError(proc.stderr.strip() or f"Command failed: {' '.join(args)}") + return proc.stdout + + +def _chunk_blocks( + *, + header: str, + blocks: Sequence[str], + chunk_chars: int, + metadata: Dict[str, Any], + categories: Sequence[str], +) -> List[Dict[str, Any]]: + items: List[Dict[str, Any]] = [] + target_chars = max(400, int(chunk_chars)) + current_blocks: List[str] = [] + current_len = 0 + chunk_index = 1 + + def flush() -> None: + nonlocal current_blocks, current_len, chunk_index + if not current_blocks: + return + body = "\n\n".join(current_blocks) + content = f"{header}\nChunk: {chunk_index}\n\nContent:\n{body}" + item_meta = dict(metadata) + item_meta["chunk_index"] = chunk_index + items.append({"content": content, "metadata": item_meta, "categories": list(categories)}) + current_blocks = [] + current_len = 0 + chunk_index += 1 + + for raw_block in blocks: + block = str(raw_block or "").strip() + if not block: + continue + block_len = len(block) + if current_blocks and current_len + block_len + 2 > target_chars: + flush() + current_blocks.append(block) + current_len += block_len + 2 + flush() + return items + + +def _read_text_file(path: Path) -> str: + for encoding in ("utf-8", "utf-16", "latin-1"): + try: + return path.read_text(encoding=encoding) + except Exception: + continue + return path.read_text(encoding="utf-8", errors="ignore") + + +def _extract_email_text(path: Path) -> str: + with path.open("rb") as handle: + message = BytesParser(policy=policy.default).parse(handle) + parts: List[str] = [] + for header in ("subject", "from", "to", "cc", "date"): + value = message.get(header) + if value: + parts.append(f"{header.title()}: {value}") + body = message.get_body(preferencelist=("plain", "html")) + if body is not None: + try: + body_text = body.get_content() + except Exception: + body_text = "" + if body_text: + parts.extend(["", str(body_text)]) + return "\n".join(parts).strip() + + +def _extract_pdf_pages(path: Path) -> List[str]: + output = _run_command(["pdftotext", str(path), "-"]) + return [page.strip() for page in output.split("\f") if page.strip()] + + +def _extract_docx_text(path: Path) -> str: + try: + text = _run_command(["textutil", "-convert", "txt", "-stdout", str(path)]).strip() + if text: + return text + except Exception: + pass + return _extract_docx_ooxml_text(path) + + +def _extract_docx_ooxml_text(path: Path) -> str: + with zipfile.ZipFile(path) as archive: + names = [name for name in archive.namelist() if name.startswith("word/") and name.endswith(".xml")] + parts: List[str] = [] + for name in sorted(names): + if not any(key in name for key in ("document.xml", "header", "footer", "footnotes", "endnotes")): + continue + root = ET.fromstring(archive.read(name)) + texts = [elem.text for elem in root.iter() if elem.tag.endswith("}t") and elem.text] + if texts: + parts.append("\n".join(texts)) + return "\n\n".join(parts).strip() + + +def _extract_pptx_ooxml_text(path: Path) -> str: + with zipfile.ZipFile(path) as archive: + slides = [] + for name in sorted(archive.namelist()): + if not name.startswith("ppt/slides/slide") or not name.endswith(".xml"): + continue + root = ET.fromstring(archive.read(name)) + texts = [elem.text for elem in root.iter() if elem.tag.endswith("}t") and elem.text] + if texts: + slides.append("\n".join(texts)) + return "\n\n".join(f"[slide={idx + 1}]\n{text}" for idx, text in enumerate(slides)).strip() + + +def _col_ref_to_index(col_ref: str) -> int: + result = 0 + for ch in col_ref: + if not ch.isalpha(): + break + result = result * 26 + (ord(ch.upper()) - ord("A") + 1) + return max(1, result) + + +def _extract_xlsx_ooxml_text(path: Path) -> str: + with zipfile.ZipFile(path) as archive: + shared_strings: List[str] = [] + if "xl/sharedStrings.xml" in archive.namelist(): + root = ET.fromstring(archive.read("xl/sharedStrings.xml")) + for elem in root.iter(): + if elem.tag.endswith("}t") and elem.text is not None: + shared_strings.append(elem.text) + + sheet_names: Dict[str, str] = {} + if "xl/workbook.xml" in archive.namelist(): + root = ET.fromstring(archive.read("xl/workbook.xml")) + for elem in root.iter(): + if elem.tag.endswith("}sheet"): + rel_id = elem.attrib.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id", "") + if rel_id: + sheet_names[rel_id] = elem.attrib.get("name", rel_id) + + rel_map: Dict[str, str] = {} + if "xl/_rels/workbook.xml.rels" in archive.namelist(): + root = ET.fromstring(archive.read("xl/_rels/workbook.xml.rels")) + for elem in root.iter(): + if elem.tag.endswith("}Relationship"): + rel_id = elem.attrib.get("Id", "") + target = elem.attrib.get("Target", "") + if rel_id and target: + rel_map[target] = sheet_names.get(rel_id, PurePosixPath(target).stem) + + sheets_out: List[str] = [] + for name in sorted(archive.namelist()): + if not name.startswith("xl/worksheets/") or not name.endswith(".xml"): + continue + root = ET.fromstring(archive.read(name)) + rows_out: List[str] = [] + for row in root.iter(): + if not row.tag.endswith("}row"): + continue + values: Dict[int, str] = {} + for cell in row: + if not cell.tag.endswith("}c"): + continue + ref = cell.attrib.get("r", "") + col_idx = _col_ref_to_index(ref) if ref else (len(values) + 1) + cell_type = cell.attrib.get("t", "") + value_text = "" + value_elem = None + for child in cell: + if child.tag.endswith("}v"): + value_elem = child + break + if value_elem is not None and value_elem.text is not None: + raw_value = value_elem.text + if cell_type == "s": + try: + value_text = shared_strings[int(raw_value)] + except Exception: + value_text = raw_value + else: + value_text = raw_value + if value_text: + values[col_idx] = value_text + if values: + rows_out.append("\t".join(values[idx] for idx in sorted(values))) + if rows_out: + sheet_label = rel_map.get(name.replace("xl/", ""), Path(name).stem) + sheets_out.append(f"[sheet={sheet_label}]\n" + "\n".join(rows_out)) + return "\n\n".join(sheets_out).strip() + + +def _extract_ipynb_text(path: Path) -> str: + with path.open("r", encoding="utf-8") as handle: + notebook = json.load(handle) + parts: List[str] = [] + for idx, cell in enumerate(notebook.get("cells") or [], start=1): + source = cell.get("source") or [] + if isinstance(source, list): + text = "".join(str(part) for part in source) + else: + text = str(source) + text = text.strip() + if text: + parts.append(f"[{cell.get('cell_type', 'cell')} {idx}]\n{text}") + return "\n\n".join(parts).strip() + + +def _extract_sqlite_text(path: Path, max_tables: int = 6, max_rows: int = 10) -> str: + conn = sqlite3.connect(str(path)) + conn.row_factory = sqlite3.Row + try: + rows = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name" + ).fetchall() + parts: List[str] = [] + for row in rows[:max_tables]: + table = str(row["name"]) + escaped_table = table.replace("'", "''") + parts.append(f"[table={table}]") + schema = conn.execute(f"PRAGMA table_info('{escaped_table}')").fetchall() + if schema: + parts.append("columns: " + ", ".join(f"{col['name']}:{col['type']}" for col in schema)) + try: + sample_rows = conn.execute( + f"SELECT * FROM '{escaped_table}' LIMIT {int(max_rows)}" + ).fetchall() + except Exception: + sample_rows = [] + for sample in sample_rows: + payload = {key: sample[key] for key in sample.keys()} + parts.append(json.dumps(payload, ensure_ascii=False)) + parts.append("") + return "\n".join(parts).strip() + finally: + conn.close() + + +def _extract_csv_text(path: Path, max_rows: int = 100) -> str: + with path.open("r", encoding="utf-8", errors="ignore", newline="") as handle: + reader = csv.reader(handle) + rows = [] + for idx, row in enumerate(reader): + if idx >= max_rows: + break + rows.append("\t".join(str(cell) for cell in row)) + return "\n".join(rows).strip() + + +def _extract_media_metadata(path: Path) -> str: + try: + output = _run_command( + [ + "ffprobe", + "-v", + "error", + "-show_entries", + "format=filename,format_name,duration,size:stream=index,codec_type,codec_name,width,height,sample_rate,channels:stream_tags=language,title", + "-of", + "json", + str(path), + ] + ) + return json.dumps(json.loads(output), ensure_ascii=False, indent=2) + except Exception: + return "" + + +def _extract_subtitles(path: Path) -> str: + try: + output = _run_command(["ffmpeg", "-v", "error", "-i", str(path), "-map", "0:s:0", "-f", "srt", "-"]) + except Exception: + return "" + lines: List[str] = [] + for raw_line in output.splitlines(): + line = raw_line.strip() + if not line or line.isdigit() or "-->" in line: + continue + lines.append(line) + return "\n".join(lines).strip() + + +def _extract_image_metadata(path: Path) -> str: + parts: List[str] = [] + try: + from PIL import Image + + with Image.open(path) as image: + parts.append(f"format: {image.format}") + parts.append(f"size: {image.size[0]}x{image.size[1]}") + parts.append(f"mode: {image.mode}") + except Exception: + pass + return "\n".join(parts).strip() + + +def _raw_header(*, profile: str, config_name: str, relative_path: str, local_path: Path, mode: str) -> str: + stat = local_path.stat() + ext = local_path.suffix.lower().lstrip(".") + return "\n".join( + [ + "[HippoCamp Raw File]", + f"Profile: {profile}", + f"Config: {config_name}", + f"Relative Path: {relative_path}", + f"File Type: {ext}", + f"Extraction Mode: {mode}", + f"File Size Bytes: {stat.st_size}", + ] + ) + + +def _extract_audio_transcript(local_path: Path) -> Optional[str]: + """Transcribe audio using OpenAI whisper (local model). Returns None if not available.""" + try: + import whisper # type: ignore[import-untyped] + except ImportError: + logger.debug("whisper not installed — skipping audio transcription for %s", local_path.name) + return None + try: + model = whisper.load_model("base") + result = model.transcribe(str(local_path), fp16=False) + text = str(result.get("text", "")).strip() + if not text: + return None + return f"[audio_transcript]\n{text}" + except Exception as exc: + logger.warning("Whisper transcription failed for %s: %s", local_path.name, exc) + return None + + +def _extract_image_ocr(local_path: Path) -> Optional[str]: + """Extract text from image using pytesseract OCR. Returns None if not available.""" + try: + from PIL import Image # type: ignore[import-untyped] + import pytesseract # type: ignore[import-untyped] + except ImportError: + logger.debug("pytesseract/Pillow not installed — skipping OCR for %s", local_path.name) + return None + try: + img = Image.open(local_path) + text = pytesseract.image_to_string(img).strip() + if not text or len(text) < 3: + return None + return f"[ocr_text]\n{text}" + except Exception as exc: + logger.warning("OCR failed for %s: %s", local_path.name, exc) + return None + + +def raw_file_to_items( + *, + local_path: Path, + relative_path: str, + profile: str, + config_name: str, + chunk_chars: int, +) -> RawExtractionResult: + ext = local_path.suffix.lower().lstrip(".") + mode = "text" + notes: List[str] = [] + blocks: List[str] = [] + + try: + if ext in TEXT_EXTENSIONS: + blocks = [_extract_csv_text(local_path) if ext == "csv" else _read_text_file(local_path)] + elif ext == "eml": + blocks = [_extract_email_text(local_path)] + elif ext == "pdf": + blocks = [f"[page={idx + 1}]\n{page}" for idx, page in enumerate(_extract_pdf_pages(local_path))] + elif ext == "docx": + blocks = [_extract_docx_text(local_path)] + elif ext == "pptx": + blocks = [_extract_pptx_ooxml_text(local_path)] + elif ext == "xlsx": + blocks = [_extract_xlsx_ooxml_text(local_path)] + elif ext == "ipynb": + blocks = [_extract_ipynb_text(local_path)] + elif ext == "sqlite": + blocks = [_extract_sqlite_text(local_path)] + elif ext in VIDEO_EXTENSIONS: + subtitles = _extract_subtitles(local_path) + metadata = _extract_media_metadata(local_path) + if subtitles: + mode = "subtitle_text" + blocks = [subtitles] + if metadata: + blocks.append("[media_metadata]\n" + metadata) + else: + mode = "metadata_only" + blocks = ["[media_metadata]\n" + metadata] if metadata else [f"Raw {ext} video file."] + notes.append("No subtitle text extracted.") + elif ext in AUDIO_EXTENSIONS: + transcript = _extract_audio_transcript(local_path) + if transcript: + mode = "transcript" + blocks = [transcript] + metadata_text = _extract_media_metadata(local_path) + if metadata_text: + blocks.append("[media_metadata]\n" + metadata_text) + else: + mode = "metadata_only" + metadata_text = _extract_media_metadata(local_path) + blocks = ["[media_metadata]\n" + metadata_text] if metadata_text else [f"Raw {ext} audio file."] + notes.append("No speech-to-text available (whisper not installed or failed).") + elif ext in IMAGE_EXTENSIONS: + ocr_text = _extract_image_ocr(local_path) + if ocr_text: + mode = "ocr_text" + blocks = [ocr_text] + img_metadata = _extract_image_metadata(local_path) + if img_metadata: + blocks.append("[image_metadata]\n" + img_metadata) + else: + mode = "metadata_only" + img_metadata = _extract_image_metadata(local_path) + blocks = ["[image_metadata]\n" + img_metadata] if img_metadata else [f"Raw {ext} image file."] + notes.append("No OCR text extracted (pytesseract not installed or no text found).") + else: + mode = "metadata_only" + blocks = [f"Unsupported raw file type for text extraction: {ext or 'unknown'}"] + notes.append(f"Unsupported file extension: {ext or 'unknown'}") + except Exception as exc: + mode = "metadata_only" + blocks = [f"Extraction failed for {relative_path}: {exc}"] + notes.append(str(exc)) + + blocks = [block for block in blocks if str(block or "").strip()] + if not blocks: + blocks = [f"Empty extracted content for {relative_path}"] + notes.append("Empty extracted content.") + + header = _raw_header( + profile=profile, + config_name=config_name, + relative_path=relative_path, + local_path=local_path, + mode=mode, + ) + metadata = { + "benchmark": "hippocamp", + "exposure_mode": "raw_files_only", + "config_name": config_name, + "profile": profile, + "file_path": relative_path, + "file_name": local_path.name, + "file_type": ext or None, + "raw_extraction_mode": mode, + "raw_extraction_notes": notes, + } + categories = ["hippocamp", "raw_files_only", config_name.lower(), profile.lower()] + items = _chunk_blocks( + header=header, + blocks=blocks, + chunk_chars=chunk_chars, + metadata=metadata, + categories=categories, + ) + return RawExtractionResult(items=items, mode=mode, notes=notes) diff --git a/dhee/configs/base.py b/dhee/configs/base.py index c017ee9..e3e80ad 100644 --- a/dhee/configs/base.py +++ b/dhee/configs/base.py @@ -591,12 +591,6 @@ def _positive_int(cls, v: int) -> int: return max(1, int(v)) -def _get_teaching_config(): - """Lazy import to avoid circular dependency.""" - from dhee.teaching.config import TeachingConfig - return TeachingConfig() - - class MemoryConfig(BaseModel): vector_store: VectorStoreConfig = Field(default_factory=VectorStoreConfig) llm: LLMConfig = Field(default_factory=LLMConfig) @@ -631,7 +625,6 @@ class MemoryConfig(BaseModel): cost_guardrail: CostGuardrailConfig = Field(default_factory=CostGuardrailConfig) skill: SkillConfig = Field(default_factory=SkillConfig) task: TaskConfig = Field(default_factory=TaskConfig) - teaching: "TeachingConfig" = Field(default_factory=lambda: _get_teaching_config()) metamemory: MetamemoryInlineConfig = Field(default_factory=MetamemoryInlineConfig) prospective: ProspectiveInlineConfig = Field(default_factory=ProspectiveInlineConfig) procedural: ProceduralInlineConfig = Field(default_factory=ProceduralInlineConfig) @@ -675,12 +668,3 @@ def full(cls) -> "MemoryConfig": return full_config() -# Resolve forward reference for TeachingConfig -def _rebuild_memory_config(): - try: - from dhee.teaching.config import TeachingConfig - MemoryConfig.model_rebuild() - except ImportError: - pass - -_rebuild_memory_config() diff --git a/dhee/core/agi_loop.py b/dhee/core/agi_loop.py deleted file mode 100644 index 12ea159..0000000 --- a/dhee/core/agi_loop.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Dhee v3 — Cognitive Maintenance Cycle. - -Replaces the phantom AGI loop with honest, real maintenance operations. - -v2.2 had 8 steps, 6 of which imported non-existent engram_* packages. -v3 runs only what actually exists: - 1. Consolidation (active → passive, via safe consolidation engine) - 2. Decay (forgetting curves) - -Planned but not yet implemented (will be added as real Job classes): - - Anchor candidate resolution - - Distillation promotion - - Conflict scanning - - Stale intention cleanup - -The old API surface (run_agi_cycle, get_system_health) is preserved -for backward compatibility with existing callers. -""" - -from __future__ import annotations - -import logging -from datetime import datetime, timezone -from typing import Any, Dict, Optional - -logger = logging.getLogger(__name__) - - -def run_agi_cycle( - memory: Any, - user_id: str = "default", - context: Optional[str] = None, -) -> Dict[str, Any]: - """Run one maintenance cycle. Only executes real subsystems. - - Args: - memory: Dhee Memory instance - user_id: User identifier for scoped operations - context: Optional current context (reserved for future use) - - Returns: - Dict with status of each step - """ - now = datetime.now(timezone.utc).isoformat() - results: Dict[str, Any] = {"timestamp": now, "user_id": user_id} - - # Step 1: Consolidation — run distillation (episodic → semantic) - try: - if hasattr(memory, "_kernel") and memory._kernel: - consolidation = memory._kernel.sleep_cycle(user_id=user_id) - results["consolidation"] = {"status": "ok", "result": consolidation} - else: - results["consolidation"] = {"status": "skipped", "reason": "no kernel"} - except Exception as e: - results["consolidation"] = {"status": "error", "error": str(e)} - - # Step 2: Decay — apply forgetting curves - try: - decay_result = memory.apply_decay(scope={"user_id": user_id}) - results["decay"] = {"status": "ok", "result": decay_result} - except Exception as e: - results["decay"] = {"status": "error", "error": str(e)} - - # Compute summary - statuses = [ - v.get("status", "unknown") - for v in results.values() - if isinstance(v, dict) and "status" in v - ] - ok_count = statuses.count("ok") - error_count = statuses.count("error") - skipped_count = statuses.count("skipped") - - results["summary"] = { - "ok": ok_count, - "errors": error_count, - "skipped": skipped_count, - "total_subsystems": len(statuses), - } - - return results - - -def get_system_health(memory: Any, user_id: str = "default") -> Dict[str, Any]: - """Report health status across real cognitive subsystems. - - Only reports subsystems that actually exist — no phantom package checks. - """ - now = datetime.now(timezone.utc).isoformat() - systems: Dict[str, Dict] = {} - - # Core memory - try: - stats = memory.get_stats(user_id=user_id) - systems["core_memory"] = {"available": True, "stats": stats} - except Exception as e: - systems["core_memory"] = {"available": False, "error": str(e)} - - # Knowledge graph - systems["knowledge_graph"] = { - "available": ( - hasattr(memory, "knowledge_graph") - and memory.knowledge_graph is not None - ), - } - if systems["knowledge_graph"]["available"]: - try: - systems["knowledge_graph"]["stats"] = memory.knowledge_graph.stats() - except Exception: - pass - - # Cognition kernel - has_kernel = hasattr(memory, "_kernel") and memory._kernel is not None - systems["cognition_kernel"] = {"available": has_kernel} - if has_kernel: - try: - systems["cognition_kernel"]["stats"] = memory._kernel.cognition_health( - user_id=user_id - ) - except Exception: - pass - - # Active memory / consolidation - systems["consolidation"] = { - "available": ( - hasattr(memory, "_consolidation_engine") - and memory._consolidation_engine is not None - ), - } - - # v3 stores (if wired) - systems["v3_event_store"] = { - "available": hasattr(memory, "_event_store") and memory._event_store is not None, - } - - available = sum(1 for s in systems.values() if s.get("available")) - total = len(systems) - - return { - "timestamp": now, - "systems": systems, - "available": available, - "total": total, - "health_pct": round(available / total * 100, 1) if total else 0, - } diff --git a/dhee/core/anchor_resolver.py b/dhee/core/anchor_resolver.py deleted file mode 100644 index 0cf1dad..0000000 --- a/dhee/core/anchor_resolver.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Dhee v3 — Anchor Resolver: per-field candidates + confidence-weighted resolution. - -Makes context extraction fallible, revisable, and auditable. - -Instead of a single ContextAnchor with one confidence score, each field -(era, place, time_absolute, activity, etc.) gets competing candidates. -Resolution picks the best candidate per field. Re-anchoring is safe -because raw events are never touched. - -Design contract: - - Extraction produces candidates, not final truth - - Same memory can hold alternate candidate anchors - - Anchor correction does not mutate raw event history - - Resolution is deterministic: highest confidence per field wins - - Zero LLM calls — rule-based extraction + confidence scoring -""" - -from __future__ import annotations - -import json -import logging -import sqlite3 -import threading -import uuid -from contextlib import contextmanager -from dataclasses import dataclass, field -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional - -logger = logging.getLogger(__name__) - - -def _utcnow_iso() -> str: - return datetime.now(timezone.utc).isoformat() - - -# Fields that can have competing candidates -ANCHOR_FIELDS = frozenset({ - "era", "place", "place_type", "place_detail", - "time_absolute", "time_range_start", "time_range_end", - "time_derivation", "activity", -}) - - -@dataclass -class AnchorCandidate: - """A proposed value for a single anchor field.""" - - candidate_id: str - anchor_id: str - field_name: str - field_value: str - confidence: float = 0.5 - extractor_source: str = "default" - source_event_ids: List[str] = field(default_factory=list) - derivation_version: int = 1 - status: str = "pending" # pending | accepted | rejected | superseded - - -class AnchorCandidateStore: - """Manages per-field anchor candidates in the database.""" - - def __init__(self, conn: sqlite3.Connection, lock: threading.RLock): - self._conn = conn - self._lock = lock - - @contextmanager - def _tx(self): - with self._lock: - try: - yield self._conn - self._conn.commit() - except Exception: - self._conn.rollback() - raise - - def submit( - self, - anchor_id: str, - field_name: str, - field_value: str, - *, - confidence: float = 0.5, - extractor_source: str = "default", - source_event_ids: Optional[List[str]] = None, - ) -> str: - """Submit a candidate for an anchor field. Returns candidate_id.""" - if field_name not in ANCHOR_FIELDS: - raise ValueError(f"Invalid anchor field: {field_name}") - - cid = str(uuid.uuid4()) - with self._tx() as conn: - conn.execute( - """INSERT INTO anchor_candidates - (candidate_id, anchor_id, field_name, field_value, - confidence, extractor_source, source_event_ids, - derivation_version, status, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, 1, 'pending', ?)""", - ( - cid, anchor_id, field_name, field_value, - confidence, extractor_source, - json.dumps(source_event_ids or []), - _utcnow_iso(), - ), - ) - return cid - - def get_candidates( - self, - anchor_id: str, - field_name: Optional[str] = None, - *, - status: Optional[str] = None, - ) -> List[Dict[str, Any]]: - """Get candidates for an anchor, optionally filtered by field and status.""" - query = "SELECT * FROM anchor_candidates WHERE anchor_id = ?" - params: list = [anchor_id] - if field_name: - query += " AND field_name = ?" - params.append(field_name) - if status: - query += " AND status = ?" - params.append(status) - query += " ORDER BY confidence DESC" - - with self._lock: - rows = self._conn.execute(query, params).fetchall() - return [self._row_to_dict(r) for r in rows] - - def set_status(self, candidate_id: str, status: str) -> bool: - with self._lock: - try: - result = self._conn.execute( - "UPDATE anchor_candidates SET status = ? WHERE candidate_id = ?", - (status, candidate_id), - ) - self._conn.commit() - return result.rowcount > 0 - except Exception: - self._conn.rollback() - raise - - @staticmethod - def _row_to_dict(row: sqlite3.Row) -> Dict[str, Any]: - source_ids = row["source_event_ids"] - if isinstance(source_ids, str): - try: - source_ids = json.loads(source_ids) - except (json.JSONDecodeError, TypeError): - source_ids = [] - return { - "candidate_id": row["candidate_id"], - "anchor_id": row["anchor_id"], - "field_name": row["field_name"], - "field_value": row["field_value"], - "confidence": row["confidence"], - "extractor_source": row["extractor_source"], - "source_event_ids": source_ids, - "derivation_version": row["derivation_version"], - "status": row["status"], - "created_at": row["created_at"], - } - - -class AnchorResolver: - """Resolves anchor fields from competing candidates. - - Resolution strategy: - 1. For each anchor field, find all pending/accepted candidates - 2. Pick the highest-confidence candidate per field - 3. Mark winner as 'accepted', losers as 'superseded' - 4. Update the anchor row with resolved values - - Re-resolution: when new candidates arrive (correction, new evidence), - call resolve() again — it re-evaluates all candidates. - """ - - def __init__( - self, - candidate_store: AnchorCandidateStore, - anchor_store: "AnchorStore", - ): - self.candidates = candidate_store - self.anchors = anchor_store - - def resolve(self, anchor_id: str) -> Dict[str, Any]: - """Resolve all fields for an anchor. Returns the resolved field values. - - Steps: - 1. Get all non-rejected candidates grouped by field - 2. For each field, pick highest confidence - 3. Mark winners as accepted, others as superseded - 4. Update anchor with resolved values - """ - resolved: Dict[str, str] = {} - resolution_details: Dict[str, Dict[str, Any]] = {} - - for field_name in ANCHOR_FIELDS: - candidates = self.candidates.get_candidates( - anchor_id, field_name - ) - # Filter to only pending/accepted (not rejected) - active = [ - c for c in candidates - if c["status"] in ("pending", "accepted") - ] - - if not active: - continue - - # Sort by confidence descending, then by created_at ascending (earlier = better tiebreak) - active.sort(key=lambda c: (-c["confidence"], c["created_at"])) - winner = active[0] - - resolved[field_name] = winner["field_value"] - resolution_details[field_name] = { - "value": winner["field_value"], - "confidence": winner["confidence"], - "source": winner["extractor_source"], - "candidate_id": winner["candidate_id"], - "competing_count": len(active), - } - - # Mark winner accepted, others superseded - for c in active: - if c["candidate_id"] == winner["candidate_id"]: - self.candidates.set_status(c["candidate_id"], "accepted") - else: - self.candidates.set_status(c["candidate_id"], "superseded") - - # Update anchor with resolved values - if resolved: - self.anchors.update_fields(anchor_id, **resolved) - - return { - "anchor_id": anchor_id, - "resolved_fields": resolved, - "details": resolution_details, - } - - def re_anchor( - self, - anchor_id: str, - field_name: str, - new_value: str, - *, - confidence: float = 0.9, - source: str = "user_correction", - source_event_ids: Optional[List[str]] = None, - ) -> Dict[str, Any]: - """Submit a correction for a specific field and re-resolve. - - The user says "no, the place was Bengaluru, not Ghazipur." - This submits a high-confidence candidate and re-runs resolution. - """ - # Submit the correction as a new candidate - cid = self.candidates.submit( - anchor_id=anchor_id, - field_name=field_name, - field_value=new_value, - confidence=confidence, - extractor_source=source, - source_event_ids=source_event_ids, - ) - - # Re-resolve just this field (and all others for consistency) - result = self.resolve(anchor_id) - result["correction_candidate_id"] = cid - return result - - def extract_and_submit( - self, - anchor_id: str, - content: str, - *, - source_event_ids: Optional[List[str]] = None, - ) -> List[str]: - """Rule-based extraction of anchor candidates from content. - - Extracts candidates for each field it can identify. Returns - list of candidate_ids. - - Zero LLM calls — keyword/pattern matching only. - """ - candidates_created: List[str] = [] - eids = source_event_ids or [] - lower = content.lower() - - # Activity detection - activity_keywords = { - "coding": ["coding", "programming", "debug", "commit", "deploy", "refactor"], - "meeting": ["meeting", "standup", "call", "sync", "discussion"], - "research": ["research", "reading", "paper", "study", "learn"], - "travel": ["travel", "flight", "airport", "train", "driving"], - "writing": ["writing", "blog", "document", "email", "report"], - } - for activity, keywords in activity_keywords.items(): - matches = sum(1 for kw in keywords if kw in lower) - if matches >= 1: - confidence = min(0.3 + 0.15 * matches, 0.85) - cid = self.candidates.submit( - anchor_id=anchor_id, - field_name="activity", - field_value=activity, - confidence=confidence, - extractor_source="keyword_activity", - source_event_ids=eids, - ) - candidates_created.append(cid) - - # Place type detection - place_types = { - "office": ["office", "workplace", "desk", "cubicle"], - "home": ["home", "house", "apartment", "flat"], - "school": ["school", "university", "college", "campus", "class"], - "travel": ["airport", "station", "hotel", "flight"], - } - for ptype, keywords in place_types.items(): - if any(kw in lower for kw in keywords): - cid = self.candidates.submit( - anchor_id=anchor_id, - field_name="place_type", - field_value=ptype, - confidence=0.6, - extractor_source="keyword_place_type", - source_event_ids=eids, - ) - candidates_created.append(cid) - - return candidates_created diff --git a/dhee/core/answer_orchestration.py b/dhee/core/answer_orchestration.py index 1fff68b..095c064 100644 --- a/dhee/core/answer_orchestration.py +++ b/dhee/core/answer_orchestration.py @@ -1,41 +1,24 @@ -"""Answer-time orchestration utilities for memory-heavy QA. +"""Query planning utilities for orchestrated search. -This module is benchmark-agnostic and can be reused by runtime APIs. -It provides: -- lightweight query-intent routing -- optional query rewriting for retrieval -- map stage (atomic fact extraction) -- deterministic reducers for high-leverage question types +Provides intent classification and query rewriting to improve retrieval. +No answer synthesis — Dhee retrieves context, the agent answers. """ from __future__ import annotations -import json -import logging import re from dataclasses import dataclass -from datetime import date, datetime, timezone from enum import Enum -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Optional -logger = logging.getLogger(__name__) - -_SESSION_ID_RE = re.compile(r"^Session ID:\s*(?P\S+)\s*$", re.MULTILINE) _RECENT_QUERY_RE = re.compile(r"\b(latest|most recent(?:ly)?|currently|current|recent(?:ly)?|as of|last)\b", re.I) -# Superlative patterns: "the most", "the least", "the first", "the last" -# These require ARGMAX/ARGMIN — not just listing set members. _SUPERLATIVE_RE = re.compile( r"\b(?:the\s+)?(?:most|least|fewest|highest|lowest|biggest|smallest|first|last" r"|(?:fly|flew|visit|use|eat|watch|play|buy|read|drive|travel)\w*\s+(?:the\s+)?most)\b", re.I, ) _LOW_CONFIDENCE_RE = re.compile( - r"\b(i\s+don['’]?t\s+know|not\s+enough\s+information|insufficient\s+information|unknown|cannot\s+determine)\b", - re.I, -) -_MONEY_RE = re.compile(r"[-+]?\$?\s*(\d{1,3}(?:,\d{3})*|\d+)(?:\.(\d+))?") -_DURATION_RE = re.compile( - r"([-+]?\d+(?:\.\d+)?)\s*(years?|months?|weeks?|days?|hours?|minutes?)", + r"\b(i\s+don['']?t\s+know|not\s+enough\s+information|insufficient\s+information|unknown|cannot\s+determine)\b", re.I, ) @@ -46,19 +29,17 @@ class AnswerIntent(str, Enum): DURATION = "duration" LATEST = "latest" SET_MEMBERS = "set_members" + ANALYSIS = "analysis" FREEFORM = "freeform" -_NUMERIC_INTENTS = {AnswerIntent.COUNT, AnswerIntent.MONEY_SUM, AnswerIntent.DURATION} - - @dataclass class QueryPlan: intent: AnswerIntent rewritten_query: str search_limit: int context_limit: int - should_map_reduce: bool + should_map_reduce: bool # kept for API compat; always False in new path def classify_answer_intent(question: str, question_type: str = "") -> AnswerIntent: @@ -68,36 +49,23 @@ def classify_answer_intent(question: str, question_type: str = "") -> AnswerInte if not q: return AnswerIntent.FREEFORM - # DURATION must be checked BEFORE money — "how much time did I spend" is duration, not money. if re.search(r"\b(how long|duration|elapsed|time spent|total years?|total months?)\b", q): return AnswerIntent.DURATION if re.search(r"\bhow much time\b", q): return AnswerIntent.DURATION - - # "How many days/months/weeks ago" or "how many days between X and Y" - # are temporal-duration questions, not counting questions. if re.search(r"\bhow many\s+(days?|weeks?|months?|years?|hours?|minutes?)\b", q): return AnswerIntent.DURATION - # Money: strict signals only. "spend/spent" + time words is DURATION (caught above). - # Exclude "days/hours spent" — that's DURATION, not money. money_signals = bool( re.search(r"\b(money|dollars?|usd|spent|spend|cost|price)\b", q) ) if money_signals and re.search(r"\b(how much|total|sum|spent|cost)\b", q): - # "total number of days spent" is DURATION, not money if re.search(r"\b(days?|weeks?|months?|years?|hours?|minutes?)\s+(spent|in)\b", q): return AnswerIntent.DURATION - # "what percentage" is FREEFORM, not money if re.search(r"\bpercentage\b", q): return AnswerIntent.FREEFORM return AnswerIntent.MONEY_SUM - # "How many [quantity-unit]" asks for a numeric VALUE, not a COUNT of distinct items. - # COUNT = enumerate distinct items (cities, books, sports, doctors) - # FREEFORM = read/compute a numeric value (points, pages, followers, views, copies) - # Allow up to 3 modifier words between "how many" and the unit: - # "how many Instagram followers" / "how many rare items" / "how many completed videos" _QUANTITY_UNITS = ( r"points?|dollars?|credits?|tokens?|calories?|miles?|steps?|pounds?" r"|kilograms?|grams?|liters?|gallons?|servings?|reps?|sets?" @@ -112,48 +80,61 @@ def classify_answer_intent(question: str, question_type: str = "") -> AnswerInte ) if _QUANTITY_UNITS_RE.search(q): return AnswerIntent.FREEFORM - - # "[noun] count" pattern: "page count", "word count", "calorie count", "step count" if re.search(r"\b(page|word|calorie|step|follower|subscriber|view|video|item)\s+count\b", q): return AnswerIntent.FREEFORM - # "What is the total number of [quantity]" needs arithmetic (sum), not item counting. _TOTAL_QUANTITY_RE = re.compile( r"\btotal\s+number\s+of\s+(?:\w+\s+){0,2}(" + _QUANTITY_UNITS + r")\b" ) if _TOTAL_QUANTITY_RE.search(q): return AnswerIntent.FREEFORM - # Knowledge-update questions need LATEST intent — must check BEFORE "how many" - # to prevent "How many times did X change?" from being routed to COUNT if "knowledge-update" in qtype: return AnswerIntent.LATEST - - # "How much" alone (without money signals) is a value question, not a count. - # "How much is the painting worth?" → FREEFORM - # "How much will I save?" → FREEFORM if re.search(r"\bhow much\b", q): return AnswerIntent.FREEFORM - if re.search(r"\b(how many|number of|count|total number)\b", q): return AnswerIntent.COUNT - # Superlative questions FIRST: "which X the most/least/first/last" - # Must come before generic LATEST check — "the most last month" is COUNT, not LATEST. if _SUPERLATIVE_RE.search(q): - # Frequency superlatives → COUNT (need argmax) if re.search(r"\b(the most|most often|most frequent)\b", q, re.I): return AnswerIntent.COUNT - # Temporal superlatives → LATEST (ordering by date) if re.search(r"\b(most recent|first|earliest|latest|newest|oldest)\b", q, re.I): return AnswerIntent.LATEST if _RECENT_QUERY_RE.search(q): return AnswerIntent.LATEST - if re.search(r"\b(which|what are|list|name all)\b", q): return AnswerIntent.SET_MEMBERS + if re.search( + r"\b(analy[sz]e|clarify|summarize|explain\b.*\b(reasoning|basis|context|approach))" + r"|\b(legal opinion|legal basis|legal aid)" + r"|\b(comprehensive|in[- ]?depth)\b", + q, + ): + return AnswerIntent.ANALYSIS + if re.search(r"\bhow should (i|we)\s+(reply|respond|draft|write|prepare)\b", q): + return AnswerIntent.ANALYSIS + if re.search( + r"\b(check\s+(my|the|our)\s+\w+|flag\s+any|missing\s+or\s+incorrect" + r"|inconsisten|verify|validate|cross[- ]?check)\b", + q, + ): + return AnswerIntent.ANALYSIS + if re.search( + r"\b(my\s+(routine|usual|typical|approach|process|habit|workflow))" + r"|\b(how\s+(do|did)\s+i\s+usually)\b" + r"|\b(what\s+(is|are)\s+my\s+\w*\s*(like|routine|habit|approach))\b" + r"|\b(can you\s+(take a look|check|help me))\b", + q, + ): + return AnswerIntent.ANALYSIS + if re.search(r"\b(categorize|categori[sz]e|filter|organize|sort|group)\b", q): + return AnswerIntent.ANALYSIS + if "hippocamp" in qtype and "profiling" in qtype: + return AnswerIntent.ANALYSIS + return AnswerIntent.FREEFORM @@ -163,7 +144,6 @@ def rewrite_query_for_intent(question: str, intent: AnswerIntent) -> str: return q if intent == AnswerIntent.COUNT: - # If the original question asks "which X the most", the count is for argmax if _SUPERLATIVE_RE.search(q): return ( f"{q}\nList each occurrence of each distinct item across ALL sessions. " @@ -182,7 +162,12 @@ def rewrite_query_for_intent(question: str, intent: AnswerIntent) -> str: ) if intent == AnswerIntent.SET_MEMBERS: return f"{q}\nList all distinct relevant items with deduplication." - # FREEFORM: add derivation instruction for computation questions + if intent == AnswerIntent.ANALYSIS: + return ( + f"{q}\nExtract all relevant facts, details, and context from the retrieved documents. " + f"Synthesize a comprehensive, grounded answer using ONLY information found in the evidence. " + f"Cite specific details (names, dates, amounts, file references) where available." + ) if re.search(r"\bwhat time\b", q, re.I): return ( f"{q}\nIf the exact answer is not stated directly, COMPUTE it from the " @@ -213,16 +198,14 @@ def build_query_plan( AnswerIntent.DURATION, AnswerIntent.LATEST, AnswerIntent.SET_MEMBERS, + AnswerIntent.ANALYSIS, } - # Also expand FREEFORM questions that need multi-fact derivation. - # "how many" questions may be classified FREEFORM (e.g., "how many items" - # hits the quantity-unit filter) but still need map-reduce to aggregate - # across multiple sessions. Same for multi-session aggregation patterns. if not should_expand and intent == AnswerIntent.FREEFORM: q_lower = question.lower() if re.search( r"\b(what time|what day|what date|at what age|how many|how much" - r"|total number|in total|all the|list all|what are all)\b", + r"|total number|in total|all the|list all|what are all" + r"|based on|according to|please help|can you help)\b", q_lower, ): should_expand = True @@ -241,702 +224,5 @@ def build_query_plan( ) -def build_map_candidates( - results: Sequence[Dict[str, Any]], - *, - max_candidates: int, - per_candidate_max_chars: int, -) -> List[Dict[str, str]]: - out: List[Dict[str, str]] = [] - for row in list(results)[: max(1, int(max_candidates))]: - metadata = row.get("metadata") or {} - session_id = str(metadata.get("session_id") or "").strip() - memory_text = str(row.get("memory") or "") - if not session_id and memory_text: - match = _SESSION_ID_RE.search(memory_text) - if match: - session_id = match.group("session_id") - session_date = str( - metadata.get("event_time") - or metadata.get("session_date") - or metadata.get("event_date") - or "" - ).strip() - evidence = str(row.get("evidence_text") or "").strip() or memory_text - if not evidence.strip(): - continue - - out.append( - { - "session_id": session_id or "unknown", - "session_date": session_date, - "text": _truncate_text(evidence, per_candidate_max_chars), - } - ) - return out - - -def _truncate_text(text: str, max_chars: int) -> str: - if max_chars <= 0: - return "" - text = str(text or "") - if len(text) <= max_chars: - return text - return text[: max_chars - 3].rstrip() + "..." - - -def _extract_json_payload(raw: str) -> Optional[Any]: - if not raw: - return None - raw = raw.strip() - - try: - return json.loads(raw) - except json.JSONDecodeError: - pass - - obj_match = re.search(r"\{[\s\S]*\}", raw) - if obj_match: - candidate = obj_match.group(0) - try: - return json.loads(candidate) - except json.JSONDecodeError: - pass - - arr_match = re.search(r"\[[\s\S]*\]", raw) - if arr_match: - candidate = arr_match.group(0) - try: - return json.loads(candidate) - except json.JSONDecodeError: - pass - - return None - - -def _normalize_bool(value: Any) -> bool: - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - if isinstance(value, str): - return value.strip().lower() in {"1", "true", "yes", "y"} - return False - - -def _to_float(value: Any) -> Optional[float]: - if value is None: - return None - if isinstance(value, (int, float)): - return float(value) - text = str(value).strip() - if not text: - return None - text = text.replace(",", "") - if text.startswith("$"): - text = text[1:] - try: - return float(text) - except ValueError: - return None - - -def _parse_event_datetime(value: Any) -> Optional[datetime]: - if value is None: - return None - if isinstance(value, datetime): - dt = value - elif isinstance(value, date): - dt = datetime.combine(value, datetime.min.time()) - else: - text = str(value).strip() - if not text: - return None - if text.endswith("Z"): - text = text[:-1] + "+00:00" - try: - dt = datetime.fromisoformat(text) - except ValueError: - date_match = re.match(r"^(\d{4}-\d{2}-\d{2})", text) - if not date_match: - return None - try: - d = date.fromisoformat(date_match.group(1)) - except ValueError: - return None - dt = datetime.combine(d, datetime.min.time()) - - if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) - else: - dt = dt.astimezone(timezone.utc) - return dt - - -def extract_atomic_facts( - *, - llm: Any, - question: str, - question_type: str, - question_date: str, - candidates: Sequence[Dict[str, str]], -) -> List[Dict[str, Any]]: - if not candidates: - return [] - - candidate_blocks = [] - for idx, c in enumerate(candidates, start=1): - candidate_blocks.append( - "\n".join( - [ - f"[Candidate {idx}] session_id={c.get('session_id', 'unknown')} date={c.get('session_date', '')}", - c.get("text", ""), - ] - ) - ) - - prompt = ( - "You are a fact extraction engine for memory QA.\n" - "Extract only facts relevant to answering the question.\n" - "IMPORTANT: Deduplicate facts. If the same item/event appears in " - "multiple sessions, emit it ONCE with the canonical_key set.\n" - "canonical_key = a short lowercase identifier for the unique item " - "(e.g. 'boots_zara', 'blazer_dry_cleaning', 'project_alpha'). " - "Same real-world item across sessions MUST share the same canonical_key.\n" - "Return STRICT JSON only, no markdown.\n\n" - f"Question: {question}\n" - f"Question Type: {question_type or 'unknown'}\n" - f"Question Date: {question_date or 'unknown'}\n\n" - "Candidate Context:\n" - + "\n\n".join(candidate_blocks) - + "\n\n" - "Required JSON schema:\n" - "{\"facts\":[" - "{" - "\"session_id\":\"string\"," - "\"event_date\":\"YYYY-MM-DD or empty\"," - "\"subject\":\"string\"," - "\"predicate\":\"string\"," - "\"value\":\"string\"," - "\"numeric_value\":0," - "\"unit\":\"string\"," - "\"currency\":\"string\"," - "\"canonical_key\":\"unique_item_id (REQUIRED, lowercase, e.g. boots_zara)\"," - "\"relevant\":true" - "}" - "]}\n" - "Return an empty list if nothing relevant: {\"facts\":[]}" - ) - - try: - raw = str(llm.generate(prompt)).strip() - except Exception as exc: - logger.warning("Map-stage fact extraction failed: %s", exc) - return [] - - logger.info("Map-stage raw LLM response (first 800 chars): %s", raw[:800]) - payload = _extract_json_payload(raw) - if payload is None: - logger.warning("Map-stage payload parse failed. Raw response (first 500 chars): %s", raw[:500]) - return [] - - if isinstance(payload, list): - facts_raw = payload - elif isinstance(payload, dict): - facts_raw = payload.get("facts") - else: - facts_raw = None - - if not isinstance(facts_raw, list): - logger.warning("Map-stage facts_raw is not a list: %s", type(facts_raw)) - return [] - - logger.info("Map-stage extracted %d raw facts from LLM", len(facts_raw)) - facts: List[Dict[str, Any]] = [] - for row in facts_raw: - if not isinstance(row, dict): - continue - value = str(row.get("value") or "").strip() - subject = str(row.get("subject") or "").strip() - predicate = str(row.get("predicate") or "").strip() - if not value and not subject and not predicate: - continue - - facts.append( - { - "session_id": str(row.get("session_id") or "").strip(), - "event_date": str(row.get("event_date") or "").strip(), - "subject": subject, - "predicate": predicate, - "value": value, - "numeric_value": _to_float(row.get("numeric_value")), - "unit": str(row.get("unit") or "").strip().lower(), - "currency": str(row.get("currency") or "").strip().upper(), - "canonical_key": str(row.get("canonical_key") or "").strip(), - "relevant": _normalize_bool(row.get("relevant", True)), - } - ) - - logger.info("Map-stage final facts: %d (relevant=%d)", len(facts), sum(1 for f in facts if _normalize_bool(f.get("relevant", True)))) - if facts: - for i, f in enumerate(facts[:5]): - logger.info(" fact[%d]: subject=%s predicate=%s value=%s canonical_key=%s relevant=%s", - i, f.get("subject", ""), f.get("predicate", ""), f.get("value", ""), f.get("canonical_key", ""), f.get("relevant")) - return facts - - -def _extract_money_value(text: str) -> Optional[float]: - if not text: - return None - m = _MONEY_RE.search(text) - if not m: - return None - whole = m.group(1).replace(",", "") - frac = m.group(2) - try: - if frac: - return float(f"{whole}.{frac}") - return float(whole) - except ValueError: - return None - - -def _extract_duration_value(text: str) -> Optional[Tuple[float, str]]: - if not text: - return None - m = _DURATION_RE.search(text) - if not m: - return None - try: - return float(m.group(1)), m.group(2).lower() - except ValueError: - return None - - -def _normalize_unit(unit: str) -> str: - unit = str(unit or "").strip().lower() - if unit.endswith("s"): - unit = unit[:-1] - aliases = { - "yr": "year", - "yrs": "year", - "hr": "hour", - "hrs": "hour", - "min": "minute", - "mins": "minute", - } - return aliases.get(unit, unit) - - -def _duration_target_unit(question: str) -> str: - q = str(question or "").lower() - for unit in ("year", "month", "week", "day", "hour", "minute"): - if re.search(rf"\b{unit}s?\b", q): - return unit - return "day" - - -def _convert_duration(value: float, src_unit: str, dst_unit: str) -> Optional[float]: - src = _normalize_unit(src_unit) - dst = _normalize_unit(dst_unit) - to_days = { - "year": 365.0, - "month": 30.0, - "week": 7.0, - "day": 1.0, - "hour": 1.0 / 24.0, - "minute": 1.0 / 1440.0, - } - if src not in to_days or dst not in to_days: - return None - return (value * to_days[src]) / to_days[dst] - - -def reduce_atomic_facts( - *, - question: str, - intent: AnswerIntent, - facts: Sequence[Dict[str, Any]], -) -> Tuple[Optional[str], Dict[str, Any]]: - relevant = [f for f in facts if _normalize_bool(f.get("relevant", True))] - meta: Dict[str, Any] = { - "fact_count": len(facts), - "relevant_fact_count": len(relevant), - "intent": intent.value, - } - if not relevant: - return None, meta - - if intent == AnswerIntent.COUNT: - # Superlative COUNT (argmax): "which X the most" → return the most frequent VALUE - is_argmax = bool(_SUPERLATIVE_RE.search(question)) - if is_argmax: - # Count occurrences of each distinct value - value_counts: Dict[str, int] = {} - value_display: Dict[str, str] = {} # lowercase → original case - for f in relevant: - val = str(f.get("value") or "").strip() - if not val: - continue - key = val.lower() - value_counts[key] = value_counts.get(key, 0) + 1 - if key not in value_display: - value_display[key] = val - if not value_counts: - return None, meta - # Find the value with highest count - best_key = max(value_counts, key=value_counts.get) - best_count = value_counts[best_key] - meta["argmax_value"] = value_display[best_key] - meta["argmax_count"] = best_count - meta["value_distribution"] = { - value_display[k]: c for k, c in value_counts.items() - } - return value_display[best_key], meta - - # Standard COUNT: count unique items - keys = set() - for f in relevant: - key = str(f.get("canonical_key") or "").strip().lower() - if not key: - # Build key from predicate + normalized value. - # Include predicate so "return boots" != "pick up boots". - val = str(f.get("value") or "").strip().lower() - pred = str(f.get("predicate") or "").strip().lower() - # Strip common prefixes that don't change identity - for prefix in ("new ", "a ", "an ", "the ", "my ", "pair of ", "some "): - if val.startswith(prefix): - val = val[len(prefix):] - parts = [p for p in (pred, val) if p] - key = " | ".join(parts) if parts else "" - if key: - keys.add(key) - if not keys: - return None, meta - meta["reduced_unique_keys"] = len(keys) - return str(len(keys)), meta - - if intent == AnswerIntent.MONEY_SUM: - values: List[float] = [] - for f in relevant: - amount = _to_float(f.get("numeric_value")) - if amount is None: - amount = _extract_money_value(str(f.get("value") or "")) - if amount is None: - continue - values.append(amount) - if not values: - return None, meta - total = sum(values) - meta["money_terms"] = len(values) - if abs(total - round(total)) < 1e-9: - return f"${int(round(total)):,}", meta - return f"${total:,.2f}", meta - - if intent == AnswerIntent.DURATION: - target = _duration_target_unit(question) - values: List[float] = [] - for f in relevant: - numeric = _to_float(f.get("numeric_value")) - unit = _normalize_unit(str(f.get("unit") or "")) - if numeric is None or not unit: - parsed = _extract_duration_value(str(f.get("value") or "")) - if parsed: - numeric, unit = parsed - unit = _normalize_unit(unit) - if numeric is None or not unit: - continue - converted = _convert_duration(float(numeric), unit, target) - if converted is None: - continue - values.append(converted) - if not values: - return None, meta - total = sum(values) - meta["duration_terms"] = len(values) - rounded = round(total, 2) - if abs(rounded - round(rounded)) < 1e-9: - rounded = int(round(rounded)) - unit_out = target if rounded == 1 else f"{target}s" - return f"{rounded} {unit_out}", meta - - if intent == AnswerIntent.LATEST: - dated: List[Tuple[datetime, Dict[str, Any]]] = [] - for f in relevant: - dt = _parse_event_datetime(f.get("event_date")) - if dt is not None: - dated.append((dt, f)) - if dated: - dated.sort(key=lambda x: x[0], reverse=True) - best = dated[0][1] - answer = str(best.get("value") or "").strip() - if answer: - return answer, meta - # fallback: first relevant value - for f in relevant: - answer = str(f.get("value") or "").strip() - if answer: - return answer, meta - return None, meta - - if intent == AnswerIntent.SET_MEMBERS: - values = [] - seen = set() - for f in relevant: - val = str(f.get("value") or "").strip() - if not val: - continue - key = val.lower() - if key in seen: - continue - seen.add(key) - values.append(val) - if values: - return ", ".join(values), meta - return None, meta - - return None, meta - - -def _extract_numeric_mentions(text: str) -> List[float]: - values: List[float] = [] - if not text: - return values - for match in _MONEY_RE.finditer(str(text)): - whole = (match.group(1) or "").replace(",", "") - frac = match.group(2) - if not whole: - continue - try: - number = float(f"{whole}.{frac}") if frac else float(whole) - except ValueError: - continue - values.append(number) - return values - - -def deterministic_inconsistency_check( - *, - question: str, - intent: AnswerIntent, - results: Sequence[Dict[str, Any]], - coverage: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: - """Detect deterministic evidence inconsistencies before map/reduce. - - This is intentionally cheap and LLM-free. - """ - reasons: List[str] = [] - coverage_payload = dict(coverage or {}) - coverage_sufficient = bool(coverage_payload.get("sufficient")) - if not coverage_sufficient: - reasons.append("coverage_insufficient") - - if not results: - reasons.append("no_results") - return {"inconsistent": True, "reasons": reasons} - - intent_value = intent.value if isinstance(intent, AnswerIntent) else str(intent or "") - top_rows = list(results)[:12] - - if intent_value in {"count", "set_members"}: - numeric_candidates = set() - for row in top_rows: - text = str(row.get("evidence_text") or row.get("memory") or "") - for value in _extract_numeric_mentions(text): - if value >= 0: - numeric_candidates.add(round(float(value), 3)) - if len(numeric_candidates) >= 2: - reasons.append("count_numeric_conflict") - - if intent_value == "money_sum": - amounts = set() - for row in top_rows: - text = str(row.get("evidence_text") or row.get("memory") or "") - for value in _extract_numeric_mentions(text): - amounts.add(round(float(value), 2)) - if len(amounts) >= 2: - reasons.append("money_terms_multiple") - - if intent_value == "duration": - units = set() - for row in top_rows: - text = str(row.get("evidence_text") or row.get("memory") or "") - parsed = _extract_duration_value(text) - if parsed: - units.add(_normalize_unit(parsed[1])) - if len(units) >= 2: - reasons.append("duration_unit_mixed") - - if intent_value == "latest": - dated_hits = int(coverage_payload.get("dated_fact_count", 0) or 0) - if dated_hits <= 0: - # fallback scan for explicit dates in evidence - has_date_like = False - for row in top_rows: - text = str(row.get("evidence_text") or row.get("memory") or "") - if re.search(r"\b\d{4}-\d{2}-\d{2}\b", text): - has_date_like = True - break - if not has_date_like: - reasons.append("latest_missing_dated_evidence") - - return {"inconsistent": bool(reasons), "reasons": reasons} - - -def render_fact_context(facts: Sequence[Dict[str, Any]], max_facts: int = 20) -> str: - lines: List[str] = [] - for f in list(facts)[: max(1, int(max_facts))]: - if not _normalize_bool(f.get("relevant", True)): - continue - parts = [] - if f.get("event_date"): - parts.append(f"date={f['event_date']}") - if f.get("session_id"): - parts.append(f"session={f['session_id']}") - label = " ".join(parts) - value = str(f.get("value") or "").strip() - subj = str(f.get("subject") or "").strip() - pred = str(f.get("predicate") or "").strip() - body = " | ".join(x for x in [subj, pred, value] if x) - if not body: - continue - if label: - lines.append(f"- [{label}] {body}") - else: - lines.append(f"- {body}") - return "\n".join(lines) - - def is_low_confidence_answer(answer: str) -> bool: return bool(_LOW_CONFIDENCE_RE.search(str(answer or "").strip())) - - -def should_override_with_reducer(intent: AnswerIntent) -> bool: - return intent in _NUMERIC_INTENTS or intent == AnswerIntent.LATEST - - -# --------------------------------------------------------------------------- -# Event-first reducer — zero LLM cost -# --------------------------------------------------------------------------- - - -def reduce_from_episodic_events( - *, - question: str, - intent: AnswerIntent, - events: Sequence[Dict[str, Any]], -) -> Tuple[Optional[str], Dict[str, Any]]: - """Reduce episodic events into an answer — zero LLM cost. - - Works directly from event dicts produced by episodic_index rather than - LLM-extracted atomic facts. Uses the same deterministic logic as - reduce_atomic_facts() but adapted for the event schema. - """ - meta: Dict[str, Any] = { - "event_count": len(events), - "intent": intent.value, - "source": "episodic_events", - } - if not events: - return None, meta - - if intent == AnswerIntent.COUNT: - keys: set = set() - for ev in events: - key = str(ev.get("canonical_key") or "").strip().lower() - if not key: - value = str(ev.get("value_text") or "").strip().lower() - if value: - key = value - if key: - keys.add(key) - if not keys: - return None, meta - meta["reduced_unique_keys"] = len(keys) - return str(len(keys)), meta - - if intent == AnswerIntent.MONEY_SUM: - values: List[float] = [] - for ev in events: - if str(ev.get("event_type") or "").lower() != "money": - continue - amount = _to_float(ev.get("value_num")) - if amount is None: - amount = _extract_money_value(str(ev.get("value_text") or "")) - if amount is not None: - values.append(amount) - if not values: - return None, meta - total = sum(values) - meta["money_terms"] = len(values) - if abs(total - round(total)) < 1e-9: - return f"${int(round(total)):,}", meta - return f"${total:,.2f}", meta - - if intent == AnswerIntent.DURATION: - target = _duration_target_unit(question) - values = [] - for ev in events: - if str(ev.get("event_type") or "").lower() != "duration": - continue - numeric = _to_float(ev.get("value_num")) - unit = _normalize_unit(str(ev.get("value_unit") or "")) - if numeric is None or not unit: - parsed = _extract_duration_value(str(ev.get("value_text") or "")) - if parsed: - numeric, unit = parsed - unit = _normalize_unit(unit) - if numeric is None or not unit: - continue - converted = _convert_duration(float(numeric), unit, target) - if converted is not None: - values.append(converted) - if not values: - return None, meta - total = sum(values) - meta["duration_terms"] = len(values) - rounded = round(total, 2) - if abs(rounded - round(rounded)) < 1e-9: - rounded = int(round(rounded)) - unit_out = target if rounded == 1 else f"{target}s" - return f"{rounded} {unit_out}", meta - - if intent == AnswerIntent.LATEST: - dated: List[Tuple[datetime, Dict[str, Any]]] = [] - for ev in events: - dt = _parse_event_datetime( - ev.get("normalized_time_start") or ev.get("event_time") - ) - if dt is not None: - dated.append((dt, ev)) - if dated: - dated.sort(key=lambda x: x[0], reverse=True) - best = dated[0][1] - answer = str(best.get("value_text") or "").strip() - if answer: - return answer, meta - # fallback: first event value - for ev in events: - answer = str(ev.get("value_text") or "").strip() - if answer: - return answer, meta - return None, meta - - if intent == AnswerIntent.SET_MEMBERS: - values_list: List[str] = [] - seen: set = set() - for ev in events: - val = str(ev.get("value_text") or "").strip() - if not val: - continue - key = val.lower() - if key in seen: - continue - seen.add(key) - values_list.append(val) - if values_list: - return ", ".join(values_list), meta - return None, meta - - return None, meta diff --git a/dhee/core/code_exec_counter.py b/dhee/core/code_exec_counter.py deleted file mode 100644 index 16f1a52..0000000 --- a/dhee/core/code_exec_counter.py +++ /dev/null @@ -1,266 +0,0 @@ -"""Sandboxed code-exec counting for deterministic aggregation. - -Instead of asking the LLM to count/sum items across multiple sessions, -this module has the LLM emit Python code that enumerates items, then -executes it in a restricted sandbox to produce a deterministic answer. -""" - -from __future__ import annotations - -import io -import logging -import re -import threading -from datetime import date, datetime, timedelta -from typing import Any, Dict, List, Optional - -logger = logging.getLogger(__name__) - -# Allowed builtins for the sandbox — no file I/O, no imports, no exec/eval. -# Includes datetime types for date arithmetic (e.g., days between two dates). -_SAFE_BUILTINS = { - "len": len, - "sum": sum, - "max": max, - "min": min, - "sorted": sorted, - "set": set, - "list": list, - "dict": dict, - "int": int, - "float": float, - "str": str, - "range": range, - "enumerate": enumerate, - "zip": zip, - "print": print, # will be redirected to StringIO - "abs": abs, - "round": round, - "tuple": tuple, - "True": True, - "False": False, - "None": None, - # Date/time types — safe, no I/O, needed for temporal arithmetic - "datetime": datetime, - "date": date, - "timedelta": timedelta, -} - -# Safe import lines that can be stripped before the blocked-pattern check. -# LLMs habitually emit these even when the types are already available. -_SAFE_IMPORT_RE = re.compile( - r"^\s*(?:from\s+datetime\s+import\s+[\w\s,]+|import\s+datetime)\s*$", - re.MULTILINE, -) - -# Patterns that indicate dangerous code. -# Only match import/from at statement-start (^) to avoid false positives in -# comments like "# data from session 1" or strings like "trip from NYC". -_BLOCKED_PATTERNS = re.compile( - r"\b(__\w+__|exec|eval|compile|globals|locals|getattr|setattr|delattr" - r"|subprocess|shutil|pathlib)\b" - r"|^import\s+(?!datetime\b)" # block `import X` unless X is datetime - r"|^from\s+(?!datetime\b)\w" # block `from X` unless X is datetime - r"|\bos\.\w" # block os.anything - r"|\bsys\.\w" # block sys.anything - r"|\bopen\s*\(", # block open() calls, not the word "open" - re.MULTILINE, -) - -_CODE_BLOCK_RE = re.compile(r"```(?:python)?\s*\n(.*?)```", re.DOTALL) -_ANSWER_RE = re.compile(r"^ANSWER:\s*(.+)$", re.MULTILINE) -_ITEMS_RE = re.compile(r"^ITEMS:\s*(.+)$", re.MULTILINE) - - -_MAX_CODE_EXEC_CONTEXT_CHARS = 12000 - - -def build_code_counting_prompt( - question: str, - retrieved_context: str, - question_date: str = "", -) -> str: - """Build a prompt that asks the LLM to generate Python counting code.""" - # Truncate context to fit smaller models' context windows - ctx = retrieved_context[:_MAX_CODE_EXEC_CONTEXT_CHARS] if len(retrieved_context) > _MAX_CODE_EXEC_CONTEXT_CHARS else retrieved_context - date_line = f"\nQuestion Date (today): {question_date}" if question_date else "" - return f"""You are a precise counting and date-arithmetic assistant. Read the question and context below, -then write a short Python script following this EXACT pattern: - -```python -# Step 1: Define ONE list of (label, value) tuples — one entry per distinct item -items = [ - ("item description from session X", numeric_value), - ("item description from session Y", numeric_value), -] - -# Step 2: Compute answer from the items list ONLY -answer = sum(v for _, v in items) # for sum/duration questions -# OR: answer = len(items) # for counting questions - -# Step 3: Print results — the answer MUST be derived from the items list above -print(f"ANSWER: {{answer}}") -print(f"ITEMS: {{[label for label, _ in items]}}") -``` - -CRITICAL RULES: -- Create exactly ONE list called `items` with ALL relevant entries from ALL sessions -- For "how many hours/days/weeks" questions: each tuple is ("description", hours_or_days_number), answer = sum() -- For "how many times/items" questions: each tuple is ("description", 1), answer = len() -- For "how much money" questions: each tuple is ("description", dollar_amount), answer = sum() -- For "how many days/months between X and Y" questions: use datetime(year, month, day) to compute exact differences -- For "how many months/days ago" questions: compute from the Question Date using datetime arithmetic -- The ANSWER line MUST be computed from the `items` list or datetime math — never hardcode it -- Available without import: datetime, date, timedelta, len, sum, min, max, abs, round -- Read EVERY session in the context — missing one entry means a wrong answer -- If the same event appears in multiple sessions, include it only ONCE (deduplicate) -- Add the unit to the ANSWER (e.g., "8 days", "140 hours", "$5850", "30 days") - -Question: {question}{date_line} - -Context: -{ctx} - -Write ONLY the Python code inside a code block: -```python -""" - - -def execute_counting_code(code: str, timeout: float = 5.0) -> Optional[Dict[str, Any]]: - """Execute counting code in a restricted sandbox. - - Returns {{"answer": str, "items": list}} or None on failure. - """ - # Strip safe import lines (datetime) before the blocked-pattern check. - # LLMs emit these habitually; the types are already in _SAFE_BUILTINS. - code = _SAFE_IMPORT_RE.sub("", code) - - # Validate: block dangerous patterns - if _BLOCKED_PATTERNS.search(code): - logger.warning("Code-exec blocked: dangerous pattern detected in: %s", code[:200]) - return None - - # Capture stdout - captured = io.StringIO() - - restricted_globals = {"__builtins__": {}} - for name, obj in _SAFE_BUILTINS.items(): - restricted_globals[name] = obj - - # Redirect print to captured output - def safe_print(*args, **kwargs): - kwargs["file"] = captured - print(*args, **kwargs) - - restricted_globals["print"] = safe_print - - result = {"completed": False, "error": None} - - def _run(): - try: - exec(code, restricted_globals) # noqa: S102 - result["completed"] = True - except Exception as e: - result["error"] = str(e) - - thread = threading.Thread(target=_run, daemon=True) - thread.start() - thread.join(timeout=timeout) - - if not result["completed"]: - if result["error"]: - logger.warning("Code-exec error: %s", result["error"]) - else: - logger.warning("Code-exec timeout after %.1fs", timeout) - return None - - output = captured.getvalue() - - # Parse ANSWER line - answer_match = _ANSWER_RE.search(output) - if not answer_match: - logger.debug("Code-exec: no ANSWER line in output: %r", output[:200]) - return None - - answer = answer_match.group(1).strip() - - # Parse ITEMS line (optional) - items: List[str] = [] - items_match = _ITEMS_RE.search(output) - if items_match: - items_raw = items_match.group(1).strip() - # Parse list by splitting on commas (avoid eval for safety) - inner = items_raw.strip("[]") - if inner: - items = [x.strip().strip("'\"") for x in inner.split(",") if x.strip().strip("'\"")] - - return {"answer": answer, "items": items} - - -def refine_count_with_code_exec( - *, - llm: Any, - question: str, - question_type: str, - retrieved_context: str, - draft_answer: str, - question_date: str = "", -) -> Optional[str]: - """Full pipeline: prompt LLM to generate code -> exec -> parse answer. - - Returns the refined answer string, or None if code-exec fails. - """ - prompt = build_code_counting_prompt(question, retrieved_context, question_date=question_date) - - try: - raw_response = str(llm.generate(prompt)).strip() - except Exception as e: - logger.warning("Code-exec LLM call failed: %s", e) - return None - - # Extract code block - code_match = _CODE_BLOCK_RE.search(raw_response) - if code_match: - code = code_match.group(1).strip() - else: - # Try treating the entire response as code if it looks like Python - lines = raw_response.strip().splitlines() - code_lines = [ln for ln in lines if not ln.startswith("```")] - if any(ln.strip().startswith(("items", "total", "count", "print", "#", "result")) for ln in code_lines): - code = "\n".join(code_lines) - else: - logger.debug("Code-exec: no code block found in LLM response") - return None - - result = execute_counting_code(code) - if not result: - return None - - answer = result["answer"] - items = result.get("items", []) - - logger.info( - "Code-exec result: answer=%r, items_count=%d, draft=%r", - answer, len(items), draft_answer, - ) - - # Cross-check: if we have items and an answer, verify consistency - if items and answer: - try: - answer_num = float(re.sub(r"[^\d.]", "", answer.split()[0])) - # For pure counting, items list length should match - q_lower = question.lower() - if any(w in q_lower for w in ("how many times", "how many", "number of")): - if abs(answer_num - len(items)) > 0.5 and not any( - w in q_lower for w in ("hours", "days", "weeks", "months", "minutes") - ): - # Items count is more trustworthy for pure counting - logger.debug( - "Code-exec cross-check: stated %s but %d items; using items count", - answer, len(items), - ) - answer = str(len(items)) - except (ValueError, IndexError): - pass - - return answer if answer else None diff --git a/dhee/core/conflicts.py b/dhee/core/conflicts.py deleted file mode 100644 index a8c7ca9..0000000 --- a/dhee/core/conflicts.py +++ /dev/null @@ -1,238 +0,0 @@ -"""Dhee v3 — Cognitive Conflict Store + Auto-Resolution. - -Tracks contradictions and disagreements between derived objects. -Conflicts are explicit rows, not silent resolution. - -Auto-resolution: if one side has confidence > 0.8 and the other < 0.3, -auto-resolve in favor of the high-confidence side. Otherwise, flag -for manual resolution. - -Conflict types: - - belief_contradiction: two beliefs claim opposing things - - anchor_disagreement: anchor candidate disagrees with resolved anchor - - distillation_conflict: candidate conflicts with promoted truth - - invalidation_dispute: partial invalidation verification disagrees - -Design contract: - - Every contradiction gets an explicit row - - Auto-resolution only when confidence gap > 0.5 - - Zero LLM calls — confidence comparison only -""" - -from __future__ import annotations - -import json -import logging -import sqlite3 -import threading -import uuid -from contextlib import contextmanager -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional - -logger = logging.getLogger(__name__) - - -def _utcnow_iso() -> str: - return datetime.now(timezone.utc).isoformat() - - -# Auto-resolution thresholds -AUTO_RESOLVE_HIGH = 0.8 -AUTO_RESOLVE_LOW = 0.3 -AUTO_RESOLVE_GAP = 0.5 - - -class ConflictStore: - """Manages cognitive conflicts in the database.""" - - def __init__(self, conn: sqlite3.Connection, lock: threading.RLock): - self._conn = conn - self._lock = lock - - @contextmanager - def _tx(self): - with self._lock: - try: - yield self._conn - self._conn.commit() - except Exception: - self._conn.rollback() - raise - - def create( - self, - conflict_type: str, - side_a_type: str, - side_a_id: str, - side_b_type: str, - side_b_id: str, - *, - side_a_confidence: Optional[float] = None, - side_b_confidence: Optional[float] = None, - ) -> Dict[str, Any]: - """Create a conflict. Attempts auto-resolution if confidence gap is clear. - - Returns dict with conflict_id, resolution_status, and auto_resolution details. - """ - cid = str(uuid.uuid4()) - now = _utcnow_iso() - - # Attempt auto-resolution - resolution_status = "open" - resolution_json = None - auto_confidence = None - - if side_a_confidence is not None and side_b_confidence is not None: - gap = abs(side_a_confidence - side_b_confidence) - if gap >= AUTO_RESOLVE_GAP: - if (side_a_confidence >= AUTO_RESOLVE_HIGH - and side_b_confidence <= AUTO_RESOLVE_LOW): - resolution_status = "auto_resolved" - auto_confidence = side_a_confidence - resolution_json = json.dumps({ - "winner": "side_a", - "winner_type": side_a_type, - "winner_id": side_a_id, - "reason": f"confidence gap: {side_a_confidence:.2f} vs {side_b_confidence:.2f}", - }) - elif (side_b_confidence >= AUTO_RESOLVE_HIGH - and side_a_confidence <= AUTO_RESOLVE_LOW): - resolution_status = "auto_resolved" - auto_confidence = side_b_confidence - resolution_json = json.dumps({ - "winner": "side_b", - "winner_type": side_b_type, - "winner_id": side_b_id, - "reason": f"confidence gap: {side_b_confidence:.2f} vs {side_a_confidence:.2f}", - }) - - with self._tx() as conn: - conn.execute( - """INSERT INTO cognitive_conflicts - (conflict_id, conflict_type, side_a_type, side_a_id, - side_b_type, side_b_id, detected_at, - resolution_status, resolution_json, - auto_resolution_confidence) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - cid, conflict_type, side_a_type, side_a_id, - side_b_type, side_b_id, now, - resolution_status, resolution_json, auto_confidence, - ), - ) - - result = { - "conflict_id": cid, - "conflict_type": conflict_type, - "resolution_status": resolution_status, - } - if resolution_status == "auto_resolved": - result["auto_resolution"] = json.loads(resolution_json) - return result - - def resolve( - self, - conflict_id: str, - resolution: Dict[str, Any], - *, - by: str = "user", - ) -> bool: - """Manually resolve a conflict. Returns True if updated.""" - status = "user_resolved" if by == "user" else "auto_resolved" - with self._tx() as conn: - result = conn.execute( - """UPDATE cognitive_conflicts - SET resolution_status = ?, resolution_json = ? - WHERE conflict_id = ? AND resolution_status = 'open'""", - (status, json.dumps(resolution), conflict_id), - ) - return result.rowcount > 0 - - def defer(self, conflict_id: str) -> bool: - """Defer a conflict for later resolution.""" - with self._tx() as conn: - result = conn.execute( - """UPDATE cognitive_conflicts - SET resolution_status = 'deferred' - WHERE conflict_id = ? AND resolution_status = 'open'""", - (conflict_id,), - ) - return result.rowcount > 0 - - def get(self, conflict_id: str) -> Optional[Dict[str, Any]]: - with self._lock: - row = self._conn.execute( - "SELECT * FROM cognitive_conflicts WHERE conflict_id = ?", - (conflict_id,), - ).fetchone() - return self._row_to_dict(row) if row else None - - def list_open(self, *, limit: int = 50) -> List[Dict[str, Any]]: - """Get all open (unresolved) conflicts.""" - with self._lock: - rows = self._conn.execute( - """SELECT * FROM cognitive_conflicts - WHERE resolution_status = 'open' - ORDER BY detected_at DESC - LIMIT ?""", - (limit,), - ).fetchall() - return [self._row_to_dict(r) for r in rows] - - def list_for_object( - self, object_type: str, object_id: str - ) -> List[Dict[str, Any]]: - """Get all conflicts involving a specific object.""" - with self._lock: - rows = self._conn.execute( - """SELECT * FROM cognitive_conflicts - WHERE (side_a_type = ? AND side_a_id = ?) - OR (side_b_type = ? AND side_b_id = ?) - ORDER BY detected_at DESC""", - (object_type, object_id, object_type, object_id), - ).fetchall() - return [self._row_to_dict(r) for r in rows] - - def count_open(self) -> int: - with self._lock: - row = self._conn.execute( - "SELECT COUNT(*) FROM cognitive_conflicts WHERE resolution_status = 'open'" - ).fetchone() - return row[0] if row else 0 - - def has_open_conflicts( - self, object_type: str, object_id: str - ) -> bool: - """Check if an object has any open conflicts (for retrieval penalty).""" - with self._lock: - row = self._conn.execute( - """SELECT 1 FROM cognitive_conflicts - WHERE resolution_status = 'open' - AND ((side_a_type = ? AND side_a_id = ?) - OR (side_b_type = ? AND side_b_id = ?)) - LIMIT 1""", - (object_type, object_id, object_type, object_id), - ).fetchone() - return row is not None - - @staticmethod - def _row_to_dict(row: sqlite3.Row) -> Dict[str, Any]: - resolution = row["resolution_json"] - if isinstance(resolution, str): - try: - resolution = json.loads(resolution) - except (json.JSONDecodeError, TypeError): - resolution = None - return { - "conflict_id": row["conflict_id"], - "conflict_type": row["conflict_type"], - "side_a_type": row["side_a_type"], - "side_a_id": row["side_a_id"], - "side_b_type": row["side_b_type"], - "side_b_id": row["side_b_id"], - "detected_at": row["detected_at"], - "resolution_status": row["resolution_status"], - "resolution": resolution, - "auto_resolution_confidence": row["auto_resolution_confidence"], - } diff --git a/dhee/core/consolidation.py b/dhee/core/consolidation.py deleted file mode 100644 index 7c07add..0000000 --- a/dhee/core/consolidation.py +++ /dev/null @@ -1,162 +0,0 @@ -""" -Consolidation Engine — promotes important active signals to passive memory. - -v3 FIX: Breaks the feedback loop identified in the architecture critique. - -Old behavior (DANGEROUS): - _promote_to_passive() called memory.add() → triggered full enrichment - pipeline → could create new active signals → infinite consolidation loop. - -New behavior (SAFE): - _promote_to_passive() calls memory.add() with infer=False AND tags - promoted memories with source="consolidated" provenance metadata. - _should_promote() rejects signals that were already consolidated - (prevents re-consolidation of promoted content). - -The enrichment pipeline is explicitly skipped for consolidated memories -because the content was already enriched when it entered active memory. -""" - -import logging -from typing import Any, Dict, Protocol, TYPE_CHECKING - -from dhee.configs.active import ActiveMemoryConfig - -if TYPE_CHECKING: - from dhee.memory.main import FullMemory - -logger = logging.getLogger(__name__) - - -class ActiveMemoryStore(Protocol): - """Structural contract for consolidation's active-memory dependency.""" - - def get_consolidation_candidates( - self, - *, - min_age_seconds: int, - min_reads: int, - ) -> list[Dict[str, Any]]: - ... - - def mark_consolidated(self, signal_ids: list[str]) -> None: - ... - - -class ConsolidationEngine: - """Promotes qualifying active signals into passive (Engram) memory.""" - - def __init__( - self, - active_store: ActiveMemoryStore, - memory: "FullMemory", - config: ActiveMemoryConfig, - ): - self.active = active_store - self.memory = memory - self.config = config - self.consolidation = config.consolidation - - def run_cycle(self) -> Dict[str, Any]: - """Run one consolidation cycle. Returns promotion stats.""" - candidates = self.active.get_consolidation_candidates( - min_age_seconds=self.config.consolidation_min_age_seconds, - min_reads=self.config.consolidation_min_reads, - ) - - promoted = [] - skipped = 0 - errors = 0 - feedback_loop_blocked = 0 - - for signal in candidates: - if not self._should_promote(signal): - skipped += 1 - continue - try: - self._promote_to_passive(signal) - promoted.append(signal["id"]) - except Exception: - logger.exception("Failed to promote signal %s", signal["id"]) - errors += 1 - - if promoted: - self.active.mark_consolidated(promoted) - - return { - "promoted": len(promoted), - "checked": len(candidates), - "skipped": skipped, - "errors": errors, - "feedback_loop_blocked": feedback_loop_blocked, - } - - def _should_promote(self, signal: Dict[str, Any]) -> bool: - """Determine if a signal qualifies for promotion to passive memory.""" - signal_type = signal.get("signal_type", "") - ttl_tier = signal.get("ttl_tier", "") - read_count = signal.get("read_count", 0) - - # v3 FIX: Block re-consolidation of already-consolidated content. - # This breaks the feedback loop where promoted content generates - # new active signals that get re-consolidated infinitely. - signal_metadata = signal.get("metadata", {}) - if isinstance(signal_metadata, dict): - if signal_metadata.get("source") == "consolidated": - return False - if signal_metadata.get("consolidated_from"): - return False - - # Also check the value field for consolidation markers - value = signal.get("value", "") - if isinstance(value, str) and "[consolidated]" in value.lower(): - return False - - # Directives always promote - if signal_type == "directive" and self.consolidation.directive_to_passive: - return True - - # Critical tier promotes - if ttl_tier == "critical" and self.consolidation.promote_critical: - return True - - # High-read signals promote - if ( - self.consolidation.promote_high_read - and read_count >= self.consolidation.promote_read_threshold - ): - return True - - return False - - def _promote_to_passive(self, signal: Dict[str, Any]) -> None: - """Add a signal's content to passive memory. - - v3 FIX: Uses infer=False to skip the LLM enrichment pipeline. - Tags with source="consolidated" to prevent re-consolidation. - """ - signal_type = signal.get("signal_type", "event") - user_id = signal.get("user_id", "default") - key = signal.get("key", "") - value = signal.get("value", "") - - # Build content string - content = f"[{key}] {value}" if key else value - - self.memory.add( - messages=content, - user_id=user_id, - metadata={ - # Provenance: identifies this as consolidated content - "source": "consolidated", - "consolidated_from": signal.get("id"), - "signal_key": key, - "signal_type": signal_type, - }, - immutable=(signal_type == "directive"), - initial_layer="lml" if signal_type == "directive" else "sml", - # v3 FIX: Skip enrichment pipeline entirely. - # Content was already enriched when it entered active memory. - # Re-enrichment would generate divergent facts/entities. - infer=False, - ) diff --git a/dhee/core/derived_store.py b/dhee/core/derived_store.py deleted file mode 100644 index 390492c..0000000 --- a/dhee/core/derived_store.py +++ /dev/null @@ -1,1145 +0,0 @@ -"""Dhee v3 — Type-specific derived cognition stores. - -Each derived type gets its own store class because they have different: -- Lifecycle rules (beliefs: Bayesian; policies: win-rate; insights: strength) -- Indexing needs (anchors: era/place; policies: granularity/utility) -- Invalidation behavior (beliefs: retract; policies: deprecate; anchors: re-resolve) -- Conflict semantics (beliefs: contradiction pairs; policies: approach conflicts) - -All stores share a common database connection (from RawEventStore) and -the derived_lineage table for traceability. - -Design contract: - - Every derived object has derivation_version + lineage_fingerprint - - Invalidation statuses (stale, suspect, invalidated) are orthogonal to - type-specific lifecycle statuses - - Zero LLM calls — pure storage and state transitions -""" - -from __future__ import annotations - -import hashlib -import json -import logging -import sqlite3 -import threading -import uuid -from contextlib import contextmanager -from dataclasses import dataclass, field -from datetime import datetime, timezone -from enum import Enum -from typing import Any, Dict, List, Optional, Tuple - -logger = logging.getLogger(__name__) - - -def _utcnow_iso() -> str: - return datetime.now(timezone.utc).isoformat() - - -def _parse_json(value: Any, default: Any = None) -> Any: - if value is None: - return default if default is not None else [] - if isinstance(value, (dict, list)): - return value - try: - return json.loads(value) - except (json.JSONDecodeError, TypeError): - return default if default is not None else [] - - -def _compute_lineage_fingerprint(source_event_ids: List[str], version: int) -> str: - """Deterministic fingerprint from sorted source IDs + version.""" - payload = "|".join(sorted(source_event_ids)) + f"|v{version}" - return hashlib.sha256(payload.encode()).hexdigest()[:16] - - -# ========================================================================= -# Enums -# ========================================================================= - -class BeliefStatus(str, Enum): - PROPOSED = "proposed" - HELD = "held" - CHALLENGED = "challenged" - REVISED = "revised" - RETRACTED = "retracted" - # Invalidation statuses (from three-tier model) - STALE = "stale" - SUSPECT = "suspect" - INVALIDATED = "invalidated" - - -class PolicyStatus(str, Enum): - PROPOSED = "proposed" - ACTIVE = "active" - VALIDATED = "validated" - DEPRECATED = "deprecated" - STALE = "stale" - SUSPECT = "suspect" - INVALIDATED = "invalidated" - - -class PolicyGranularity(str, Enum): - TASK = "task" - STEP = "step" - - -class InsightType(str, Enum): - CAUSAL = "causal" - WARNING = "warning" - STRATEGY = "strategy" - PATTERN = "pattern" - - -class AbstractionLevel(str, Enum): - SPECIFIC = "specific" - DOMAIN = "domain" - UNIVERSAL = "universal" - - -class DerivedType(str, Enum): - BELIEF = "belief" - POLICY = "policy" - ANCHOR = "anchor" - INSIGHT = "insight" - HEURISTIC = "heuristic" - - -class DerivedStatus(str, Enum): - """Common invalidation statuses across all derived types.""" - ACTIVE = "active" - STALE = "stale" - SUSPECT = "suspect" - INVALIDATED = "invalidated" - - -# ========================================================================= -# Base store with shared connection management -# ========================================================================= - -class _DerivedStoreBase: - """Shared connection management for all derived stores. - - All stores share a single SQLite connection. The connection is - created externally (by RawEventStore or a coordinator) and passed in. - """ - - def __init__(self, conn: sqlite3.Connection, lock: threading.RLock): - self._conn = conn - self._lock = lock - - @contextmanager - def _tx(self): - with self._lock: - try: - yield self._conn - self._conn.commit() - except Exception: - self._conn.rollback() - raise - - -# ========================================================================= -# BeliefStore -# ========================================================================= - -class BeliefStore(_DerivedStoreBase): - """Confidence-tracked claims with Bayesian updates and contradiction detection.""" - - def add( - self, - user_id: str, - claim: str, - *, - domain: str = "general", - confidence: float = 0.5, - source_memory_ids: Optional[List[str]] = None, - source_episode_ids: Optional[List[str]] = None, - tags: Optional[List[str]] = None, - belief_id: Optional[str] = None, - ) -> str: - bid = belief_id or str(uuid.uuid4()) - now = _utcnow_iso() - smids = source_memory_ids or [] - seids = source_episode_ids or [] - fp = _compute_lineage_fingerprint(smids, 1) - - with self._tx() as conn: - conn.execute( - """INSERT INTO beliefs - (belief_id, user_id, claim, domain, status, confidence, - source_memory_ids, source_episode_ids, derivation_version, - lineage_fingerprint, created_at, updated_at, tags_json) - VALUES (?, ?, ?, ?, 'proposed', ?, ?, ?, 1, ?, ?, ?, ?)""", - ( - bid, user_id, claim, domain, confidence, - json.dumps(smids), json.dumps(seids), fp, - now, now, json.dumps(tags or []), - ), - ) - return bid - - def get(self, belief_id: str) -> Optional[Dict[str, Any]]: - with self._lock: - row = self._conn.execute( - "SELECT * FROM beliefs WHERE belief_id = ?", - (belief_id,), - ).fetchone() - return self._row_to_dict(row) if row else None - - def list_by_user( - self, - user_id: str, - *, - domain: Optional[str] = None, - status: Optional[str] = None, - min_confidence: float = 0.0, - limit: int = 50, - ) -> List[Dict[str, Any]]: - query = "SELECT * FROM beliefs WHERE user_id = ? AND confidence >= ?" - params: list = [user_id, min_confidence] - if domain: - query += " AND domain = ?" - params.append(domain) - if status: - query += " AND status = ?" - params.append(status) - query += " ORDER BY confidence DESC LIMIT ?" - params.append(limit) - - with self._lock: - rows = self._conn.execute(query, params).fetchall() - return [self._row_to_dict(r) for r in rows] - - def update_confidence( - self, - belief_id: str, - new_confidence: float, - *, - new_status: Optional[str] = None, - evidence: Optional[Dict[str, Any]] = None, - revision_reason: Optional[str] = None, - ) -> bool: - """Update belief confidence with optional evidence and revision tracking.""" - now = _utcnow_iso() - with self._tx() as conn: - row = conn.execute( - "SELECT confidence, status, evidence_json, revisions_json FROM beliefs WHERE belief_id = ?", - (belief_id,), - ).fetchone() - if not row: - return False - - old_conf = row["confidence"] - old_status = row["status"] - - # Append evidence - evidence_list = _parse_json(row["evidence_json"], []) - if evidence: - evidence_list.append(evidence) - - # Append revision - revisions = _parse_json(row["revisions_json"], []) - revisions.append({ - "timestamp": now, - "old_confidence": old_conf, - "new_confidence": new_confidence, - "old_status": old_status, - "new_status": new_status or old_status, - "reason": revision_reason or "confidence_update", - }) - - # Auto-derive status if not explicitly set - status = new_status - if not status: - if new_confidence >= 0.7: - status = BeliefStatus.HELD.value - elif new_confidence <= 0.1: - status = BeliefStatus.RETRACTED.value - else: - status = old_status - - conn.execute( - """UPDATE beliefs - SET confidence = ?, status = ?, evidence_json = ?, - revisions_json = ?, updated_at = ? - WHERE belief_id = ?""", - ( - new_confidence, status, json.dumps(evidence_list), - json.dumps(revisions), now, belief_id, - ), - ) - return True - - def add_contradiction(self, belief_a_id: str, belief_b_id: str) -> None: - """Link two beliefs as contradicting each other.""" - now = _utcnow_iso() - with self._tx() as conn: - for bid, other_id in [(belief_a_id, belief_b_id), (belief_b_id, belief_a_id)]: - row = conn.execute( - "SELECT contradicts_ids FROM beliefs WHERE belief_id = ?", - (bid,), - ).fetchone() - if row: - ids = _parse_json(row["contradicts_ids"], []) - if other_id not in ids: - ids.append(other_id) - conn.execute( - "UPDATE beliefs SET contradicts_ids = ?, status = 'challenged', updated_at = ? WHERE belief_id = ?", - (json.dumps(ids), now, bid), - ) - - def set_status(self, belief_id: str, status: str) -> bool: - with self._tx() as conn: - result = conn.execute( - "UPDATE beliefs SET status = ?, updated_at = ? WHERE belief_id = ?", - (status, _utcnow_iso(), belief_id), - ) - return result.rowcount > 0 - - def get_by_invalidation_status( - self, status: str, *, limit: int = 100 - ) -> List[Dict[str, Any]]: - """Get beliefs in stale/suspect/invalidated status for repair jobs.""" - with self._lock: - rows = self._conn.execute( - "SELECT * FROM beliefs WHERE status = ? LIMIT ?", - (status, limit), - ).fetchall() - return [self._row_to_dict(r) for r in rows] - - @staticmethod - def _row_to_dict(row: sqlite3.Row) -> Dict[str, Any]: - return { - "belief_id": row["belief_id"], - "user_id": row["user_id"], - "claim": row["claim"], - "domain": row["domain"], - "status": row["status"], - "confidence": row["confidence"], - "evidence": _parse_json(row["evidence_json"], []), - "revisions": _parse_json(row["revisions_json"], []), - "contradicts_ids": _parse_json(row["contradicts_ids"], []), - "source_memory_ids": _parse_json(row["source_memory_ids"], []), - "source_episode_ids": _parse_json(row["source_episode_ids"], []), - "derivation_version": row["derivation_version"], - "lineage_fingerprint": row["lineage_fingerprint"], - "created_at": row["created_at"], - "updated_at": row["updated_at"], - "tags": _parse_json(row["tags_json"], []), - } - - -# ========================================================================= -# PolicyStore -# ========================================================================= - -class PolicyStore(_DerivedStoreBase): - """Condition→action rules with utility tracking (D2Skill dual-granularity).""" - - def add( - self, - user_id: str, - name: str, - condition: Dict[str, Any], - action: Dict[str, Any], - *, - granularity: str = "task", - source_task_ids: Optional[List[str]] = None, - source_episode_ids: Optional[List[str]] = None, - tags: Optional[List[str]] = None, - policy_id: Optional[str] = None, - ) -> str: - pid = policy_id or str(uuid.uuid4()) - now = _utcnow_iso() - stids = source_task_ids or [] - seids = source_episode_ids or [] - fp = _compute_lineage_fingerprint(stids, 1) - - with self._tx() as conn: - conn.execute( - """INSERT INTO policies - (policy_id, user_id, name, granularity, status, - condition_json, action_json, source_task_ids, - source_episode_ids, derivation_version, - lineage_fingerprint, created_at, updated_at, tags_json) - VALUES (?, ?, ?, ?, 'proposed', ?, ?, ?, ?, 1, ?, ?, ?, ?)""", - ( - pid, user_id, name, granularity, - json.dumps(condition), json.dumps(action), - json.dumps(stids), json.dumps(seids), fp, - now, now, json.dumps(tags or []), - ), - ) - return pid - - def get(self, policy_id: str) -> Optional[Dict[str, Any]]: - with self._lock: - row = self._conn.execute( - "SELECT * FROM policies WHERE policy_id = ?", - (policy_id,), - ).fetchone() - return self._row_to_dict(row) if row else None - - def list_by_user( - self, - user_id: str, - *, - granularity: Optional[str] = None, - status: Optional[str] = None, - limit: int = 50, - ) -> List[Dict[str, Any]]: - query = "SELECT * FROM policies WHERE user_id = ?" - params: list = [user_id] - if granularity: - query += " AND granularity = ?" - params.append(granularity) - if status: - query += " AND status = ?" - params.append(status) - query += " ORDER BY utility DESC LIMIT ?" - params.append(limit) - - with self._lock: - rows = self._conn.execute(query, params).fetchall() - return [self._row_to_dict(r) for r in rows] - - def record_outcome( - self, - policy_id: str, - success: bool, - *, - baseline_score: Optional[float] = None, - actual_score: Optional[float] = None, - ) -> Optional[float]: - """Record an outcome for a policy. Returns the delta if scores provided. - - Updates apply_count, success/failure counts, and utility EMA. - Auto-transitions status: validated (win_rate >= 0.6 after 5+) or - deprecated (win_rate < 0.4 after 5+). - """ - now = _utcnow_iso() - with self._tx() as conn: - row = conn.execute( - """SELECT apply_count, success_count, failure_count, - utility, cumulative_delta, status - FROM policies WHERE policy_id = ?""", - (policy_id,), - ).fetchone() - if not row: - return None - - apply_count = row["apply_count"] + 1 - success_count = row["success_count"] + (1 if success else 0) - failure_count = row["failure_count"] + (0 if success else 1) - utility = row["utility"] - cumulative = row["cumulative_delta"] - status = row["status"] - - delta = 0.0 - if baseline_score is not None and actual_score is not None: - delta = actual_score - baseline_score - utility = 0.3 * delta + 0.7 * utility # EMA alpha=0.3 - cumulative += delta - - # Auto-transition after enough data - if apply_count >= 5 and status not in ("stale", "suspect", "invalidated"): - win_rate = (success_count + 1) / (apply_count + 2) # Laplace - if win_rate >= 0.6: - status = PolicyStatus.VALIDATED.value - elif win_rate < 0.4: - status = PolicyStatus.DEPRECATED.value - elif status == PolicyStatus.PROPOSED.value: - status = PolicyStatus.ACTIVE.value - - conn.execute( - """UPDATE policies - SET apply_count = ?, success_count = ?, failure_count = ?, - utility = ?, last_delta = ?, cumulative_delta = ?, - status = ?, updated_at = ? - WHERE policy_id = ?""", - ( - apply_count, success_count, failure_count, - utility, delta, cumulative, status, now, policy_id, - ), - ) - return delta - - def set_status(self, policy_id: str, status: str) -> bool: - with self._tx() as conn: - result = conn.execute( - "UPDATE policies SET status = ?, updated_at = ? WHERE policy_id = ?", - (status, _utcnow_iso(), policy_id), - ) - return result.rowcount > 0 - - def get_by_invalidation_status( - self, status: str, *, limit: int = 100 - ) -> List[Dict[str, Any]]: - with self._lock: - rows = self._conn.execute( - "SELECT * FROM policies WHERE status = ? LIMIT ?", - (status, limit), - ).fetchall() - return [self._row_to_dict(r) for r in rows] - - @staticmethod - def _row_to_dict(row: sqlite3.Row) -> Dict[str, Any]: - return { - "policy_id": row["policy_id"], - "user_id": row["user_id"], - "name": row["name"], - "granularity": row["granularity"], - "status": row["status"], - "condition": _parse_json(row["condition_json"], {}), - "action": _parse_json(row["action_json"], {}), - "apply_count": row["apply_count"], - "success_count": row["success_count"], - "failure_count": row["failure_count"], - "utility": row["utility"], - "last_delta": row["last_delta"], - "cumulative_delta": row["cumulative_delta"], - "source_task_ids": _parse_json(row["source_task_ids"], []), - "source_episode_ids": _parse_json(row["source_episode_ids"], []), - "derivation_version": row["derivation_version"], - "lineage_fingerprint": row["lineage_fingerprint"], - "created_at": row["created_at"], - "updated_at": row["updated_at"], - "tags": _parse_json(row["tags_json"], []), - } - - -# ========================================================================= -# AnchorStore -# ========================================================================= - -class AnchorStore(_DerivedStoreBase): - """Hierarchical context anchors (era/place/time/activity).""" - - def add( - self, - user_id: str, - *, - memory_event_id: Optional[str] = None, - era: Optional[str] = None, - place: Optional[str] = None, - place_type: Optional[str] = None, - place_detail: Optional[str] = None, - time_absolute: Optional[str] = None, - time_markers: Optional[List[str]] = None, - time_range_start: Optional[str] = None, - time_range_end: Optional[str] = None, - time_derivation: Optional[str] = None, - activity: Optional[str] = None, - session_id: Optional[str] = None, - session_position: int = 0, - anchor_id: Optional[str] = None, - ) -> str: - aid = anchor_id or str(uuid.uuid4()) - now = _utcnow_iso() - source_ids = [memory_event_id] if memory_event_id else [] - fp = _compute_lineage_fingerprint(source_ids, 1) - - with self._tx() as conn: - conn.execute( - """INSERT INTO anchors - (anchor_id, user_id, memory_event_id, era, place, - place_type, place_detail, time_absolute, - time_markers_json, time_range_start, time_range_end, - time_derivation, activity, session_id, session_position, - derivation_version, lineage_fingerprint, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 1, ?, ?)""", - ( - aid, user_id, memory_event_id, era, place, - place_type, place_detail, time_absolute, - json.dumps(time_markers or []), - time_range_start, time_range_end, time_derivation, - activity, session_id, session_position, fp, now, - ), - ) - return aid - - def get(self, anchor_id: str) -> Optional[Dict[str, Any]]: - with self._lock: - row = self._conn.execute( - "SELECT * FROM anchors WHERE anchor_id = ?", - (anchor_id,), - ).fetchone() - return self._row_to_dict(row) if row else None - - def get_by_event(self, memory_event_id: str) -> Optional[Dict[str, Any]]: - with self._lock: - row = self._conn.execute( - "SELECT * FROM anchors WHERE memory_event_id = ?", - (memory_event_id,), - ).fetchone() - return self._row_to_dict(row) if row else None - - def list_by_user( - self, - user_id: str, - *, - era: Optional[str] = None, - place: Optional[str] = None, - activity: Optional[str] = None, - limit: int = 50, - ) -> List[Dict[str, Any]]: - query = "SELECT * FROM anchors WHERE user_id = ?" - params: list = [user_id] - if era: - query += " AND era = ?" - params.append(era) - if place: - query += " AND place = ?" - params.append(place) - if activity: - query += " AND activity = ?" - params.append(activity) - query += " ORDER BY created_at DESC LIMIT ?" - params.append(limit) - - with self._lock: - rows = self._conn.execute(query, params).fetchall() - return [self._row_to_dict(r) for r in rows] - - def update_fields(self, anchor_id: str, **fields: Any) -> bool: - """Update specific anchor fields. Only allows known anchor columns.""" - allowed = { - "era", "place", "place_type", "place_detail", - "time_absolute", "time_markers_json", "time_range_start", - "time_range_end", "time_derivation", "activity", - } - updates = {k: v for k, v in fields.items() if k in allowed} - if not updates: - return False - - set_clause = ", ".join(f"{k} = ?" for k in updates) - values = list(updates.values()) + [anchor_id] - - with self._tx() as conn: - result = conn.execute( - f"UPDATE anchors SET {set_clause} WHERE anchor_id = ?", - values, - ) - return result.rowcount > 0 - - @staticmethod - def _row_to_dict(row: sqlite3.Row) -> Dict[str, Any]: - return { - "anchor_id": row["anchor_id"], - "user_id": row["user_id"], - "memory_event_id": row["memory_event_id"], - "era": row["era"], - "place": row["place"], - "place_type": row["place_type"], - "place_detail": row["place_detail"], - "time_absolute": row["time_absolute"], - "time_markers": _parse_json(row["time_markers_json"], []), - "time_range_start": row["time_range_start"], - "time_range_end": row["time_range_end"], - "time_derivation": row["time_derivation"], - "activity": row["activity"], - "session_id": row["session_id"], - "session_position": row["session_position"], - "derivation_version": row["derivation_version"], - "lineage_fingerprint": row["lineage_fingerprint"], - "created_at": row["created_at"], - } - - -# ========================================================================= -# InsightStore -# ========================================================================= - -class InsightStore(_DerivedStoreBase): - """Synthesized causal hypotheses with strength tracking.""" - - def add( - self, - user_id: str, - content: str, - *, - insight_type: str = "pattern", - source_task_types: Optional[List[str]] = None, - confidence: float = 0.5, - tags: Optional[List[str]] = None, - insight_id: Optional[str] = None, - ) -> str: - iid = insight_id or str(uuid.uuid4()) - now = _utcnow_iso() - - with self._tx() as conn: - conn.execute( - """INSERT INTO insights - (insight_id, user_id, content, insight_type, - source_task_types_json, confidence, - derivation_version, lineage_fingerprint, - created_at, tags_json) - VALUES (?, ?, ?, ?, ?, ?, 1, '', ?, ?)""", - ( - iid, user_id, content, insight_type, - json.dumps(source_task_types or []), - confidence, now, json.dumps(tags or []), - ), - ) - return iid - - def get(self, insight_id: str) -> Optional[Dict[str, Any]]: - with self._lock: - row = self._conn.execute( - "SELECT * FROM insights WHERE insight_id = ?", - (insight_id,), - ).fetchone() - return self._row_to_dict(row) if row else None - - def list_by_user( - self, - user_id: str, - *, - insight_type: Optional[str] = None, - min_confidence: float = 0.0, - status: Optional[str] = None, - limit: int = 50, - ) -> List[Dict[str, Any]]: - query = "SELECT * FROM insights WHERE user_id = ? AND confidence >= ?" - params: list = [user_id, min_confidence] - if insight_type: - query += " AND insight_type = ?" - params.append(insight_type) - if status: - query += " AND status = ?" - params.append(status) - query += " ORDER BY confidence DESC LIMIT ?" - params.append(limit) - - with self._lock: - rows = self._conn.execute(query, params).fetchall() - return [self._row_to_dict(r) for r in rows] - - def record_outcome( - self, - insight_id: str, - success: bool, - *, - baseline_score: Optional[float] = None, - actual_score: Optional[float] = None, - ) -> bool: - """Record validation/invalidation outcome. Updates confidence + utility.""" - now = _utcnow_iso() - with self._tx() as conn: - row = conn.execute( - """SELECT confidence, validation_count, invalidation_count, - utility, apply_count, status - FROM insights WHERE insight_id = ?""", - (insight_id,), - ).fetchone() - if not row: - return False - - conf = row["confidence"] - v_count = row["validation_count"] - i_count = row["invalidation_count"] - utility = row["utility"] - apply_count = row["apply_count"] + 1 - - if success: - v_count += 1 - conf = min(1.0, conf + 0.05) - else: - i_count += 1 - conf = max(0.0, conf - 0.1) - - if baseline_score is not None and actual_score is not None: - delta = actual_score - baseline_score - utility = 0.3 * delta + 0.7 * utility - - conn.execute( - """UPDATE insights - SET confidence = ?, validation_count = ?, - invalidation_count = ?, utility = ?, - apply_count = ?, last_validated = ?, - status = ? - WHERE insight_id = ?""", - ( - conf, v_count, i_count, utility, apply_count, now, - row["status"], # preserve current status - insight_id, - ), - ) - return True - - def set_status(self, insight_id: str, status: str) -> bool: - with self._tx() as conn: - result = conn.execute( - "UPDATE insights SET status = ? WHERE insight_id = ?", - (status, insight_id), - ) - return result.rowcount > 0 - - def get_by_invalidation_status( - self, status: str, *, limit: int = 100 - ) -> List[Dict[str, Any]]: - with self._lock: - rows = self._conn.execute( - "SELECT * FROM insights WHERE status = ? LIMIT ?", - (status, limit), - ).fetchall() - return [self._row_to_dict(r) for r in rows] - - @staticmethod - def _row_to_dict(row: sqlite3.Row) -> Dict[str, Any]: - return { - "insight_id": row["insight_id"], - "user_id": row["user_id"], - "content": row["content"], - "insight_type": row["insight_type"], - "source_task_types": _parse_json(row["source_task_types_json"], []), - "confidence": row["confidence"], - "validation_count": row["validation_count"], - "invalidation_count": row["invalidation_count"], - "utility": row["utility"], - "apply_count": row["apply_count"], - "derivation_version": row["derivation_version"], - "lineage_fingerprint": row["lineage_fingerprint"], - "created_at": row["created_at"], - "last_validated": row["last_validated"], - "tags": _parse_json(row["tags_json"], []), - "status": row["status"], - } - - -# ========================================================================= -# HeuristicStore -# ========================================================================= - -class HeuristicStore(_DerivedStoreBase): - """Transferable reasoning patterns (ERL, 3 abstraction levels).""" - - def add( - self, - user_id: str, - content: str, - *, - abstraction_level: str = "specific", - source_task_types: Optional[List[str]] = None, - confidence: float = 0.5, - tags: Optional[List[str]] = None, - heuristic_id: Optional[str] = None, - ) -> str: - hid = heuristic_id or str(uuid.uuid4()) - now = _utcnow_iso() - - with self._tx() as conn: - conn.execute( - """INSERT INTO heuristics - (heuristic_id, user_id, content, abstraction_level, - source_task_types_json, confidence, - derivation_version, lineage_fingerprint, - created_at, tags_json) - VALUES (?, ?, ?, ?, ?, ?, 1, '', ?, ?)""", - ( - hid, user_id, content, abstraction_level, - json.dumps(source_task_types or []), - confidence, now, json.dumps(tags or []), - ), - ) - return hid - - def get(self, heuristic_id: str) -> Optional[Dict[str, Any]]: - with self._lock: - row = self._conn.execute( - "SELECT * FROM heuristics WHERE heuristic_id = ?", - (heuristic_id,), - ).fetchone() - return self._row_to_dict(row) if row else None - - def list_by_user( - self, - user_id: str, - *, - abstraction_level: Optional[str] = None, - min_confidence: float = 0.0, - status: Optional[str] = None, - limit: int = 50, - ) -> List[Dict[str, Any]]: - query = "SELECT * FROM heuristics WHERE user_id = ? AND confidence >= ?" - params: list = [user_id, min_confidence] - if abstraction_level: - query += " AND abstraction_level = ?" - params.append(abstraction_level) - if status: - query += " AND status = ?" - params.append(status) - query += " ORDER BY confidence DESC LIMIT ?" - params.append(limit) - - with self._lock: - rows = self._conn.execute(query, params).fetchall() - return [self._row_to_dict(r) for r in rows] - - def record_outcome( - self, - heuristic_id: str, - success: bool, - *, - baseline_score: Optional[float] = None, - actual_score: Optional[float] = None, - ) -> bool: - now = _utcnow_iso() - with self._tx() as conn: - row = conn.execute( - """SELECT confidence, validation_count, invalidation_count, - utility, last_delta, apply_count, status - FROM heuristics WHERE heuristic_id = ?""", - (heuristic_id,), - ).fetchone() - if not row: - return False - - conf = row["confidence"] - v_count = row["validation_count"] - i_count = row["invalidation_count"] - utility = row["utility"] - apply_count = row["apply_count"] + 1 - delta = 0.0 - - if success: - v_count += 1 - conf = min(1.0, conf + 0.05) - else: - i_count += 1 - conf = max(0.0, conf - 0.1) - - if baseline_score is not None and actual_score is not None: - delta = actual_score - baseline_score - utility = 0.3 * delta + 0.7 * utility - - conn.execute( - """UPDATE heuristics - SET confidence = ?, validation_count = ?, - invalidation_count = ?, utility = ?, - last_delta = ?, apply_count = ?, status = ? - WHERE heuristic_id = ?""", - ( - conf, v_count, i_count, utility, - delta, apply_count, row["status"], - heuristic_id, - ), - ) - return True - - def set_status(self, heuristic_id: str, status: str) -> bool: - with self._tx() as conn: - result = conn.execute( - "UPDATE heuristics SET status = ? WHERE heuristic_id = ?", - (status, heuristic_id), - ) - return result.rowcount > 0 - - def get_by_invalidation_status( - self, status: str, *, limit: int = 100 - ) -> List[Dict[str, Any]]: - with self._lock: - rows = self._conn.execute( - "SELECT * FROM heuristics WHERE status = ? LIMIT ?", - (status, limit), - ).fetchall() - return [self._row_to_dict(r) for r in rows] - - @staticmethod - def _row_to_dict(row: sqlite3.Row) -> Dict[str, Any]: - return { - "heuristic_id": row["heuristic_id"], - "user_id": row["user_id"], - "content": row["content"], - "abstraction_level": row["abstraction_level"], - "source_task_types": _parse_json(row["source_task_types_json"], []), - "confidence": row["confidence"], - "validation_count": row["validation_count"], - "invalidation_count": row["invalidation_count"], - "utility": row["utility"], - "last_delta": row["last_delta"], - "apply_count": row["apply_count"], - "derivation_version": row["derivation_version"], - "lineage_fingerprint": row["lineage_fingerprint"], - "created_at": row["created_at"], - "tags": _parse_json(row["tags_json"], []), - "status": row["status"], - } - - -# ========================================================================= -# DerivedLineageStore -# ========================================================================= - -class DerivedLineageStore(_DerivedStoreBase): - """Links derived objects to source raw events for traceability. - - Supports the three-tier invalidation model: - - Given a source event, find all derived objects that depend on it - - Given a derived object, find all source events it was built from - - Contribution weight enables partial invalidation decisions - """ - - def add( - self, - derived_type: str, - derived_id: str, - source_event_id: str, - *, - contribution_weight: float = 1.0, - lineage_id: Optional[str] = None, - ) -> str: - lid = lineage_id or str(uuid.uuid4()) - with self._tx() as conn: - conn.execute( - """INSERT INTO derived_lineage - (lineage_id, derived_type, derived_id, - source_event_id, contribution_weight) - VALUES (?, ?, ?, ?, ?)""", - (lid, derived_type, derived_id, source_event_id, contribution_weight), - ) - return lid - - def add_batch( - self, - derived_type: str, - derived_id: str, - source_event_ids: List[str], - *, - weights: Optional[List[float]] = None, - ) -> List[str]: - """Add multiple lineage links at once.""" - w = weights or [1.0] * len(source_event_ids) - ids = [] - with self._tx() as conn: - for eid, weight in zip(source_event_ids, w): - lid = str(uuid.uuid4()) - conn.execute( - """INSERT INTO derived_lineage - (lineage_id, derived_type, derived_id, - source_event_id, contribution_weight) - VALUES (?, ?, ?, ?, ?)""", - (lid, derived_type, derived_id, eid, weight), - ) - ids.append(lid) - return ids - - def get_sources( - self, derived_type: str, derived_id: str - ) -> List[Dict[str, Any]]: - """Get all source events for a derived object.""" - with self._lock: - rows = self._conn.execute( - """SELECT lineage_id, source_event_id, contribution_weight, created_at - FROM derived_lineage - WHERE derived_type = ? AND derived_id = ?""", - (derived_type, derived_id), - ).fetchall() - return [ - { - "lineage_id": r["lineage_id"], - "source_event_id": r["source_event_id"], - "contribution_weight": r["contribution_weight"], - "created_at": r["created_at"], - } - for r in rows - ] - - def get_dependents( - self, source_event_id: str - ) -> List[Dict[str, Any]]: - """Get all derived objects that depend on a source event. - - This is the key query for invalidation cascades. - """ - with self._lock: - rows = self._conn.execute( - """SELECT lineage_id, derived_type, derived_id, - contribution_weight, created_at - FROM derived_lineage - WHERE source_event_id = ?""", - (source_event_id,), - ).fetchall() - return [ - { - "lineage_id": r["lineage_id"], - "derived_type": r["derived_type"], - "derived_id": r["derived_id"], - "contribution_weight": r["contribution_weight"], - "created_at": r["created_at"], - } - for r in rows - ] - - def get_source_count(self, derived_type: str, derived_id: str) -> int: - """Count source events for a derived object.""" - with self._lock: - row = self._conn.execute( - """SELECT COUNT(*) FROM derived_lineage - WHERE derived_type = ? AND derived_id = ?""", - (derived_type, derived_id), - ).fetchone() - return row[0] if row else 0 - - def get_contribution_weight( - self, derived_type: str, derived_id: str, source_event_id: str - ) -> Optional[float]: - """Get the contribution weight of a specific source to a derived object. - - Used by partial invalidation to decide severity. - """ - with self._lock: - row = self._conn.execute( - """SELECT contribution_weight FROM derived_lineage - WHERE derived_type = ? AND derived_id = ? AND source_event_id = ?""", - (derived_type, derived_id, source_event_id), - ).fetchone() - return row["contribution_weight"] if row else None - - def delete_for_derived(self, derived_type: str, derived_id: str) -> int: - """Remove all lineage links for a derived object (e.g., before re-deriving).""" - with self._tx() as conn: - result = conn.execute( - "DELETE FROM derived_lineage WHERE derived_type = ? AND derived_id = ?", - (derived_type, derived_id), - ) - return result.rowcount - - -# ========================================================================= -# CognitionStore — Coordinator that holds all sub-stores -# ========================================================================= - -class CognitionStore: - """Unified access to all v3 stores sharing a single SQLite connection. - - Usage: - store = CognitionStore() # or CognitionStore(db_path="...") - store.events.add(content="...", user_id="...") - store.beliefs.add(user_id="...", claim="...") - store.lineage.add("belief", bid, event_id) - """ - - def __init__(self, db_path: Optional[str] = None): - from dhee.core.events import RawEventStore, _default_db_path - - self.db_path = db_path or _default_db_path() - - # RawEventStore owns the connection and schema initialization - self.events = RawEventStore(self.db_path) - - # All derived stores share the same connection + lock - conn = self.events._conn - lock = self.events._lock - - self.beliefs = BeliefStore(conn, lock) - self.policies = PolicyStore(conn, lock) - self.anchors = AnchorStore(conn, lock) - self.insights = InsightStore(conn, lock) - self.heuristics = HeuristicStore(conn, lock) - self.lineage = DerivedLineageStore(conn, lock) - - def close(self) -> None: - self.events.close() diff --git a/dhee/core/events.py b/dhee/core/events.py deleted file mode 100644 index 5bd12b7..0000000 --- a/dhee/core/events.py +++ /dev/null @@ -1,443 +0,0 @@ -"""Dhee v3 — RawEventStore: immutable source-of-truth memory events. - -Every call to remember() writes an immutable raw event. Corrections create -new events with supersedes_event_id pointing to the original. Deletions -mark events as 'deleted' (soft delete — never physically removed). - -Design contract: - - Raw events are never mutated after creation - - Content-hash dedup prevents duplicate storage of identical content - - Corrections/deletions change status of the OLD event and create a NEW event - - All derived cognition traces back to raw events via derived_lineage - - Zero LLM calls — this is a pure storage layer -""" - -from __future__ import annotations - -import hashlib -import json -import logging -import os -import sqlite3 -import threading -import uuid -from contextlib import contextmanager -from datetime import datetime, timezone -from dataclasses import dataclass, field -from enum import Enum -from typing import Any, Dict, List, Optional - -from dhee.core.storage import initialize_schema - -logger = logging.getLogger(__name__) - - -class EventStatus(str, Enum): - ACTIVE = "active" - CORRECTED = "corrected" - DELETED = "deleted" - - -@dataclass -class RawMemoryEvent: - """In-memory representation of a raw memory event.""" - - event_id: str - user_id: str - content: str - content_hash: str - status: EventStatus = EventStatus.ACTIVE - session_id: Optional[str] = None - source: str = "user" - supersedes_event_id: Optional[str] = None - metadata: Dict[str, Any] = field(default_factory=dict) - created_at: Optional[str] = None - - @staticmethod - def compute_hash(content: str) -> str: - """SHA-256 content hash for dedup.""" - return hashlib.sha256(content.encode("utf-8")).hexdigest() - - def to_dict(self) -> Dict[str, Any]: - return { - "event_id": self.event_id, - "user_id": self.user_id, - "content": self.content, - "content_hash": self.content_hash, - "status": self.status.value, - "session_id": self.session_id, - "source": self.source, - "supersedes_event_id": self.supersedes_event_id, - "metadata": self.metadata, - "created_at": self.created_at, - } - - -def _utcnow_iso() -> str: - return datetime.now(timezone.utc).isoformat() - - -def _default_db_path() -> str: - data_dir = os.environ.get("DHEE_DATA_DIR") or os.path.join( - os.path.expanduser("~"), ".dhee" - ) - return os.path.join(data_dir, "v3.db") - - -class RawEventStore: - """Immutable raw event storage backed by SQLite. - - Thread-safe via RLock. Follows the same connection pattern as - dhee/db/sqlite.py (_SQLiteBase). - """ - - def __init__(self, db_path: Optional[str] = None): - self.db_path = db_path or _default_db_path() - db_dir = os.path.dirname(self.db_path) - if db_dir: - os.makedirs(db_dir, exist_ok=True) - - self._conn = sqlite3.connect(self.db_path, check_same_thread=False) - self._conn.execute("PRAGMA journal_mode=WAL") - self._conn.execute("PRAGMA busy_timeout=5000") - self._conn.execute("PRAGMA synchronous=FULL") - self._conn.execute("PRAGMA cache_size=-8000") - self._conn.execute("PRAGMA temp_store=MEMORY") - self._conn.row_factory = sqlite3.Row - self._lock = threading.RLock() - - # Initialize all v3 tables - initialize_schema(self._conn) - - def close(self) -> None: - with self._lock: - if self._conn: - try: - self._conn.close() - except Exception: - pass - self._conn = None # type: ignore[assignment] - - @contextmanager - def _tx(self): - """Yield connection under lock with commit/rollback.""" - with self._lock: - try: - yield self._conn - self._conn.commit() - except Exception: - self._conn.rollback() - raise - - # ------------------------------------------------------------------ - # Write operations - # ------------------------------------------------------------------ - - def add( - self, - content: str, - user_id: str, - *, - session_id: Optional[str] = None, - source: str = "user", - metadata: Optional[Dict[str, Any]] = None, - event_id: Optional[str] = None, - ) -> RawMemoryEvent: - """Store a new raw memory event. Returns the event (existing if dedup hit). - - Content-hash dedup: if identical content already exists for this user - and is active, returns the existing event instead of creating a duplicate. - """ - content_hash = RawMemoryEvent.compute_hash(content) - eid = event_id or str(uuid.uuid4()) - meta = metadata or {} - now = _utcnow_iso() - - with self._tx() as conn: - # Dedup check — same content, same user, still active - existing = conn.execute( - """SELECT event_id, user_id, content, content_hash, status, - session_id, source, supersedes_event_id, - metadata_json, created_at - FROM raw_memory_events - WHERE content_hash = ? AND user_id = ? AND status = 'active' - LIMIT 1""", - (content_hash, user_id), - ).fetchone() - - if existing: - return self._row_to_event(existing) - - conn.execute( - """INSERT INTO raw_memory_events - (event_id, user_id, session_id, created_at, content, - content_hash, source, status, metadata_json) - VALUES (?, ?, ?, ?, ?, ?, ?, 'active', ?)""", - ( - eid, user_id, session_id, now, content, - content_hash, source, json.dumps(meta), - ), - ) - - return RawMemoryEvent( - event_id=eid, - user_id=user_id, - content=content, - content_hash=content_hash, - status=EventStatus.ACTIVE, - session_id=session_id, - source=source, - metadata=meta, - created_at=now, - ) - - def correct( - self, - original_event_id: str, - new_content: str, - *, - source: str = "user_correction", - metadata: Optional[Dict[str, Any]] = None, - ) -> RawMemoryEvent: - """Correct an existing event. - - 1. Marks the original event as 'corrected' - 2. Creates a new event with supersedes_event_id pointing to original - 3. Returns the new event - - Raises ValueError if original event not found or not active. - """ - with self._tx() as conn: - original = conn.execute( - """SELECT event_id, user_id, session_id, status - FROM raw_memory_events WHERE event_id = ?""", - (original_event_id,), - ).fetchone() - - if not original: - raise ValueError(f"Event not found: {original_event_id}") - if original["status"] != "active": - raise ValueError( - f"Cannot correct event with status '{original['status']}': " - f"{original_event_id}" - ) - - # Mark original as corrected - conn.execute( - "UPDATE raw_memory_events SET status = 'corrected' WHERE event_id = ?", - (original_event_id,), - ) - - # Create correction event - new_id = str(uuid.uuid4()) - content_hash = RawMemoryEvent.compute_hash(new_content) - meta = metadata or {} - now = _utcnow_iso() - - conn.execute( - """INSERT INTO raw_memory_events - (event_id, user_id, session_id, created_at, content, - content_hash, source, status, supersedes_event_id, metadata_json) - VALUES (?, ?, ?, ?, ?, ?, ?, 'active', ?, ?)""", - ( - new_id, original["user_id"], original["session_id"], - now, new_content, content_hash, source, - original_event_id, json.dumps(meta), - ), - ) - - return RawMemoryEvent( - event_id=new_id, - user_id=original["user_id"], - content=new_content, - content_hash=content_hash, - status=EventStatus.ACTIVE, - session_id=original["session_id"], - source=source, - supersedes_event_id=original_event_id, - metadata=meta, - created_at=now, - ) - - def delete(self, event_id: str) -> bool: - """Soft-delete a raw event. Returns True if status changed. - - Marks the event as 'deleted'. Does NOT physically remove it. - Raises ValueError if event not found. - """ - with self._tx() as conn: - row = conn.execute( - "SELECT status FROM raw_memory_events WHERE event_id = ?", - (event_id,), - ).fetchone() - - if not row: - raise ValueError(f"Event not found: {event_id}") - if row["status"] == "deleted": - return False - - conn.execute( - "UPDATE raw_memory_events SET status = 'deleted' WHERE event_id = ?", - (event_id,), - ) - return True - - # ------------------------------------------------------------------ - # Read operations - # ------------------------------------------------------------------ - - def get(self, event_id: str) -> Optional[RawMemoryEvent]: - """Get a single event by ID.""" - with self._lock: - row = self._conn.execute( - """SELECT event_id, user_id, content, content_hash, status, - session_id, source, supersedes_event_id, - metadata_json, created_at - FROM raw_memory_events WHERE event_id = ?""", - (event_id,), - ).fetchone() - return self._row_to_event(row) if row else None - - def get_by_hash( - self, content_hash: str, user_id: str - ) -> Optional[RawMemoryEvent]: - """Get active event by content hash + user.""" - with self._lock: - row = self._conn.execute( - """SELECT event_id, user_id, content, content_hash, status, - session_id, source, supersedes_event_id, - metadata_json, created_at - FROM raw_memory_events - WHERE content_hash = ? AND user_id = ? AND status = 'active' - LIMIT 1""", - (content_hash, user_id), - ).fetchone() - return self._row_to_event(row) if row else None - - def list_by_user( - self, - user_id: str, - *, - status: Optional[EventStatus] = None, - limit: int = 100, - offset: int = 0, - ) -> List[RawMemoryEvent]: - """List events for a user, newest first.""" - with self._lock: - if status: - rows = self._conn.execute( - """SELECT event_id, user_id, content, content_hash, status, - session_id, source, supersedes_event_id, - metadata_json, created_at - FROM raw_memory_events - WHERE user_id = ? AND status = ? - ORDER BY created_at DESC - LIMIT ? OFFSET ?""", - (user_id, status.value, limit, offset), - ).fetchall() - else: - rows = self._conn.execute( - """SELECT event_id, user_id, content, content_hash, status, - session_id, source, supersedes_event_id, - metadata_json, created_at - FROM raw_memory_events - WHERE user_id = ? - ORDER BY created_at DESC - LIMIT ? OFFSET ?""", - (user_id, limit, offset), - ).fetchall() - return [self._row_to_event(r) for r in rows] - - def get_supersedes_chain(self, event_id: str) -> List[RawMemoryEvent]: - """Walk the supersedes chain from newest to oldest. - - Given an event that supersedes another, returns the full chain: - [newest_correction, ..., original_event] - """ - chain: List[RawMemoryEvent] = [] - seen: set = set() - current_id: Optional[str] = event_id - - while current_id and current_id not in seen: - seen.add(current_id) - event = self.get(current_id) - if not event: - break - chain.append(event) - current_id = event.supersedes_event_id - - return chain - - def count( - self, user_id: str, *, status: Optional[EventStatus] = None - ) -> int: - """Count events for a user.""" - with self._lock: - if status: - row = self._conn.execute( - "SELECT COUNT(*) FROM raw_memory_events WHERE user_id = ? AND status = ?", - (user_id, status.value), - ).fetchone() - else: - row = self._conn.execute( - "SELECT COUNT(*) FROM raw_memory_events WHERE user_id = ?", - (user_id,), - ).fetchone() - return row[0] if row else 0 - - def get_events_since( - self, user_id: str, since_iso: str, *, status: Optional[EventStatus] = None - ) -> List[RawMemoryEvent]: - """Get events created after a given ISO timestamp.""" - with self._lock: - if status: - rows = self._conn.execute( - """SELECT event_id, user_id, content, content_hash, status, - session_id, source, supersedes_event_id, - metadata_json, created_at - FROM raw_memory_events - WHERE user_id = ? AND created_at > ? AND status = ? - ORDER BY created_at ASC""", - (user_id, since_iso, status.value), - ).fetchall() - else: - rows = self._conn.execute( - """SELECT event_id, user_id, content, content_hash, status, - session_id, source, supersedes_event_id, - metadata_json, created_at - FROM raw_memory_events - WHERE user_id = ? AND created_at > ? - ORDER BY created_at ASC""", - (user_id, since_iso), - ).fetchall() - return [self._row_to_event(r) for r in rows] - - # ------------------------------------------------------------------ - # Internal - # ------------------------------------------------------------------ - - @staticmethod - def _row_to_event(row: sqlite3.Row) -> RawMemoryEvent: - meta_raw = row["metadata_json"] - if isinstance(meta_raw, str): - try: - meta = json.loads(meta_raw) - except (json.JSONDecodeError, TypeError): - meta = {} - elif isinstance(meta_raw, dict): - meta = meta_raw - else: - meta = {} - - return RawMemoryEvent( - event_id=row["event_id"], - user_id=row["user_id"], - content=row["content"], - content_hash=row["content_hash"], - status=EventStatus(row["status"]), - session_id=row["session_id"], - source=row["source"], - supersedes_event_id=row["supersedes_event_id"], - metadata=meta, - created_at=row["created_at"], - ) diff --git a/dhee/core/fusion_v3.py b/dhee/core/fusion_v3.py deleted file mode 100644 index 03732c7..0000000 --- a/dhee/core/fusion_v3.py +++ /dev/null @@ -1,303 +0,0 @@ -"""Dhee v3 — 5-Stage Weighted Reciprocal Rank Fusion Pipeline. - -Explicit ranking contract (zero LLM on hot path): - -Stage 1: Per-index retrieval (parallel, 0 LLM) -Stage 2: Score normalization (min-max within each index) -Stage 3: Weighted RRF (k=60, configurable weights per index) -Stage 4: Post-fusion adjustments (recency, confidence, staleness, conflicts) -Stage 5: Final ranking + dedup - -No reranker stage. If retrieval quality is insufficient, fix embeddings -or distillation, not the hot path. -""" - -from __future__ import annotations - -import logging -import math -from dataclasses import dataclass, field -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional, Tuple - -logger = logging.getLogger(__name__) - - -@dataclass -class FusionConfig: - """Configuration for the 5-stage fusion pipeline.""" - - # Stage 1: Per-index top-K - raw_top_k: int = 20 - distilled_top_k: int = 15 - episodic_top_k: int = 10 - - # Stage 3: RRF weights - rrf_k: int = 60 # standard RRF constant - weight_distilled: float = 1.0 - weight_episodic: float = 0.7 - weight_raw: float = 0.5 - - # Stage 4: Adjustment parameters - recency_boost_max: float = 0.3 # max 30% boost for fresh raw - recency_decay_hours: float = 24.0 - confidence_floor: float = 0.5 # score *= 0.5 + 0.5 * confidence - staleness_penalty: float = 0.3 # stale/suspect get 70% penalty - contradiction_penalty: float = 0.5 # open conflicts get 50% penalty - - # Stage 5: Final output - final_top_n: int = 10 - - -@dataclass -class FusionCandidate: - """A candidate passing through the fusion pipeline.""" - - row_id: str - source_kind: str # raw | distilled | episodic - source_type: str # event | belief | policy | insight | heuristic - source_id: str - retrieval_text: str - raw_score: float = 0.0 # cosine similarity from index - normalized_score: float = 0.0 # after min-max normalization - rrf_score: float = 0.0 # after weighted RRF - adjusted_score: float = 0.0 # after post-fusion adjustments - confidence: float = 1.0 - utility: float = 0.0 - status: str = "active" - created_at: Optional[str] = None - has_open_conflicts: bool = False - # Lineage for dedup - lineage_event_ids: Optional[List[str]] = None - - -@dataclass -class FusionBreakdown: - """Loggable breakdown of how fusion produced its results.""" - - query: str - config: Dict[str, Any] - per_index_counts: Dict[str, int] - pre_adjustment_top5: List[Dict[str, Any]] - post_adjustment_top5: List[Dict[str, Any]] - dedup_removed: int - final_count: int - - def to_dict(self) -> Dict[str, Any]: - return { - "query": self.query[:100], - "per_index_counts": self.per_index_counts, - "pre_adjustment_top5": self.pre_adjustment_top5, - "post_adjustment_top5": self.post_adjustment_top5, - "dedup_removed": self.dedup_removed, - "final_count": self.final_count, - } - - -def _parse_iso(iso_str: Optional[str]) -> Optional[datetime]: - if not iso_str: - return None - try: - return datetime.fromisoformat(iso_str.replace("Z", "+00:00")) - except (ValueError, AttributeError): - return None - - -class RRFFusion: - """5-stage Weighted Reciprocal Rank Fusion pipeline. - - Usage: - fusion = RRFFusion(config) - results, breakdown = fusion.fuse( - raw_candidates=[...], - distilled_candidates=[...], - episodic_candidates=[...], - conflict_checker=lambda type, id: bool, - ) - """ - - def __init__(self, config: Optional[FusionConfig] = None): - self.config = config or FusionConfig() - - def fuse( - self, - raw_candidates: List[FusionCandidate], - distilled_candidates: List[FusionCandidate], - episodic_candidates: Optional[List[FusionCandidate]] = None, - *, - conflict_checker: Optional[Any] = None, - query: str = "", - ) -> Tuple[List[FusionCandidate], FusionBreakdown]: - """Run the full 5-stage fusion pipeline. - - Args: - raw_candidates: Candidates from raw_index with raw_score set - distilled_candidates: Candidates from distilled_index - episodic_candidates: Optional candidates from episodic_index - conflict_checker: callable(source_type, source_id) -> bool - query: The original query string (for logging) - - Returns: - (ranked_results, breakdown) - """ - cfg = self.config - episodic = episodic_candidates or [] - - # Stage 1: Trim to per-index top-K - raw = sorted(raw_candidates, key=lambda c: -c.raw_score)[:cfg.raw_top_k] - dist = sorted(distilled_candidates, key=lambda c: -c.raw_score)[:cfg.distilled_top_k] - epi = sorted(episodic, key=lambda c: -c.raw_score)[:cfg.episodic_top_k] - - per_index_counts = { - "raw": len(raw), "distilled": len(dist), "episodic": len(epi), - } - - # Stage 2: Min-max normalization within each index - self._normalize(raw) - self._normalize(dist) - self._normalize(epi) - - # Stage 3: Weighted RRF - # Build a combined dict: row_id → candidate, accumulating RRF score - combined: Dict[str, FusionCandidate] = {} - - for rank, c in enumerate(raw): - rrf = cfg.weight_raw / (cfg.rrf_k + rank + 1) - if c.row_id in combined: - combined[c.row_id].rrf_score += rrf - else: - c.rrf_score = rrf - combined[c.row_id] = c - - for rank, c in enumerate(dist): - rrf = cfg.weight_distilled / (cfg.rrf_k + rank + 1) - if c.row_id in combined: - combined[c.row_id].rrf_score += rrf - else: - c.rrf_score = rrf - combined[c.row_id] = c - - for rank, c in enumerate(epi): - rrf = cfg.weight_episodic / (cfg.rrf_k + rank + 1) - if c.row_id in combined: - combined[c.row_id].rrf_score += rrf - else: - c.rrf_score = rrf - combined[c.row_id] = c - - # Pre-adjustment snapshot - pre_sorted = sorted(combined.values(), key=lambda c: -c.rrf_score) - pre_top5 = [ - {"row_id": c.row_id, "kind": c.source_kind, "rrf": round(c.rrf_score, 6)} - for c in pre_sorted[:5] - ] - - # Stage 4: Post-fusion adjustments - now = datetime.now(timezone.utc) - for c in combined.values(): - score = c.rrf_score - - # Recency boost (raw only) - if c.source_kind == "raw" and c.created_at: - created = _parse_iso(c.created_at) - if created: - age_hours = max(0, (now - created).total_seconds() / 3600) - boost = 1.0 + cfg.recency_boost_max * math.exp( - -age_hours / cfg.recency_decay_hours - ) - score *= boost - - # Confidence normalization - score *= cfg.confidence_floor + (1.0 - cfg.confidence_floor) * c.confidence - - # Staleness penalty - if c.status in ("stale", "suspect"): - score *= cfg.staleness_penalty - - # Hard invalidation exclusion - if c.status == "invalidated": - score = 0.0 - - # Contradiction penalty - if conflict_checker and c.source_type and c.source_id: - try: - if conflict_checker(c.source_type, c.source_id): - c.has_open_conflicts = True - score *= cfg.contradiction_penalty - except Exception: - pass - - c.adjusted_score = score - - # Stage 5: Final ranking + dedup - ranked = sorted(combined.values(), key=lambda c: -c.adjusted_score) - - # Dedup: if raw and distilled of same content via lineage, keep distilled - seen_source_ids: Dict[str, FusionCandidate] = {} - deduped: List[FusionCandidate] = [] - dedup_removed = 0 - - for c in ranked: - sid = c.source_id - if sid in seen_source_ids: - existing = seen_source_ids[sid] - # Keep the distilled version - if c.source_kind == "distilled" and existing.source_kind == "raw": - deduped = [x for x in deduped if x.source_id != sid] - deduped.append(c) - seen_source_ids[sid] = c - dedup_removed += 1 - else: - dedup_removed += 1 - else: - seen_source_ids[sid] = c - deduped.append(c) - - final = deduped[:cfg.final_top_n] - - # Post-adjustment snapshot - post_top5 = [ - { - "row_id": c.row_id, "kind": c.source_kind, - "adjusted": round(c.adjusted_score, 6), - "conflicts": c.has_open_conflicts, - } - for c in final[:5] - ] - - breakdown = FusionBreakdown( - query=query, - config={ - "rrf_k": cfg.rrf_k, - "weights": { - "raw": cfg.weight_raw, - "distilled": cfg.weight_distilled, - "episodic": cfg.weight_episodic, - }, - }, - per_index_counts=per_index_counts, - pre_adjustment_top5=pre_top5, - post_adjustment_top5=post_top5, - dedup_removed=dedup_removed, - final_count=len(final), - ) - - return final, breakdown - - @staticmethod - def _normalize(candidates: List[FusionCandidate]) -> None: - """Min-max normalize raw_score within a candidate list.""" - if not candidates: - return - - scores = [c.raw_score for c in candidates] - min_s = min(scores) - max_s = max(scores) - spread = max_s - min_s - - if spread < 1e-9: - for c in candidates: - c.normalized_score = 1.0 if max_s > 0 else 0.0 - else: - for c in candidates: - c.normalized_score = (c.raw_score - min_s) / spread diff --git a/dhee/core/graph_evolution.py b/dhee/core/graph_evolution.py deleted file mode 100644 index 3192c53..0000000 --- a/dhee/core/graph_evolution.py +++ /dev/null @@ -1,600 +0,0 @@ -"""Evolving knowledge graph — versioned entities, personalized PageRank, schema-free extraction. - -Extends KnowledgeGraph (graph.py) with three capabilities: - -1. **Entity versioning**: Every entity mutation is stored as a version snapshot. - Queries can ask "what was X at time T?" and diffs show how entities evolve. - -2. **Personalized PageRank**: Per-user / per-agent importance scores over - the entity–memory graph. Guides retrieval toward what matters *to this user*. - -3. **Schema-free extraction**: Uses BuddhiMini (or any LLM) to discover - entity types at runtime, stored as EntityType.DYNAMIC with a type_label. -""" - -from __future__ import annotations - -import json -import logging -import os -import time -from collections import defaultdict -from dataclasses import dataclass, field -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional, Set, Tuple - -from dhee.core.graph import ( - Entity, - EntityType, - KnowledgeGraph, - Relationship, - RelationType, -) - -logger = logging.getLogger(__name__) - - -# --------------------------------------------------------------------------- -# Entity Versioning -# --------------------------------------------------------------------------- - -@dataclass -class EntityVersion: - """A point-in-time snapshot of an entity's state.""" - - entity_name: str - version: int - timestamp: str # ISO-8601 - entity_type: str - type_label: Optional[str] = None # For DYNAMIC entities - aliases: List[str] = field(default_factory=list) - memory_ids: List[str] = field(default_factory=list) - metadata: Dict[str, Any] = field(default_factory=dict) - change_reason: str = "" # What triggered this version - - def to_dict(self) -> Dict[str, Any]: - return { - "entity_name": self.entity_name, - "version": self.version, - "timestamp": self.timestamp, - "entity_type": self.entity_type, - "type_label": self.type_label, - "aliases": self.aliases, - "memory_ids": self.memory_ids, - "metadata": self.metadata, - "change_reason": self.change_reason, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "EntityVersion": - return cls( - entity_name=data["entity_name"], - version=data["version"], - timestamp=data["timestamp"], - entity_type=data["entity_type"], - type_label=data.get("type_label"), - aliases=data.get("aliases", []), - memory_ids=data.get("memory_ids", []), - metadata=data.get("metadata", {}), - change_reason=data.get("change_reason", ""), - ) - - -def _now_iso() -> str: - return datetime.now(timezone.utc).isoformat() - - -class EntityVersionStore: - """Append-only version log for entity snapshots. - - Stored as JSONL on disk — one line per version, O(1) append. - """ - - def __init__(self, path: str): - self._path = path - self._versions: Dict[str, List[EntityVersion]] = defaultdict(list) - self._load() - - def _load(self) -> None: - if not os.path.exists(self._path): - return - try: - with open(self._path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line: - continue - v = EntityVersion.from_dict(json.loads(line)) - self._versions[v.entity_name].append(v) - except (OSError, json.JSONDecodeError) as e: - logger.warning("Failed to load entity versions: %s", e) - - def record(self, entity: Entity, reason: str = "") -> EntityVersion: - """Snapshot the current state of an entity.""" - history = self._versions[entity.name] - version_num = (history[-1].version + 1) if history else 1 - - v = EntityVersion( - entity_name=entity.name, - version=version_num, - timestamp=_now_iso(), - entity_type=entity.entity_type.value, - type_label=entity.metadata.get("type_label"), - aliases=sorted(entity.aliases), - memory_ids=sorted(entity.memory_ids), - metadata=dict(entity.metadata), - change_reason=reason, - ) - - history.append(v) - self._append(v) - return v - - def _append(self, v: EntityVersion) -> None: - try: - os.makedirs(os.path.dirname(self._path) or ".", exist_ok=True) - with open(self._path, "a", encoding="utf-8") as f: - f.write(json.dumps(v.to_dict(), ensure_ascii=False) + "\n") - except OSError as e: - logger.warning("Failed to persist entity version: %s", e) - - def get_history(self, entity_name: str) -> List[EntityVersion]: - """All versions of an entity, oldest first.""" - return list(self._versions.get(entity_name, [])) - - def get_at_time(self, entity_name: str, iso_time: str) -> Optional[EntityVersion]: - """Get the entity version that was current at a given time.""" - history = self._versions.get(entity_name, []) - result = None - for v in history: - if v.timestamp <= iso_time: - result = v - else: - break - return result - - def diff(self, entity_name: str, v1: int, v2: int) -> Dict[str, Any]: - """Compute the difference between two versions of an entity.""" - history = self._versions.get(entity_name, []) - ver_map = {v.version: v for v in history} - old = ver_map.get(v1) - new = ver_map.get(v2) - if not old or not new: - return {"error": "version not found"} - - changes: Dict[str, Any] = {} - if old.entity_type != new.entity_type: - changes["entity_type"] = {"old": old.entity_type, "new": new.entity_type} - if old.type_label != new.type_label: - changes["type_label"] = {"old": old.type_label, "new": new.type_label} - - old_aliases = set(old.aliases) - new_aliases = set(new.aliases) - if old_aliases != new_aliases: - changes["aliases_added"] = sorted(new_aliases - old_aliases) - changes["aliases_removed"] = sorted(old_aliases - new_aliases) - - old_mids = set(old.memory_ids) - new_mids = set(new.memory_ids) - if old_mids != new_mids: - changes["memories_added"] = sorted(new_mids - old_mids) - changes["memories_removed"] = sorted(old_mids - new_mids) - - # Metadata diff (shallow) - for key in set(old.metadata) | set(new.metadata): - old_val = old.metadata.get(key) - new_val = new.metadata.get(key) - if old_val != new_val: - changes.setdefault("metadata", {})[key] = { - "old": old_val, "new": new_val, - } - - return { - "entity": entity_name, - "from_version": v1, - "to_version": v2, - "changes": changes, - } - - @property - def entity_count(self) -> int: - return len(self._versions) - - -# --------------------------------------------------------------------------- -# Personalized PageRank -# --------------------------------------------------------------------------- - -class PersonalizedPageRank: - """Per-user importance ranking over the entity–memory graph. - - Runs a standard power-iteration PageRank seeded from a user's - interaction history (memories they wrote, entities they mentioned). - Results are cached and refreshed when the graph changes. - """ - - def __init__( - self, - damping: float = 0.85, - iterations: int = 20, - tolerance: float = 1e-6, - ): - self.damping = damping - self.iterations = iterations - self.tolerance = tolerance - self._cache: Dict[str, Dict[str, float]] = {} # user_id -> {node -> score} - - def compute( - self, - graph: KnowledgeGraph, - seed_memory_ids: Optional[Set[str]] = None, - user_id: str = "default", - ) -> Dict[str, float]: - """Compute personalized PageRank for a user. - - Args: - graph: The knowledge graph. - seed_memory_ids: Memories that form the user's personalization - vector. If None, uses all memories. - user_id: Cache key. - - Returns: - Dict mapping node IDs (memory IDs and "entity:name") to scores. - """ - # Build adjacency from relationships - adj: Dict[str, Set[str]] = defaultdict(set) - all_nodes: Set[str] = set() - - for rel in graph.relationships: - adj[rel.source_id].add(rel.target_id) - adj[rel.target_id].add(rel.source_id) - all_nodes.add(rel.source_id) - all_nodes.add(rel.target_id) - - # Add entity nodes - for entity_name, entity in graph.entities.items(): - enode = f"entity:{entity_name}" - all_nodes.add(enode) - for mid in entity.memory_ids: - adj[mid].add(enode) - adj[enode].add(mid) - all_nodes.add(mid) - - if not all_nodes: - return {} - - n = len(all_nodes) - node_list = sorted(all_nodes) - node_idx = {node: i for i, node in enumerate(node_list)} - - # Personalization vector: uniform over seeds, zero elsewhere - personalization = [0.0] * n - if seed_memory_ids: - seeds_in_graph = [ - node_idx[mid] for mid in seed_memory_ids if mid in node_idx - ] - if seeds_in_graph: - weight = 1.0 / len(seeds_in_graph) - for idx in seeds_in_graph: - personalization[idx] = weight - else: - personalization = [1.0 / n] * n - else: - personalization = [1.0 / n] * n - - # Power iteration - scores = [1.0 / n] * n - for _ in range(self.iterations): - new_scores = [0.0] * n - for i, node in enumerate(node_list): - neighbors = adj.get(node, set()) - if not neighbors: - # Dangling node — distribute uniformly - share = scores[i] / n - for j in range(n): - new_scores[j] += share - else: - share = scores[i] / len(neighbors) - for nb in neighbors: - if nb in node_idx: - new_scores[node_idx[nb]] += share - - # Apply damping + personalization - for i in range(n): - new_scores[i] = ( - (1 - self.damping) * personalization[i] - + self.damping * new_scores[i] - ) - - # Check convergence - delta = sum(abs(new_scores[i] - scores[i]) for i in range(n)) - scores = new_scores - if delta < self.tolerance: - break - - result = {node_list[i]: scores[i] for i in range(n)} - self._cache[user_id] = result - return result - - def get_top_entities( - self, - graph: KnowledgeGraph, - user_id: str = "default", - seed_memory_ids: Optional[Set[str]] = None, - limit: int = 20, - ) -> List[Tuple[str, float]]: - """Get top-ranked entities for a user.""" - if user_id not in self._cache: - self.compute(graph, seed_memory_ids=seed_memory_ids, user_id=user_id) - - scores = self._cache.get(user_id, {}) - entity_scores = [ - (name.replace("entity:", ""), score) - for name, score in scores.items() - if name.startswith("entity:") - ] - entity_scores.sort(key=lambda x: x[1], reverse=True) - return entity_scores[:limit] - - def boost_retrieval( - self, - memory_ids: List[str], - user_id: str = "default", - ) -> Dict[str, float]: - """Get PageRank boost factors for a set of candidate memory IDs.""" - scores = self._cache.get(user_id, {}) - if not scores: - return {} - return {mid: scores.get(mid, 0.0) for mid in memory_ids} - - def invalidate(self, user_id: Optional[str] = None) -> None: - """Clear cached scores. Call when graph changes.""" - if user_id: - self._cache.pop(user_id, None) - else: - self._cache.clear() - - -# --------------------------------------------------------------------------- -# Schema-Free Entity Extraction -# --------------------------------------------------------------------------- - -_SCHEMA_FREE_PROMPT = """Extract entities from the following text. For each entity, provide: -- name: The entity name -- type: A descriptive type (e.g., "person", "technology", "framework", "metric", - "emotion", "event", "disease", "recipe" — any type that fits, not limited to a fixed set) -- relevance: How important this entity is to the text (0.0 to 1.0) - -Text: {content} - -Return a JSON array. Example: -[{{"name": "FastAPI", "type": "framework", "relevance": 0.9}}] - -Return ONLY the JSON array:""" - - -def extract_entities_schema_free( - content: str, - memory_id: str, - graph: KnowledgeGraph, - llm: Any = None, - min_relevance: float = 0.3, -) -> List[Entity]: - """Extract entities without a fixed type schema. - - Uses LLM to discover entity types at runtime. Discovered types are - stored as EntityType.DYNAMIC with a ``type_label`` in metadata. - - Falls back to graph.extract_entities() regex path if no LLM. - """ - if not llm: - return graph.extract_entities(content, memory_id, use_llm=False) - - prompt = _SCHEMA_FREE_PROMPT.format(content=content[:2000]) - - try: - response = llm.generate(prompt) - arr_start = response.find("[") - if arr_start < 0: - return graph.extract_entities(content, memory_id, use_llm=False) - - items, _ = json.JSONDecoder().raw_decode(response, arr_start) - except Exception as e: - logger.debug("Schema-free extraction failed (%s), falling back to regex", e) - return graph.extract_entities(content, memory_id, use_llm=False) - - # Known EntityType values (lowercase) - _known_types = {t.value for t in EntityType} - - entities: List[Entity] = [] - for item in items: - name = item.get("name", "").strip() - if not name: - continue - - relevance = float(item.get("relevance", 0.5)) - if relevance < min_relevance: - continue - - raw_type = item.get("type", "unknown").strip().lower() - - # Map to existing enum if possible; otherwise DYNAMIC - if raw_type in _known_types: - entity_type = EntityType(raw_type) - type_label = None - else: - entity_type = EntityType.DYNAMIC - type_label = raw_type - - entity = graph._get_or_create_entity(name, entity_type) - entity.memory_ids.add(memory_id) - entity.metadata["relevance"] = max( - entity.metadata.get("relevance", 0.0), relevance, - ) - if type_label: - entity.metadata["type_label"] = type_label - - entities.append(entity) - - graph.memory_entities[memory_id] = {e.name for e in entities} - return entities - - -# --------------------------------------------------------------------------- -# EvolvingGraph — wraps it all together -# --------------------------------------------------------------------------- - -class EvolvingGraph: - """Knowledge graph with entity versioning, PageRank, and schema-free extraction. - - Drop-in extension of KnowledgeGraph — delegates core graph operations - and adds evolution capabilities on top. - """ - - def __init__( - self, - data_dir: Optional[str] = None, - llm: Any = None, - damping: float = 0.85, - ): - self._data_dir = data_dir or os.path.join( - os.path.expanduser("~"), ".dhee", "graph", - ) - os.makedirs(self._data_dir, exist_ok=True) - - graph_path = os.path.join(self._data_dir, "graph.json") - self.graph = KnowledgeGraph.load(graph_path, llm=llm) - - self._versions = EntityVersionStore( - os.path.join(self._data_dir, "entity_versions.jsonl"), - ) - self._pagerank = PersonalizedPageRank(damping=damping) - self._llm = llm - - # ── Entity operations (versioned) ── - - def extract_and_version( - self, - content: str, - memory_id: str, - reason: str = "new_memory", - schema_free: bool = True, - ) -> List[Entity]: - """Extract entities from content and record versions for any changes.""" - if schema_free and self._llm: - entities = extract_entities_schema_free( - content, memory_id, self.graph, llm=self._llm, - ) - else: - entities = self.graph.extract_entities(content, memory_id) - - # Record version for each entity that was touched - for entity in entities: - self._versions.record(entity, reason=reason) - - # Invalidate PageRank caches (graph changed) - self._pagerank.invalidate() - - return entities - - def update_entity( - self, - entity_name: str, - updates: Dict[str, Any], - reason: str = "update", - ) -> Optional[Entity]: - """Update an entity's fields and record the version.""" - entity = self.graph.entities.get(entity_name) - if not entity: - return None - - if "entity_type" in updates: - entity.entity_type = EntityType(updates["entity_type"]) - if "aliases" in updates: - entity.aliases.update(updates["aliases"]) - if "metadata" in updates: - entity.metadata.update(updates["metadata"]) - - self._versions.record(entity, reason=reason) - self._pagerank.invalidate() - return entity - - def get_entity_history(self, entity_name: str) -> List[EntityVersion]: - return self._versions.get_history(entity_name) - - def get_entity_at_time( - self, entity_name: str, iso_time: str, - ) -> Optional[EntityVersion]: - return self._versions.get_at_time(entity_name, iso_time) - - def entity_diff( - self, entity_name: str, v1: int, v2: int, - ) -> Dict[str, Any]: - return self._versions.diff(entity_name, v1, v2) - - # ── PageRank ── - - def compute_pagerank( - self, - user_id: str = "default", - seed_memory_ids: Optional[Set[str]] = None, - ) -> Dict[str, float]: - return self._pagerank.compute( - self.graph, seed_memory_ids=seed_memory_ids, user_id=user_id, - ) - - def get_important_entities( - self, - user_id: str = "default", - seed_memory_ids: Optional[Set[str]] = None, - limit: int = 20, - ) -> List[Tuple[str, float]]: - return self._pagerank.get_top_entities( - self.graph, user_id=user_id, - seed_memory_ids=seed_memory_ids, limit=limit, - ) - - def pagerank_boost( - self, - memory_ids: List[str], - user_id: str = "default", - ) -> Dict[str, float]: - return self._pagerank.boost_retrieval(memory_ids, user_id=user_id) - - # ── Graph delegation ── - - def add_relationship(self, *args, **kwargs) -> Relationship: - rel = self.graph.add_relationship(*args, **kwargs) - self._pagerank.invalidate() - return rel - - def link_by_shared_entities(self, memory_id: str) -> List[Relationship]: - rels = self.graph.link_by_shared_entities(memory_id) - if rels: - self._pagerank.invalidate() - return rels - - def get_related_memories(self, *args, **kwargs): - return self.graph.get_related_memories(*args, **kwargs) - - def get_causal_chain(self, *args, **kwargs): - return self.graph.get_causal_chain(*args, **kwargs) - - def get_memory_graph(self, memory_id: str) -> Dict[str, Any]: - return self.graph.get_memory_graph(memory_id) - - # ── Persistence ── - - def save(self) -> None: - """Persist graph to disk. Version store auto-persists on append.""" - graph_path = os.path.join(self._data_dir, "graph.json") - self.graph.save(graph_path) - - def stats(self) -> Dict[str, Any]: - base = self.graph.stats() - base["versioned_entities"] = self._versions.entity_count - base["dynamic_entities"] = sum( - 1 for e in self.graph.entities.values() - if e.entity_type == EntityType.DYNAMIC - ) - return base diff --git a/dhee/core/invalidation.py b/dhee/core/invalidation.py deleted file mode 100644 index 2b64316..0000000 --- a/dhee/core/invalidation.py +++ /dev/null @@ -1,248 +0,0 @@ -"""Dhee v3 — Three-Tier Invalidation Engine. - -Graduated invalidation based on what happened to the source: - -1. Hard invalidation: source deleted → derived tombstoned -2. Soft invalidation: source corrected → derived marked stale, re-eval queued -3. Partial invalidation: one of N sources changed, contribution < 30% - → derived marked suspect with confidence penalty - -Design contract: - - Invalidation is async — marks status + enqueues repair jobs - - Never synchronously rewrites derived objects - - Type-aware: each derived type has its own invalidation response - - All cascades are traceable via maintenance_jobs - - Zero LLM calls -""" - -from __future__ import annotations - -import json -import logging -import uuid -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional - -logger = logging.getLogger(__name__) - - -def _utcnow_iso() -> str: - return datetime.now(timezone.utc).isoformat() - - -# Threshold: if a changed source contributed >= this fraction, -# escalate from partial to soft invalidation -PARTIAL_ESCALATION_THRESHOLD = 0.30 - - -class InvalidationEngine: - """Cascades invalidation from raw events to derived objects. - - Usage: - engine = InvalidationEngine(lineage, stores_map, job_enqueuer) - engine.on_event_corrected(event_id) # soft + partial - engine.on_event_deleted(event_id) # hard + partial - """ - - def __init__( - self, - lineage: "DerivedLineageStore", - stores: Dict[str, Any], - conn: "sqlite3.Connection", - lock: "threading.RLock", - ): - """ - Args: - lineage: DerivedLineageStore for tracing dependencies - stores: Map of derived_type → store instance, e.g. - {"belief": belief_store, "policy": policy_store, ...} - conn: Shared SQLite connection for enqueuing jobs - lock: Shared threading lock - """ - self.lineage = lineage - self.stores = stores - self._conn = conn - self._lock = lock - - def on_event_corrected(self, event_id: str) -> Dict[str, Any]: - """Handle a raw event being corrected (superseded by new event). - - For each dependent derived object: - - If sole source → soft invalidation (stale) - - If one of many → check contribution weight: - - >= 30% → soft invalidation - - < 30% → partial invalidation (suspect + confidence penalty) - """ - return self._cascade(event_id, mode="corrected") - - def on_event_deleted(self, event_id: str) -> Dict[str, Any]: - """Handle a raw event being deleted. - - For each dependent derived object: - - If sole source → hard invalidation (tombstone) - - If one of many → check contribution weight: - - >= 30% → soft invalidation (stale + re-eval) - - < 30% → partial invalidation (suspect + confidence penalty) - """ - return self._cascade(event_id, mode="deleted") - - def _cascade(self, event_id: str, mode: str) -> Dict[str, Any]: - """Core cascade logic for both correction and deletion.""" - dependents = self.lineage.get_dependents(event_id) - - result = { - "event_id": event_id, - "mode": mode, - "hard_invalidated": [], - "soft_invalidated": [], - "partial_invalidated": [], - "jobs_enqueued": [], - "errors": [], - } - - for dep in dependents: - dtype = dep["derived_type"] - did = dep["derived_id"] - weight = dep["contribution_weight"] - - try: - # How many total sources does this derived object have? - total_sources = self.lineage.get_source_count(dtype, did) - - if total_sources <= 1: - # Sole source — hard or soft depending on mode - if mode == "deleted": - self._hard_invalidate(dtype, did) - result["hard_invalidated"].append( - {"type": dtype, "id": did} - ) - else: # corrected - self._soft_invalidate(dtype, did) - job_id = self._enqueue_repair( - dtype, did, "repair_stale_derived" - ) - result["soft_invalidated"].append( - {"type": dtype, "id": did} - ) - result["jobs_enqueued"].append(job_id) - - elif weight >= PARTIAL_ESCALATION_THRESHOLD: - # Major contributor — treat as soft invalidation - self._soft_invalidate(dtype, did) - job_id = self._enqueue_repair( - dtype, did, "repair_stale_derived" - ) - result["soft_invalidated"].append( - {"type": dtype, "id": did, "weight": weight} - ) - result["jobs_enqueued"].append(job_id) - - else: - # Minor contributor — partial invalidation - self._partial_invalidate(dtype, did, weight) - job_id = self._enqueue_repair( - dtype, did, "verify_suspect_derived" - ) - result["partial_invalidated"].append( - {"type": dtype, "id": did, "weight": weight} - ) - result["jobs_enqueued"].append(job_id) - - except Exception as e: - logger.exception( - "Invalidation failed for %s:%s from event %s", - dtype, did, event_id, - ) - result["errors"].append({ - "type": dtype, "id": did, "error": str(e) - }) - - return result - - # ------------------------------------------------------------------ - # Invalidation tier implementations - # ------------------------------------------------------------------ - - def _hard_invalidate(self, dtype: str, did: str) -> None: - """Source gone, child unusable. Mark as tombstone.""" - store = self.stores.get(dtype) - if store and hasattr(store, "set_status"): - store.set_status(did, "invalidated") - logger.info("Hard invalidated %s:%s", dtype, did) - - def _soft_invalidate(self, dtype: str, did: str) -> None: - """Source changed, child needs re-evaluation.""" - store = self.stores.get(dtype) - if store and hasattr(store, "set_status"): - store.set_status(did, "stale") - logger.info("Soft invalidated %s:%s → stale", dtype, did) - - def _partial_invalidate( - self, dtype: str, did: str, weight: float - ) -> None: - """Minor source change. Mark suspect + confidence penalty.""" - store = self.stores.get(dtype) - if not store: - return - - # Apply confidence penalty proportional to contribution weight - if hasattr(store, "get") and hasattr(store, "update_confidence"): - obj = store.get(did) - if obj and "confidence" in obj: - penalty = weight * 0.5 # half the contribution weight - new_conf = max(0.05, obj["confidence"] - penalty) - store.update_confidence( - did, new_conf, - new_status="suspect", - revision_reason=f"partial_invalidation (weight={weight:.2f})", - ) - elif hasattr(store, "set_status"): - store.set_status(did, "suspect") - - logger.info( - "Partial invalidated %s:%s → suspect (weight=%.2f)", - dtype, did, weight, - ) - - # ------------------------------------------------------------------ - # Job enqueuing - # ------------------------------------------------------------------ - - def _enqueue_repair( - self, dtype: str, did: str, job_name: str - ) -> str: - """Enqueue a repair job for a derived object.""" - job_id = str(uuid.uuid4()) - now = _utcnow_iso() - payload = json.dumps({ - "derived_type": dtype, - "derived_id": did, - }) - idem_key = f"{job_name}:{dtype}:{did}" - - with self._lock: - try: - # Check idempotency — don't enqueue if already pending/running - existing = self._conn.execute( - """SELECT job_id FROM maintenance_jobs - WHERE idempotency_key = ? - AND status IN ('pending', 'running') - LIMIT 1""", - (idem_key,), - ).fetchone() - - if existing: - return existing["job_id"] - - self._conn.execute( - """INSERT INTO maintenance_jobs - (job_id, job_name, status, payload_json, - created_at, idempotency_key) - VALUES (?, ?, 'pending', ?, ?, ?)""", - (job_id, job_name, payload, now, idem_key), - ) - self._conn.commit() - return job_id - except Exception: - self._conn.rollback() - raise diff --git a/dhee/core/jobs.py b/dhee/core/jobs.py deleted file mode 100644 index 69f7e5f..0000000 --- a/dhee/core/jobs.py +++ /dev/null @@ -1,477 +0,0 @@ -"""Dhee v3 — Job Registry: named, idempotent, observable maintenance jobs. - -Replaces agi_loop.py's phantom subsystems with real, independently testable jobs. - -Each job: - - Has a unique name (e.g., "distill_episodic_to_semantic") - - Is idempotent: same input → same output, safe to retry - - Is observable: status, timing, retry count tracked in maintenance_jobs table - - Is leasable: acquires a lock before running (via LeaseManager) - - Returns a structured result dict - -Design contract: - - Jobs NEVER call memory.add() or any write-path that triggers enrichment - - Jobs write to derived stores + lineage only - - Jobs are cold-path: called by heartbeat/cron, never by hot-path remember/recall -""" - -from __future__ import annotations - -import json -import logging -import traceback -import uuid -from abc import ABC, abstractmethod -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional, Type - -from dhee.core.lease_manager import LeaseManager - -logger = logging.getLogger(__name__) - - -def _utcnow_iso() -> str: - return datetime.now(timezone.utc).isoformat() - - -class Job(ABC): - """Base class for all maintenance jobs.""" - - # Subclasses MUST set this - name: str = "" - - def __init__(self): - if not self.name: - raise ValueError(f"{self.__class__.__name__} must set 'name'") - - @abstractmethod - def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]: - """Run the job. Returns a result dict. - - Args: - payload: Job-specific input parameters - - Returns: - Dict with job results (stored in maintenance_jobs.result_json) - - Raises: - Exception on failure (caught by JobRegistry, stored as error) - """ - ... - - def make_idempotency_key(self, payload: Dict[str, Any]) -> Optional[str]: - """Generate an idempotency key for dedup. Override if needed. - - Returns None to skip idempotency checking (every run creates a new job). - """ - return None - - -class JobRegistry: - """Manages job registration, scheduling, execution, and observability. - - Usage: - registry = JobRegistry(conn, lock, lease_manager) - registry.register(ApplyForgettingJob) - registry.register(DistillEpisodicJob) - - # Run a specific job - result = registry.run("apply_forgetting", payload={"user_id": "u1"}) - - # Run all due jobs (heartbeat) - results = registry.run_all(owner_id="worker-1") - """ - - def __init__( - self, - conn: "sqlite3.Connection", - lock: "threading.RLock", - lease_manager: LeaseManager, - ): - import sqlite3 - import threading - - self._conn = conn - self._lock = lock - self._lease = lease_manager - self._jobs: Dict[str, Job] = {} - - def register(self, job_class: Type[Job]) -> None: - """Register a job class. Instantiates it.""" - job = job_class() - if job.name in self._jobs: - logger.warning("Job %s already registered, replacing", job.name) - self._jobs[job.name] = job - - def list_registered(self) -> List[str]: - """List all registered job names.""" - return list(self._jobs.keys()) - - def run( - self, - job_name: str, - *, - payload: Optional[Dict[str, Any]] = None, - owner_id: str = "default-worker", - ) -> Dict[str, Any]: - """Run a single job by name. Acquires lease, executes, records result. - - Returns: - Dict with: job_id, job_name, status, result/error, timing - """ - if job_name not in self._jobs: - return { - "job_name": job_name, - "status": "error", - "error": f"Unknown job: {job_name}", - } - - job = self._jobs[job_name] - payload = payload or {} - - # Idempotency check - idem_key = job.make_idempotency_key(payload) - if idem_key: - with self._lock: - existing = self._conn.execute( - """SELECT job_id, status, result_json FROM maintenance_jobs - WHERE idempotency_key = ? AND status IN ('completed', 'running') - LIMIT 1""", - (idem_key,), - ).fetchone() - if existing: - return { - "job_id": existing["job_id"], - "job_name": job_name, - "status": "skipped_idempotent", - "existing_status": existing["status"], - } - - # Acquire lease - lock_id = f"job:{job_name}" - if not self._lease.acquire(lock_id, owner_id): - return { - "job_name": job_name, - "status": "skipped_locked", - "holder": self._lease.get_holder(lock_id), - } - - # Create job record - job_id = str(uuid.uuid4()) - now = _utcnow_iso() - - try: - with self._lock: - self._conn.execute( - """INSERT INTO maintenance_jobs - (job_id, job_name, status, payload_json, - created_at, started_at, idempotency_key) - VALUES (?, ?, 'running', ?, ?, ?, ?)""", - ( - job_id, job_name, json.dumps(payload), - now, now, idem_key, - ), - ) - self._conn.commit() - - # Execute - result = job.execute(payload) - completed_at = _utcnow_iso() - - with self._lock: - self._conn.execute( - """UPDATE maintenance_jobs - SET status = 'completed', result_json = ?, - completed_at = ? - WHERE job_id = ?""", - (json.dumps(result, default=str), completed_at, job_id), - ) - self._conn.commit() - - return { - "job_id": job_id, - "job_name": job_name, - "status": "completed", - "result": result, - "started_at": now, - "completed_at": completed_at, - } - - except Exception as e: - error_msg = f"{type(e).__name__}: {e}" - logger.exception("Job %s failed: %s", job_name, error_msg) - - with self._lock: - self._conn.execute( - """UPDATE maintenance_jobs - SET status = 'failed', error_message = ?, - completed_at = ?, - retry_count = retry_count + 1 - WHERE job_id = ?""", - (error_msg, _utcnow_iso(), job_id), - ) - self._conn.commit() - - return { - "job_id": job_id, - "job_name": job_name, - "status": "failed", - "error": error_msg, - } - - finally: - self._lease.release(lock_id, owner_id) - - def run_all( - self, *, owner_id: str = "default-worker" - ) -> List[Dict[str, Any]]: - """Run all registered jobs. Returns list of results.""" - results = [] - for name in self._jobs: - result = self.run(name, owner_id=owner_id) - results.append(result) - return results - - def get_job_history( - self, - job_name: str, - *, - limit: int = 10, - ) -> List[Dict[str, Any]]: - """Get recent execution history for a job.""" - with self._lock: - rows = self._conn.execute( - """SELECT job_id, job_name, status, payload_json, - result_json, error_message, - created_at, started_at, completed_at, - retry_count - FROM maintenance_jobs - WHERE job_name = ? - ORDER BY created_at DESC - LIMIT ?""", - (job_name, limit), - ).fetchall() - - return [ - { - "job_id": r["job_id"], - "job_name": r["job_name"], - "status": r["status"], - "payload": json.loads(r["payload_json"] or "{}"), - "result": json.loads(r["result_json"] or "null"), - "error": r["error_message"], - "created_at": r["created_at"], - "started_at": r["started_at"], - "completed_at": r["completed_at"], - "retry_count": r["retry_count"], - } - for r in rows - ] - - def get_health(self) -> Dict[str, Any]: - """Get health summary across all registered jobs.""" - health: Dict[str, Any] = { - "registered_jobs": list(self._jobs.keys()), - "total_registered": len(self._jobs), - "job_status": {}, - } - - for name in self._jobs: - with self._lock: - row = self._conn.execute( - """SELECT status, completed_at, error_message - FROM maintenance_jobs - WHERE job_name = ? - ORDER BY created_at DESC - LIMIT 1""", - (name,), - ).fetchone() - - if row: - health["job_status"][name] = { - "last_status": row["status"], - "last_completed": row["completed_at"], - "last_error": row["error_message"], - } - else: - health["job_status"][name] = {"last_status": "never_run"} - - return health - - -# ========================================================================= -# Concrete Job Implementations -# ========================================================================= - -class ApplyForgettingJob(Job): - """Apply decay/forgetting curves to memory strengths. - - Replaces agi_loop step 2 (decay). - """ - - name = "apply_forgetting" - - def __init__(self): - super().__init__() - self._memory = None # Set by caller via set_context() - - def set_context(self, memory: Any) -> "ApplyForgettingJob": - self._memory = memory - return self - - def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._memory: - return {"status": "skipped", "reason": "no memory instance"} - - user_id = payload.get("user_id", "default") - try: - result = self._memory.apply_decay(scope={"user_id": user_id}) - return {"status": "ok", "decay_result": result} - except Exception as e: - return {"status": "error", "error": str(e)} - - -class RunConsolidationJob(Job): - """Run the cognition kernel's sleep_cycle (distillation). - - Replaces agi_loop step 1 (consolidate). - """ - - name = "run_consolidation" - - def __init__(self): - super().__init__() - self._kernel = None - - def set_context(self, kernel: Any) -> "RunConsolidationJob": - self._kernel = kernel - return self - - def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._kernel: - return {"status": "skipped", "reason": "no kernel instance"} - - user_id = payload.get("user_id", "default") - try: - result = self._kernel.sleep_cycle(user_id=user_id) - return {"status": "ok", "consolidation_result": result} - except Exception as e: - return {"status": "error", "error": str(e)} - - -class ExtractStepPoliciesJob(Job): - """Extract step-level policies from completed tasks. - - Replaces the inline policy extraction in record_learning_outcomes. - """ - - name = "extract_step_policies" - - def __init__(self): - super().__init__() - self._kernel = None - - def set_context(self, kernel: Any) -> "ExtractStepPoliciesJob": - self._kernel = kernel - return self - - def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._kernel: - return {"status": "skipped", "reason": "no kernel instance"} - - user_id = payload.get("user_id", "default") - task_id = payload.get("task_id") - - if not task_id: - return {"status": "skipped", "reason": "no task_id in payload"} - - try: - # Look up the task - task = self._kernel.task_manager.get(task_id) - if not task: - return {"status": "skipped", "reason": f"task {task_id} not found"} - - # Use existing policy extraction - policies_created = 0 - if hasattr(self._kernel, 'policy_manager') and self._kernel.policy_manager: - from dhee.core.task_state import TaskStatus - if task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED): - # Extract via existing mechanisms - policies_created = self._kernel.policy_manager.extract_from_task( - task, user_id=user_id - ) if hasattr(self._kernel.policy_manager, 'extract_from_task') else 0 - - return {"status": "ok", "policies_created": policies_created} - except Exception as e: - return {"status": "error", "error": str(e)} - - def make_idempotency_key(self, payload: Dict[str, Any]) -> Optional[str]: - task_id = payload.get("task_id", "") - return f"step_policies:{task_id}" if task_id else None - - -class DetectFailurePatternsJob(Job): - """Run the FailurePatternDetector on terminal tasks. - - Replaces the inline pattern detection in record_learning_outcomes. - """ - - name = "detect_failure_patterns" - - def __init__(self): - super().__init__() - self._kernel = None - - def set_context(self, kernel: Any) -> "DetectFailurePatternsJob": - self._kernel = kernel - return self - - def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._kernel: - return {"status": "skipped", "reason": "no kernel instance"} - - user_id = payload.get("user_id", "default") - try: - from dhee.core.pattern_detector import ( - FailurePatternDetector, extract_features, - ) - from dhee.core.task_state import TaskStatus - - # Get terminal tasks - all_tasks = self._kernel.task_manager.list_tasks( - user_id=user_id, limit=200 - ) - terminal = [ - t for t in all_tasks - if t.status in (TaskStatus.COMPLETED, TaskStatus.FAILED) - ] - - if len(terminal) < 10: - return { - "status": "ok", "patterns_found": 0, - "reason": f"only {len(terminal)} terminal tasks (need 10+)", - } - - features = extract_features(terminal) - detector = FailurePatternDetector() - patterns = detector.detect_and_describe(features) - - stored = 0 - for pattern in patterns: - if hasattr(self._kernel, '_store_pattern_as_policy'): - policy = self._kernel._store_pattern_as_policy( - user_id, "detected", pattern, - ) - if policy: - stored += 1 - - return { - "status": "ok", - "terminal_tasks": len(terminal), - "patterns_found": len(patterns), - "patterns_stored": stored, - } - except ImportError: - return {"status": "skipped", "reason": "pattern_detector not available"} - except Exception as e: - return {"status": "error", "error": str(e)} diff --git a/dhee/core/lease_manager.py b/dhee/core/lease_manager.py deleted file mode 100644 index aeb70b6..0000000 --- a/dhee/core/lease_manager.py +++ /dev/null @@ -1,256 +0,0 @@ -"""Dhee v3 — SQLite Lease Manager for job concurrency control. - -Ensures that only one runner can execute a given maintenance job at a time. -Uses SQLite's BEGIN IMMEDIATE for atomic lease acquisition. - -Design contract: - - Leases are time-bounded (default 300s) - - Expired leases are automatically stolen - - Renew extends lease while holding it - - Release is explicit; stale leases cleaned on next acquire - - Zero external dependencies — pure SQLite -""" - -from __future__ import annotations - -import logging -import sqlite3 -import threading -import uuid -from contextlib import contextmanager -from datetime import datetime, timezone, timedelta -from typing import Optional - -logger = logging.getLogger(__name__) - -DEFAULT_LEASE_DURATION_SECONDS = 300 - - -def _utcnow() -> datetime: - return datetime.now(timezone.utc) - - -def _utcnow_iso() -> str: - return _utcnow().isoformat() - - -class LeaseManager: - """SQLite-based distributed lease manager. - - Each lock_id represents a named resource (e.g., a job name). - Only one owner can hold a lease at a time. Expired leases are - automatically reclaimed. - - Usage: - lm = LeaseManager(conn, lock) - acquired = lm.acquire("distill_batch", owner_id="worker-1") - if acquired: - try: - # do work - lm.renew("distill_batch", "worker-1") # extend if long - finally: - lm.release("distill_batch", "worker-1") - """ - - def __init__( - self, - conn: sqlite3.Connection, - lock: threading.RLock, - *, - default_duration_seconds: int = DEFAULT_LEASE_DURATION_SECONDS, - ): - self._conn = conn - self._lock = lock - self.default_duration = default_duration_seconds - - @contextmanager - def _tx(self): - with self._lock: - try: - yield self._conn - self._conn.commit() - except Exception: - self._conn.rollback() - raise - - def acquire( - self, - lock_id: str, - owner_id: str, - *, - duration_seconds: Optional[int] = None, - ) -> bool: - """Try to acquire a lease. Returns True if acquired. - - If the lock is held by another owner and not expired, returns False. - If the lock is expired, steals it (atomic via BEGIN IMMEDIATE). - If the lock is held by the same owner, renews it. - """ - duration = duration_seconds or self.default_duration - now = _utcnow() - expires = (now + timedelta(seconds=duration)).isoformat() - now_iso = now.isoformat() - - with self._tx() as conn: - row = conn.execute( - "SELECT owner_id, lease_expires_at FROM locks WHERE lock_id = ?", - (lock_id,), - ).fetchone() - - if row is None: - # No lock exists — create it - conn.execute( - """INSERT INTO locks (lock_id, owner_id, lease_expires_at, updated_at) - VALUES (?, ?, ?, ?)""", - (lock_id, owner_id, expires, now_iso), - ) - return True - - existing_owner = row["owner_id"] - existing_expires = row["lease_expires_at"] - - # Same owner — renew - if existing_owner == owner_id: - conn.execute( - "UPDATE locks SET lease_expires_at = ?, updated_at = ? WHERE lock_id = ?", - (expires, now_iso, lock_id), - ) - return True - - # Different owner — check if expired - try: - exp_dt = datetime.fromisoformat(existing_expires.replace("Z", "+00:00")) - except (ValueError, AttributeError): - exp_dt = _utcnow() # treat unparseable as expired - - if now >= exp_dt: - # Expired — steal the lease - conn.execute( - "UPDATE locks SET owner_id = ?, lease_expires_at = ?, updated_at = ? WHERE lock_id = ?", - (owner_id, expires, now_iso, lock_id), - ) - logger.info( - "Lease %s stolen from %s (expired %s) by %s", - lock_id, existing_owner, existing_expires, owner_id, - ) - return True - - # Not expired — someone else holds it - return False - - def release(self, lock_id: str, owner_id: str) -> bool: - """Release a lease. Returns True if successfully released. - - Only the current owner can release. Returns False if: - - Lock doesn't exist - - Lock is held by a different owner - """ - with self._tx() as conn: - row = conn.execute( - "SELECT owner_id FROM locks WHERE lock_id = ?", - (lock_id,), - ).fetchone() - - if not row or row["owner_id"] != owner_id: - return False - - conn.execute("DELETE FROM locks WHERE lock_id = ?", (lock_id,)) - return True - - def renew( - self, - lock_id: str, - owner_id: str, - *, - duration_seconds: Optional[int] = None, - ) -> bool: - """Extend a lease. Returns True if renewed. - - Only the current owner can renew. Returns False if: - - Lock doesn't exist - - Lock is held by a different owner - - Lock has already expired (use acquire to re-take) - """ - duration = duration_seconds or self.default_duration - now = _utcnow() - expires = (now + timedelta(seconds=duration)).isoformat() - now_iso = now.isoformat() - - with self._tx() as conn: - row = conn.execute( - "SELECT owner_id, lease_expires_at FROM locks WHERE lock_id = ?", - (lock_id,), - ).fetchone() - - if not row or row["owner_id"] != owner_id: - return False - - # Check not expired - try: - exp_dt = datetime.fromisoformat( - row["lease_expires_at"].replace("Z", "+00:00") - ) - except (ValueError, AttributeError): - return False - - if now >= exp_dt: - return False # expired — must re-acquire - - conn.execute( - "UPDATE locks SET lease_expires_at = ?, updated_at = ? WHERE lock_id = ?", - (expires, now_iso, lock_id), - ) - return True - - def is_held(self, lock_id: str) -> bool: - """Check if a lock is currently held (not expired).""" - with self._lock: - row = self._conn.execute( - "SELECT lease_expires_at FROM locks WHERE lock_id = ?", - (lock_id,), - ).fetchone() - - if not row: - return False - - try: - exp_dt = datetime.fromisoformat( - row["lease_expires_at"].replace("Z", "+00:00") - ) - except (ValueError, AttributeError): - return False - - return _utcnow() < exp_dt - - def get_holder(self, lock_id: str) -> Optional[str]: - """Get the current holder of a lock, or None if unheld/expired.""" - with self._lock: - row = self._conn.execute( - "SELECT owner_id, lease_expires_at FROM locks WHERE lock_id = ?", - (lock_id,), - ).fetchone() - - if not row: - return None - - try: - exp_dt = datetime.fromisoformat( - row["lease_expires_at"].replace("Z", "+00:00") - ) - except (ValueError, AttributeError): - return None - - if _utcnow() >= exp_dt: - return None - - return row["owner_id"] - - def cleanup_expired(self) -> int: - """Remove all expired lease rows. Returns count removed.""" - now_iso = _utcnow_iso() - with self._tx() as conn: - result = conn.execute( - "DELETE FROM locks WHERE lease_expires_at < ?", - (now_iso,), - ) - return result.rowcount diff --git a/dhee/core/metrics.py b/dhee/core/metrics.py deleted file mode 100644 index 08a0f9f..0000000 --- a/dhee/core/metrics.py +++ /dev/null @@ -1,207 +0,0 @@ -"""Observability — structured logging and metrics for Dhee. - -Counters, histograms, and a @measure decorator for all critical paths. -Prometheus-compatible /metrics endpoint output. -""" - -import logging -import time -import threading -from collections import defaultdict -from dataclasses import dataclass, field -from functools import wraps -from typing import Any, Callable, Dict, List, Optional - -logger = logging.getLogger(__name__) - - -class Counter: - """Thread-safe monotonic counter.""" - - def __init__(self, name: str, description: str = ""): - self.name = name - self.description = description - self._value = 0 - self._lock = threading.Lock() - - def inc(self, amount: int = 1) -> None: - with self._lock: - self._value += amount - - @property - def value(self) -> int: - return self._value - - def to_prometheus(self) -> str: - return ( - f"# HELP {self.name} {self.description}\n" - f"# TYPE {self.name} counter\n" - f"{self.name} {self._value}\n" - ) - - -class Histogram: - """Thread-safe latency histogram with percentile tracking.""" - - def __init__(self, name: str, description: str = "", max_samples: int = 1000): - self.name = name - self.description = description - self._samples: List[float] = [] - self._max_samples = max_samples - self._lock = threading.Lock() - self._sum = 0.0 - self._count = 0 - - def observe(self, value: float) -> None: - with self._lock: - self._samples.append(value) - self._sum += value - self._count += 1 - if len(self._samples) > self._max_samples: - self._samples = self._samples[-self._max_samples:] - - @property - def count(self) -> int: - return self._count - - @property - def avg(self) -> float: - return self._sum / self._count if self._count else 0.0 - - def percentile(self, p: float) -> float: - with self._lock: - if not self._samples: - return 0.0 - sorted_samples = sorted(self._samples) - idx = int(len(sorted_samples) * p / 100) - return sorted_samples[min(idx, len(sorted_samples) - 1)] - - def to_prometheus(self) -> str: - p50 = self.percentile(50) - p95 = self.percentile(95) - p99 = self.percentile(99) - return ( - f"# HELP {self.name} {self.description}\n" - f"# TYPE {self.name} histogram\n" - f'{self.name}{{quantile="0.5"}} {p50:.4f}\n' - f'{self.name}{{quantile="0.95"}} {p95:.4f}\n' - f'{self.name}{{quantile="0.99"}} {p99:.4f}\n' - f"{self.name}_sum {self._sum:.4f}\n" - f"{self.name}_count {self._count}\n" - ) - - -class DheeMetrics: - """Centralized metrics registry for Dhee.""" - - _instance = None - _lock = threading.Lock() - - def __new__(cls): - with cls._lock: - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._initialized = False - return cls._instance - - def __init__(self): - if self._initialized: - return - self._initialized = True - - # Counters - self.memory_add = Counter("dhee_memory_add_total", "Total memories added") - self.memory_search = Counter("dhee_memory_search_total", "Total memory searches") - self.deterministic_resolve = Counter( - "dhee_deterministic_resolve_total", "Queries resolved deterministically via SQL" - ) - self.vector_search = Counter("dhee_vector_search_total", "Vector search operations") - self.cognitive_loop = Counter("dhee_cognitive_loop_total", "Cognitive decomposition runs") - self.llm_calls = Counter("dhee_llm_calls_total", "Total LLM API calls") - self.engram_extractions = Counter( - "dhee_engram_extractions_total", "Structured engram extractions" - ) - self.prospective_triggered = Counter( - "dhee_prospective_triggered_total", "Prospective scenes triggered" - ) - self.facts_stored = Counter("dhee_facts_stored_total", "Structured facts stored") - self.rerank_calls = Counter("dhee_rerank_calls_total", "Reranker invocations") - - # Histograms - self.add_latency = Histogram( - "dhee_memory_add_seconds", "Memory add latency in seconds" - ) - self.search_latency = Histogram( - "dhee_memory_search_seconds", "Memory search latency in seconds" - ) - self.resolve_latency = Histogram( - "dhee_deterministic_resolve_seconds", "Deterministic resolution latency" - ) - self.extraction_latency = Histogram( - "dhee_engram_extraction_seconds", "Engram extraction latency" - ) - self.cognitive_latency = Histogram( - "dhee_cognitive_loop_seconds", "Cognitive loop latency" - ) - self.llm_latency = Histogram("dhee_llm_call_seconds", "LLM call latency") - self.rerank_latency = Histogram("dhee_rerank_seconds", "Reranker latency") - - def to_prometheus(self) -> str: - """Render all metrics in Prometheus text format.""" - parts = [] - for attr_name in dir(self): - attr = getattr(self, attr_name) - if isinstance(attr, (Counter, Histogram)): - parts.append(attr.to_prometheus()) - return "\n".join(parts) - - def to_dict(self) -> Dict[str, Any]: - """Render metrics as a dict for JSON API.""" - result = {} - for attr_name in dir(self): - attr = getattr(self, attr_name) - if isinstance(attr, Counter): - result[attr.name] = attr.value - elif isinstance(attr, Histogram): - result[attr.name] = { - "count": attr.count, - "avg": round(attr.avg, 4), - "p50": round(attr.percentile(50), 4), - "p95": round(attr.percentile(95), 4), - "p99": round(attr.percentile(99), 4), - } - return result - - -def get_metrics() -> DheeMetrics: - """Get the singleton metrics instance.""" - return DheeMetrics() - - -def measure(counter_name: str = "", histogram_name: str = ""): - """Decorator to measure function execution. - - Usage: - @measure("memory_add", "add_latency") - def add(self, ...): ... - """ - def decorator(func: Callable) -> Callable: - @wraps(func) - def wrapper(*args, **kwargs): - metrics = get_metrics() - t0 = time.monotonic() - try: - result = func(*args, **kwargs) - return result - finally: - elapsed = time.monotonic() - t0 - if counter_name: - counter = getattr(metrics, counter_name, None) - if counter: - counter.inc() - if histogram_name: - histogram = getattr(metrics, histogram_name, None) - if histogram: - histogram.observe(elapsed) - return wrapper - return decorator diff --git a/dhee/core/promotion.py b/dhee/core/promotion.py deleted file mode 100644 index 99e0550..0000000 --- a/dhee/core/promotion.py +++ /dev/null @@ -1,316 +0,0 @@ -"""Dhee v3 — Promotion Pipeline: validate and promote distillation candidates. - -The promotion flow: - 1. Select pending candidates from distillation_candidates - 2. Validate (confidence threshold, conflict check) - 3. Promote transactionally into the target derived store - 4. Write lineage rows - 5. Mark candidate as promoted with promoted_id - -Design contract: - - Promotion is transactional: either fully committed or rolled back - - Every promoted object gets lineage rows linking to source events - - Idempotent: re-running on already-promoted candidates is a no-op - - Zero LLM calls — pure storage operations -""" - -from __future__ import annotations - -import json -import logging -import uuid -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional - -from dhee.core.derived_store import ( - BeliefStore, - PolicyStore, - InsightStore, - HeuristicStore, - DerivedLineageStore, -) -from dhee.core.distillation import DistillationStore - -logger = logging.getLogger(__name__) - - -def _utcnow_iso() -> str: - return datetime.now(timezone.utc).isoformat() - - -# Minimum confidence to promote a candidate -MIN_PROMOTION_CONFIDENCE = 0.3 - - -class PromotionResult: - """Result of a promotion batch.""" - - def __init__(self): - self.promoted: List[str] = [] - self.rejected: List[str] = [] - self.quarantined: List[str] = [] - self.skipped: List[str] = [] - self.errors: List[Dict[str, str]] = [] - - def to_dict(self) -> Dict[str, Any]: - return { - "promoted": len(self.promoted), - "rejected": len(self.rejected), - "quarantined": len(self.quarantined), - "skipped": len(self.skipped), - "errors": len(self.errors), - "promoted_ids": self.promoted, - } - - -class PromotionEngine: - """Validates and promotes distillation candidates into derived stores. - - Usage: - engine = PromotionEngine( - distillation=distillation_store, - beliefs=belief_store, - policies=policy_store, - insights=insight_store, - heuristics=heuristic_store, - lineage=lineage_store, - ) - result = engine.promote_pending(target_type="belief", limit=20) - """ - - def __init__( - self, - distillation: DistillationStore, - beliefs: BeliefStore, - policies: PolicyStore, - insights: InsightStore, - heuristics: HeuristicStore, - lineage: DerivedLineageStore, - *, - min_confidence: float = MIN_PROMOTION_CONFIDENCE, - ): - self.distillation = distillation - self.beliefs = beliefs - self.policies = policies - self.insights = insights - self.heuristics = heuristics - self.lineage = lineage - self.min_confidence = min_confidence - - self._promoters = { - "belief": self._promote_belief, - "policy": self._promote_policy, - "insight": self._promote_insight, - "heuristic": self._promote_heuristic, - } - - def promote_pending( - self, - target_type: Optional[str] = None, - *, - limit: int = 50, - ) -> PromotionResult: - """Promote all pending candidates of a given type. - - Args: - target_type: Filter by type (belief, policy, etc.) or None for all - limit: Max candidates to process - - Returns: - PromotionResult with counts and IDs - """ - result = PromotionResult() - candidates = self.distillation.get_pending(target_type, limit=limit) - - for candidate in candidates: - cid = candidate["candidate_id"] - ctype = candidate["target_type"] - - try: - # Validate - validation = self._validate(candidate) - - if validation == "reject": - self.distillation.set_status(cid, "rejected") - result.rejected.append(cid) - continue - - if validation == "quarantine": - self.distillation.set_status(cid, "quarantined") - result.quarantined.append(cid) - continue - - # Promote - promoter = self._promoters.get(ctype) - if not promoter: - logger.warning("No promoter for type: %s", ctype) - result.skipped.append(cid) - continue - - promoted_id = promoter(candidate) - if promoted_id: - # Write lineage - self._write_lineage( - ctype, promoted_id, candidate["source_event_ids"] - ) - # Mark candidate as promoted - self.distillation.set_status( - cid, "promoted", promoted_id=promoted_id - ) - result.promoted.append(promoted_id) - else: - result.skipped.append(cid) - - except Exception as e: - logger.exception( - "Failed to promote candidate %s: %s", cid, e - ) - result.errors.append({ - "candidate_id": cid, - "error": str(e), - }) - - return result - - def promote_single( - self, candidate_id: str - ) -> Dict[str, Any]: - """Promote a single candidate by ID.""" - candidate = self.distillation.get(candidate_id) - if not candidate: - return {"status": "error", "error": f"Candidate not found: {candidate_id}"} - - if candidate["status"] != "pending_validation": - return { - "status": "skipped", - "reason": f"Candidate status is '{candidate['status']}', not pending", - } - - ctype = candidate["target_type"] - validation = self._validate(candidate) - - if validation == "reject": - self.distillation.set_status(candidate_id, "rejected") - return {"status": "rejected", "reason": "validation_failed"} - - if validation == "quarantine": - self.distillation.set_status(candidate_id, "quarantined") - return {"status": "quarantined", "reason": "needs_review"} - - promoter = self._promoters.get(ctype) - if not promoter: - return {"status": "error", "error": f"No promoter for type: {ctype}"} - - promoted_id = promoter(candidate) - if promoted_id: - self._write_lineage(ctype, promoted_id, candidate["source_event_ids"]) - self.distillation.set_status( - candidate_id, "promoted", promoted_id=promoted_id - ) - return {"status": "promoted", "promoted_id": promoted_id} - - return {"status": "skipped", "reason": "promoter_returned_none"} - - # ------------------------------------------------------------------ - # Validation - # ------------------------------------------------------------------ - - def _validate(self, candidate: Dict[str, Any]) -> str: - """Validate a candidate. Returns: 'accept', 'reject', or 'quarantine'.""" - confidence = candidate.get("confidence", 0.0) - payload = candidate.get("payload", {}) - - # Hard reject: below minimum confidence - if confidence < self.min_confidence: - return "reject" - - # Hard reject: empty payload - if not payload: - return "reject" - - # Type-specific validation - target_type = candidate["target_type"] - - if target_type == "belief": - claim = payload.get("claim", "") - if not claim or len(claim.strip()) < 5: - return "reject" - - elif target_type == "policy": - if not payload.get("name") or not payload.get("condition"): - return "reject" - - elif target_type in ("insight", "heuristic"): - content = payload.get("content", "") - if not content or len(content.strip()) < 10: - return "reject" - - return "accept" - - # ------------------------------------------------------------------ - # Type-specific promoters - # ------------------------------------------------------------------ - - def _promote_belief(self, candidate: Dict[str, Any]) -> Optional[str]: - payload = candidate["payload"] - return self.beliefs.add( - user_id=payload["user_id"], - claim=payload["claim"], - domain=payload.get("domain", "general"), - confidence=candidate["confidence"], - source_memory_ids=candidate["source_event_ids"], - tags=payload.get("tags"), - ) - - def _promote_policy(self, candidate: Dict[str, Any]) -> Optional[str]: - payload = candidate["payload"] - return self.policies.add( - user_id=payload["user_id"], - name=payload["name"], - condition=payload.get("condition", {}), - action=payload.get("action", {}), - granularity=payload.get("granularity", "task"), - source_task_ids=candidate["source_event_ids"], - tags=payload.get("tags"), - ) - - def _promote_insight(self, candidate: Dict[str, Any]) -> Optional[str]: - payload = candidate["payload"] - return self.insights.add( - user_id=payload["user_id"], - content=payload["content"], - insight_type=payload.get("insight_type", "pattern"), - confidence=candidate["confidence"], - tags=payload.get("tags"), - ) - - def _promote_heuristic(self, candidate: Dict[str, Any]) -> Optional[str]: - payload = candidate["payload"] - return self.heuristics.add( - user_id=payload["user_id"], - content=payload["content"], - abstraction_level=payload.get("abstraction_level", "specific"), - confidence=candidate["confidence"], - tags=payload.get("tags"), - ) - - # ------------------------------------------------------------------ - # Lineage - # ------------------------------------------------------------------ - - def _write_lineage( - self, - derived_type: str, - derived_id: str, - source_event_ids: List[str], - ) -> None: - """Write lineage rows linking the promoted object to source events.""" - if not source_event_ids: - return - - # Equal weight distribution across sources - weight = 1.0 / len(source_event_ids) - self.lineage.add_batch( - derived_type, derived_id, source_event_ids, - weights=[weight] * len(source_event_ids), - ) diff --git a/dhee/core/proposition_context.py b/dhee/core/proposition_context.py deleted file mode 100644 index 7f7642f..0000000 --- a/dhee/core/proposition_context.py +++ /dev/null @@ -1,144 +0,0 @@ -"""Proposition-based context builder for memory QA. - -Instead of feeding 30K chars of raw conversation to an LLM, this module -builds ~3-8K chars of focused, structured facts with source citations -from episodic events and retrieval results. - -This improves answer accuracy for freeform questions by reducing noise -and letting even weak models answer correctly with clean input. -""" - -from __future__ import annotations - -import re -from typing import Any, Dict, List, Optional, Sequence - - -def build_proposition_context( - *, - events: Sequence[Dict[str, Any]], - results: Sequence[Dict[str, Any]], - question: str, - max_chars: int = 8000, -) -> str: - """Build LLM context from propositions + evidence snippets. - - Instead of 30K chars of raw conversation, produces ~3-8K chars of - focused, structured facts with source citations. - - Args: - events: Episodic events matched to the query. - results: Search results from the retrieval pipeline. - question: The user's question (used for relevance hints). - max_chars: Maximum output length. - - Returns: - A compact context string ready for LLM consumption. - """ - lines: List[str] = [] - remaining = max(1, int(max_chars)) - - # Section 1: Structured facts from episodic events. - if events: - lines.append("Structured Facts:") - remaining -= len(lines[-1]) + 1 - seen_keys: set = set() - for event in events: - value = str(event.get("value_text") or "").strip() - if not value: - continue - # Deduplicate by canonical key. - ckey = str(event.get("canonical_key") or "").strip().lower() - if ckey and ckey in seen_keys: - continue - if ckey: - seen_keys.add(ckey) - - session_id = str(event.get("session_id") or "").strip() - event_time = str(event.get("event_time") or "").strip() - event_type = str(event.get("event_type") or "fact").strip() - actor = str(event.get("actor_role") or event.get("actor_id") or "").strip() - - source_parts: List[str] = [] - if session_id: - source_parts.append(f"Session {session_id}") - if event_time: - # Show only date portion for readability. - date_part = event_time[:10] if len(event_time) >= 10 else event_time - source_parts.append(date_part) - source = " | ".join(source_parts) if source_parts else "unknown" - - fact_parts: List[str] = [] - if actor: - fact_parts.append(actor) - fact_parts.append(f"({event_type})") - fact_parts.append(value[:200]) - fact = " ".join(fact_parts) - - line = f"- [{source}] {fact}" - if len(line) + 1 > remaining: - break - lines.append(line) - remaining -= len(line) + 1 - - # Section 2: Evidence snippets from retrieval results. - evidence_results = list(results)[:5] - if evidence_results and remaining > 100: - lines.append("") - lines.append("Evidence Snippets:") - remaining -= 20 # header overhead - - for result in evidence_results: - evidence = str( - result.get("evidence_text") or result.get("memory") or "" - ).strip() - if not evidence: - continue - metadata = result.get("metadata") or {} - session_id = str(metadata.get("session_id") or "").strip() - session_date = str(metadata.get("session_date") or "").strip() - - header_parts: List[str] = [] - if session_id: - header_parts.append(f"Session {session_id}") - if session_date: - header_parts.append(session_date) - header = " | ".join(header_parts) if header_parts else "unknown" - - # Truncate evidence to fit budget. - budget = min(1500, remaining - 20) - if budget <= 50: - break - snippet = evidence[:budget] - - block = f"\n[{header}]\n{snippet}" - if len(block) + 1 > remaining: - break - lines.append(block) - remaining -= len(block) + 1 - - context = "\n".join(lines).strip() - return context[:max_chars] - - -def build_proposition_answer_prompt( - question: str, - prop_context: str, - question_date: str = "", -) -> str: - """Build a simplified answer prompt for proposition-based context. - - When the LLM receives clean, structured facts instead of 30K chars - of raw conversation, the prompt can be dramatically simpler. - """ - date_str = question_date or "Not specified" - return ( - "Answer this question using ONLY the facts below.\n\n" - f"Question: {question}\n" - f"Date: {date_str}\n\n" - f"Facts:\n{prop_context}\n\n" - "Answer concisely using the exact values from the facts above. " - "If the answer requires counting items, count the distinct items " - "listed in the facts. If the facts do not contain enough " - "information, say so." - ) diff --git a/dhee/core/read_model.py b/dhee/core/read_model.py deleted file mode 100644 index 9b97bc6..0000000 --- a/dhee/core/read_model.py +++ /dev/null @@ -1,327 +0,0 @@ -"""Dhee v3 — Read Model: materialized retrieval view + delta overlay. - -Writes are normalized across type-specific tables. Reads are fast via -a precomputed retrieval_view table plus a delta overlay of recent changes -not yet folded in. - -Design contract: - - retrieval_view is a real table (materialized), not a SQL VIEW - - Delta overlay covers raw events + derived objects created since last refresh - - Hot-path retrieval queries the view + delta, fuses results - - View refresh is a cold-path job (recompute_retrieval_view) - - Zero LLM calls -""" - -from __future__ import annotations - -import json -import logging -import sqlite3 -import threading -from contextlib import contextmanager -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional - -logger = logging.getLogger(__name__) - - -def _utcnow_iso() -> str: - return datetime.now(timezone.utc).isoformat() - - -# Schema for the materialized retrieval view -RETRIEVAL_VIEW_SCHEMA = """ -CREATE TABLE IF NOT EXISTS retrieval_view ( - row_id TEXT PRIMARY KEY, - source_kind TEXT NOT NULL CHECK (source_kind IN ('raw', 'distilled', 'episodic')), - source_type TEXT NOT NULL, - source_id TEXT NOT NULL, - user_id TEXT NOT NULL, - retrieval_text TEXT NOT NULL, - summary TEXT, - anchor_era TEXT, - anchor_place TEXT, - anchor_activity TEXT, - confidence REAL DEFAULT 1.0, - utility REAL DEFAULT 0.0, - status TEXT DEFAULT 'active', - created_at TEXT NOT NULL, - refreshed_at TEXT NOT NULL -); - -CREATE INDEX IF NOT EXISTS idx_rv_user_kind ON retrieval_view(user_id, source_kind); -CREATE INDEX IF NOT EXISTS idx_rv_source ON retrieval_view(source_type, source_id); -CREATE INDEX IF NOT EXISTS idx_rv_status ON retrieval_view(status) WHERE status != 'active'; -""" - - -class ReadModel: - """Materialized retrieval view with delta overlay. - - Usage: - model = ReadModel(conn, lock) - model.refresh(events, beliefs, policies, ...) # cold-path - results = model.query(user_id, limit=20) # hot-path - """ - - def __init__(self, conn: sqlite3.Connection, lock: threading.RLock): - self._conn = conn - self._lock = lock - self._ensure_schema() - self._last_refresh: Optional[str] = None - - def _ensure_schema(self) -> None: - with self._lock: - self._conn.executescript(RETRIEVAL_VIEW_SCHEMA) - self._conn.commit() - - @contextmanager - def _tx(self): - with self._lock: - try: - yield self._conn - self._conn.commit() - except Exception: - self._conn.rollback() - raise - - # ------------------------------------------------------------------ - # Cold-path: refresh the materialized view - # ------------------------------------------------------------------ - - def refresh( - self, - user_id: str, - *, - events_store: Optional[Any] = None, - beliefs_store: Optional[Any] = None, - policies_store: Optional[Any] = None, - insights_store: Optional[Any] = None, - heuristics_store: Optional[Any] = None, - anchors_store: Optional[Any] = None, - ) -> Dict[str, int]: - """Rebuild the retrieval view for a user. Cold-path operation. - - Returns counts of rows refreshed per source type. - """ - now = _utcnow_iso() - counts: Dict[str, int] = {} - - with self._tx() as conn: - # Clear existing rows for this user - conn.execute( - "DELETE FROM retrieval_view WHERE user_id = ?", - (user_id,), - ) - - # Raw events - if events_store: - from dhee.core.events import EventStatus - events = events_store.list_by_user( - user_id, status=EventStatus.ACTIVE, limit=5000 - ) - for e in events: - conn.execute( - """INSERT INTO retrieval_view - (row_id, source_kind, source_type, source_id, - user_id, retrieval_text, confidence, - status, created_at, refreshed_at) - VALUES (?, 'raw', 'event', ?, ?, ?, 1.0, 'active', ?, ?)""", - ( - f"raw:{e.event_id}", e.event_id, user_id, - e.content, e.created_at or now, now, - ), - ) - counts["raw_events"] = len(events) - - # Beliefs - if beliefs_store: - beliefs = beliefs_store.list_by_user(user_id, limit=1000) - for b in beliefs: - if b["status"] in ("invalidated",): - continue - conn.execute( - """INSERT INTO retrieval_view - (row_id, source_kind, source_type, source_id, - user_id, retrieval_text, summary, confidence, - utility, status, created_at, refreshed_at) - VALUES (?, 'distilled', 'belief', ?, ?, ?, ?, ?, 0.0, ?, ?, ?)""", - ( - f"belief:{b['belief_id']}", b["belief_id"], - user_id, b["claim"], - f"[{b['domain']}] {b['claim']}", - b["confidence"], b["status"], - b["created_at"], now, - ), - ) - counts["beliefs"] = len(beliefs) - - # Policies - if policies_store: - policies = policies_store.list_by_user(user_id, limit=500) - for p in policies: - if p["status"] in ("invalidated",): - continue - text = f"{p['name']}: {json.dumps(p['action'])}" - conn.execute( - """INSERT INTO retrieval_view - (row_id, source_kind, source_type, source_id, - user_id, retrieval_text, summary, confidence, - utility, status, created_at, refreshed_at) - VALUES (?, 'distilled', 'policy', ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - f"policy:{p['policy_id']}", p["policy_id"], - user_id, text, p["name"], - 1.0, p["utility"], p["status"], - p["created_at"], now, - ), - ) - counts["policies"] = len(policies) - - # Insights - if insights_store: - insights = insights_store.list_by_user(user_id, limit=500) - for i in insights: - if i["status"] in ("invalidated",): - continue - conn.execute( - """INSERT INTO retrieval_view - (row_id, source_kind, source_type, source_id, - user_id, retrieval_text, confidence, - utility, status, created_at, refreshed_at) - VALUES (?, 'distilled', 'insight', ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - f"insight:{i['insight_id']}", i["insight_id"], - user_id, i["content"], i["confidence"], - i["utility"], i["status"], - i["created_at"], now, - ), - ) - counts["insights"] = len(insights) - - # Heuristics - if heuristics_store: - heuristics = heuristics_store.list_by_user(user_id, limit=500) - for h in heuristics: - if h["status"] in ("invalidated",): - continue - conn.execute( - """INSERT INTO retrieval_view - (row_id, source_kind, source_type, source_id, - user_id, retrieval_text, confidence, - utility, status, created_at, refreshed_at) - VALUES (?, 'distilled', 'heuristic', ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - f"heuristic:{h['heuristic_id']}", h["heuristic_id"], - user_id, h["content"], h["confidence"], - h["utility"], h["status"], - h["created_at"], now, - ), - ) - counts["heuristics"] = len(heuristics) - - self._last_refresh = now - return counts - - # ------------------------------------------------------------------ - # Hot-path: query the view - # ------------------------------------------------------------------ - - def query( - self, - user_id: str, - *, - source_kind: Optional[str] = None, - source_type: Optional[str] = None, - status_exclude: Optional[List[str]] = None, - limit: int = 100, - ) -> List[Dict[str, Any]]: - """Query the retrieval view. Returns rows for downstream fusion.""" - query = "SELECT * FROM retrieval_view WHERE user_id = ?" - params: list = [user_id] - - if source_kind: - query += " AND source_kind = ?" - params.append(source_kind) - if source_type: - query += " AND source_type = ?" - params.append(source_type) - - excludes = status_exclude or ["invalidated"] - for s in excludes: - query += " AND status != ?" - params.append(s) - - query += " ORDER BY created_at DESC LIMIT ?" - params.append(limit) - - with self._lock: - rows = self._conn.execute(query, params).fetchall() - - return [ - { - "row_id": r["row_id"], - "source_kind": r["source_kind"], - "source_type": r["source_type"], - "source_id": r["source_id"], - "user_id": r["user_id"], - "retrieval_text": r["retrieval_text"], - "summary": r["summary"], - "anchor_era": r["anchor_era"], - "anchor_place": r["anchor_place"], - "anchor_activity": r["anchor_activity"], - "confidence": r["confidence"], - "utility": r["utility"], - "status": r["status"], - "created_at": r["created_at"], - } - for r in rows - ] - - def get_delta( - self, - user_id: str, - since_iso: str, - *, - events_store: Optional[Any] = None, - ) -> List[Dict[str, Any]]: - """Get raw events created since the last refresh. - - These haven't been folded into the retrieval_view yet. - Used by fusion to overlay recent changes on top of the materialized view. - """ - if not events_store: - return [] - - from dhee.core.events import EventStatus - recent = events_store.get_events_since( - user_id, since_iso, status=EventStatus.ACTIVE - ) - return [ - { - "row_id": f"delta:{e.event_id}", - "source_kind": "raw", - "source_type": "event", - "source_id": e.event_id, - "user_id": e.user_id, - "retrieval_text": e.content, - "summary": None, - "confidence": 1.0, - "utility": 0.0, - "status": "active", - "created_at": e.created_at, - } - for e in recent - ] - - @property - def last_refresh(self) -> Optional[str]: - return self._last_refresh - - def row_count(self, user_id: str) -> int: - with self._lock: - row = self._conn.execute( - "SELECT COUNT(*) FROM retrieval_view WHERE user_id = ?", - (user_id,), - ).fetchone() - return row[0] if row else 0 diff --git a/dhee/core/salience.py b/dhee/core/salience.py deleted file mode 100644 index 2746196..0000000 --- a/dhee/core/salience.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Salience tagging for memories. - -Computes emotional valence, arousal, and overall salience score. -High-salience memories decay slower and rank higher in search. -""" - -from __future__ import annotations - -import json -import logging -import re -from typing import Any, Dict, Optional - -logger = logging.getLogger(__name__) - -# Heuristic keyword lists for fast salience estimation -_POSITIVE_WORDS = frozenset({ - "love", "great", "excellent", "amazing", "wonderful", "happy", "success", - "perfect", "awesome", "fantastic", "brilliant", "enjoy", "solved", "fixed", - "win", "achieved", "celebrate", "milestone", "breakthrough", "promoted", -}) -_NEGATIVE_WORDS = frozenset({ - "hate", "terrible", "awful", "horrible", "fail", "failure", "crash", - "error", "bug", "broken", "angry", "frustrated", "blocked", "lost", - "dead", "killed", "disaster", "critical", "urgent", "emergency", -}) -_HIGH_AROUSAL_WORDS = frozenset({ - "urgent", "critical", "emergency", "asap", "immediately", "deadline", - "panic", "crash", "outage", "breaking", "alert", "important", "warning", - "danger", "production", "incident", "blocker", "showstopper", -}) - -_SALIENCE_PROMPT = """Rate the emotional content of this text. - -Text: {content} - -Respond in JSON: -{{"valence": 0.0, "arousal": 0.0, "reasoning": "..."}} - -valence: -1.0 (very negative) to 1.0 (very positive), 0.0 = neutral -arousal: 0.0 (calm/routine) to 1.0 (intense/urgent)""" - - -def compute_salience_heuristic(content: str) -> Dict[str, float]: - """Fast heuristic salience computation from keyword matching.""" - words = set(re.findall(r'\b\w+\b', content.lower())) - - pos_count = len(words & _POSITIVE_WORDS) - neg_count = len(words & _NEGATIVE_WORDS) - arousal_count = len(words & _HIGH_AROUSAL_WORDS) - - total_emotional = pos_count + neg_count - if total_emotional > 0: - valence = (pos_count - neg_count) / total_emotional - else: - valence = 0.0 - - arousal = min(1.0, arousal_count * 0.25) - - salience_score = min(1.0, (abs(valence) + arousal) / 2) - - return { - "sal_valence": round(valence, 3), - "sal_arousal": round(arousal, 3), - "sal_salience_score": round(salience_score, 3), - } - - -def compute_salience_llm(content: str, llm: Any) -> Dict[str, float]: - """LLM-based salience computation (slower, more accurate).""" - formatted = _SALIENCE_PROMPT.format(content=content) - - try: - response = llm.generate(formatted) - text = response if isinstance(response, str) else str(response) - start = text.find("{") - if start >= 0: - parsed, _ = json.JSONDecoder().raw_decode(text, start) - valence = max(-1.0, min(1.0, float(parsed.get("valence", 0.0)))) - arousal = max(0.0, min(1.0, float(parsed.get("arousal", 0.0)))) - salience_score = min(1.0, (abs(valence) + arousal) / 2) - return { - "sal_valence": round(valence, 3), - "sal_arousal": round(arousal, 3), - "sal_salience_score": round(salience_score, 3), - } - except Exception as e: - logger.warning("LLM salience computation failed: %s", e) - - return compute_salience_heuristic(content) - - -def compute_salience( - content: str, - llm: Optional[Any] = None, - use_llm: bool = False, -) -> Dict[str, float]: - """Compute salience for a memory's content. - - Returns dict with sal_valence, sal_arousal, sal_salience_score. - """ - if use_llm and llm: - return compute_salience_llm(content, llm) - return compute_salience_heuristic(content) - - -def salience_decay_modifier(salience_score: float) -> float: - """Compute decay rate modifier based on salience. - - High-salience memories decay slower. - Returns a multiplier for the decay lambda (< 1.0 means slower decay). - """ - return 1.0 - (salience_score * 0.5) diff --git a/dhee/core/storage.py b/dhee/core/storage.py deleted file mode 100644 index 08e1e08..0000000 --- a/dhee/core/storage.py +++ /dev/null @@ -1,428 +0,0 @@ -"""Dhee v3 — Schema DDL for the event-sourced cognition substrate. - -All tables live in a single SQLite database (v3.db). Schema is organized as: - -Layer 1 — Raw truth: - raw_memory_events Immutable source-of-truth memory events - -Layer 2 — Derived cognition (type-specific tables): - beliefs Confidence-tracked claims with Bayesian updates - policies Condition→action rules with utility tracking - anchors Hierarchical context (era/place/time/activity) - insights Synthesized causal hypotheses - heuristics Transferable reasoning patterns - -Layer 3 — Infrastructure: - derived_lineage Links derived objects → source raw events - maintenance_jobs Cold-path job registry - locks SQLite lease manager for job concurrency - cognitive_conflicts Contradiction/disagreement queue - anchor_candidates Per-field extraction candidates (Phase 2) - distillation_candidates Consolidation promotion candidates (Phase 4) - -All tables use TEXT PRIMARY KEY (UUIDs), ISO timestamps, and JSON for -nested structures. Follows existing Dhee conventions from dhee/db/sqlite.py. -""" - -# --------------------------------------------------------------------------- -# Layer 1: Raw truth -# --------------------------------------------------------------------------- - -RAW_MEMORY_EVENTS = """ -CREATE TABLE IF NOT EXISTS raw_memory_events ( - event_id TEXT PRIMARY KEY, - user_id TEXT NOT NULL, - session_id TEXT, - created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')), - content TEXT NOT NULL, - content_hash TEXT NOT NULL, - source TEXT DEFAULT 'user', - status TEXT NOT NULL DEFAULT 'active' - CHECK (status IN ('active', 'corrected', 'deleted')), - supersedes_event_id TEXT REFERENCES raw_memory_events(event_id), - metadata_json TEXT DEFAULT '{}' -); - -CREATE INDEX IF NOT EXISTS idx_rme_user_status - ON raw_memory_events(user_id, status); -CREATE INDEX IF NOT EXISTS idx_rme_content_hash - ON raw_memory_events(content_hash, user_id); -CREATE INDEX IF NOT EXISTS idx_rme_created - ON raw_memory_events(created_at DESC); -CREATE INDEX IF NOT EXISTS idx_rme_supersedes - ON raw_memory_events(supersedes_event_id) - WHERE supersedes_event_id IS NOT NULL; -""" - -# --------------------------------------------------------------------------- -# Layer 2: Derived cognition — type-specific tables -# --------------------------------------------------------------------------- - -BELIEFS = """ -CREATE TABLE IF NOT EXISTS beliefs ( - belief_id TEXT PRIMARY KEY, - user_id TEXT NOT NULL, - claim TEXT NOT NULL, - domain TEXT DEFAULT 'general', - status TEXT NOT NULL DEFAULT 'proposed' - CHECK (status IN ( - 'proposed', 'held', 'challenged', - 'revised', 'retracted', - 'stale', 'suspect', 'invalidated' - )), - confidence REAL NOT NULL DEFAULT 0.5, - evidence_json TEXT DEFAULT '[]', - revisions_json TEXT DEFAULT '[]', - contradicts_ids TEXT DEFAULT '[]', - source_memory_ids TEXT DEFAULT '[]', - source_episode_ids TEXT DEFAULT '[]', - derivation_version INTEGER NOT NULL DEFAULT 1, - lineage_fingerprint TEXT, - created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')), - updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')), - tags_json TEXT DEFAULT '[]' -); - -CREATE INDEX IF NOT EXISTS idx_beliefs_user_domain_status - ON beliefs(user_id, domain, status); -CREATE INDEX IF NOT EXISTS idx_beliefs_user_confidence - ON beliefs(user_id, confidence DESC); -CREATE INDEX IF NOT EXISTS idx_beliefs_status - ON beliefs(status) - WHERE status IN ('stale', 'suspect', 'invalidated'); -""" - -POLICIES = """ -CREATE TABLE IF NOT EXISTS policies ( - policy_id TEXT PRIMARY KEY, - user_id TEXT NOT NULL, - name TEXT NOT NULL, - granularity TEXT NOT NULL DEFAULT 'task' - CHECK (granularity IN ('task', 'step')), - status TEXT NOT NULL DEFAULT 'proposed' - CHECK (status IN ( - 'proposed', 'active', 'validated', 'deprecated', - 'stale', 'suspect', 'invalidated' - )), - condition_json TEXT NOT NULL DEFAULT '{}', - action_json TEXT NOT NULL DEFAULT '{}', - apply_count INTEGER NOT NULL DEFAULT 0, - success_count INTEGER NOT NULL DEFAULT 0, - failure_count INTEGER NOT NULL DEFAULT 0, - utility REAL NOT NULL DEFAULT 0.0, - last_delta REAL NOT NULL DEFAULT 0.0, - cumulative_delta REAL NOT NULL DEFAULT 0.0, - source_task_ids TEXT DEFAULT '[]', - source_episode_ids TEXT DEFAULT '[]', - derivation_version INTEGER NOT NULL DEFAULT 1, - lineage_fingerprint TEXT, - created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')), - updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')), - tags_json TEXT DEFAULT '[]' -); - -CREATE INDEX IF NOT EXISTS idx_policies_user_gran_status - ON policies(user_id, granularity, status); -CREATE INDEX IF NOT EXISTS idx_policies_user_utility - ON policies(user_id, utility DESC); -CREATE INDEX IF NOT EXISTS idx_policies_status - ON policies(status) - WHERE status IN ('stale', 'suspect', 'invalidated'); -""" - -ANCHORS = """ -CREATE TABLE IF NOT EXISTS anchors ( - anchor_id TEXT PRIMARY KEY, - user_id TEXT NOT NULL, - memory_event_id TEXT REFERENCES raw_memory_events(event_id), - era TEXT, - place TEXT, - place_type TEXT, - place_detail TEXT, - time_absolute TEXT, - time_markers_json TEXT DEFAULT '[]', - time_range_start TEXT, - time_range_end TEXT, - time_derivation TEXT, - activity TEXT, - session_id TEXT, - session_position INTEGER DEFAULT 0, - derivation_version INTEGER NOT NULL DEFAULT 1, - lineage_fingerprint TEXT, - created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')) -); - -CREATE INDEX IF NOT EXISTS idx_anchors_user_era_place - ON anchors(user_id, era, place, activity); -CREATE INDEX IF NOT EXISTS idx_anchors_user_time - ON anchors(user_id, time_range_start, time_range_end); -CREATE INDEX IF NOT EXISTS idx_anchors_event - ON anchors(memory_event_id) - WHERE memory_event_id IS NOT NULL; -""" - -INSIGHTS = """ -CREATE TABLE IF NOT EXISTS insights ( - insight_id TEXT PRIMARY KEY, - user_id TEXT NOT NULL, - content TEXT NOT NULL, - insight_type TEXT NOT NULL DEFAULT 'pattern' - CHECK (insight_type IN ( - 'causal', 'warning', 'strategy', 'pattern' - )), - source_task_types_json TEXT DEFAULT '[]', - confidence REAL NOT NULL DEFAULT 0.5, - validation_count INTEGER NOT NULL DEFAULT 0, - invalidation_count INTEGER NOT NULL DEFAULT 0, - utility REAL NOT NULL DEFAULT 0.0, - apply_count INTEGER NOT NULL DEFAULT 0, - derivation_version INTEGER NOT NULL DEFAULT 1, - lineage_fingerprint TEXT, - created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')), - last_validated TEXT, - tags_json TEXT DEFAULT '[]', - status TEXT NOT NULL DEFAULT 'active' - CHECK (status IN ( - 'active', 'stale', 'suspect', 'invalidated' - )) -); - -CREATE INDEX IF NOT EXISTS idx_insights_user_type_conf - ON insights(user_id, insight_type, confidence DESC); -CREATE INDEX IF NOT EXISTS idx_insights_user_utility - ON insights(user_id, utility DESC); -CREATE INDEX IF NOT EXISTS idx_insights_status - ON insights(status) - WHERE status IN ('stale', 'suspect', 'invalidated'); -""" - -HEURISTICS = """ -CREATE TABLE IF NOT EXISTS heuristics ( - heuristic_id TEXT PRIMARY KEY, - user_id TEXT NOT NULL, - content TEXT NOT NULL, - abstraction_level TEXT NOT NULL DEFAULT 'specific' - CHECK (abstraction_level IN ( - 'specific', 'domain', 'universal' - )), - source_task_types_json TEXT DEFAULT '[]', - confidence REAL NOT NULL DEFAULT 0.5, - validation_count INTEGER NOT NULL DEFAULT 0, - invalidation_count INTEGER NOT NULL DEFAULT 0, - utility REAL NOT NULL DEFAULT 0.0, - last_delta REAL NOT NULL DEFAULT 0.0, - apply_count INTEGER NOT NULL DEFAULT 0, - derivation_version INTEGER NOT NULL DEFAULT 1, - lineage_fingerprint TEXT, - created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')), - tags_json TEXT DEFAULT '[]', - status TEXT NOT NULL DEFAULT 'active' - CHECK (status IN ( - 'active', 'stale', 'suspect', 'invalidated' - )) -); - -CREATE INDEX IF NOT EXISTS idx_heuristics_user_level_conf - ON heuristics(user_id, abstraction_level, confidence DESC); -CREATE INDEX IF NOT EXISTS idx_heuristics_user_utility - ON heuristics(user_id, utility DESC); -CREATE INDEX IF NOT EXISTS idx_heuristics_status - ON heuristics(status) - WHERE status IN ('stale', 'suspect', 'invalidated'); -""" - -# --------------------------------------------------------------------------- -# Layer 3: Infrastructure -# --------------------------------------------------------------------------- - -DERIVED_LINEAGE = """ -CREATE TABLE IF NOT EXISTS derived_lineage ( - lineage_id TEXT PRIMARY KEY, - derived_type TEXT NOT NULL - CHECK (derived_type IN ( - 'belief', 'policy', 'anchor', - 'insight', 'heuristic' - )), - derived_id TEXT NOT NULL, - source_event_id TEXT NOT NULL REFERENCES raw_memory_events(event_id), - contribution_weight REAL NOT NULL DEFAULT 1.0, - created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')) -); - -CREATE INDEX IF NOT EXISTS idx_lineage_derived - ON derived_lineage(derived_type, derived_id); -CREATE INDEX IF NOT EXISTS idx_lineage_source - ON derived_lineage(source_event_id); -""" - -MAINTENANCE_JOBS = """ -CREATE TABLE IF NOT EXISTS maintenance_jobs ( - job_id TEXT PRIMARY KEY, - job_name TEXT NOT NULL, - status TEXT NOT NULL DEFAULT 'pending' - CHECK (status IN ( - 'pending', 'running', 'completed', - 'failed', 'cancelled' - )), - payload_json TEXT DEFAULT '{}', - result_json TEXT, - error_message TEXT, - created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')), - started_at TEXT, - completed_at TEXT, - retry_count INTEGER NOT NULL DEFAULT 0, - max_retries INTEGER NOT NULL DEFAULT 3, - idempotency_key TEXT -); - -CREATE INDEX IF NOT EXISTS idx_jobs_status_name - ON maintenance_jobs(status, job_name); -CREATE UNIQUE INDEX IF NOT EXISTS idx_jobs_idempotency - ON maintenance_jobs(idempotency_key) - WHERE idempotency_key IS NOT NULL; -""" - -LOCKS = """ -CREATE TABLE IF NOT EXISTS locks ( - lock_id TEXT PRIMARY KEY, - owner_id TEXT NOT NULL, - lease_expires_at TEXT NOT NULL, - updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')) -); -""" - -COGNITIVE_CONFLICTS = """ -CREATE TABLE IF NOT EXISTS cognitive_conflicts ( - conflict_id TEXT PRIMARY KEY, - conflict_type TEXT NOT NULL - CHECK (conflict_type IN ( - 'belief_contradiction', - 'anchor_disagreement', - 'distillation_conflict', - 'invalidation_dispute' - )), - side_a_type TEXT NOT NULL, - side_a_id TEXT NOT NULL, - side_b_type TEXT NOT NULL, - side_b_id TEXT NOT NULL, - detected_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')), - resolution_status TEXT NOT NULL DEFAULT 'open' - CHECK (resolution_status IN ( - 'open', 'auto_resolved', - 'user_resolved', 'deferred' - )), - resolution_json TEXT, - auto_resolution_confidence REAL -); - -CREATE INDEX IF NOT EXISTS idx_conflicts_status - ON cognitive_conflicts(resolution_status) - WHERE resolution_status = 'open'; -CREATE INDEX IF NOT EXISTS idx_conflicts_sides - ON cognitive_conflicts(side_a_type, side_a_id); -""" - -ANCHOR_CANDIDATES = """ -CREATE TABLE IF NOT EXISTS anchor_candidates ( - candidate_id TEXT PRIMARY KEY, - anchor_id TEXT NOT NULL REFERENCES anchors(anchor_id), - field_name TEXT NOT NULL, - field_value TEXT NOT NULL, - confidence REAL NOT NULL DEFAULT 0.5, - extractor_source TEXT NOT NULL DEFAULT 'default', - source_event_ids TEXT DEFAULT '[]', - derivation_version INTEGER NOT NULL DEFAULT 1, - status TEXT NOT NULL DEFAULT 'pending' - CHECK (status IN ( - 'pending', 'accepted', 'rejected', 'superseded' - )), - created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')) -); - -CREATE INDEX IF NOT EXISTS idx_anchor_cand_anchor_field - ON anchor_candidates(anchor_id, field_name, status); -""" - -DISTILLATION_CANDIDATES = """ -CREATE TABLE IF NOT EXISTS distillation_candidates ( - candidate_id TEXT PRIMARY KEY, - source_event_ids TEXT NOT NULL DEFAULT '[]', - derivation_version INTEGER NOT NULL DEFAULT 1, - confidence REAL NOT NULL DEFAULT 0.5, - canonical_key TEXT, - idempotency_key TEXT, - target_type TEXT NOT NULL - CHECK (target_type IN ( - 'belief', 'policy', 'insight', 'heuristic' - )), - payload_json TEXT NOT NULL DEFAULT '{}', - status TEXT NOT NULL DEFAULT 'pending_validation' - CHECK (status IN ( - 'pending_validation', 'promoted', - 'rejected', 'quarantined' - )), - promoted_id TEXT, - created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')) -); - -CREATE UNIQUE INDEX IF NOT EXISTS idx_distill_idempotency - ON distillation_candidates(idempotency_key) - WHERE idempotency_key IS NOT NULL AND status != 'rejected'; -CREATE INDEX IF NOT EXISTS idx_distill_status - ON distillation_candidates(status, target_type); -""" - -SCHEMA_VERSION = """ -CREATE TABLE IF NOT EXISTS v3_schema_version ( - version INTEGER PRIMARY KEY, - applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%f', 'now')), - description TEXT -); -""" - -# --------------------------------------------------------------------------- -# Ordered list of all DDL statements for initialization -# --------------------------------------------------------------------------- - -ALL_SCHEMAS = [ - # Version tracking - SCHEMA_VERSION, - # Layer 1 - RAW_MEMORY_EVENTS, - # Layer 2 - BELIEFS, - POLICIES, - ANCHORS, - INSIGHTS, - HEURISTICS, - # Layer 3 - DERIVED_LINEAGE, - MAINTENANCE_JOBS, - LOCKS, - COGNITIVE_CONFLICTS, - ANCHOR_CANDIDATES, - DISTILLATION_CANDIDATES, -] - -CURRENT_VERSION = 1 - - -def initialize_schema(conn: "sqlite3.Connection") -> None: - """Create all v3 tables if they don't exist. - - Idempotent — safe to call on every startup. - """ - for ddl in ALL_SCHEMAS: - conn.executescript(ddl) - - # Record schema version (idempotent) - existing = conn.execute( - "SELECT 1 FROM v3_schema_version WHERE version = ?", - (CURRENT_VERSION,), - ).fetchone() - if not existing: - conn.execute( - "INSERT INTO v3_schema_version (version, description) VALUES (?, ?)", - (CURRENT_VERSION, "Initial v3 event-sourced substrate"), - ) - conn.commit() diff --git a/dhee/core/v3_health.py b/dhee/core/v3_health.py deleted file mode 100644 index 7f1379a..0000000 --- a/dhee/core/v3_health.py +++ /dev/null @@ -1,193 +0,0 @@ -"""Dhee v3 — Observability: expanded cognition_health with v3 metrics. - -Adds v3-specific metrics to the existing cognition_health() output: -- Stale/suspect/invalidated derived counts per type -- Pending conflict backlog -- Lease contention (active locks) -- Candidate promotion stats (promoted/rejected/quarantined) -- Retrieval view freshness -- Maintenance job health - -All metrics are pure SQL COUNT/aggregation queries. Zero LLM calls. -""" - -from __future__ import annotations - -import logging -from datetime import datetime, timezone -from typing import Any, Dict, Optional - -logger = logging.getLogger(__name__) - - -def _utcnow_iso() -> str: - return datetime.now(timezone.utc).isoformat() - - -def v3_health( - conn: "sqlite3.Connection", - lock: "threading.RLock", - *, - user_id: Optional[str] = None, -) -> Dict[str, Any]: - """Compute v3 substrate health metrics. - - Returns a dict suitable for merging into existing cognition_health() output. - """ - health: Dict[str, Any] = {} - - with lock: - # --- Raw event counts --- - if user_id: - row = conn.execute( - "SELECT COUNT(*) FROM raw_memory_events WHERE user_id = ? AND status = 'active'", - (user_id,), - ).fetchone() - health["raw_events_active"] = row[0] if row else 0 - - row = conn.execute( - "SELECT COUNT(*) FROM raw_memory_events WHERE user_id = ? AND status = 'corrected'", - (user_id,), - ).fetchone() - health["raw_events_corrected"] = row[0] if row else 0 - else: - row = conn.execute( - "SELECT COUNT(*) FROM raw_memory_events WHERE status = 'active'" - ).fetchone() - health["raw_events_active"] = row[0] if row else 0 - - # --- Derived object counts by invalidation status --- - derived_tables = { - "beliefs": "belief_id", - "policies": "policy_id", - "insights": "insight_id", - "heuristics": "heuristic_id", - } - invalidation_statuses = ("stale", "suspect", "invalidated") - derived_health: Dict[str, Dict[str, int]] = {} - - for table, _pk in derived_tables.items(): - counts: Dict[str, int] = {} - for status in invalidation_statuses: - try: - if user_id: - row = conn.execute( - f"SELECT COUNT(*) FROM {table} WHERE status = ? AND user_id = ?", - (status, user_id), - ).fetchone() - else: - row = conn.execute( - f"SELECT COUNT(*) FROM {table} WHERE status = ?", - (status,), - ).fetchone() - counts[status] = row[0] if row else 0 - except Exception: - counts[status] = -1 # table might not exist yet - derived_health[table] = counts - - health["derived_invalidation"] = derived_health - - # --- Conflict backlog --- - try: - row = conn.execute( - "SELECT COUNT(*) FROM cognitive_conflicts WHERE resolution_status = 'open'" - ).fetchone() - health["open_conflicts"] = row[0] if row else 0 - except Exception: - health["open_conflicts"] = -1 - - # --- Lease contention --- - try: - now_iso = _utcnow_iso() - row = conn.execute( - "SELECT COUNT(*) FROM locks WHERE lease_expires_at > ?", - (now_iso,), - ).fetchone() - health["active_leases"] = row[0] if row else 0 - except Exception: - health["active_leases"] = -1 - - # --- Candidate promotion stats --- - try: - promo_stats: Dict[str, int] = {} - for status in ("pending_validation", "promoted", "rejected", "quarantined"): - row = conn.execute( - "SELECT COUNT(*) FROM distillation_candidates WHERE status = ?", - (status,), - ).fetchone() - promo_stats[status] = row[0] if row else 0 - health["distillation_candidates"] = promo_stats - except Exception: - health["distillation_candidates"] = {} - - # --- Maintenance job health --- - try: - job_stats: Dict[str, int] = {} - for status in ("pending", "running", "completed", "failed"): - row = conn.execute( - "SELECT COUNT(*) FROM maintenance_jobs WHERE status = ?", - (status,), - ).fetchone() - job_stats[status] = row[0] if row else 0 - health["maintenance_jobs"] = job_stats - except Exception: - health["maintenance_jobs"] = {} - - # --- Retrieval view freshness --- - try: - row = conn.execute( - "SELECT MAX(refreshed_at) FROM retrieval_view" - ).fetchone() - health["retrieval_view_last_refresh"] = row[0] if row and row[0] else None - - row = conn.execute( - "SELECT COUNT(*) FROM retrieval_view" - ).fetchone() - health["retrieval_view_rows"] = row[0] if row else 0 - except Exception: - health["retrieval_view_last_refresh"] = None - health["retrieval_view_rows"] = 0 - - # --- Lineage coverage --- - try: - row = conn.execute( - "SELECT COUNT(DISTINCT derived_type || ':' || derived_id) FROM derived_lineage" - ).fetchone() - health["lineage_derived_objects"] = row[0] if row else 0 - - row = conn.execute( - "SELECT COUNT(DISTINCT source_event_id) FROM derived_lineage" - ).fetchone() - health["lineage_source_events"] = row[0] if row else 0 - except Exception: - health["lineage_derived_objects"] = 0 - health["lineage_source_events"] = 0 - - # --- Warnings --- - warnings: list = [] - - di = health.get("derived_invalidation", {}) - total_stale = sum( - counts.get("stale", 0) for counts in di.values() - if isinstance(counts, dict) - ) - total_suspect = sum( - counts.get("suspect", 0) for counts in di.values() - if isinstance(counts, dict) - ) - if total_stale > 10: - warnings.append(f"{total_stale} stale derived objects awaiting repair") - if total_suspect > 5: - warnings.append(f"{total_suspect} suspect derived objects need verification") - - oc = health.get("open_conflicts", 0) - if isinstance(oc, int) and oc > 5: - warnings.append(f"{oc} unresolved cognitive conflicts") - - jobs = health.get("maintenance_jobs", {}) - if isinstance(jobs, dict) and jobs.get("failed", 0) > 3: - warnings.append(f"{jobs['failed']} failed maintenance jobs") - - health["v3_warnings"] = warnings - - return health diff --git a/dhee/core/v3_migration.py b/dhee/core/v3_migration.py deleted file mode 100644 index 427770e..0000000 --- a/dhee/core/v3_migration.py +++ /dev/null @@ -1,238 +0,0 @@ -"""Dhee v3 — Migration Bridge: dual-write from v2 to v3. - -Phase 10 migration strategy: -1. Add raw event store without changing external API -2. Dual-write: old path + new raw event path -3. Backfill old engrams into raw + derived form - -This module provides the dual-write bridge and backfill utilities. -The external API (remember/recall/context/checkpoint) stays stable. - -Design contract: - - Old path continues to work — no breakage - - New path writes to v3 raw events in parallel - - Feature flag controls whether recall reads from v3 - - Backfill is idempotent and resumable - - Zero LLM calls -""" - -from __future__ import annotations - -import hashlib -import json -import logging -import os -import uuid -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional - -logger = logging.getLogger(__name__) - - -def _utcnow_iso() -> str: - return datetime.now(timezone.utc).isoformat() - - -# Feature flags (env vars) -def _flag(name: str, default: bool = False) -> bool: - val = os.environ.get(name, "").lower() - if val in ("1", "true", "yes"): - return True - if val in ("0", "false", "no"): - return False - return default - - -class V3MigrationBridge: - """Dual-write bridge between v2 (UniversalEngram) and v3 (event-sourced). - - Usage: - bridge = V3MigrationBridge(v3_store) - # In remember(): - bridge.on_remember(content, user_id, memory_id) - # In recall(): - if bridge.should_use_v3_read(): - results = bridge.recall_from_v3(query, user_id) - - Feature flags: - DHEE_V3_WRITE=1 → dual-write to v3 raw events (default: on) - DHEE_V3_READ=1 → read from v3 retrieval view (default: off) - """ - - def __init__( - self, - v3_store: Optional["CognitionStore"] = None, - ): - self._store = v3_store - self._write_enabled = _flag("DHEE_V3_WRITE", default=True) - self._read_enabled = _flag("DHEE_V3_READ", default=False) - - @property - def write_enabled(self) -> bool: - return self._write_enabled and self._store is not None - - @property - def read_enabled(self) -> bool: - return self._read_enabled and self._store is not None - - # ------------------------------------------------------------------ - # Dual-write hooks - # ------------------------------------------------------------------ - - def on_remember( - self, - content: str, - user_id: str, - *, - session_id: Optional[str] = None, - source: str = "user", - metadata: Optional[Dict[str, Any]] = None, - v2_memory_id: Optional[str] = None, - ) -> Optional[str]: - """Dual-write: store raw event alongside v2 memory. - - Returns the v3 event_id, or None if v3 write is disabled. - """ - if not self.write_enabled: - return None - - try: - meta = metadata or {} - if v2_memory_id: - meta["v2_memory_id"] = v2_memory_id - - event = self._store.events.add( - content=content, - user_id=user_id, - session_id=session_id, - source=source, - metadata=meta, - ) - return event.event_id - except Exception as e: - logger.warning("v3 dual-write failed (non-fatal): %s", e) - return None - - def on_correction( - self, - original_content: str, - new_content: str, - user_id: str, - ) -> Optional[str]: - """Handle a memory correction in v3. - - Finds the original event by content hash and creates a correction. - """ - if not self.write_enabled: - return None - - try: - content_hash = hashlib.sha256( - original_content.encode("utf-8") - ).hexdigest() - original = self._store.events.get_by_hash(content_hash, user_id) - if not original: - # Original not in v3 yet — just add the new content - event = self._store.events.add( - content=new_content, user_id=user_id, - source="user_correction", - ) - return event.event_id - - correction = self._store.events.correct( - original.event_id, new_content, - source="user_correction", - ) - return correction.event_id - except Exception as e: - logger.warning("v3 correction failed (non-fatal): %s", e) - return None - - # ------------------------------------------------------------------ - # Backfill: v2 engrams → v3 raw events - # ------------------------------------------------------------------ - - def backfill_from_v2( - self, - memories: List[Dict[str, Any]], - *, - user_id: str = "default", - batch_size: int = 100, - ) -> Dict[str, int]: - """Backfill v2 memories into v3 raw events. - - Idempotent: content-hash dedup prevents duplicates. - - Args: - memories: List of v2 memory dicts with at least 'memory' key - user_id: Default user ID if not in memory dict - batch_size: Process in batches for progress reporting - - Returns: - {"total": N, "created": M, "skipped_dedup": K, "errors": E} - """ - if not self._store: - return {"total": 0, "error": "v3 store not initialized"} - - stats = {"total": len(memories), "created": 0, "skipped_dedup": 0, "errors": 0} - - for i, mem in enumerate(memories): - content = mem.get("memory", mem.get("content", "")) - if not content: - stats["errors"] += 1 - continue - - uid = mem.get("user_id", user_id) - created_at = mem.get("created_at") - - try: - content_hash = hashlib.sha256(content.encode("utf-8")).hexdigest() - existing = self._store.events.get_by_hash(content_hash, uid) - if existing: - stats["skipped_dedup"] += 1 - continue - - meta = {} - v2_id = mem.get("id", mem.get("memory_id")) - if v2_id: - meta["v2_memory_id"] = v2_id - if mem.get("layer"): - meta["v2_layer"] = mem["layer"] - if mem.get("strength"): - meta["v2_strength"] = mem["strength"] - - self._store.events.add( - content=content, - user_id=uid, - source="v2_backfill", - metadata=meta, - ) - stats["created"] += 1 - - except Exception as e: - logger.warning("Backfill error for memory %d: %s", i, e) - stats["errors"] += 1 - - return stats - - # ------------------------------------------------------------------ - # v3 read path (behind feature flag) - # ------------------------------------------------------------------ - - def should_use_v3_read(self) -> bool: - return self.read_enabled - - def get_v3_stats(self) -> Dict[str, Any]: - """Get basic stats about v3 state (for monitoring).""" - if not self._store: - return {"v3_available": False} - - try: - return { - "v3_available": True, - "write_enabled": self._write_enabled, - "read_enabled": self._read_enabled, - "event_count": self._store.events.count("default"), - } - except Exception as e: - return {"v3_available": True, "error": str(e)} diff --git a/dhee/decay/__init__.py b/dhee/decay/__init__.py deleted file mode 100644 index 0b7b506..0000000 --- a/dhee/decay/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Decay helpers for Engram v2.""" - -from dhee.decay.refcounts import RefCountManager - -__all__ = ["RefCountManager"] diff --git a/dhee/edge/__init__.py b/dhee/edge/__init__.py deleted file mode 100644 index 2e4fc3b..0000000 --- a/dhee/edge/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Dhee Edge — minimal-footprint cognition for hardware/humanoid deployment.""" - -from dhee.edge.edge_plugin import DheeEdge -from dhee.edge.edge_trainer import EdgeTrainer - -__all__ = ["DheeEdge", "EdgeTrainer"] diff --git a/dhee/edge/edge_plugin.py b/dhee/edge/edge_plugin.py deleted file mode 100644 index 2777aa2..0000000 --- a/dhee/edge/edge_plugin.py +++ /dev/null @@ -1,322 +0,0 @@ -"""DheeEdge — minimal cognition plugin for edge/hardware deployment. - -Designed for humanoid robots, IoT devices, and AI hardware products. -All computation runs locally — no cloud API calls, no internet required. - -Constraints: - - LLM: DheeModel (GGUF Q4, ~1.5GB) or mock fallback - - Embedder: ONNX MiniLM (22MB) or hash-based fallback - - Vector store: sqlite_vec (local file) - - RAM: <500MB working set - - No external API calls ever - -Adds embodiment hooks for hardware integration: - - on_sensor_input() — process sensor data into episodic memory - - on_action_result() — record action outcomes for environment learning - - predict_environment() — predict next state from memory patterns - -Usage: - from dhee.edge import DheeEdge - - d = DheeEdge(data_dir="/data/dhee") - d.remember("User prefers quiet mode after 10pm") - d.on_sensor_input("microphone", {"volume_db": 85, "duration": 3.0}) - d.on_action_result("reduce_volume", success=True, env_state={"volume_db": 40}) -""" - -from __future__ import annotations - -import json -import logging -import os -import time -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -from dhee.adapters.base import DheePlugin - -logger = logging.getLogger(__name__) - - -class DheeEdge(DheePlugin): - """Minimal cognition plugin for edge deployment. - - Forces all-offline providers. No API calls, no internet. - Extends DheePlugin with embodiment hooks for hardware. - - Args: - data_dir: Storage directory (required for edge — no temp dirs). - model_path: Path to GGUF model file for local LLM inference. - user_id: Default user ID. - """ - - def __init__( - self, - data_dir: Union[str, Path], - model_path: Optional[str] = None, - user_id: str = "default", - ): - # Force offline — never make API calls. - # Try persistent storage first; fall back to in-memory if - # sqlite_vec extension isn't available on this platform. - try: - super().__init__( - data_dir=data_dir, - provider="mock", - user_id=user_id, - in_memory=False, - offline=True, - ) - except (AttributeError, OSError) as e: - logger.debug("Persistent storage unavailable (%s), using in-memory", e) - super().__init__( - data_dir=data_dir, - provider="mock", - user_id=user_id, - in_memory=True, - offline=True, - ) - - # Embodiment state - self._sensor_history: List[Dict[str, Any]] = [] - self._action_history: List[Dict[str, Any]] = [] - self._environment_model: Dict[str, Any] = {} - - # Try to upgrade to local GGUF model - if model_path: - self._try_load_local_model(model_path) - - def _try_load_local_model(self, model_path: str) -> None: - """Attempt to load a local GGUF model for on-device LLM inference.""" - if not os.path.exists(model_path): - logger.debug("GGUF model not found at %s, using mock LLM", model_path) - return - try: - from dhee.llms.dhee import DheeLLM - self._engram._memory.llm = DheeLLM( - config={"model_path": model_path, "backend": "gguf"} - ) - logger.info("Loaded local GGUF model: %s", model_path) - except ImportError: - logger.debug("llama-cpp-python not available, using mock LLM") - except Exception as e: - logger.debug("Failed to load GGUF model: %s", e) - - # ------------------------------------------------------------------ - # Embodiment hooks (from Self-evolving Embodied AI, arXiv:2602.04411) - # ------------------------------------------------------------------ - - def on_sensor_input( - self, - sensor_type: str, - data: Dict[str, Any], - user_id: Optional[str] = None, - ) -> Dict[str, Any]: - """Process sensor data into episodic memory. - - Converts raw sensor readings into natural language memories that - can be recalled later. Tracks sensor patterns for environment - prediction. - - Args: - sensor_type: Type of sensor (e.g., "microphone", "camera", "imu") - data: Sensor data dict with readings and metadata - user_id: Override default user_id - - Returns: - {"stored": bool, "id": str, "description": str} - """ - uid = user_id or self._user_id - timestamp = data.get("timestamp", time.time()) - - # Build natural language description from sensor data - description = self._describe_sensor_data(sensor_type, data) - - # Store as memory - result = self.remember( - content=description, - user_id=uid, - metadata={ - "source": "sensor", - "sensor_type": sensor_type, - "timestamp": timestamp, - "raw_data": data, - }, - ) - - # Track in sensor history (bounded) - record = { - "sensor_type": sensor_type, - "data": data, - "timestamp": timestamp, - "description": description, - } - self._sensor_history.append(record) - if len(self._sensor_history) > 500: - self._sensor_history = self._sensor_history[-500:] - - # Update environment model - self._update_environment_model(sensor_type, data) - - result["description"] = description - return result - - def on_action_result( - self, - action: str, - success: bool, - env_state: Optional[Dict[str, Any]] = None, - user_id: Optional[str] = None, - ) -> Dict[str, Any]: - """Record action outcomes for environment self-prediction. - - Builds a causal model: action + context → outcome. Over time, - the system learns which actions work in which states. - - Args: - action: What the agent did (e.g., "reduce_volume", "move_forward") - success: Whether the action achieved its goal - env_state: Environment state after action - user_id: Override default user_id - """ - uid = user_id or self._user_id - - # Store as memory with outcome - outcome_word = "succeeded" if success else "failed" - content = f"Action '{action}' {outcome_word}" - if env_state: - state_summary = ", ".join(f"{k}={v}" for k, v in list(env_state.items())[:5]) - content += f". Environment state: {state_summary}" - - result = self.remember( - content=content, - user_id=uid, - metadata={ - "source": "action_result", - "action": action, - "success": success, - "env_state": env_state, - }, - ) - - # Track action history - record = { - "action": action, - "success": success, - "env_state": env_state, - "timestamp": time.time(), - } - self._action_history.append(record) - if len(self._action_history) > 500: - self._action_history = self._action_history[-500:] - - # Record outcome for performance tracking - task_type = f"action_{action}" - self._buddhi.record_outcome( - user_id=uid, - task_type=task_type, - score=1.0 if success else 0.0, - ) - - return result - - def predict_environment( - self, - current_state: Dict[str, Any], - proposed_action: Optional[str] = None, - ) -> Dict[str, Any]: - """Predict next environment state from memory patterns. - - Uses action history to estimate what will happen if a given - action is taken in the current state. - - Args: - current_state: Current environment state dict - proposed_action: Action being considered (optional) - - Returns: - {"prediction": str, "confidence": float, "similar_outcomes": list} - """ - # Build query from current state + proposed action - state_desc = ", ".join(f"{k}={v}" for k, v in list(current_state.items())[:5]) - query = f"environment state: {state_desc}" - if proposed_action: - query += f", action: {proposed_action}" - - # Search for similar past situations - similar = self.recall(query=query, limit=5) - - # Compute confidence from action history - confidence = 0.0 - outcomes = [] - if proposed_action and self._action_history: - matching = [ - a for a in self._action_history - if a["action"] == proposed_action - ] - if matching: - success_rate = sum(1 for a in matching if a["success"]) / len(matching) - confidence = success_rate - outcomes = matching[-3:] # last 3 similar actions - - # Simple prediction based on success rate - prediction = "unknown" - if confidence > 0.7: - prediction = f"Action '{proposed_action}' is likely to succeed (confidence: {confidence:.0%})" - elif confidence > 0.3: - prediction = f"Action '{proposed_action}' has mixed results (confidence: {confidence:.0%})" - elif confidence > 0 and proposed_action: - prediction = f"Action '{proposed_action}' has often failed (confidence: {confidence:.0%})" - - return { - "prediction": prediction, - "confidence": round(confidence, 3), - "similar_memories": similar[:3], - "recent_outcomes": [ - {"action": o["action"], "success": o["success"]} - for o in outcomes - ], - } - - def adapt_embodiment( - self, - capabilities: Dict[str, Any], - user_id: Optional[str] = None, - ) -> Dict[str, Any]: - """Update self-model when hardware capabilities change. - - Call when sensors are added/removed, actuators change, or the - physical form factor is modified. - - Args: - capabilities: New capability dict (e.g., {"has_camera": True, "arm_reach_cm": 60}) - """ - uid = user_id or self._user_id - cap_desc = ", ".join(f"{k}: {v}" for k, v in capabilities.items()) - content = f"Embodiment update: {cap_desc}" - return self.remember( - content=content, - user_id=uid, - metadata={"source": "embodiment_update", "capabilities": capabilities}, - ) - - # ------------------------------------------------------------------ - # Internal helpers - # ------------------------------------------------------------------ - - def _describe_sensor_data(self, sensor_type: str, data: Dict[str, Any]) -> str: - """Convert raw sensor data to natural language for memory storage.""" - readings = ", ".join( - f"{k}={v}" for k, v in data.items() - if k != "timestamp" and not isinstance(v, (dict, list)) - ) - return f"Sensor[{sensor_type}]: {readings}" - - def _update_environment_model( - self, sensor_type: str, data: Dict[str, Any], - ) -> None: - """Update the running environment model with new sensor data.""" - self._environment_model[sensor_type] = { - "last_reading": data, - "last_updated": time.time(), - } diff --git a/dhee/edge/edge_trainer.py b/dhee/edge/edge_trainer.py deleted file mode 100644 index 572bd9c..0000000 --- a/dhee/edge/edge_trainer.py +++ /dev/null @@ -1,508 +0,0 @@ -"""EdgeTrainer — on-device micro-training for edge deployments. - -Runs minimal LoRA fine-tuning directly on edge hardware (ARM CPU, low-RAM -devices). Designed for DheeEdge scenarios where the model needs to adapt -to its specific user/environment without cloud connectivity. - -Constraints: - - CPU-only training (no CUDA required) - - <2GB peak RAM during training - - LoRA rank 4-8 (tiny adapter, ~2MB) - - Micro-batches of 1-4 samples - - 10-50 gradient steps per cycle (not epochs) - -Training data sources (all local): - - Samskara signals from SamskaraCollector.get_training_data() - - Action outcome pairs from DheeEdge._action_history - - Sensor pattern correlations - -Architecture: - EdgeTrainer does NOT require torch/transformers at init. It checks - for their availability lazily. On devices without PyTorch, it logs - a warning and becomes a no-op. - -Usage: - from dhee.edge.edge_trainer import EdgeTrainer - - trainer = EdgeTrainer( - model_path="/data/models/dhee-2b-q4.gguf", - adapter_dir="/data/dhee/adapters", - ) - - # Check if training is possible on this device - if trainer.can_train: - result = trainer.micro_train(training_data) -""" - -from __future__ import annotations - -import json -import logging -import os -import time -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional - -logger = logging.getLogger(__name__) - - -@dataclass -class MicroTrainResult: - """Result of a micro-training cycle.""" - - success: bool - steps_completed: int = 0 - samples_used: int = 0 - loss_start: float = 0.0 - loss_end: float = 0.0 - adapter_path: Optional[str] = None - duration_seconds: float = 0.0 - error: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - return { - "success": self.success, - "steps_completed": self.steps_completed, - "samples_used": self.samples_used, - "loss_start": round(self.loss_start, 4), - "loss_end": round(self.loss_end, 4), - "adapter_path": self.adapter_path, - "duration_seconds": round(self.duration_seconds, 2), - "error": self.error, - } - - -@dataclass -class EdgeTrainingConfig: - """Configuration for edge micro-training.""" - - lora_rank: int = 4 - lora_alpha: int = 8 - learning_rate: float = 2e-4 - max_steps: int = 30 - micro_batch_size: int = 2 - max_seq_len: int = 256 - gradient_accumulation_steps: int = 2 - warmup_steps: int = 3 - weight_decay: float = 0.01 - max_samples: int = 100 # Limit training data - - -class EdgeTrainer: - """On-device micro-training for edge deployments. - - Performs minimal LoRA fine-tuning on CPU with tight resource budgets. - Training is designed to be interruptible — partial progress is saved. - """ - - def __init__( - self, - model_path: Optional[str] = None, - adapter_dir: Optional[str] = None, - config: Optional[EdgeTrainingConfig] = None, - ): - """ - Args: - model_path: Path to the base model (GGUF or safetensors). - adapter_dir: Directory to save/load LoRA adapters. - config: Training hyperparameters. - """ - self._model_path = model_path - self._adapter_dir = adapter_dir or os.path.join( - os.path.expanduser("~"), ".dhee", "edge_adapters", - ) - self.config = config or EdgeTrainingConfig() - self._training_history: List[Dict[str, Any]] = [] - self._torch_available: Optional[bool] = None - - @property - def can_train(self) -> bool: - """Check if training is possible on this device.""" - if self._torch_available is None: - try: - import torch # noqa: F401 - self._torch_available = True - except ImportError: - self._torch_available = False - logger.info( - "PyTorch not available — edge training disabled. " - "Install with: pip install torch --index-url https://download.pytorch.org/whl/cpu" - ) - return self._torch_available and self._model_path is not None - - def micro_train( - self, - training_data: Dict[str, Any], - samskara_signals: Optional[Dict[str, Any]] = None, - ) -> MicroTrainResult: - """Run a micro-training cycle. - - Args: - training_data: Dict with keys: - - sft_samples: List of {"input": str, "output": str} dicts - - dpo_pairs: List of {"chosen": str, "rejected": str} dicts (optional) - samskara_signals: Optional vasana report for sample weighting. - - Returns: - MicroTrainResult with training metrics. - """ - if not self.can_train: - return MicroTrainResult( - success=False, - error="Training not available (missing PyTorch or model)", - ) - - start_time = time.time() - - # Prepare training samples - sft_samples = training_data.get("sft_samples", []) - if not sft_samples: - return MicroTrainResult( - success=False, - error="No training samples provided", - ) - - # Limit to max_samples - samples = sft_samples[:self.config.max_samples] - - # Weight samples by vasana if available - if samskara_signals: - samples = self._weight_samples(samples, samskara_signals) - - try: - result = self._run_lora_training(samples) - result.duration_seconds = time.time() - start_time - - # Record in history - self._training_history.append({ - "timestamp": time.time(), - "result": result.to_dict(), - }) - - # Persist training log - self._save_training_log() - - return result - except Exception as e: - logger.warning("Micro-training failed: %s", e) - return MicroTrainResult( - success=False, - duration_seconds=time.time() - start_time, - error=str(e), - ) - - def _run_lora_training(self, samples: List[Dict]) -> MicroTrainResult: - """Run LoRA fine-tuning with PyTorch. - - This is CPU-optimized: - - float32 (no mixed precision on CPU) - - Gradient checkpointing for memory efficiency - - Small LoRA rank (4) = tiny trainable parameter count - """ - import torch - from torch.utils.data import DataLoader, Dataset - - class TextDataset(Dataset): - def __init__(self, data: List[Dict]): - self.data = data - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - item = self.data[idx] - return item.get("input", ""), item.get("output", "") - - # Check if we can load the model for training - # For GGUF models, we need llama-cpp-python for inference but - # can't fine-tune them directly. We look for a safetensors/HF model. - model, tokenizer = self._load_model_for_training() - if model is None: - # Fallback: save training data for later batch processing - return self._save_for_deferred_training(samples) - - dataset = TextDataset(samples) - loader = DataLoader( - dataset, - batch_size=self.config.micro_batch_size, - shuffle=True, - ) - - # Apply LoRA - model = self._apply_lora(model) - model.train() - - # Optimizer - trainable_params = [p for p in model.parameters() if p.requires_grad] - optimizer = torch.optim.AdamW( - trainable_params, - lr=self.config.learning_rate, - weight_decay=self.config.weight_decay, - ) - - # Training loop - total_steps = 0 - losses = [] - accum_loss = 0.0 - - for step_idx in range(self.config.max_steps): - for batch_inputs, batch_outputs in loader: - # Tokenize - combined = [ - f"{inp} {out}" for inp, out in zip(batch_inputs, batch_outputs) - ] - encodings = tokenizer( - combined, - return_tensors="pt", - max_length=self.config.max_seq_len, - truncation=True, - padding=True, - ) - - # Forward pass - outputs = model( - input_ids=encodings["input_ids"], - attention_mask=encodings["attention_mask"], - labels=encodings["input_ids"], - ) - loss = outputs.loss / self.config.gradient_accumulation_steps - loss.backward() - accum_loss += loss.item() - - total_steps += 1 - - if total_steps % self.config.gradient_accumulation_steps == 0: - torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) - optimizer.step() - optimizer.zero_grad() - losses.append(accum_loss) - accum_loss = 0.0 - - if total_steps >= self.config.max_steps: - break - - if total_steps >= self.config.max_steps: - break - - # Save adapter - adapter_path = self._save_adapter(model) - - return MicroTrainResult( - success=True, - steps_completed=total_steps, - samples_used=len(samples), - loss_start=losses[0] if losses else 0.0, - loss_end=losses[-1] if losses else 0.0, - adapter_path=adapter_path, - ) - - def _load_model_for_training(self): - """Load model + tokenizer for training. - - Returns (model, tokenizer) or (None, None) if not available. - """ - if not self._model_path: - return None, None - - # GGUF files can't be fine-tuned directly - if self._model_path.endswith(".gguf"): - # Check for a companion safetensors model - base_dir = os.path.dirname(self._model_path) - safetensors_path = os.path.join(base_dir, "model.safetensors") - config_path = os.path.join(base_dir, "config.json") - if not os.path.exists(config_path): - logger.info( - "GGUF model cannot be fine-tuned directly. " - "Saving training data for deferred processing." - ) - return None, None - model_dir = base_dir - else: - model_dir = self._model_path - - try: - from transformers import AutoModelForCausalLM, AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained( - model_dir, trust_remote_code=True, - ) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - model = AutoModelForCausalLM.from_pretrained( - model_dir, - torch_dtype="auto", - trust_remote_code=True, - ) - return model, tokenizer - except Exception as e: - logger.info("Model loading failed: %s", e) - return None, None - - def _apply_lora(self, model): - """Apply LoRA adapters to the model.""" - try: - from peft import LoraConfig, get_peft_model, TaskType - lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - r=self.config.lora_rank, - lora_alpha=self.config.lora_alpha, - lora_dropout=0.0, # No dropout for micro-training - target_modules=["q_proj", "v_proj"], # Minimal target - ) - model = get_peft_model(model, lora_config) - trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) - total = sum(p.numel() for p in model.parameters()) - logger.info( - "LoRA applied: %d trainable / %d total params (%.2f%%)", - trainable, total, 100 * trainable / total, - ) - return model - except ImportError: - logger.warning( - "peft not available — training all parameters (not recommended " - "for edge). Install: pip install peft" - ) - return model - - def _save_adapter(self, model) -> str: - """Save the LoRA adapter to disk.""" - os.makedirs(self._adapter_dir, exist_ok=True) - adapter_name = f"adapter_{int(time.time())}" - adapter_path = os.path.join(self._adapter_dir, adapter_name) - - try: - if hasattr(model, "save_pretrained"): - model.save_pretrained(adapter_path) - else: - # Fallback: save state dict - import torch - torch.save( - {k: v for k, v in model.state_dict().items() if "lora" in k}, - os.path.join(adapter_path, "lora_weights.pt"), - ) - except Exception as e: - logger.warning("Adapter save failed: %s", e) - adapter_path = "" - - return adapter_path - - def _save_for_deferred_training( - self, samples: List[Dict], - ) -> MicroTrainResult: - """Save training data to disk for later batch processing. - - Used when the model format doesn't support direct fine-tuning - (e.g., GGUF without a companion HF model). - """ - os.makedirs(self._adapter_dir, exist_ok=True) - deferred_path = os.path.join( - self._adapter_dir, f"deferred_{int(time.time())}.jsonl", - ) - - with open(deferred_path, "w", encoding="utf-8") as f: - for sample in samples: - f.write(json.dumps(sample, ensure_ascii=False) + "\n") - - return MicroTrainResult( - success=True, - steps_completed=0, - samples_used=len(samples), - adapter_path=deferred_path, - error="deferred: saved training data for batch processing", - ) - - def _weight_samples( - self, - samples: List[Dict], - samskara_signals: Dict[str, Any], - ) -> List[Dict]: - """Weight training samples based on vasana degradation signals. - - Samples from degrading dimensions get 2x representation. - """ - degrading = set(samskara_signals.get("degrading_dimensions", [])) - if not degrading: - return samples - - weighted = [] - for sample in samples: - weighted.append(sample) - sample_type = sample.get("type", "") - if sample_type in degrading: - weighted.append(sample) # Duplicate for emphasis - - return weighted - - def _save_training_log(self) -> None: - """Persist training history to disk.""" - os.makedirs(self._adapter_dir, exist_ok=True) - log_path = os.path.join(self._adapter_dir, "training_log.jsonl") - try: - with open(log_path, "a", encoding="utf-8") as f: - entry = self._training_history[-1] - f.write(json.dumps(entry, ensure_ascii=False) + "\n") - except OSError: - pass - - # ------------------------------------------------------------------ - # Integration with DheeEdge - # ------------------------------------------------------------------ - - def train_from_edge( - self, - edge_plugin: Any, - samskara: Optional[Any] = None, - ) -> MicroTrainResult: - """Convenience: collect training data from a DheeEdge instance and train. - - Gathers: - - Action outcome pairs as SFT samples - - Samskara signals for weighting - """ - # Collect action-based SFT samples - sft_samples = [] - action_history = getattr(edge_plugin, "_action_history", []) - for record in action_history[-self.config.max_samples:]: - action = record.get("action", "") - success = record.get("success", False) - env = record.get("env_state", {}) - - env_desc = ", ".join(f"{k}={v}" for k, v in list(env.items())[:5]) if env else "unknown" - sft_samples.append({ - "input": f"[ACTION] state: {env_desc}, action: {action}", - "output": f"{'success' if success else 'failure'}", - "type": "action_prediction", - }) - - # Add samskara training data if available - samskara_data = {} - if samskara: - try: - samskara_data = samskara.get_training_data() - sft_samples.extend(samskara_data.get("sft_samples", [])) - except Exception: - pass - - if not sft_samples: - return MicroTrainResult( - success=False, error="No training data from edge", - ) - - return self.micro_train( - training_data={"sft_samples": sft_samples}, - samskara_signals=samskara_data, - ) - - def get_status(self) -> Dict[str, Any]: - """Get trainer status.""" - return { - "can_train": self.can_train, - "model_path": self._model_path, - "adapter_dir": self._adapter_dir, - "training_cycles": len(self._training_history), - "config": { - "lora_rank": self.config.lora_rank, - "max_steps": self.config.max_steps, - "learning_rate": self.config.learning_rate, - }, - } diff --git a/dhee/embeddings/nvidia.py b/dhee/embeddings/nvidia.py index eb3cad1..24c154e 100644 --- a/dhee/embeddings/nvidia.py +++ b/dhee/embeddings/nvidia.py @@ -20,12 +20,13 @@ def __init__(self, config: Optional[dict] = None): api_key = ( self.config.get("api_key") or os.getenv("NVIDIA_EMBEDDING_API_KEY") + or os.getenv("NVIDIA_EMBED_API_KEY") or os.getenv("NVIDIA_API_KEY") ) if not api_key: raise ValueError( "NVIDIA API key required. Set config['api_key'], " - "NVIDIA_EMBEDDING_API_KEY, or NVIDIA_API_KEY env var." + "NVIDIA_EMBEDDING_API_KEY, NVIDIA_EMBED_API_KEY, or NVIDIA_API_KEY env var." ) base_url = self.config.get("base_url", "https://integrate.api.nvidia.com/v1") @@ -54,10 +55,15 @@ def _extra_body(self, memory_action: Optional[str] = None, count: int = 1) -> di def _truncate_if_needed(self, text: str) -> str: """Truncate text to stay within model token limits. - nv-embed-v1 has a 4096 token limit. Using ~3.5 chars/token as - a conservative estimate, cap at 14000 characters. + Model-aware defaults (conservative ~3.5 chars/token): + - nv-embed-v1: 4096 tokens → 14000 chars + - nemotron-embed: 8192 tokens → 26000 chars (use 24000 for safety) """ - max_chars = int(self.config.get("max_input_chars", 14000)) + if "nemotron-embed" in self.model: + default_max = 24000 + else: + default_max = 14000 + max_chars = int(self.config.get("max_input_chars", default_max)) if len(text) > max_chars: logger.debug("Truncating input from %d to %d chars for embedding", len(text), max_chars) return text[:max_chars] diff --git a/dhee/hive/__init__.py b/dhee/hive/__init__.py deleted file mode 100644 index a35f837..0000000 --- a/dhee/hive/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Dhee Hive — multi-agent shared cognition layer. - -Built on top of engram-bus for real-time agent-to-agent communication, -with CRDT-based sync for offline/edge scenarios. -""" - -from dhee.hive.hive_memory import HiveMemory - -__all__ = ["HiveMemory"] diff --git a/dhee/hive/hive_memory.py b/dhee/hive/hive_memory.py deleted file mode 100644 index 26453e2..0000000 --- a/dhee/hive/hive_memory.py +++ /dev/null @@ -1,526 +0,0 @@ -"""HiveMemory — multi-agent shared cognition on top of engram-bus. - -Enables multiple DheePlugin instances (across agents, processes, or machines) -to share and evolve a collective knowledge base: - - - **Shared Insights**: Cross-agent discoveries and patterns. - - **Shared Heuristics**: Abstract reasoning rules mined from any agent's trajectories. - - **Shared Skills**: Proven skills that any agent can adopt. - - **Collective Signals**: Aggregated samskara-like signals for hive-level evolution. - -Architecture: - Each agent runs a local DheePlugin. HiveMemory sits alongside it and - periodically publishes local discoveries to the bus, and subscribes to - discoveries from other agents. A quality gate ensures only validated - knowledge propagates. - - ┌──────────┐ bus.publish() ┌─────────────┐ - │ Agent A │ ────────────────────▶ │ engram-bus │ - │ DheePlugin│ ◀──────────────────── │ (pub/sub + │ - │ + Hive │ bus.subscribe() │ KV store) │ - └──────────┘ └──────┬──────┘ - │ - ┌──────────┐ │ - │ Agent B │ ◀───────────────────────────┘ - │ DheePlugin│ - │ + Hive │ - └──────────┘ -""" - -from __future__ import annotations - -import json -import logging -import time -import threading -from dataclasses import asdict, dataclass, field -from datetime import datetime, timezone -from typing import Any, Callable, Dict, List, Optional, Set - -logger = logging.getLogger(__name__) - - -def _now_iso() -> str: - return datetime.now(timezone.utc).isoformat() - - -# --------------------------------------------------------------------------- -# Shared knowledge item types -# --------------------------------------------------------------------------- - -@dataclass -class SharedItem: - """A piece of knowledge shared on the hive.""" - - id: str - kind: str # "insight" | "heuristic" | "skill" | "signal" - content: Dict[str, Any] - source_agent: str - timestamp: str = field(default_factory=_now_iso) - confidence: float = 0.5 - votes_up: int = 0 - votes_down: int = 0 - adopted_by: List[str] = field(default_factory=list) - - def to_dict(self) -> Dict[str, Any]: - return { - "id": self.id, - "kind": self.kind, - "content": self.content, - "source_agent": self.source_agent, - "timestamp": self.timestamp, - "confidence": self.confidence, - "votes_up": self.votes_up, - "votes_down": self.votes_down, - "adopted_by": self.adopted_by, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SharedItem": - return cls( - id=data["id"], - kind=data["kind"], - content=data.get("content", {}), - source_agent=data.get("source_agent", "unknown"), - timestamp=data.get("timestamp", _now_iso()), - confidence=data.get("confidence", 0.5), - votes_up=data.get("votes_up", 0), - votes_down=data.get("votes_down", 0), - adopted_by=data.get("adopted_by", []), - ) - - @property - def quality_score(self) -> float: - """Wilson score lower bound — conservative estimate of true quality.""" - n = self.votes_up + self.votes_down - if n == 0: - return self.confidence - p = self.votes_up / n - # Wilson score interval (simplified) - z = 1.96 # 95% confidence - denominator = 1 + z * z / n - centre = p + z * z / (2 * n) - spread = z * ((p * (1 - p) + z * z / (4 * n)) / n) ** 0.5 - return (centre - spread) / denominator - - -# --------------------------------------------------------------------------- -# Topics -# --------------------------------------------------------------------------- - -_TOPIC_SHARE = "dhee.hive.share" -_TOPIC_VOTE = "dhee.hive.vote" -_TOPIC_ADOPT = "dhee.hive.adopt" -_TOPIC_SYNC_REQUEST = "dhee.hive.sync.request" -_TOPIC_SYNC_RESPONSE = "dhee.hive.sync.response" -_NS_HIVE = "dhee_hive" - - -# --------------------------------------------------------------------------- -# HiveMemory -# --------------------------------------------------------------------------- - -class HiveMemory: - """Multi-agent shared cognition layer. - - Wraps an engram-bus instance to provide structured knowledge sharing - with quality gating and adoption tracking. - """ - - def __init__( - self, - agent_id: str, - bus: Any = None, - min_confidence_to_share: float = 0.4, - min_quality_to_adopt: float = 0.3, - auto_subscribe: bool = True, - ): - """ - Args: - agent_id: This agent's identifier on the hive. - bus: An engram_bus.Bus instance. If None, creates an in-memory bus. - min_confidence_to_share: Minimum confidence to publish to hive. - min_quality_to_adopt: Minimum quality_score to auto-adopt shared items. - auto_subscribe: Whether to subscribe to hive topics on init. - """ - self.agent_id = agent_id - self._min_share = min_confidence_to_share - self._min_adopt = min_quality_to_adopt - - # Local store of hive items (id -> SharedItem) - self._items: Dict[str, SharedItem] = {} - self._lock = threading.RLock() - self._on_receive_callbacks: List[Callable] = [] - - # Bus connection - if bus is None: - try: - from engram_bus import Bus - bus = Bus() - except ImportError: - logger.warning("engram-bus not available, hive runs in local-only mode") - bus = None - - self._bus = bus - - if bus and auto_subscribe: - self._subscribe() - - def _subscribe(self) -> None: - """Subscribe to hive topics on the bus.""" - if not self._bus: - return - self._bus.subscribe(_TOPIC_SHARE, self._on_share_received, agent=self.agent_id) - self._bus.subscribe(_TOPIC_VOTE, self._on_vote_received, agent=self.agent_id) - self._bus.subscribe(_TOPIC_ADOPT, self._on_adopt_received, agent=self.agent_id) - self._bus.subscribe( - _TOPIC_SYNC_REQUEST, self._on_sync_request, agent=self.agent_id, - ) - self._bus.register(self.agent_id, metadata={"type": "dhee_hive_member"}) - - # ------------------------------------------------------------------ - # Publishing - # ------------------------------------------------------------------ - - def share_insight( - self, - insight_id: str, - content: Dict[str, Any], - confidence: float = 0.5, - ) -> Optional[SharedItem]: - """Share an insight (from Buddhi reflect) with the hive.""" - return self._publish_item( - item_id=f"insight:{self.agent_id}:{insight_id}", - kind="insight", - content=content, - confidence=confidence, - ) - - def share_heuristic( - self, - heuristic_id: str, - content: Dict[str, Any], - confidence: float = 0.5, - ) -> Optional[SharedItem]: - """Share a distilled heuristic with the hive.""" - return self._publish_item( - item_id=f"heuristic:{self.agent_id}:{heuristic_id}", - kind="heuristic", - content=content, - confidence=confidence, - ) - - def share_skill( - self, - skill_id: str, - content: Dict[str, Any], - confidence: float = 0.5, - ) -> Optional[SharedItem]: - """Share a proven skill with the hive.""" - return self._publish_item( - item_id=f"skill:{self.agent_id}:{skill_id}", - kind="skill", - content=content, - confidence=confidence, - ) - - def share_signal( - self, - signal_type: str, - data: Dict[str, Any], - ) -> Optional[SharedItem]: - """Share an aggregated signal (e.g., vasana shift) with the hive.""" - return self._publish_item( - item_id=f"signal:{self.agent_id}:{signal_type}:{int(time.time())}", - kind="signal", - content={"signal_type": signal_type, **data}, - confidence=0.5, - ) - - def _publish_item( - self, - item_id: str, - kind: str, - content: Dict[str, Any], - confidence: float, - ) -> Optional[SharedItem]: - if confidence < self._min_share: - logger.debug( - "Not sharing %s (confidence %.2f < %.2f)", - item_id, confidence, self._min_share, - ) - return None - - item = SharedItem( - id=item_id, - kind=kind, - content=content, - source_agent=self.agent_id, - confidence=confidence, - ) - - with self._lock: - self._items[item.id] = item - - if self._bus: - self._bus.publish(_TOPIC_SHARE, item.to_dict(), agent=self.agent_id) - # Also store in bus KV for late joiners - self._bus.put( - f"hive:{item.id}", - json.dumps(item.to_dict()), - agent=self.agent_id, - namespace=_NS_HIVE, - ) - - return item - - # ------------------------------------------------------------------ - # Voting - # ------------------------------------------------------------------ - - def vote(self, item_id: str, upvote: bool = True) -> None: - """Vote on a shared item's quality.""" - with self._lock: - item = self._items.get(item_id) - if item: - if upvote: - item.votes_up += 1 - else: - item.votes_down += 1 - - if self._bus: - self._bus.publish( - _TOPIC_VOTE, - {"item_id": item_id, "upvote": upvote, "voter": self.agent_id}, - agent=self.agent_id, - ) - - def adopt(self, item_id: str) -> Optional[SharedItem]: - """Mark a shared item as adopted by this agent.""" - with self._lock: - item = self._items.get(item_id) - if not item: - return None - if self.agent_id not in item.adopted_by: - item.adopted_by.append(self.agent_id) - - if self._bus: - self._bus.publish( - _TOPIC_ADOPT, - {"item_id": item_id, "adopter": self.agent_id}, - agent=self.agent_id, - ) - return item - - # ------------------------------------------------------------------ - # Querying - # ------------------------------------------------------------------ - - def get_shared( - self, - kind: Optional[str] = None, - min_quality: Optional[float] = None, - limit: int = 20, - ) -> List[SharedItem]: - """Get shared items, optionally filtered by kind and quality.""" - min_q = min_quality if min_quality is not None else 0.0 - - with self._lock: - items = list(self._items.values()) - - if kind: - items = [i for i in items if i.kind == kind] - items = [i for i in items if i.quality_score >= min_q] - items.sort(key=lambda i: i.quality_score, reverse=True) - return items[:limit] - - def get_adoptable(self, limit: int = 10) -> List[SharedItem]: - """Get high-quality items not yet adopted by this agent.""" - with self._lock: - candidates = [ - item for item in self._items.values() - if self.agent_id not in item.adopted_by - and item.source_agent != self.agent_id - and item.quality_score >= self._min_adopt - ] - candidates.sort(key=lambda i: i.quality_score, reverse=True) - return candidates[:limit] - - def get_hive_stats(self) -> Dict[str, Any]: - """Get statistics about the hive.""" - with self._lock: - items = list(self._items.values()) - - by_kind: Dict[str, int] = {} - by_agent: Dict[str, int] = {} - total_votes = 0 - - for item in items: - by_kind[item.kind] = by_kind.get(item.kind, 0) + 1 - by_agent[item.source_agent] = by_agent.get(item.source_agent, 0) + 1 - total_votes += item.votes_up + item.votes_down - - return { - "total_items": len(items), - "by_kind": by_kind, - "by_agent": by_agent, - "total_votes": total_votes, - "avg_quality": ( - sum(i.quality_score for i in items) / len(items) - if items else 0.0 - ), - } - - # ------------------------------------------------------------------ - # Bus callbacks - # ------------------------------------------------------------------ - - def _on_share_received( - self, topic: str, data: Any, sender_agent: Optional[str], - ) -> None: - """Handle incoming shared item from another agent.""" - if sender_agent == self.agent_id: - return # Ignore own messages - - try: - item = SharedItem.from_dict(data) - except (TypeError, KeyError) as e: - logger.debug("Invalid shared item: %s", e) - return - - with self._lock: - if item.id not in self._items: - self._items[item.id] = item - - for cb in self._on_receive_callbacks: - try: - cb(item) - except Exception as e: - logger.debug("Hive receive callback error: %s", e) - - def _on_vote_received( - self, topic: str, data: Any, sender_agent: Optional[str], - ) -> None: - """Handle incoming vote from another agent.""" - if sender_agent == self.agent_id: - return - - item_id = data.get("item_id") - upvote = data.get("upvote", True) - - with self._lock: - item = self._items.get(item_id) - if item: - if upvote: - item.votes_up += 1 - else: - item.votes_down += 1 - - def _on_adopt_received( - self, topic: str, data: Any, sender_agent: Optional[str], - ) -> None: - """Handle adoption notification from another agent.""" - item_id = data.get("item_id") - adopter = data.get("adopter") - - with self._lock: - item = self._items.get(item_id) - if item and adopter and adopter not in item.adopted_by: - item.adopted_by.append(adopter) - - def _on_sync_request( - self, topic: str, data: Any, sender_agent: Optional[str], - ) -> None: - """Respond to sync request from another agent (e.g., edge coming online).""" - if sender_agent == self.agent_id or not self._bus: - return - - # Send all our items as a sync response - with self._lock: - payload = {item_id: item.to_dict() for item_id, item in self._items.items()} - - self._bus.publish( - _TOPIC_SYNC_RESPONSE, - {"items": payload, "responder": self.agent_id}, - agent=self.agent_id, - ) - - # ------------------------------------------------------------------ - # Sync (pull-based) - # ------------------------------------------------------------------ - - def request_sync(self) -> None: - """Request a full sync from other hive members (e.g., after coming online).""" - if not self._bus: - return - self._bus.subscribe( - _TOPIC_SYNC_RESPONSE, self._on_sync_response, agent=self.agent_id, - ) - self._bus.publish( - _TOPIC_SYNC_REQUEST, - {"requester": self.agent_id}, - agent=self.agent_id, - ) - - def _on_sync_response( - self, topic: str, data: Any, sender_agent: Optional[str], - ) -> None: - """Handle sync response — merge received items.""" - items_data = data.get("items", {}) - with self._lock: - for item_id, item_dict in items_data.items(): - if item_id not in self._items: - try: - self._items[item_id] = SharedItem.from_dict(item_dict) - except (TypeError, KeyError): - pass - - # ------------------------------------------------------------------ - # Callbacks - # ------------------------------------------------------------------ - - def on_receive(self, callback: Callable[[SharedItem], None]) -> None: - """Register a callback for when new items arrive from the hive.""" - self._on_receive_callbacks.append(callback) - - # ------------------------------------------------------------------ - # Export for DheePlugin integration - # ------------------------------------------------------------------ - - def get_context_block(self, limit: int = 5) -> Dict[str, Any]: - """Get hive knowledge formatted for HyperContext injection.""" - insights = self.get_shared(kind="insight", limit=limit) - heuristics = self.get_shared(kind="heuristic", limit=limit) - skills = self.get_shared(kind="skill", limit=limit) - - return { - "hive_insights": [ - { - "source": i.source_agent, - "content": i.content, - "quality": round(i.quality_score, 2), - } - for i in insights - ], - "hive_heuristics": [ - { - "source": h.source_agent, - "content": h.content, - "quality": round(h.quality_score, 2), - } - for h in heuristics - ], - "hive_skills": [ - { - "source": s.source_agent, - "name": s.content.get("name", s.id), - "quality": round(s.quality_score, 2), - } - for s in skills - ], - } - - def close(self) -> None: - """Unsubscribe and clean up.""" - # Bus cleanup is handled by the bus owner - self._on_receive_callbacks.clear() diff --git a/dhee/hive/sync.py b/dhee/hive/sync.py deleted file mode 100644 index 1a0e546..0000000 --- a/dhee/hive/sync.py +++ /dev/null @@ -1,436 +0,0 @@ -"""CRDT-based sync protocol for offline/edge Dhee nodes. - -When a DheeEdge instance operates offline (e.g., a humanoid robot in a -warehouse with no connectivity), it accumulates local hive items. On -reconnection, it needs to merge with the central hive without conflicts. - -This module implements: - 1. **LWW-Register** (Last-Writer-Wins) for individual shared items. - 2. **G-Counter** for vote counts (grow-only, merge = max per node). - 3. **OR-Set** (Observed-Remove) for adoption lists. - 4. **SyncEnvelope** — wire format for shipping CRDT state between nodes. - -Usage: - # On edge device: - state = CRDTState(node_id="edge-1") - state.set_item(shared_item) - state.increment_votes_up("item:123") - envelope = state.export_envelope() - - # Ship envelope (HTTP, BLE, serial, file drop — whatever works) - - # On hub: - hub_state = CRDTState(node_id="hub") - hub_state.merge(envelope) -""" - -from __future__ import annotations - -import json -import logging -import os -import time -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Set, Tuple - -logger = logging.getLogger(__name__) - - -def _hlc_now(node_id: str) -> str: - """Hybrid Logical Clock timestamp: |. - - Provides a globally-unique, monotonically-increasing timestamp even - if wall clocks disagree between nodes. Uses '|' separator so node_ids - can contain dashes. - """ - return f"{int(time.time() * 1000)}|{node_id}" - - -def _hlc_compare(a: str, b: str) -> int: - """Compare two HLC timestamps. Returns -1, 0, or 1.""" - a_ms, a_node = a.split("|", 1) - b_ms, b_node = b.split("|", 1) - a_int, b_int = int(a_ms), int(b_ms) - if a_int != b_int: - return -1 if a_int < b_int else 1 - if a_node < b_node: - return -1 - if a_node > b_node: - return 1 - return 0 - - -# --------------------------------------------------------------------------- -# LWW-Register: per-item state -# --------------------------------------------------------------------------- - -@dataclass -class LWWRegister: - """Last-Writer-Wins Register for a shared item's content.""" - - value: Dict[str, Any] - timestamp: str # HLC timestamp - node_id: str - - def merge(self, other: "LWWRegister") -> "LWWRegister": - """Merge two registers — latest timestamp wins.""" - if _hlc_compare(self.timestamp, other.timestamp) >= 0: - return self - return other - - def to_dict(self) -> Dict[str, Any]: - return { - "value": self.value, - "timestamp": self.timestamp, - "node_id": self.node_id, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "LWWRegister": - return cls( - value=data["value"], - timestamp=data["timestamp"], - node_id=data["node_id"], - ) - - -# --------------------------------------------------------------------------- -# G-Counter: grow-only counter per node -# --------------------------------------------------------------------------- - -@dataclass -class GCounter: - """Grow-only counter — each node has its own monotonic count.""" - - counts: Dict[str, int] = field(default_factory=dict) # node_id -> count - - @property - def value(self) -> int: - return sum(self.counts.values()) - - def increment(self, node_id: str, amount: int = 1) -> None: - self.counts[node_id] = self.counts.get(node_id, 0) + amount - - def merge(self, other: "GCounter") -> "GCounter": - """Merge = max of each node's count.""" - all_nodes = set(self.counts) | set(other.counts) - merged = GCounter() - for node in all_nodes: - merged.counts[node] = max( - self.counts.get(node, 0), - other.counts.get(node, 0), - ) - return merged - - def to_dict(self) -> Dict[str, int]: - return dict(self.counts) - - @classmethod - def from_dict(cls, data: Dict[str, int]) -> "GCounter": - return cls(counts=dict(data)) - - -# --------------------------------------------------------------------------- -# OR-Set: Observed-Remove Set (for adoption lists) -# --------------------------------------------------------------------------- - -@dataclass -class ORSet: - """Observed-Remove Set — supports both add and remove with convergence. - - Each element is tagged with a unique (node_id, seq) pair. Removes - only remove the tags that were observed, so concurrent adds win. - """ - - # element -> set of (node_id, seq) tags - _elements: Dict[str, Set[Tuple[str, int]]] = field(default_factory=lambda: {}) - _tombstones: Set[Tuple[str, int]] = field(default_factory=set) - _seq: int = 0 - - def add(self, element: str, node_id: str) -> None: - self._seq += 1 - tag = (node_id, self._seq) - if element not in self._elements: - self._elements[element] = set() - self._elements[element].add(tag) - - def remove(self, element: str) -> None: - tags = self._elements.pop(element, set()) - self._tombstones.update(tags) - - @property - def elements(self) -> Set[str]: - return { - elem for elem, tags in self._elements.items() - if tags - self._tombstones - } - - def merge(self, other: "ORSet") -> "ORSet": - """Merge two OR-Sets.""" - merged = ORSet() - merged._seq = max(self._seq, other._seq) - merged._tombstones = self._tombstones | other._tombstones - - all_elements = set(self._elements) | set(other._elements) - for elem in all_elements: - tags_a = self._elements.get(elem, set()) - tags_b = other._elements.get(elem, set()) - # Union of live tags minus all tombstones - live_tags = (tags_a | tags_b) - merged._tombstones - if live_tags: - merged._elements[elem] = live_tags - - return merged - - def to_dict(self) -> Dict[str, Any]: - return { - "elements": { - elem: [list(t) for t in tags] - for elem, tags in self._elements.items() - }, - "tombstones": [list(t) for t in self._tombstones], - "seq": self._seq, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ORSet": - s = cls() - s._seq = data.get("seq", 0) - s._tombstones = {tuple(t) for t in data.get("tombstones", [])} - s._elements = { - elem: {tuple(t) for t in tags} - for elem, tags in data.get("elements", {}).items() - } - return s - - -# --------------------------------------------------------------------------- -# Per-item CRDT state -# --------------------------------------------------------------------------- - -@dataclass -class ItemCRDT: - """CRDT state for a single shared hive item.""" - - item_id: str - content: LWWRegister # The item payload - votes_up: GCounter = field(default_factory=GCounter) - votes_down: GCounter = field(default_factory=GCounter) - adopted_by: ORSet = field(default_factory=ORSet) - - def merge(self, other: "ItemCRDT") -> "ItemCRDT": - assert self.item_id == other.item_id - return ItemCRDT( - item_id=self.item_id, - content=self.content.merge(other.content), - votes_up=self.votes_up.merge(other.votes_up), - votes_down=self.votes_down.merge(other.votes_down), - adopted_by=self.adopted_by.merge(other.adopted_by), - ) - - def to_dict(self) -> Dict[str, Any]: - return { - "item_id": self.item_id, - "content": self.content.to_dict(), - "votes_up": self.votes_up.to_dict(), - "votes_down": self.votes_down.to_dict(), - "adopted_by": self.adopted_by.to_dict(), - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ItemCRDT": - return cls( - item_id=data["item_id"], - content=LWWRegister.from_dict(data["content"]), - votes_up=GCounter.from_dict(data.get("votes_up", {})), - votes_down=GCounter.from_dict(data.get("votes_down", {})), - adopted_by=ORSet.from_dict(data.get("adopted_by", {})), - ) - - -# --------------------------------------------------------------------------- -# SyncEnvelope — wire format -# --------------------------------------------------------------------------- - -@dataclass -class SyncEnvelope: - """Wire format for CRDT state exchange between nodes.""" - - source_node: str - timestamp: str # HLC - items: Dict[str, Dict[str, Any]] # item_id -> ItemCRDT.to_dict() - - def to_bytes(self) -> bytes: - return json.dumps({ - "source_node": self.source_node, - "timestamp": self.timestamp, - "items": self.items, - }, ensure_ascii=False).encode("utf-8") - - @classmethod - def from_bytes(cls, data: bytes) -> "SyncEnvelope": - d = json.loads(data.decode("utf-8")) - return cls( - source_node=d["source_node"], - timestamp=d["timestamp"], - items=d.get("items", {}), - ) - - def to_dict(self) -> Dict[str, Any]: - return { - "source_node": self.source_node, - "timestamp": self.timestamp, - "items": self.items, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SyncEnvelope": - return cls( - source_node=data["source_node"], - timestamp=data["timestamp"], - items=data.get("items", {}), - ) - - -# --------------------------------------------------------------------------- -# CRDTState — per-node state manager -# --------------------------------------------------------------------------- - -class CRDTState: - """Manages CRDT state for a single node. - - Each node maintains its own CRDTState. On sync, nodes exchange - SyncEnvelopes and merge them — convergence is guaranteed by the - CRDT merge semantics (commutative, associative, idempotent). - """ - - def __init__(self, node_id: str, persist_path: Optional[str] = None): - self.node_id = node_id - self._items: Dict[str, ItemCRDT] = {} - self._persist_path = persist_path - if persist_path: - self._load() - - def set_item(self, item_id: str, content: Dict[str, Any]) -> None: - """Set or update an item's content (LWW).""" - ts = _hlc_now(self.node_id) - register = LWWRegister(value=content, timestamp=ts, node_id=self.node_id) - - if item_id in self._items: - self._items[item_id].content = self._items[item_id].content.merge(register) - else: - self._items[item_id] = ItemCRDT(item_id=item_id, content=register) - - self._auto_persist() - - def increment_votes_up(self, item_id: str, amount: int = 1) -> None: - if item_id in self._items: - self._items[item_id].votes_up.increment(self.node_id, amount) - self._auto_persist() - - def increment_votes_down(self, item_id: str, amount: int = 1) -> None: - if item_id in self._items: - self._items[item_id].votes_down.increment(self.node_id, amount) - self._auto_persist() - - def add_adopter(self, item_id: str, adopter: str) -> None: - if item_id in self._items: - self._items[item_id].adopted_by.add(adopter, self.node_id) - self._auto_persist() - - def export_envelope(self) -> SyncEnvelope: - """Export current state as a sync envelope.""" - return SyncEnvelope( - source_node=self.node_id, - timestamp=_hlc_now(self.node_id), - items={ - item_id: crdt.to_dict() - for item_id, crdt in self._items.items() - }, - ) - - def merge(self, envelope: SyncEnvelope) -> int: - """Merge a received envelope into local state. - - Returns number of items updated. - """ - updated = 0 - for item_id, item_dict in envelope.items.items(): - try: - remote = ItemCRDT.from_dict(item_dict) - except (TypeError, KeyError) as e: - logger.debug("Skipping malformed item %s: %s", item_id, e) - continue - - if item_id in self._items: - merged = self._items[item_id].merge(remote) - # Check if anything actually changed - if merged.to_dict() != self._items[item_id].to_dict(): - self._items[item_id] = merged - updated += 1 - else: - self._items[item_id] = remote - updated += 1 - - if updated: - self._auto_persist() - return updated - - def get_item(self, item_id: str) -> Optional[Dict[str, Any]]: - """Get resolved state of an item (content + vote totals + adopters).""" - crdt = self._items.get(item_id) - if not crdt: - return None - return { - "item_id": item_id, - "content": crdt.content.value, - "votes_up": crdt.votes_up.value, - "votes_down": crdt.votes_down.value, - "adopted_by": sorted(crdt.adopted_by.elements), - "last_updated": crdt.content.timestamp, - } - - def list_items(self) -> List[Dict[str, Any]]: - """List all items with resolved state.""" - return [ - self.get_item(item_id) - for item_id in sorted(self._items) - ] - - @property - def item_count(self) -> int: - return len(self._items) - - # ── Persistence ── - - def _auto_persist(self) -> None: - if self._persist_path: - self.save() - - def save(self, path: Optional[str] = None) -> None: - path = path or self._persist_path - if not path: - return - os.makedirs(os.path.dirname(path) or ".", exist_ok=True) - tmp = path + ".tmp" - data = { - "node_id": self.node_id, - "items": { - item_id: crdt.to_dict() - for item_id, crdt in self._items.items() - }, - } - with open(tmp, "w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=False) - os.replace(tmp, path) - - def _load(self) -> None: - if not self._persist_path or not os.path.exists(self._persist_path): - return - try: - with open(self._persist_path, "r", encoding="utf-8") as f: - data = json.load(f) - for item_id, item_dict in data.get("items", {}).items(): - self._items[item_id] = ItemCRDT.from_dict(item_dict) - except (OSError, json.JSONDecodeError, KeyError) as e: - logger.warning("Failed to load CRDT state: %s", e) diff --git a/dhee/integrations/__init__.py b/dhee/integrations/__init__.py deleted file mode 100644 index 08bfb9a..0000000 --- a/dhee/integrations/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Third-party agent framework integrations for Engram.""" diff --git a/dhee/llms/nvidia.py b/dhee/llms/nvidia.py index 234a45d..a5eb5d1 100644 --- a/dhee/llms/nvidia.py +++ b/dhee/llms/nvidia.py @@ -21,13 +21,14 @@ def __init__(self, config: Optional[dict] = None): api_key = ( self.config.get("api_key") or os.getenv("NVIDIA_QWEN_API_KEY") + or os.getenv("NVIDIA_LLAMA_4_MAV_API_KEY") or os.getenv("LLAMA_API_KEY") or os.getenv("NVIDIA_API_KEY") ) if not api_key: raise ValueError( "NVIDIA API key required. Set config['api_key'], " - "NVIDIA_QWEN_API_KEY, LLAMA_API_KEY, or NVIDIA_API_KEY env var." + "NVIDIA_QWEN_API_KEY, NVIDIA_LLAMA_4_MAV_API_KEY, LLAMA_API_KEY, or NVIDIA_API_KEY env var." ) base_url = self.config.get("base_url", "https://integrate.api.nvidia.com/v1") @@ -58,7 +59,11 @@ def generate(self, prompt: str) -> str: extra_kwargs = {} if self.enable_thinking: extra_kwargs["extra_body"] = { - "chat_template_kwargs": {"thinking": True} + "chat_template_kwargs": {"enable_thinking": True} + } + elif "gemma" in self.model.lower(): + extra_kwargs["extra_body"] = { + "chat_template_kwargs": {"enable_thinking": False} } use_stream = self.enable_thinking or self.config.get("stream", False) diff --git a/dhee/mcp_server.py b/dhee/mcp_server.py index 9634b6c..21ae6d4 100644 --- a/dhee/mcp_server.py +++ b/dhee/mcp_server.py @@ -48,6 +48,30 @@ logger = logging.getLogger(__name__) +def _default_user_id(args: Dict[str, Any]) -> str: + return str(args.get("user_id") or os.environ.get("DHEE_USER_ID") or "default") + + +def _default_agent_id(args: Dict[str, Any]) -> str: + return str(args.get("agent_id") or os.environ.get("DHEE_AGENT_ID") or "mcp-server") + + +def _default_requester_agent_id(args: Dict[str, Any]) -> str: + return str( + args.get("requester_agent_id") + or os.environ.get("DHEE_REQUESTER_AGENT_ID") + or _default_agent_id(args) + ) + + +def _default_source_app(args: Dict[str, Any]) -> str: + return str( + args.get("source_app") + or os.environ.get("DHEE_SOURCE_APP") + or _default_agent_id(args) + ) + + def _get_embedding_dims_for_model(model: str, provider: str) -> int: EMBEDDING_DIMS = { "models/text-embedding-005": 768, @@ -77,6 +101,8 @@ def get_memory_instance() -> FullMemory: os.environ.get("NVIDIA_API_KEY") or os.environ.get("NVIDIA_QWEN_API_KEY") or os.environ.get("NVIDIA_EMBEDDING_API_KEY") + or os.environ.get("NVIDIA_EMBED_API_KEY") + or os.environ.get("NVIDIA_LLAMA_4_MAV_API_KEY") ) def _env(key: str, default: str = "") -> str: @@ -172,6 +198,9 @@ def _env(key: str, default: str = "") -> str: embedding_model_dims=embedding_dims, fade=fade_config, ) + if hasattr(config, "enrichment"): + config.enrichment.defer_enrichment = True + config.enrichment.enable_unified = True return FullMemory(config) @@ -192,8 +221,9 @@ def get_buddhi(): """Lazy singleton for the Buddhi cognition layer.""" global _buddhi if _buddhi is None: + from dhee.configs.base import _dhee_data_dir from dhee.core.buddhi import Buddhi - _buddhi = Buddhi() + _buddhi = Buddhi(data_dir=os.path.join(_dhee_data_dir(), "buddhi")) return _buddhi @@ -205,11 +235,14 @@ def get_buddhi(): TOOLS = [ Tool( name="remember", - description="Quick-save a fact or preference to memory. Creates a staging proposal commit with source_app='claude-code' and infer=False by default.", + description="Quick-save a fact or preference to memory. Stores immediately with infer=False by default and uses the configured MCP agent/source identity when not provided.", inputSchema={ "type": "object", "properties": { "content": {"type": "string", "description": "The fact or preference to remember"}, + "user_id": {"type": "string", "description": "User identifier (defaults to DHEE_USER_ID or 'default')."}, + "agent_id": {"type": "string", "description": "Agent identifier (defaults to DHEE_AGENT_ID or 'mcp-server')."}, + "source_app": {"type": "string", "description": "Source application label (defaults to DHEE_SOURCE_APP or agent_id)."}, "categories": {"type": "array", "items": {"type": "string"}, "description": "Optional categories to tag this memory with (e.g., ['preferences', 'coding'])"}, "context": { "type": "array", @@ -574,10 +607,10 @@ def get_buddhi(): def _handle_remember(memory, args): return memory.add( messages=args.get("content", ""), - user_id="default", - agent_id="claude-code", + user_id=_default_user_id(args), + agent_id=_default_agent_id(args), categories=args.get("categories"), - source_app="claude-code", + source_app=_default_source_app(args), infer=False, context_messages=args.get("context"), ) @@ -596,7 +629,7 @@ def _handle_search_memory(memory, args): context_top_k = 10 result = memory.search_orchestrated( query=args.get("query", ""), - user_id=args.get("user_id", "default"), + user_id=_default_user_id(args), agent_id=args.get("agent_id"), categories=args.get("categories"), limit=limit, @@ -612,7 +645,7 @@ def _handle_search_memory(memory, args): else: result = memory.search( query=args.get("query", ""), - user_id=args.get("user_id", "default"), + user_id=_default_user_id(args), agent_id=args.get("agent_id"), limit=limit, categories=args.get("categories"), @@ -652,7 +685,7 @@ def _handle_get_all_memories(memory, args): except (ValueError, TypeError): limit = 50 result = memory.get_all( - user_id=args.get("user_id", "default"), + user_id=_default_user_id(args), agent_id=args.get("agent_id"), limit=limit, layer=args.get("layer"), @@ -691,8 +724,10 @@ def _handle_dhee_context(memory, args): def _handle_get_last_session(_memory, args): from dhee.core.kernel import get_last_session session = get_last_session( - agent_id=args.get("agent_id", "mcp-server"), + agent_id=args.get("agent_id"), repo=args.get("repo"), + user_id=_default_user_id(args), + requester_agent_id=_default_requester_agent_id(args), fallback_log_recovery=args.get("fallback_log_recovery", True), ) if session is None: @@ -704,7 +739,8 @@ def _handle_save_session_digest(_memory, args): from dhee.core.kernel import save_session_digest return save_session_digest( task_summary=args.get("task_summary", ""), - agent_id=args.get("agent_id", "claude-code"), + agent_id=_default_agent_id(args), + requester_agent_id=_default_requester_agent_id(args), repo=args.get("repo"), status=args.get("status", "active"), decisions_made=args.get("decisions_made"), @@ -718,7 +754,7 @@ def _handle_save_session_digest(_memory, args): def _handle_get_memory_stats(memory, args): return memory.get_stats( - user_id=args.get("user_id"), + user_id=args.get("user_id") or os.environ.get("DHEE_USER_ID"), agent_id=args.get("agent_id"), ) @@ -847,7 +883,7 @@ def _handle_think(memory, arguments: Dict[str, Any]) -> Dict[str, Any]: if hasattr(memory, "think"): result = memory.think( question=question, - user_id=user_id, + user_id=_default_user_id(arguments) if not arguments.get("user_id") else user_id, max_depth=max_depth, ) if hasattr(result, "to_dict"): @@ -858,7 +894,7 @@ def _handle_think(memory, arguments: Dict[str, Any]) -> Dict[str, Any]: def _handle_anticipate(memory, arguments: Dict[str, Any]) -> Dict[str, Any]: """Proactive intelligence — Buddhi checks intentions, insights, and scenes.""" - user_id = arguments.get("user_id", "default") + user_id = _default_user_id(arguments) context = arguments.get("context") buddhi = get_buddhi() @@ -894,7 +930,7 @@ def _handle_record_outcome(_memory, arguments: Dict[str, Any]) -> Dict[str, Any] """Record task outcome for performance tracking.""" task_type = arguments.get("task_type", "") score = float(arguments.get("score", 0.0)) - user_id = arguments.get("user_id", "default") + user_id = _default_user_id(arguments) metadata = arguments.get("metadata") if not task_type: @@ -916,7 +952,7 @@ def _handle_record_outcome(_memory, arguments: Dict[str, Any]) -> Dict[str, Any] def _handle_reflect(_memory, arguments: Dict[str, Any]) -> Dict[str, Any]: """Agent-triggered reflection — synthesize insights from experience.""" task_type = arguments.get("task_type", "") - user_id = arguments.get("user_id", "default") + user_id = _default_user_id(arguments) if not task_type: return {"error": "task_type is required"} @@ -938,7 +974,7 @@ def _handle_reflect(_memory, arguments: Dict[str, Any]) -> Dict[str, Any]: def _handle_store_intention(_memory, arguments: Dict[str, Any]) -> Dict[str, Any]: """Store a future trigger — prospective memory.""" description = arguments.get("description", "") - user_id = arguments.get("user_id", "default") + user_id = _default_user_id(arguments) if not description: return {"error": "description is required"} diff --git a/dhee/mcp_slim.py b/dhee/mcp_slim.py index 1013f99..607213b 100644 --- a/dhee/mcp_slim.py +++ b/dhee/mcp_slim.py @@ -38,7 +38,7 @@ def _get_plugin(): """Create the DheePlugin singleton. Wraps Engram + Buddhi.""" global _plugin if _plugin is None: - from dhee.adapters.base import DheePlugin + from dhee.plugin import DheePlugin _plugin = DheePlugin() # Enable deferred enrichment on the underlying memory memory = _plugin.memory diff --git a/dhee/memory/main.py b/dhee/memory/main.py index 69089b3..a5b741b 100644 --- a/dhee/memory/main.py +++ b/dhee/memory/main.py @@ -33,7 +33,6 @@ from dhee.core.graph import KnowledgeGraph from dhee.core.scene import SceneProcessor from dhee.core.profile import ProfileProcessor -from dhee.core.answer_orchestration import extract_atomic_facts, reduce_atomic_facts from dhee.db.sqlite import SQLiteManager from dhee.exceptions import FadeMemValidationError from dhee.memory.base import MemoryBase @@ -430,8 +429,6 @@ def _orchestration_engine(self) -> OrchestrationEngine: profile_processor_fn=lambda: self.profile_processor, evolution_layer_fn=lambda: self.evolution_layer, llm_fn=lambda: self.llm, - extract_atomic_facts_fn=extract_atomic_facts, - reduce_atomic_facts_fn=reduce_atomic_facts, ) return self.__orchestration_engine @@ -577,7 +574,7 @@ def reranker(self): """Lazy-initialized neural reranker (only if enabled in config).""" rerank_cfg = getattr(self.config, "rerank", None) if self._reranker is None and rerank_cfg and rerank_cfg.enable_rerank: - from dhee.retrieval.reranker import create_reranker + from dhee.memory.reranker import create_reranker self._reranker = create_reranker({ "provider": rerank_cfg.provider, "model": rerank_cfg.model, diff --git a/dhee/memory/orchestration.py b/dhee/memory/orchestration.py index ff3295b..a6ab4e5 100644 --- a/dhee/memory/orchestration.py +++ b/dhee/memory/orchestration.py @@ -1,24 +1,17 @@ -"""Orchestration engine: map-reduce, episodic anchoring, hierarchical retrieval. +"""Orchestration engine: episodic anchoring, hierarchical retrieval, context assembly. -Extracted from memory/main.py — centralizes the orchestrated-search path so that -FullMemory.search_orchestrated() becomes a thin delegation wrapper. +Dhee's job: retrieve well, assemble context, return it. No answer synthesis. """ from __future__ import annotations import logging import re -import time from typing import Any, Callable, Dict, List, Optional, Tuple -from dhee.memory.cost import stable_hash_text from dhee.core.episodic_index import normalize_actor_id from dhee.core.answer_orchestration import ( - build_map_candidates, build_query_plan, - deterministic_inconsistency_check, - is_low_confidence_answer, - render_fact_context, ) logger = logging.getLogger(__name__) @@ -45,8 +38,6 @@ def __init__( profile_processor_fn: Callable, evolution_layer_fn: Callable, llm_fn: Callable, - extract_atomic_facts_fn: Callable, - reduce_atomic_facts_fn: Callable, ): self._config = config self._db = db @@ -59,71 +50,9 @@ def __init__( self._profile_processor_fn = profile_processor_fn self._evolution_layer_fn = evolution_layer_fn self._llm_fn = llm_fn - self._extract_atomic_facts_fn = extract_atomic_facts_fn - self._reduce_atomic_facts_fn = reduce_atomic_facts_fn # Internal state - self._reducer_cache: Dict[str, Dict[str, Any]] = {} self._guardrail_auto_disabled: bool = False - # -- Reducer cache helpers ------------------------------------------------ - - def _build_reducer_cache_key( - self, - *, - user_id: str, - intent_value: str, - query: str, - results: List[Dict[str, Any]], - ) -> str: - evidence_fingerprint_parts: List[str] = [] - for row in results[:30]: - mem_id = str(row.get("id") or "").strip() - score = float(row.get("composite_score", row.get("score", 0.0)) or 0.0) - evidence_fingerprint_parts.append(f"{mem_id}:{score:.4f}") - evidence_fingerprint = "|".join(evidence_fingerprint_parts) - base = "|".join( - [ - str(user_id or ""), - str(intent_value or ""), - stable_hash_text(query), - stable_hash_text(evidence_fingerprint), - ] - ) - return stable_hash_text(base) - - def _get_reducer_cache(self, cache_key: str) -> Optional[Dict[str, Any]]: - orch_cfg = getattr(self._config, "orchestration", None) - ttl_seconds = int(getattr(orch_cfg, "reducer_cache_ttl_seconds", 900) or 900) - record = self._reducer_cache.get(cache_key) - if not record: - return None - ts = float(record.get("ts", 0.0) or 0.0) - if ts <= 0.0: - return None - if (time.time() - ts) > max(1, ttl_seconds): - self._reducer_cache.pop(cache_key, None) - return None - return record - - def _put_reducer_cache( - self, - *, - cache_key: str, - reduced_answer: Optional[str], - facts: List[Dict[str, Any]], - ) -> None: - orch_cfg = getattr(self._config, "orchestration", None) - max_entries = int(getattr(orch_cfg, "reducer_cache_max_entries", 2048) or 2048) - self._reducer_cache[cache_key] = { - "ts": time.time(), - "reduced_answer": reduced_answer, - "facts": list(facts or []), - } - # Keep insertion-order bounded cache. - while len(self._reducer_cache) > max(1, max_entries): - oldest_key = next(iter(self._reducer_cache)) - self._reducer_cache.pop(oldest_key, None) - # -- Cost guardrail ------------------------------------------------------- def _enforce_write_cost_guardrail(self, *, user_id: Optional[str]) -> None: @@ -507,49 +436,15 @@ def search_orchestrated( limit=3, ) - ( - reduced_answer, - facts, - map_reduce_used, - reflection_hops, - llm_calls_used, - cache_hit, - orchestration_reasons, - results, - ) = self._execute_map_reduce( - query_plan=query_plan, - orchestrator_llm=orchestrator_llm, - results=results, - event_hits=event_hits, - coverage=coverage, - query=query, - question_type=question_type, - question_date=question_date, - mode=mode, - search_cap_value=search_cap_value, - map_max_candidates_value=map_max_candidates_value, - map_max_chars_value=map_max_chars_value, - reflection_max_hops=reflection_max_hops, - search_query=search_query, - search_limit=search_limit, - rerank=rerank, - keyword_search=keyword_search, - hybrid_alpha=hybrid_alpha, - include_evidence=include_evidence, - evidence_strategy=evidence_strategy, - evidence_max_chars=evidence_max_chars, - evidence_context_lines=evidence_context_lines, - user_id=user_id, - filters=filters, - categories=categories, - agent_id=agent_id, - run_id=run_id, - app_id=app_id, - ) - reason_codes.extend(orchestration_reasons) + # Dhee's job: retrieve and assemble context. Agent answers. + # No map-reduce, no triple extraction, no LLM calls at query time. + reduced_answer: Optional[str] = None + facts: List[Dict[str, Any]] = [] + map_reduce_used = False + reflection_hops = 0 + llm_calls_used = 0.0 + cache_hit = False - # Always use full retrieval context — proposition context (Phase 3) - # is deferred until episodic event coverage is proven reliable. context = self._build_orchestrated_context( results=results, event_hits=event_hits, @@ -558,13 +453,6 @@ def search_orchestrated( max_chars=max_context_chars, per_result_max_chars=evidence_max_chars, ) - if facts: - fact_context = render_fact_context(facts, max_facts=20) - if fact_context: - if mode == "strict": - context = "Canonical Facts:\n" + fact_context - else: - context = "Canonical Facts:\n" + fact_context + "\n\nRetrieved Context:\n" + context self._record_cost_fn( phase="query", @@ -577,22 +465,6 @@ def search_orchestrated( intent_coverage = float(coverage.get("intent_coverage", coverage.get("coverage_ratio", 0.0)) or 0.0) - # Dhee: Self-evolution — record answer generation signal - evolution_layer = self._evolution_layer_fn() - if evolution_layer and reduced_answer: - try: - source_ids = [r.get("id", "") for r in results[:context_limit] if r.get("id")] - source_texts = [r.get("memory", "") for r in results[:context_limit] if r.get("memory")] - evolution_layer.on_answer_generated( - query=query, - answer=str(reduced_answer), - source_memory_ids=source_ids, - source_texts=source_texts, - user_id=user_id or "default", - ) - except Exception as e: - logger.debug("Evolution answer hook skipped: %s", e) - return { "results": results[: max(1, int(limit))], "event_hits": event_hits, @@ -618,247 +490,3 @@ def search_orchestrated( "facts": facts, } - # -- Map-reduce execution ------------------------------------------------- - - def _execute_map_reduce( - self, - *, - query_plan: Any, - orchestrator_llm: Optional[Any], - results: List[Dict[str, Any]], - event_hits: Optional[List[Dict[str, Any]]] = None, - coverage: Optional[Dict[str, Any]], - query: str, - question_type: str, - question_date: str, - mode: str, - search_cap_value: int, - map_max_candidates_value: int, - map_max_chars_value: int, - reflection_max_hops: Optional[int], - search_query: str, - search_limit: int, - rerank: bool, - keyword_search: bool, - hybrid_alpha: float, - include_evidence: bool, - evidence_strategy: str, - evidence_max_chars: int, - evidence_context_lines: int, - user_id: str, - filters: Optional[Dict[str, Any]], - categories: Optional[List[str]], - agent_id: Optional[str], - run_id: Optional[str], - app_id: Optional[str], - ) -> Tuple[Optional[str], List[Dict[str, Any]], bool, int, float, bool, List[str], List[Dict[str, Any]]]: - """Execute map-reduce orchestration with optional reflection. - - Tries event-first reduction (zero LLM cost) before falling back - to LLM-based atomic fact extraction. - - Returns: - ( - reduced_answer, - facts, - map_reduce_used, - reflection_hops, - llm_calls_used, - cache_hit, - reason_codes, - updated_results, - ) - """ - reduced_answer: Optional[str] = None - facts: List[Dict[str, Any]] = [] - map_reduce_used = False - reflection_hops = 0 - llm_calls_used = 0.0 - cache_hit = False - reason_codes: List[str] = [] - active_orchestrator_llm = orchestrator_llm or self._llm_fn() - orch_cfg = getattr(self._config, "orchestration", None) - raw_max_query_llm_calls = getattr(orch_cfg, "max_query_llm_calls", 2) - try: - max_query_llm_calls = int(raw_max_query_llm_calls if raw_max_query_llm_calls is not None else 2) - except (TypeError, ValueError): - max_query_llm_calls = 2 - - coverage_sufficient = bool((coverage or {}).get("sufficient")) - if coverage_sufficient: - reason_codes.append("coverage_sufficient") - else: - reason_codes.append("coverage_insufficient") - - inconsistency = deterministic_inconsistency_check( - question=query, - intent=query_plan.intent, - results=results, - coverage=coverage, - ) - inconsistency_detected = bool(inconsistency.get("inconsistent")) - if inconsistency_detected: - reason_codes.extend(list(inconsistency.get("reasons") or [])) - - # NOTE: Event-first reduction (Phase 2) disabled — episodic events - # alone lack sufficient coverage for accurate multi-session counting. - # The LLM-based map-reduce path below is more reliable. - if mode == "strict": - mode_requires_map_reduce = True - else: - mode_requires_map_reduce = (not coverage_sufficient) or inconsistency_detected - - should_run_map_reduce = bool( - query_plan.should_map_reduce - and active_orchestrator_llm is not None - and results - and mode_requires_map_reduce - ) - if query_plan.should_map_reduce and active_orchestrator_llm is None: - reason_codes.append("no_orchestrator_llm") - if should_run_map_reduce and max_query_llm_calls <= 0: - reason_codes.append("query_llm_budget_exhausted") - should_run_map_reduce = False - - if should_run_map_reduce: - cache_key = self._build_reducer_cache_key( - user_id=user_id, - intent_value=query_plan.intent.value, - query=query, - results=results, - ) - cached = self._get_reducer_cache(cache_key) - if cached and str(cached.get("reduced_answer") or "").strip(): - cached_answer = str(cached.get("reduced_answer") or "").strip() - if not is_low_confidence_answer(cached_answer): - reduced_answer = cached_answer - facts = list(cached.get("facts") or []) - cache_hit = True - reason_codes.append("reducer_cache_hit") - - if not cache_hit: - map_candidates = build_map_candidates( - results, - max_candidates=map_max_candidates_value, - per_candidate_max_chars=map_max_chars_value, - ) - if llm_calls_used < float(max_query_llm_calls): - facts = self._extract_atomic_facts_fn( - llm=active_orchestrator_llm, - question=query, - question_type=question_type, - question_date=question_date, - candidates=map_candidates, - ) - reduced_answer, _ = self._reduce_atomic_facts_fn( - question=query, - intent=query_plan.intent, - facts=facts, - ) - llm_calls_used += 1.0 - map_reduce_used = True - reason_codes.append("map_reduce_executed") - if reduced_answer or facts: - self._put_reducer_cache( - cache_key=cache_key, - reduced_answer=reduced_answer, - facts=facts, - ) - else: - reason_codes.append("query_llm_budget_exhausted") - - max_hops = int( - reflection_max_hops - if reflection_max_hops is not None - else getattr(self._config.orchestration, "reflection_max_hops", 1) - ) - if ( - max_hops > 0 - and (not reduced_answer or is_low_confidence_answer(reduced_answer)) - and search_limit < search_cap_value - and llm_calls_used < float(max_query_llm_calls) - ): - reflection_hops = 1 - reason_codes.append("reflection_executed") - expanded_limit = min(search_cap_value, max(search_limit + 8, search_limit * 2)) - reflection_payload = self._search_fn( - query=search_query, - user_id=user_id, - agent_id=agent_id, - run_id=run_id, - app_id=app_id, - filters=filters, - categories=categories, - limit=expanded_limit, - rerank=rerank, - keyword_search=keyword_search, - hybrid_alpha=hybrid_alpha, - include_evidence=include_evidence, - evidence_strategy=evidence_strategy, - evidence_max_chars=evidence_max_chars, - evidence_context_lines=evidence_context_lines, - ) - reflected_results = list(reflection_payload.get("results", [])) - merged: Dict[str, Dict[str, Any]] = {} - for row in results + reflected_results: - memory_id = str(row.get("id") or "") - existing = merged.get(memory_id) - if not existing or float(row.get("composite_score", row.get("score", 0.0))) > float( - existing.get("composite_score", existing.get("score", 0.0)) - ): - merged[memory_id] = row - results = sorted( - merged.values(), - key=lambda row: float(row.get("composite_score", row.get("score", 0.0))), - reverse=True, - ) - map_candidates = build_map_candidates( - results, - max_candidates=map_max_candidates_value, - per_candidate_max_chars=map_max_chars_value, - ) - if llm_calls_used < float(max_query_llm_calls): - facts = self._extract_atomic_facts_fn( - llm=active_orchestrator_llm, - question=query, - question_type=question_type, - question_date=question_date, - candidates=map_candidates, - ) - reduced_answer, _ = self._reduce_atomic_facts_fn( - question=query, - intent=query_plan.intent, - facts=facts, - ) - llm_calls_used += 1.0 - map_reduce_used = True - if reduced_answer or facts: - self._put_reducer_cache( - cache_key=self._build_reducer_cache_key( - user_id=user_id, - intent_value=query_plan.intent.value, - query=query, - results=results, - ), - reduced_answer=reduced_answer, - facts=facts, - ) - else: - reason_codes.append("query_llm_budget_exhausted") - elif ( - max_hops > 0 - and (not reduced_answer or is_low_confidence_answer(reduced_answer)) - and search_limit < search_cap_value - ): - reason_codes.append("reflection_skipped_budget") - - return ( - reduced_answer, - facts, - map_reduce_used, - reflection_hops, - llm_calls_used, - cache_hit, - list(dict.fromkeys(reason_codes)), - results, - ) diff --git a/dhee/retrieval/reranker.py b/dhee/memory/reranker.py similarity index 100% rename from dhee/retrieval/reranker.py rename to dhee/memory/reranker.py diff --git a/dhee/memory/search_pipeline.py b/dhee/memory/search_pipeline.py index 4e4d1e6..5d76f45 100644 --- a/dhee/memory/search_pipeline.py +++ b/dhee/memory/search_pipeline.py @@ -453,7 +453,7 @@ def search( else: for mid in reecho_ids: self._reecho_memory(mid) - if agent_id: + if agent_id and hasattr(self._db, "add_memory_subscriber"): for mid in subscriber_ids: self._db.add_memory_subscriber(mid, f"agent:{agent_id}", ref_type="weak") diff --git a/dhee/memory/write_pipeline.py b/dhee/memory/write_pipeline.py index 602ce41..94fd173 100644 --- a/dhee/memory/write_pipeline.py +++ b/dhee/memory/write_pipeline.py @@ -467,7 +467,14 @@ def _add_llm_cost(input_tokens: float) -> None: content = explicit_intent.content blocked = detect_sensitive_categories(content) - allow_sensitive = bool(mem_metadata.get("allow_sensitive")) + # allow_sensitive: explicit caller opt-in, or caller explicitly provided + # the content (infer=False / user_provided=True). PII detection is a + # guardrail for agent-inferred memories from raw conversation, not for + # bulk corpus ingestion where the caller owns the content decision. + allow_sensitive = ( + bool(mem_metadata.get("allow_sensitive")) + or bool(mem_metadata.get("user_provided")) + ) if blocked and not allow_sensitive: return { "event": "BLOCKED", @@ -477,7 +484,8 @@ def _add_llm_cost(input_tokens: float) -> None: } is_task_or_note = (mem_metadata or {}).get("memory_type") in ("task", "note") - if not explicit_remember and not is_task_or_note and is_ephemeral(content): + is_user_provided = bool(mem_metadata.get("user_provided")) + if not explicit_remember and not is_task_or_note and not is_user_provided and is_ephemeral(content): return { "event": "SKIP", "reason": "ephemeral", diff --git a/dhee/mini/__init__.py b/dhee/mini/__init__.py deleted file mode 100644 index df441ba..0000000 --- a/dhee/mini/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Dhee Mini — small trainable model for self-evolving cognition.""" - -from dhee.mini.buddhi_mini import BuddhiMini -from dhee.mini.trace_segmenter import TraceSegmenter, TrainingSpan, SpanType - -__all__ = ["BuddhiMini", "TraceSegmenter", "TrainingSpan", "SpanType"] diff --git a/dhee/mini/buddhi_mini.py b/dhee/mini/buddhi_mini.py deleted file mode 100644 index 8fa8cd7..0000000 --- a/dhee/mini/buddhi_mini.py +++ /dev/null @@ -1,414 +0,0 @@ -"""BuddhiMini — small trainable model for self-evolving cognition. - -NOT a separate model from DheeModel. This is DheeModel with 3 new task -heads + a trace-driven data pipeline that produces better training data. - -The self-evolution loop: - 1. Agent uses Dhee (remember/recall/context/checkpoint) - 2. Samskara collects 12 signal types per operation - 3. TraceSegmenter splits trajectories into [REASON]/[ACT]/[MEMORY_OP] - 4. When signals reach critical mass → Nididhyasana triggers - 5. ProgressiveTrainer runs: SFT → DPO → RL - 6. DheeModel updates weights (LoRA merge or GGUF export) - 7. Hot-swapped without restart - -Research basis: - - Structured Agent Distillation (arXiv:2505.13820): span-specific losses - - AgeMem (arXiv:2601.01885): memory ops as RL-optimized tool calls - - EvolveR (arXiv:2510.16079): offline distillation → online retrieval - -New task heads (added to DheeModel's existing 6): - [MEMORY_OP] — predict optimal memory operation for context - [HEURISTIC] — generate abstract heuristic from trajectory - [RETRIEVAL_JUDGE] — predict whether retrieval results are sufficient -""" - -from __future__ import annotations - -import json -import logging -import os -import time -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional - -from dhee.mini.trace_segmenter import TraceSegmenter, TrainingSpan, SpanType - -logger = logging.getLogger(__name__) - -# Training thresholds -_MIN_SFT_SAMPLES = 50 # minimum samples to trigger SFT -_MIN_DPO_PAIRS = 20 # minimum pairs to trigger DPO -_ACCUMULATION_WINDOW = 3600 # seconds between training checks - - -@dataclass -class TrainingBuffer: - """Accumulates training data between training cycles.""" - sft_samples: List[Dict[str, str]] = field(default_factory=list) - dpo_pairs: List[Dict[str, Any]] = field(default_factory=list) - trajectories_ingested: int = 0 - contrastive_pairs_ingested: int = 0 - last_train_time: float = 0.0 - - def to_dict(self) -> Dict[str, Any]: - return { - "sft_samples": len(self.sft_samples), - "dpo_pairs": len(self.dpo_pairs), - "trajectories_ingested": self.trajectories_ingested, - "contrastive_pairs_ingested": self.contrastive_pairs_ingested, - "last_train_time": self.last_train_time, - } - - -class BuddhiMini: - """Small trainable model for self-evolving cognition. - - Wraps the existing DheeModel (Qwen3.5-2B) and adds: - 1. Trace ingestion pipeline (trajectories → training spans) - 2. Training data accumulation with thresholds - 3. Progressive training trigger (SFT → DPO → RL) - 4. 3 new inference task heads - - The model trains itself from the agent's own interaction traces. - No external training data needed. Pure self-evolution. - - Args: - data_dir: Directory for training data and checkpoints - model_size: Not used yet — reserved for future model variants - device: Device for inference (auto-detected if None) - """ - - def __init__( - self, - data_dir: Optional[str] = None, - model_size: str = "2B", - device: Optional[str] = None, - ): - self._data_dir = data_dir or os.path.join( - os.path.expanduser("~"), ".dhee", "mini" - ) - os.makedirs(self._data_dir, exist_ok=True) - - self._segmenter = TraceSegmenter() - self._buffer = TrainingBuffer() - self._model = None # lazy-loaded DheeLLM - self._device = device - - # Load persisted buffer if exists - self._load_buffer() - - # ------------------------------------------------------------------ - # Trace ingestion - # ------------------------------------------------------------------ - - def ingest_trajectory(self, trajectory) -> Dict[str, Any]: - """Ingest a trajectory and segment it into training spans. - - Called automatically by DheePlugin.end_trajectory() or manually. - - Returns: - {"spans": int, "sft_added": int, "dpo_ready": bool} - """ - spans = self._segmenter.segment(trajectory) - if not spans: - return {"spans": 0, "sft_added": 0, "dpo_ready": False} - - # Add successful spans to SFT buffer - sft_examples = self._segmenter.format_for_sft(spans) - self._buffer.sft_samples.extend(sft_examples) - self._buffer.trajectories_ingested += 1 - - # Store spans for DPO pairing later - self._save_spans(spans) - self._save_buffer() - - return { - "spans": len(spans), - "sft_added": len(sft_examples), - "dpo_ready": len(self._buffer.dpo_pairs) >= _MIN_DPO_PAIRS, - } - - def ingest_contrastive_pair( - self, - task_description: str, - success_approach: str, - failure_approach: str, - task_type: str = "general", - ) -> None: - """Ingest a contrastive pair for DPO training. - - Called when checkpoint() receives both what_worked and what_failed. - """ - self._buffer.dpo_pairs.append({ - "prompt": f"[TASK] {task_description}\n[TYPE] {task_type}", - "chosen": success_approach, - "rejected": failure_approach, - "span_type": "reflect", - }) - self._buffer.contrastive_pairs_ingested += 1 - self._save_buffer() - - # ------------------------------------------------------------------ - # Training control - # ------------------------------------------------------------------ - - def should_train(self) -> tuple: - """Check if enough data has accumulated for a training cycle. - - Returns: - (should_train: bool, reason: str) - """ - now = time.time() - if now - self._buffer.last_train_time < _ACCUMULATION_WINDOW: - return False, "Too soon since last training cycle" - - sft_ready = len(self._buffer.sft_samples) >= _MIN_SFT_SAMPLES - dpo_ready = len(self._buffer.dpo_pairs) >= _MIN_DPO_PAIRS - - if sft_ready and dpo_ready: - return True, f"Ready: {len(self._buffer.sft_samples)} SFT + {len(self._buffer.dpo_pairs)} DPO" - if sft_ready: - return True, f"SFT ready: {len(self._buffer.sft_samples)} samples" - if dpo_ready: - return True, f"DPO ready: {len(self._buffer.dpo_pairs)} pairs" - - return False, ( - f"Accumulating: {len(self._buffer.sft_samples)}/{_MIN_SFT_SAMPLES} SFT, " - f"{len(self._buffer.dpo_pairs)}/{_MIN_DPO_PAIRS} DPO" - ) - - def train_cycle(self, stage: str = "auto") -> Dict[str, Any]: - """Run one training cycle. - - Delegates to ProgressiveTrainer or Nididhyasana depending on - what's available. Returns training results. - - Args: - stage: "sft", "dpo", "progressive", or "auto" - """ - result: Dict[str, Any] = {"stage": stage, "status": "skipped"} - - try: - # Try to use Nididhyasana (existing auto-evolution loop) - from dheeModel.training.nididhyasana import NididhyasanaLoop - loop = NididhyasanaLoop(data_dir=self._data_dir) - - # Export training data in Nididhyasana format - training_data = self._export_training_data() - if not training_data: - result["status"] = "no_data" - return result - - # Run cycle - cycle_result = loop.run_cycle( - sft_data=training_data.get("sft", []), - dpo_data=training_data.get("dpo", []), - ) - result["status"] = "completed" - result["cycle"] = cycle_result - except ImportError: - logger.debug("Nididhyasana not available — storing data for manual training") - self._save_training_export() - result["status"] = "data_saved" - result["path"] = os.path.join(self._data_dir, "training_export.jsonl") - except Exception as e: - logger.debug("Training cycle failed: %s", e) - result["status"] = "error" - result["error"] = str(e) - - # Update buffer - self._buffer.last_train_time = time.time() - self._buffer.sft_samples = [] # Clear used samples - self._buffer.dpo_pairs = [] - self._save_buffer() - - return result - - # ------------------------------------------------------------------ - # Inference (edge-optimized task heads) - # ------------------------------------------------------------------ - - def classify_memory_op(self, context: str) -> str: - """Predict optimal memory operation for current context. - - Task head: [MEMORY_OP] - Returns: "store" | "retrieve" | "update" | "summarize" | "discard" | "none" - """ - model = self._get_model() - if model is None: - return self._heuristic_classify_memory_op(context) - - try: - response = model.generate_with_task( - task="MEMORY_OP", - prompt=context[:1000], - ) - op = response.strip().lower() - valid_ops = {"store", "retrieve", "update", "summarize", "discard", "none"} - return op if op in valid_ops else "none" - except Exception: - return self._heuristic_classify_memory_op(context) - - def generate_heuristic(self, trajectory_summary: str) -> str: - """Generate an abstract heuristic from a trajectory summary. - - Task head: [HEURISTIC] - Returns: A transferable reasoning pattern as natural language. - """ - model = self._get_model() - if model is None: - return f"From experience: {trajectory_summary[:200]}" - - try: - return model.generate_with_task( - task="HEURISTIC", - prompt=trajectory_summary[:2000], - ) - except Exception: - return f"From experience: {trajectory_summary[:200]}" - - def predict_retrieval_quality( - self, query: str, results: List[Dict[str, Any]], - ) -> float: - """Predict whether retrieval results are sufficient. - - Task head: [RETRIEVAL_JUDGE] - Returns: 0.0 (insufficient) to 1.0 (fully sufficient) - """ - model = self._get_model() - if model is None: - return self._heuristic_retrieval_quality(query, results) - - try: - results_text = "\n".join( - f"- {r.get('memory', '')[:100]} (score={r.get('score', 0):.2f})" - for r in results[:5] - ) - prompt = f"Query: {query}\nResults:\n{results_text}" - response = model.generate_with_task( - task="RETRIEVAL_JUDGE", - prompt=prompt, - ) - return max(0.0, min(1.0, float(response.strip()))) - except Exception: - return self._heuristic_retrieval_quality(query, results) - - # ------------------------------------------------------------------ - # Stats - # ------------------------------------------------------------------ - - def get_stats(self) -> Dict[str, Any]: - """Get BuddhiMini status.""" - should, reason = self.should_train() - return { - "buffer": self._buffer.to_dict(), - "should_train": should, - "train_reason": reason, - "model_loaded": self._model is not None, - "data_dir": self._data_dir, - } - - # ------------------------------------------------------------------ - # Internals - # ------------------------------------------------------------------ - - def _get_model(self): - """Lazy-load the DheeModel.""" - if self._model is not None: - return self._model - try: - from dhee.llms.dhee import DheeLLM - self._model = DheeLLM(config={"device": self._device} if self._device else {}) - return self._model - except Exception: - return None - - def _heuristic_classify_memory_op(self, context: str) -> str: - """Rule-based fallback for memory op classification.""" - cl = context.lower() - if any(w in cl for w in ["remember", "store", "save", "note"]): - return "store" - if any(w in cl for w in ["recall", "search", "find", "what did"]): - return "retrieve" - if any(w in cl for w in ["update", "change", "correct"]): - return "update" - if any(w in cl for w in ["forget", "delete", "remove"]): - return "discard" - if any(w in cl for w in ["summarize", "consolidate", "compress"]): - return "summarize" - return "none" - - def _heuristic_retrieval_quality( - self, query: str, results: List[Dict[str, Any]], - ) -> float: - """Rule-based fallback for retrieval quality prediction.""" - if not results: - return 0.0 - top_score = results[0].get("score", 0) if results else 0 - count = len(results) - # Simple heuristic: score * coverage - coverage = min(count / 3.0, 1.0) - return round(min(top_score * coverage, 1.0), 3) - - def _export_training_data(self) -> Optional[Dict[str, List]]: - """Export accumulated buffer as training data.""" - if not self._buffer.sft_samples and not self._buffer.dpo_pairs: - return None - return { - "sft": list(self._buffer.sft_samples), - "dpo": list(self._buffer.dpo_pairs), - } - - def _save_training_export(self) -> None: - """Save training data to JSONL for manual training.""" - path = os.path.join(self._data_dir, "training_export.jsonl") - try: - with open(path, "a", encoding="utf-8") as f: - for sample in self._buffer.sft_samples: - f.write(json.dumps({"type": "sft", **sample}) + "\n") - for pair in self._buffer.dpo_pairs: - f.write(json.dumps({"type": "dpo", **pair}) + "\n") - except OSError as e: - logger.debug("Failed to save training export: %s", e) - - def _save_spans(self, spans: List[TrainingSpan]) -> None: - """Persist spans for later DPO pairing.""" - path = os.path.join(self._data_dir, "spans.jsonl") - try: - with open(path, "a", encoding="utf-8") as f: - for span in spans: - f.write(json.dumps(span.to_dict()) + "\n") - except OSError as e: - logger.debug("Failed to save spans: %s", e) - - def _save_buffer(self) -> None: - """Persist buffer metadata.""" - path = os.path.join(self._data_dir, "buffer.json") - try: - data = { - "sft_count": len(self._buffer.sft_samples), - "dpo_count": len(self._buffer.dpo_pairs), - "trajectories_ingested": self._buffer.trajectories_ingested, - "contrastive_pairs_ingested": self._buffer.contrastive_pairs_ingested, - "last_train_time": self._buffer.last_train_time, - } - with open(path, "w", encoding="utf-8") as f: - json.dump(data, f) - except OSError: - pass - - def _load_buffer(self) -> None: - """Load persisted buffer metadata.""" - path = os.path.join(self._data_dir, "buffer.json") - if not os.path.exists(path): - return - try: - with open(path, "r", encoding="utf-8") as f: - data = json.load(f) - self._buffer.trajectories_ingested = data.get("trajectories_ingested", 0) - self._buffer.contrastive_pairs_ingested = data.get("contrastive_pairs_ingested", 0) - self._buffer.last_train_time = data.get("last_train_time", 0.0) - except (OSError, json.JSONDecodeError): - pass diff --git a/dhee/mini/progressive_trainer.py b/dhee/mini/progressive_trainer.py deleted file mode 100644 index b6f082e..0000000 --- a/dhee/mini/progressive_trainer.py +++ /dev/null @@ -1,411 +0,0 @@ -"""Progressive Trainer — 3-stage training for BuddhiMini. - -Based on AgeMem (arXiv:2601.01885): memory ops as RL-optimized tool calls -with 3-stage progressive training for optimal learning. - -Stage 1 — SFT (Supervised Fine-Tuning): - Train on high-quality trajectory spans from TraceSegmenter. - Each span is a (task_context → [SPAN_TYPE] output) example. - Span-specific losses per Structured Agent Distillation (arXiv:2505.13820). - -Stage 2 — DPO (Direct Preference Optimization): - Train on contrastive pairs from ContrastiveStore. - Each pair: (task_context, success_approach, failure_approach). - The model learns to prefer successful reasoning patterns. - -Stage 3 — RL (Retrieval-Quality Reward): - Use retrieval quality as reward signal. - After the model updates, measure whether recall@K improves. - If yes, keep the update. If no, rollback. - -The trainer does NOT run training itself — it curates data and delegates -to the existing Nididhyasana/train.py pipeline. Its job is to decide -WHAT to train on and in WHAT order. -""" - -from __future__ import annotations - -import json -import logging -import os -import time -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional - -logger = logging.getLogger(__name__) - - -@dataclass -class TrainingStageResult: - """Result from a single training stage.""" - - stage: str # sft | dpo | rl - status: str # completed | skipped | error - samples_used: int = 0 - metrics: Dict[str, float] = field(default_factory=dict) - duration_seconds: float = 0.0 - error: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - d = { - "stage": self.stage, - "status": self.status, - "samples_used": self.samples_used, - "duration_seconds": round(self.duration_seconds, 1), - } - if self.metrics: - d["metrics"] = self.metrics - if self.error: - d["error"] = self.error - return d - - -@dataclass -class ProgressiveTrainingResult: - """Result from a full progressive training cycle.""" - - cycle_id: str - stages: List[TrainingStageResult] - total_duration: float = 0.0 - model_improved: bool = False - data_exported_path: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - return { - "cycle_id": self.cycle_id, - "stages": [s.to_dict() for s in self.stages], - "total_duration": round(self.total_duration, 1), - "model_improved": self.model_improved, - "data_exported_path": self.data_exported_path, - } - - -class ProgressiveTrainer: - """Curates and orders training data for the 3-stage progressive pipeline. - - This is the brain that decides what data to train on. The actual - training execution is delegated to Nididhyasana (which calls train.py). - - Usage: - trainer = ProgressiveTrainer(data_dir="/path/to/training") - result = trainer.run_cycle( - sft_data=[...], # from TraceSegmenter - dpo_data=[...], # from ContrastiveStore - samskara_data={...} # from Samskara.get_training_data() - ) - """ - - # Minimum data thresholds - MIN_SFT = 20 - MIN_DPO = 10 - - def __init__( - self, - data_dir: Optional[str] = None, - train_fn=None, - ): - self._dir = data_dir or os.path.join( - os.path.expanduser("~"), ".dhee", "progressive_training" - ) - os.makedirs(self._dir, exist_ok=True) - self._train_fn = train_fn # injected training function - self._cycle_count = 0 - self._history: List[Dict[str, Any]] = [] - self._load_history() - - def run_cycle( - self, - sft_data: Optional[List[Dict[str, str]]] = None, - dpo_data: Optional[List[Dict[str, str]]] = None, - samskara_data: Optional[Dict[str, Any]] = None, - ) -> ProgressiveTrainingResult: - """Run a full progressive training cycle: SFT → DPO → RL. - - Each stage is optional — skipped if insufficient data. - """ - cycle_id = f"prog_{self._cycle_count:04d}_{int(time.time())}" - self._cycle_count += 1 - cycle_dir = os.path.join(self._dir, cycle_id) - os.makedirs(cycle_dir, exist_ok=True) - - start = time.time() - stages: List[TrainingStageResult] = [] - - # Merge samskara SFT samples with explicit SFT data - all_sft = list(sft_data or []) - if samskara_data: - all_sft.extend(samskara_data.get("sft_samples", [])) - - # Merge samskara DPO pairs with explicit DPO data - all_dpo = list(dpo_data or []) - if samskara_data: - all_dpo.extend(samskara_data.get("dpo_pairs", [])) - - # Weight data by vasana degradation (focus on weak areas) - if samskara_data: - all_sft = self._weight_by_vasana(all_sft, samskara_data) - - # --- Stage 1: SFT --- - sft_result = self._run_sft(all_sft, cycle_dir) - stages.append(sft_result) - - # --- Stage 2: DPO --- - dpo_result = self._run_dpo(all_dpo, cycle_dir) - stages.append(dpo_result) - - # --- Stage 3: RL (evaluation-based) --- - rl_result = self._run_rl_eval(cycle_dir, sft_result, dpo_result) - stages.append(rl_result) - - total_duration = time.time() - start - completed_stages = [s for s in stages if s.status == "completed"] - - result = ProgressiveTrainingResult( - cycle_id=cycle_id, - stages=stages, - total_duration=total_duration, - model_improved=len(completed_stages) > 0, - data_exported_path=cycle_dir, - ) - - self._record_history(result) - return result - - # ------------------------------------------------------------------ - # Stage 1: SFT - # ------------------------------------------------------------------ - - def _run_sft( - self, data: List[Dict[str, str]], cycle_dir: str, - ) -> TrainingStageResult: - """Stage 1: Supervised Fine-Tuning on trajectory spans.""" - if len(data) < self.MIN_SFT: - return TrainingStageResult( - stage="sft", status="skipped", - metrics={"reason": f"insufficient data ({len(data)}/{self.MIN_SFT})"}, - ) - - start = time.time() - - # Curate: prioritize diverse span types - curated = self._curate_sft(data) - - # Export as train.jsonl - train_path = os.path.join(cycle_dir, "sft_train.jsonl") - self._write_jsonl(train_path, curated) - - # Try to run actual training - metrics = {} - if self._train_fn: - try: - train_result = self._train_fn( - data_dir=cycle_dir, - output_dir=os.path.join(cycle_dir, "sft_output"), - ) - metrics = train_result if isinstance(train_result, dict) else {} - except Exception as e: - return TrainingStageResult( - stage="sft", status="error", - samples_used=len(curated), - duration_seconds=time.time() - start, - error=str(e), - ) - else: - metrics["data_exported"] = train_path - - return TrainingStageResult( - stage="sft", status="completed", - samples_used=len(curated), - metrics=metrics, - duration_seconds=time.time() - start, - ) - - def _curate_sft(self, data: List[Dict[str, str]]) -> List[Dict[str, str]]: - """Curate SFT data: balance span types, cap per type.""" - by_type: Dict[str, List] = {} - for sample in data: - t = sample.get("type", "general") - by_type.setdefault(t, []).append(sample) - - # Take up to 50 per type for balance - curated = [] - for samples in by_type.values(): - curated.extend(samples[:50]) - - return curated - - # ------------------------------------------------------------------ - # Stage 2: DPO - # ------------------------------------------------------------------ - - def _run_dpo( - self, data: List[Dict[str, str]], cycle_dir: str, - ) -> TrainingStageResult: - """Stage 2: Direct Preference Optimization on contrastive pairs.""" - if len(data) < self.MIN_DPO: - return TrainingStageResult( - stage="dpo", status="skipped", - metrics={"reason": f"insufficient data ({len(data)}/{self.MIN_DPO})"}, - ) - - start = time.time() - - # Export as dpo_pairs.jsonl - dpo_path = os.path.join(cycle_dir, "dpo_pairs.jsonl") - self._write_jsonl(dpo_path, data) - - metrics = {} - if self._train_fn: - try: - train_result = self._train_fn( - data_dir=cycle_dir, - output_dir=os.path.join(cycle_dir, "dpo_output"), - dpo_mode=True, - ) - metrics = train_result if isinstance(train_result, dict) else {} - except Exception as e: - return TrainingStageResult( - stage="dpo", status="error", - samples_used=len(data), - duration_seconds=time.time() - start, - error=str(e), - ) - else: - metrics["data_exported"] = dpo_path - - return TrainingStageResult( - stage="dpo", status="completed", - samples_used=len(data), - metrics=metrics, - duration_seconds=time.time() - start, - ) - - # ------------------------------------------------------------------ - # Stage 3: RL (reward = retrieval quality) - # ------------------------------------------------------------------ - - def _run_rl_eval( - self, - cycle_dir: str, - sft_result: TrainingStageResult, - dpo_result: TrainingStageResult, - ) -> TrainingStageResult: - """Stage 3: RL evaluation — decide whether to keep updates. - - In practice, this stage verifies that the SFT/DPO changes - didn't degrade retrieval quality. It's a gate, not a trainer. - """ - start = time.time() - - # If neither SFT nor DPO actually trained, skip - if sft_result.status != "completed" and dpo_result.status != "completed": - return TrainingStageResult( - stage="rl", status="skipped", - metrics={"reason": "no prior stages completed"}, - ) - - # Compute aggregate quality from what we have - sft_samples = sft_result.samples_used - dpo_samples = dpo_result.samples_used - total_samples = sft_samples + dpo_samples - - # Simple quality heuristic: more diverse data = more likely to help - quality_estimate = min(1.0, total_samples / 100.0) - - metrics = { - "quality_estimate": round(quality_estimate, 3), - "sft_contribution": sft_samples, - "dpo_contribution": dpo_samples, - "verdict": "keep" if quality_estimate > 0.3 else "rollback", - } - - return TrainingStageResult( - stage="rl", status="completed", - metrics=metrics, - duration_seconds=time.time() - start, - ) - - # ------------------------------------------------------------------ - # Vasana-weighted data emphasis - # ------------------------------------------------------------------ - - def _weight_by_vasana( - self, - data: List[Dict[str, str]], - samskara_data: Dict[str, Any], - ) -> List[Dict[str, str]]: - """Duplicate samples from degrading dimensions for emphasis. - - If retrieval_recall is degrading, duplicate RETRIEVAL_HIT samples. - Same weighting logic as Nididhyasana._curate_dataset(). - """ - degrading = set(samskara_data.get("degrading_dimensions", [])) - if not degrading: - return data - - # Map degrading dimensions to sample types - dim_to_type = { - "retrieval_recall": {"retrieval_hit", "retrieval_miss"}, - "retrieval_precision": {"retrieval_hit"}, - "answer_quality": {"answer_accepted", "answer_corrected"}, - "fact_extraction": {"extraction"}, - } - - emphasized_types = set() - for dim in degrading: - emphasized_types |= dim_to_type.get(dim, set()) - - # Duplicate matching samples (2x weight) - weighted = [] - for sample in data: - weighted.append(sample) - sample_type = sample.get("type", "").lower() - valence = sample.get("valence", "") - if sample_type in emphasized_types or valence == "klishta": - weighted.append(sample) # duplicate = 2x emphasis - - return weighted - - # ------------------------------------------------------------------ - # Helpers - # ------------------------------------------------------------------ - - def _write_jsonl(self, path: str, data: List[Dict]) -> None: - try: - with open(path, "w", encoding="utf-8") as f: - for item in data: - f.write(json.dumps(item, ensure_ascii=False) + "\n") - except OSError as e: - logger.debug("Failed to write %s: %s", path, e) - - def _record_history(self, result: ProgressiveTrainingResult) -> None: - self._history.append(result.to_dict()) - path = os.path.join(self._dir, "history.jsonl") - try: - with open(path, "a", encoding="utf-8") as f: - f.write(json.dumps(result.to_dict(), ensure_ascii=False) + "\n") - except OSError: - pass - - def _load_history(self) -> None: - path = os.path.join(self._dir, "history.jsonl") - if not os.path.exists(path): - return - try: - with open(path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if line: - try: - self._history.append(json.loads(line)) - self._cycle_count += 1 - except json.JSONDecodeError: - continue - except OSError: - pass - - def get_stats(self) -> Dict[str, Any]: - return { - "cycles_completed": self._cycle_count, - "last_cycle": self._history[-1] if self._history else None, - } diff --git a/dhee/mini/trace_segmenter.py b/dhee/mini/trace_segmenter.py deleted file mode 100644 index 5522981..0000000 --- a/dhee/mini/trace_segmenter.py +++ /dev/null @@ -1,246 +0,0 @@ -"""Trace segmenter — converts agent trajectories into training spans. - -Based on Structured Agent Distillation (Liu et al., arXiv:2505.13820): -segments agent interaction traces into [REASON], [ACT], and [MEMORY_OP] -spans with span-specific training losses for more efficient learning. - -The key insight: different types of agent behavior (reasoning vs action -vs memory management) benefit from different training objectives. -Token-level distillation treats them uniformly and is less effective. -""" - -from __future__ import annotations - -import re -import uuid -from dataclasses import dataclass, field -from enum import Enum -from typing import Any, Dict, List, Optional, Tuple - - -class SpanType(str, Enum): - """Type of training span — each gets its own loss weight.""" - - REASON = "reason" # Internal reasoning, planning, analysis - ACT = "act" # Tool calls, commands, actions taken - MEMORY_OP = "memory_op" # Memory operations (store, retrieve, update, summarize, discard) - REFLECT = "reflect" # Self-reflection, insight synthesis - OBSERVE = "observe" # Observation, reading results, understanding state - - -@dataclass -class TrainingSpan: - """A single segment of an agent trajectory for training. - - Each span has a type, the text content, and metadata about the - trajectory it came from. Spans from successful trajectories are - used for SFT; paired success/failure spans for DPO. - """ - id: str - span_type: SpanType - content: str # the text of this span - context_before: str # preceding context (for input) - trajectory_id: str - step_index: int - task_description: str - success: bool # was the overall trajectory successful? - metadata: Dict[str, Any] = field(default_factory=dict) - - def to_sft_example(self) -> Dict[str, str]: - """Format as SFT training example (input → output).""" - return { - "input": f"[TASK] {self.task_description}\n[CONTEXT] {self.context_before}", - "output": f"[{self.span_type.value.upper()}] {self.content}", - "type": self.span_type.value, - } - - def to_dict(self) -> Dict[str, Any]: - return { - "id": self.id, - "span_type": self.span_type.value, - "content": self.content, - "context_before": self.context_before[:500], - "trajectory_id": self.trajectory_id, - "step_index": self.step_index, - "task_description": self.task_description, - "success": self.success, - } - - -# Patterns for classifying steps into span types -_MEMORY_PATTERNS = re.compile( - r"(?:remember|recall|search|store|forget|update.*memor|delete.*memor|checkpoint)", - re.IGNORECASE, -) -_REASON_PATTERNS = re.compile( - r"(?:think|plan|analyze|consider|reason|decide|evaluate|assess|compare)", - re.IGNORECASE, -) -_REFLECT_PATTERNS = re.compile( - r"(?:reflect|insight|learned|worked|failed|improve|heuristic|takeaway)", - re.IGNORECASE, -) - - -class TraceSegmenter: - """Segments agent trajectories into typed training spans. - - Takes a Trajectory (from dhee.skills.trajectory) and produces - a list of TrainingSpan objects suitable for: - - SFT: train on successful spans - - DPO: pair successful/failed spans for preference learning - - RL: use retrieval quality as reward signal - - Usage: - from dhee.mini.trace_segmenter import TraceSegmenter - from dhee.skills.schema import Trajectory - - segmenter = TraceSegmenter() - spans = segmenter.segment(trajectory) - - # For SFT training - sft_data = segmenter.format_for_sft(spans) - - # For DPO training - dpo_data = segmenter.format_for_dpo(success_spans, failure_spans) - """ - - def segment(self, trajectory) -> List[TrainingSpan]: - """Segment a trajectory into typed training spans. - - Args: - trajectory: A Trajectory object from dhee.skills.schema - - Returns: - List of TrainingSpan objects - """ - spans: List[TrainingSpan] = [] - context_parts: List[str] = [] - - for i, step in enumerate(trajectory.steps): - # Classify the step - span_type = self._classify_step(step) - - # Build content from step - content = self._extract_content(step) - if not content: - continue - - # Context is everything before this step - context_before = "\n".join(context_parts[-3:]) # last 3 steps - - span = TrainingSpan( - id=str(uuid.uuid4()), - span_type=span_type, - content=content, - context_before=context_before, - trajectory_id=trajectory.id, - step_index=i, - task_description=trajectory.task_description, - success=trajectory.success, - metadata={ - "tool": getattr(step, "tool", ""), - "error": getattr(step, "error", None), - "duration_ms": getattr(step, "duration_ms", None), - }, - ) - spans.append(span) - - # Update rolling context - context_parts.append(f"[{span_type.value}] {content[:200]}") - - return spans - - def format_for_sft(self, spans: List[TrainingSpan]) -> List[Dict[str, str]]: - """Format successful spans as SFT training examples.""" - return [ - span.to_sft_example() - for span in spans - if span.success - ] - - def format_for_dpo( - self, - success_spans: List[TrainingSpan], - failure_spans: List[TrainingSpan], - ) -> List[Dict[str, Any]]: - """Create DPO training pairs from success/failure spans. - - Pairs are created by matching spans with the same span_type - and similar step_index from successful and failed trajectories. - """ - pairs = [] - - # Group by span type - success_by_type: Dict[str, List[TrainingSpan]] = {} - failure_by_type: Dict[str, List[TrainingSpan]] = {} - - for s in success_spans: - success_by_type.setdefault(s.span_type.value, []).append(s) - for f in failure_spans: - failure_by_type.setdefault(f.span_type.value, []).append(f) - - # Create pairs for each shared type - for span_type in set(success_by_type) & set(failure_by_type): - chosen_list = success_by_type[span_type] - rejected_list = failure_by_type[span_type] - - # Pair by position (zip truncates to shorter) - for chosen, rejected in zip(chosen_list, rejected_list): - pairs.append({ - "prompt": f"[TASK] {chosen.task_description}\n" - f"[CONTEXT] {chosen.context_before}", - "chosen": f"[{span_type.upper()}] {chosen.content}", - "rejected": f"[{span_type.upper()}] {rejected.content}", - "span_type": span_type, - }) - - return pairs - - def _classify_step(self, step) -> SpanType: - """Classify a trajectory step into a span type.""" - action = getattr(step, "action", "") - tool = getattr(step, "tool", "") - result_summary = getattr(step, "result_summary", "") - combined = f"{action} {tool} {result_summary}" - - # Memory operations - if _MEMORY_PATTERNS.search(combined): - return SpanType.MEMORY_OP - - # Reflection - if _REFLECT_PATTERNS.search(combined): - return SpanType.REFLECT - - # Reasoning (no tool call, just thinking) - if not tool and _REASON_PATTERNS.search(combined): - return SpanType.REASON - - # Tool call = action - if tool: - return SpanType.ACT - - # Observation (reading results) - if result_summary and not tool: - return SpanType.OBSERVE - - # Default to reasoning - return SpanType.REASON - - def _extract_content(self, step) -> str: - """Extract the text content from a trajectory step.""" - parts = [] - action = getattr(step, "action", "") - tool = getattr(step, "tool", "") - result_summary = getattr(step, "result_summary", "") - - if action: - parts.append(action) - if tool: - args = getattr(step, "args", {}) - args_str = ", ".join(f"{k}={v}" for k, v in list(args.items())[:3]) if args else "" - parts.append(f"tool={tool}({args_str})") - if result_summary: - parts.append(f"→ {result_summary[:300]}") - - return " | ".join(parts) diff --git a/dhee/adapters/base.py b/dhee/plugin.py similarity index 100% rename from dhee/adapters/base.py rename to dhee/plugin.py diff --git a/dhee/retrieval/__init__.py b/dhee/retrieval/__init__.py deleted file mode 100644 index 49ed407..0000000 --- a/dhee/retrieval/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Engram v2 retrieval components.""" - -try: - from dhee.retrieval.dual_search import DualSearchEngine -except ImportError: - DualSearchEngine = None - -from dhee.retrieval.reranker import NvidiaReranker, create_reranker - -__all__ = ["DualSearchEngine", "NvidiaReranker", "create_reranker"] diff --git a/dhee/simple.py b/dhee/simple.py index 376f9b0..cd3184d 100644 --- a/dhee/simple.py +++ b/dhee/simple.py @@ -41,17 +41,31 @@ def _detect_provider() -> str: """Detect which LLM/embedder provider to use based on environment.""" - if os.environ.get("GEMINI_API_KEY"): - return "gemini" if os.environ.get("OPENAI_API_KEY"): return "openai" - # Default to gemini if no key found (will fail later with clear error) - return "gemini" + if os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY"): + return "gemini" + if ( + os.environ.get("NVIDIA_API_KEY") + or os.environ.get("NVIDIA_QWEN_API_KEY") + or os.environ.get("NVIDIA_EMBEDDING_API_KEY") + or os.environ.get("NVIDIA_EMBED_API_KEY") + or os.environ.get("NVIDIA_LLAMA_4_MAV_API_KEY") + ): + return "nvidia" + # Zero-config fallback: local simple embedder + mock LLM. + return "mock" def _get_embedding_dims(provider: str) -> int: """Get embedding dimensions for provider.""" - return 3072 if provider == "gemini" else 1536 + if provider == "gemini": + return 3072 + if provider == "openai": + return 1536 + if provider == "nvidia": + return 2048 + return 384 def _has_api_key() -> bool: diff --git a/dhee/teaching/__init__.py b/dhee/teaching/__init__.py deleted file mode 100644 index 80309bf..0000000 --- a/dhee/teaching/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Teaching primitives for Engram — concepts, student models, and teaching memory. - -Domain objects stored as Engram memories with ``memory_type`` + metadata dict, -following the same pattern as ``engram.memory.tasks``. -""" - -from dhee.teaching.config import TeachingConfig -from dhee.teaching.concepts import ConceptStore -from dhee.teaching.student_model import StudentModel -from dhee.teaching.teaching_memory import TeachingMemory - -__all__ = [ - "TeachingConfig", - "ConceptStore", - "StudentModel", - "TeachingMemory", -] diff --git a/dhee/teaching/concepts.py b/dhee/teaching/concepts.py deleted file mode 100644 index 0f3ad5a..0000000 --- a/dhee/teaching/concepts.py +++ /dev/null @@ -1,307 +0,0 @@ -"""ConceptStore — curriculum concepts stored as Engram memories. - -Concepts are stored with ``memory_type="concept"`` under a shared -``user_id`` (the curriculum namespace). Prerequisites and cross-subject -links are represented via the knowledge graph (``RelationType.REQUIRES`` -and ``RelationType.RELATED_TO``). -""" - -from __future__ import annotations - -import json -import logging -import uuid -from typing import Any, Dict, List, Optional - -from dhee.teaching.config import TeachingConfig - -logger = logging.getLogger(__name__) - - -class ConceptStore: - """CRUD and graph operations over curriculum concepts.""" - - def __init__(self, memory: "CoreMemory", config: TeachingConfig | None = None): # noqa: F821 - self.memory = memory - self.config = config or TeachingConfig() - self._namespace = self.config.concept_namespace - - # ------------------------------------------------------------------ - # Create / update - # ------------------------------------------------------------------ - - def add_concept( - self, - concept_id: str, - name: str, - subject: str, - *, - difficulty: float = 0.5, - prerequisites: List[str] | None = None, - cross_subject_links: List[str] | None = None, - keywords: List[str] | None = None, - description: str = "", - ) -> Dict[str, Any]: - """Add a concept to the store (idempotent by ``concept_id``).""" - - # Check for existing concept with same concept_id - existing = self._find_by_concept_id(concept_id) - if existing: - return existing - - meta = { - "memory_type": "concept", - "concept_id": concept_id, - "subject": subject, - "difficulty": difficulty, - "prerequisites": prerequisites or [], - "cross_subject_links": cross_subject_links or [], - "keywords": keywords or [], - } - - content = f"{name}: {description}" if description else name - result = self.memory.add( - content=content, - user_id=self._namespace, - metadata=meta, - categories=[f"concept/{subject}"], - ) - - mem_id = self._extract_id(result) - if not mem_id: - logger.warning("Failed to store concept %s", concept_id) - return {"error": "Failed to store concept"} - - # Create prerequisite edges via knowledge graph - if prerequisites and hasattr(self.memory, "knowledge_graph"): - graph = self.memory.knowledge_graph - for prereq_id in prerequisites: - prereq = self._find_by_concept_id(prereq_id) - prereq_mem_id = prereq.get("memory_id") if prereq else None - if prereq_mem_id: - from dhee.core.graph import RelationType - - graph.add_relationship( - source_id=mem_id, - target_id=prereq_mem_id, - relation_type=RelationType.REQUIRES, - metadata={"concept_link": True}, - ) - - # Create cross-subject edges - if cross_subject_links and hasattr(self.memory, "knowledge_graph"): - graph = self.memory.knowledge_graph - for linked_id in cross_subject_links: - linked = self._find_by_concept_id(linked_id) - linked_mem_id = linked.get("memory_id") if linked else None - if linked_mem_id: - from dhee.core.graph import RelationType - - graph.add_relationship( - source_id=mem_id, - target_id=linked_mem_id, - relation_type=RelationType.RELATED_TO, - metadata={"cross_subject": True}, - ) - - return { - "memory_id": mem_id, - "concept_id": concept_id, - "name": name, - "subject": subject, - } - - # ------------------------------------------------------------------ - # Read - # ------------------------------------------------------------------ - - def get_concept(self, concept_id: str) -> Optional[Dict[str, Any]]: - """Retrieve a concept by its concept_id.""" - return self._find_by_concept_id(concept_id) - - def search_concepts( - self, - query: str, - subject: str | None = None, - limit: int = 10, - ) -> List[Dict[str, Any]]: - """Semantic search over concepts, optionally filtered by subject.""" - results = self.memory.search( - query=query, - user_id=self._namespace, - limit=limit * 2, - ) - - concepts = [] - for mem in results: - md = self._parse_metadata(mem) - if md.get("memory_type") != "concept": - continue - if subject and md.get("subject") != subject: - continue - concepts.append(self._format_concept(mem)) - if len(concepts) >= limit: - break - - return concepts - - def get_prerequisites( - self, - concept_id: str, - depth: int = 2, - ) -> List[Dict[str, Any]]: - """Traverse prerequisite graph to the given depth.""" - root = self._find_by_concept_id(concept_id) - if not root: - return [] - - root_mem_id = root.get("memory_id") - if not root_mem_id: - return [] - - # Try graph traversal first - if hasattr(self.memory, "knowledge_graph"): - from dhee.core.graph import RelationType - - graph = self.memory.knowledge_graph - related = graph.get_related_memories( - root_mem_id, - relation_types=[RelationType.REQUIRES], - max_depth=depth, - ) - prereqs = [] - for rel in related: - target_id = rel.get("target_id") or rel.get("memory_id") - if target_id: - mem = self.memory.get(target_id) - if mem: - prereqs.append(self._format_concept(mem)) - return prereqs - - # Fallback: use metadata-based prerequisites - md = self._parse_metadata(root) - prereq_ids = md.get("prerequisites", []) - prereqs = [] - for pid in prereq_ids: - c = self._find_by_concept_id(pid) - if c: - prereqs.append(c) - return prereqs - - def get_cross_subject_links( - self, - concept_id: str, - ) -> List[Dict[str, Any]]: - """Find concepts linked across subjects.""" - root = self._find_by_concept_id(concept_id) - if not root: - return [] - - root_mem_id = root.get("memory_id") - if not root_mem_id: - return [] - - if hasattr(self.memory, "knowledge_graph"): - from dhee.core.graph import RelationType - - graph = self.memory.knowledge_graph - related = graph.get_related_memories( - root_mem_id, - relation_types=[RelationType.RELATED_TO], - max_depth=1, - ) - links = [] - for rel in related: - target_id = rel.get("target_id") or rel.get("memory_id") - if target_id: - mem = self.memory.get(target_id) - if mem: - links.append(self._format_concept(mem)) - return links - - # Fallback - md = self._parse_metadata(root) - linked_ids = md.get("cross_subject_links", []) - return [c for pid in linked_ids if (c := self._find_by_concept_id(pid))] - - def link_concepts_cross_subject( - self, - concept_a_id: str, - concept_b_id: str, - ) -> bool: - """Create a cross-subject edge between two concepts.""" - a = self._find_by_concept_id(concept_a_id) - b = self._find_by_concept_id(concept_b_id) - if not a or not b: - return False - - a_mem_id = a.get("memory_id") - b_mem_id = b.get("memory_id") - if not a_mem_id or not b_mem_id: - return False - - if hasattr(self.memory, "knowledge_graph"): - from dhee.core.graph import RelationType - - self.memory.knowledge_graph.add_relationship( - source_id=a_mem_id, - target_id=b_mem_id, - relation_type=RelationType.RELATED_TO, - metadata={"cross_subject": True}, - ) - return True - - return False - - # ------------------------------------------------------------------ - # Helpers - # ------------------------------------------------------------------ - - def _find_by_concept_id(self, concept_id: str) -> Optional[Dict[str, Any]]: - """Find concept memory by concept_id in metadata.""" - if hasattr(self.memory, "db") and hasattr(self.memory.db, "get_all_memories"): - memories = self.memory.db.get_all_memories( - user_id=self._namespace, - memory_type="concept", - limit=500, - ) - for mem in memories: - md = self._parse_metadata(mem) - if md.get("concept_id") == concept_id: - return self._format_concept(mem) - return None - - @staticmethod - def _extract_id(result: Dict[str, Any]) -> Optional[str]: - if isinstance(result, dict): - results = result.get("results", []) - if results and isinstance(results, list): - first = results[0] - return first.get("id") or first.get("memory_id") - return result.get("id") or result.get("memory_id") - return None - - @staticmethod - def _parse_metadata(mem: Dict[str, Any]) -> Dict[str, Any]: - md = mem.get("metadata", {}) - if isinstance(md, str): - try: - md = json.loads(md) - except (json.JSONDecodeError, TypeError): - md = {} - return md - - @classmethod - def _format_concept(cls, mem: Dict[str, Any]) -> Dict[str, Any]: - md = cls._parse_metadata(mem) - return { - "memory_id": mem.get("id"), - "concept_id": md.get("concept_id", ""), - "name": mem.get("content", "").split(":")[0].strip(), - "subject": md.get("subject", ""), - "difficulty": md.get("difficulty", 0.5), - "prerequisites": md.get("prerequisites", []), - "cross_subject_links": md.get("cross_subject_links", []), - "keywords": md.get("keywords", []), - "strength": mem.get("strength", 1.0), - } diff --git a/dhee/teaching/config.py b/dhee/teaching/config.py deleted file mode 100644 index 152c8c6..0000000 --- a/dhee/teaching/config.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Configuration for teaching primitives.""" - -from __future__ import annotations - -from pydantic import BaseModel, field_validator - - -class TeachingConfig(BaseModel): - """Configuration for the teaching memory subsystem.""" - - enable_teaching: bool = False # Off by default (backward compat) - concept_namespace: str = "sensai_curriculum" - mastery_initial_score: float = 0.35 - mastery_increment: float = 0.08 - mastery_decrement_on_misconception: float = -0.15 - weak_concept_threshold: float = 0.45 - mastered_concept_threshold: float = 0.70 - - @field_validator( - "mastery_initial_score", - "mastery_increment", - "weak_concept_threshold", - "mastered_concept_threshold", - ) - @classmethod - def _clamp_unit_float(cls, v: float) -> float: - return min(1.0, max(0.0, float(v))) diff --git a/dhee/teaching/student_model.py b/dhee/teaching/student_model.py deleted file mode 100644 index 3c5346c..0000000 --- a/dhee/teaching/student_model.py +++ /dev/null @@ -1,372 +0,0 @@ -"""StudentModel — per-student profile and concept mastery via Engram. - -Profiles are stored as ``memory_type="student_profile"`` and per-concept -mastery as ``memory_type="concept_mastery"``. FadeMem handles mastery -decay naturally: accessing a concept mastery memory boosts its strength -(spaced repetition). -""" - -from __future__ import annotations - -import json -import logging -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional - -from dhee.teaching.config import TeachingConfig - -logger = logging.getLogger(__name__) - - -class StudentModel: - """Per-student learning profile and concept mastery tracker.""" - - def __init__(self, memory: "CoreMemory", config: TeachingConfig | None = None): # noqa: F821 - self.memory = memory - self.config = config or TeachingConfig() - - # ------------------------------------------------------------------ - # Profile - # ------------------------------------------------------------------ - - def get_or_create_profile(self, student_id: str) -> Dict[str, Any]: - """Get or create a student learning profile.""" - existing = self._find_profile(student_id) - if existing: - return existing - - now = datetime.now(timezone.utc).isoformat() - meta = { - "memory_type": "student_profile", - "student_id": student_id, - "learning_style": "unknown", - "interests": [], - "goals": [], - "effective_analogies": [], - "created_at": now, - "updated_at": now, - } - - content = f"Student profile for {student_id}" - result = self.memory.add( - content=content, - user_id=student_id, - metadata=meta, - categories=["student/profile"], - ) - - mem_id = self._extract_id(result) - return { - "memory_id": mem_id, - "student_id": student_id, - "learning_style": "unknown", - "interests": [], - "goals": [], - "effective_analogies": [], - } - - def update_profile( - self, - student_id: str, - updates: Dict[str, Any], - ) -> Dict[str, Any]: - """Merge updates into the student profile.""" - profile_mem = self._find_profile_raw(student_id) - if not profile_mem: - profile = self.get_or_create_profile(student_id) - profile_mem = self._find_profile_raw(student_id) - if not profile_mem: - return profile - - md = self._parse_metadata(profile_mem) - mem_id = profile_mem.get("id") - - # Merge updates - for key, value in updates.items(): - if key in ("interests", "goals", "effective_analogies") and isinstance(value, list): - existing = md.get(key, []) - merged = list(dict.fromkeys(existing + value)) # dedup, preserve order - md[key] = merged - else: - md[key] = value - - md["updated_at"] = datetime.now(timezone.utc).isoformat() - - if hasattr(self.memory, "db") and hasattr(self.memory.db, "update_memory"): - self.memory.db.update_memory(mem_id, {"metadata": json.dumps(md)}) - - return self._format_profile({"id": mem_id, "metadata": md}) - - # ------------------------------------------------------------------ - # Mastery - # ------------------------------------------------------------------ - - def get_mastery(self, student_id: str, concept_id: str) -> float: - """Get current mastery score. Accessing boosts strength (spaced repetition).""" - mastery_mem = self._find_mastery_raw(student_id, concept_id) - if not mastery_mem: - return self.config.mastery_initial_score - - md = self._parse_metadata(mastery_mem) - score = md.get("mastery_score", self.config.mastery_initial_score) - - # Access the memory to trigger FadeMem strength boost - mem_id = mastery_mem.get("id") - if mem_id and hasattr(self.memory, "db"): - self.memory.db.increment_access(mem_id) - - return score - - def update_mastery( - self, - student_id: str, - concept_id: str, - score_delta: float, - ) -> float: - """Update mastery score by delta. Returns new score.""" - mastery_mem = self._find_mastery_raw(student_id, concept_id) - - if not mastery_mem: - # Create new mastery record - new_score = max(0.0, min(1.0, self.config.mastery_initial_score + score_delta)) - now = datetime.now(timezone.utc).isoformat() - meta = { - "memory_type": "concept_mastery", - "student_id": student_id, - "concept_id": concept_id, - "mastery_score": new_score, - "history": [{"delta": score_delta, "score": new_score, "at": now}], - "created_at": now, - "updated_at": now, - } - self.memory.add( - content=f"Mastery: {student_id} / {concept_id} = {new_score:.2f}", - user_id=student_id, - metadata=meta, - categories=["student/mastery"], - ) - return new_score - - md = self._parse_metadata(mastery_mem) - old_score = md.get("mastery_score", self.config.mastery_initial_score) - new_score = max(0.0, min(1.0, old_score + score_delta)) - - now = datetime.now(timezone.utc).isoformat() - history = md.get("history", []) - history.append({"delta": score_delta, "score": new_score, "at": now}) - md["mastery_score"] = new_score - md["history"] = history[-50:] # keep last 50 entries - md["updated_at"] = now - - mem_id = mastery_mem.get("id") - if mem_id and hasattr(self.memory, "db"): - self.memory.db.update_memory(mem_id, {"metadata": json.dumps(md)}) - self.memory.db.increment_access(mem_id) - - return new_score - - def get_weak_concepts( - self, - student_id: str, - threshold: float | None = None, - ) -> List[Dict[str, Any]]: - """Get concepts below the weak threshold.""" - threshold = threshold or self.config.weak_concept_threshold - all_mastery = self._get_all_mastery(student_id) - return [ - m for m in all_mastery - if m["mastery_score"] < threshold - ] - - def get_decay_risk_concepts(self, student_id: str) -> List[Dict[str, Any]]: - """Get concepts whose mastery is decaying below the promotion threshold.""" - all_mastery = self._get_all_mastery(student_id) - results = [] - for m in all_mastery: - mem = self._find_mastery_raw(student_id, m["concept_id"]) - if not mem: - continue - # Check if Engram strength is decaying - strength = mem.get("strength", 1.0) - if strength < 0.5 and m["mastery_score"] >= self.config.weak_concept_threshold: - results.append({**m, "decay_strength": strength}) - return results - - def record_misconception( - self, - student_id: str, - concept_id: str, - misconception: str, - correction: str, - ) -> None: - """Record a misconception and apply mastery penalty.""" - self.update_mastery( - student_id, concept_id, self.config.mastery_decrement_on_misconception - ) - - meta = { - "memory_type": "misconception", - "student_id": student_id, - "concept_id": concept_id, - "misconception": misconception, - "correction": correction, - "recorded_at": datetime.now(timezone.utc).isoformat(), - } - self.memory.add( - content=f"Misconception ({concept_id}): {misconception} -> Correction: {correction}", - user_id=student_id, - metadata=meta, - categories=["student/misconception"], - ) - - def record_effective_analogy( - self, - student_id: str, - concept_id: str, - analogy: str, - success: bool = True, - ) -> None: - """Record an analogy that worked (or didn't) for a student.""" - meta = { - "memory_type": "effective_analogy", - "student_id": student_id, - "concept_id": concept_id, - "analogy": analogy, - "success": success, - "recorded_at": datetime.now(timezone.utc).isoformat(), - } - self.memory.add( - content=f"Analogy ({concept_id}): {analogy} [{'success' if success else 'failed'}]", - user_id=student_id, - metadata=meta, - categories=["student/analogy"], - ) - - # Also update profile with effective analogies - if success: - self.update_profile(student_id, { - "effective_analogies": [{"concept_id": concept_id, "analogy": analogy}], - }) - - def get_effective_analogies( - self, - student_id: str, - concept_id: str | None = None, - ) -> List[Dict[str, Any]]: - """Retrieve effective analogies for a student, optionally filtered by concept.""" - if not hasattr(self.memory, "db"): - return [] - - memories = self.memory.db.get_all_memories( - user_id=student_id, - memory_type="effective_analogy", - limit=100, - ) - - results = [] - for mem in memories: - md = self._parse_metadata(mem) - if md.get("memory_type") != "effective_analogy": - continue - if not md.get("success", True): - continue - if concept_id and md.get("concept_id") != concept_id: - continue - results.append({ - "concept_id": md.get("concept_id", ""), - "analogy": md.get("analogy", ""), - "success": md.get("success", True), - }) - return results - - # ------------------------------------------------------------------ - # Helpers - # ------------------------------------------------------------------ - - def _find_profile(self, student_id: str) -> Optional[Dict[str, Any]]: - raw = self._find_profile_raw(student_id) - return self._format_profile(raw) if raw else None - - def _find_profile_raw(self, student_id: str) -> Optional[Dict[str, Any]]: - if not hasattr(self.memory, "db"): - return None - memories = self.memory.db.get_all_memories( - user_id=student_id, - memory_type="student_profile", - limit=5, - ) - for mem in memories: - md = self._parse_metadata(mem) - if md.get("memory_type") == "student_profile": - return mem - return None - - def _find_mastery_raw( - self, student_id: str, concept_id: str - ) -> Optional[Dict[str, Any]]: - if not hasattr(self.memory, "db"): - return None - memories = self.memory.db.get_all_memories( - user_id=student_id, - memory_type="concept_mastery", - limit=500, - ) - for mem in memories: - md = self._parse_metadata(mem) - if ( - md.get("memory_type") == "concept_mastery" - and md.get("concept_id") == concept_id - ): - return mem - return None - - def _get_all_mastery(self, student_id: str) -> List[Dict[str, Any]]: - if not hasattr(self.memory, "db"): - return [] - memories = self.memory.db.get_all_memories( - user_id=student_id, - memory_type="concept_mastery", - limit=500, - ) - results = [] - for mem in memories: - md = self._parse_metadata(mem) - if md.get("memory_type") == "concept_mastery": - results.append({ - "concept_id": md.get("concept_id", ""), - "mastery_score": md.get("mastery_score", self.config.mastery_initial_score), - "memory_id": mem.get("id"), - }) - return results - - @staticmethod - def _extract_id(result: Dict[str, Any]) -> Optional[str]: - if isinstance(result, dict): - results = result.get("results", []) - if results and isinstance(results, list): - first = results[0] - return first.get("id") or first.get("memory_id") - return result.get("id") or result.get("memory_id") - return None - - @staticmethod - def _parse_metadata(mem: Dict[str, Any]) -> Dict[str, Any]: - md = mem.get("metadata", {}) - if isinstance(md, str): - try: - md = json.loads(md) - except (json.JSONDecodeError, TypeError): - md = {} - return md - - @classmethod - def _format_profile(cls, mem: Dict[str, Any]) -> Dict[str, Any]: - md = cls._parse_metadata(mem) - return { - "memory_id": mem.get("id"), - "student_id": md.get("student_id", ""), - "learning_style": md.get("learning_style", "unknown"), - "interests": md.get("interests", []), - "goals": md.get("goals", []), - "effective_analogies": md.get("effective_analogies", []), - } diff --git a/dhee/teaching/teaching_memory.py b/dhee/teaching/teaching_memory.py deleted file mode 100644 index 33b9e21..0000000 --- a/dhee/teaching/teaching_memory.py +++ /dev/null @@ -1,255 +0,0 @@ -"""TeachingMemory — specialized memory types for the teaching domain. - -Stores lesson episodes, concept explanations, comprehension checks, -misconceptions, and effective analogies as Engram memories. -""" - -from __future__ import annotations - -import json -import logging -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional - -from dhee.teaching.config import TeachingConfig - -logger = logging.getLogger(__name__) - - -class TeachingMemory: - """Domain-specific memory storage for teaching interactions.""" - - def __init__(self, memory: "CoreMemory", config: TeachingConfig | None = None): # noqa: F821 - self.memory = memory - self.config = config or TeachingConfig() - - # ------------------------------------------------------------------ - # Lesson episodes - # ------------------------------------------------------------------ - - def store_lesson_episode( - self, - student_id: str, - concept_id: str, - lesson_summary: str, - *, - segments_completed: int = 0, - session_id: str = "", - topic_name: str = "", - subject: str = "", - ) -> Dict[str, Any]: - """Store a completed (or partial) lesson episode.""" - now = datetime.now(timezone.utc).isoformat() - meta = { - "memory_type": "lesson_episode", - "student_id": student_id, - "concept_id": concept_id, - "session_id": session_id, - "segments_completed": segments_completed, - "topic_name": topic_name, - "subject": subject, - "recorded_at": now, - } - - content = f"Lesson ({topic_name or concept_id}): {lesson_summary}" - result = self.memory.add( - content=content, - user_id=student_id, - metadata=meta, - categories=[f"lesson/{subject}" if subject else "lesson"], - ) - - mem_id = self._extract_id(result) - return {"memory_id": mem_id, **meta} - - # ------------------------------------------------------------------ - # Concept explanations - # ------------------------------------------------------------------ - - def store_concept_explanation( - self, - student_id: str, - concept_id: str, - approach: str, - explanation_text: str, - success: bool = True, - ) -> Dict[str, Any]: - """Store an explanation attempt for a concept.""" - now = datetime.now(timezone.utc).isoformat() - meta = { - "memory_type": "concept_explanation", - "student_id": student_id, - "concept_id": concept_id, - "approach": approach, - "success": success, - "recorded_at": now, - } - - content = f"Explanation ({concept_id}, {approach}): {explanation_text[:500]}" - result = self.memory.add( - content=content, - user_id=student_id, - metadata=meta, - categories=["teaching/explanation"], - ) - - mem_id = self._extract_id(result) - return {"memory_id": mem_id, **meta} - - # ------------------------------------------------------------------ - # Comprehension checks - # ------------------------------------------------------------------ - - def store_comprehension_check( - self, - student_id: str, - concept_id: str, - question: str, - student_answer: str, - evaluation: Dict[str, Any], - ) -> Dict[str, Any]: - """Store a comprehension check result.""" - now = datetime.now(timezone.utc).isoformat() - level = evaluation.get("level", "unknown") - meta = { - "memory_type": "comprehension_check", - "student_id": student_id, - "concept_id": concept_id, - "question": question, - "student_answer": student_answer, - "level": level, - "confidence": evaluation.get("confidence", 0.5), - "misconception": evaluation.get("misconception"), - "recorded_at": now, - } - - content = ( - f"Check ({concept_id}): Q: {question[:200]} | " - f"A: {student_answer[:200]} | Level: {level}" - ) - result = self.memory.add( - content=content, - user_id=student_id, - metadata=meta, - categories=["teaching/check"], - ) - - mem_id = self._extract_id(result) - return {"memory_id": mem_id, **meta} - - # ------------------------------------------------------------------ - # Search - # ------------------------------------------------------------------ - - def search_past_explanations( - self, - student_id: str, - concept_id: str, - limit: int = 5, - ) -> List[Dict[str, Any]]: - """Find past explanations for a concept with a specific student.""" - results = self.memory.search( - query=f"explanation {concept_id}", - user_id=student_id, - limit=limit * 2, - ) - - explanations = [] - for mem in results: - md = self._parse_metadata(mem) - if ( - md.get("memory_type") == "concept_explanation" - and md.get("concept_id") == concept_id - ): - explanations.append({ - "memory_id": mem.get("id"), - "approach": md.get("approach", ""), - "success": md.get("success", True), - "content": mem.get("content", ""), - "strength": mem.get("strength", 1.0), - }) - if len(explanations) >= limit: - break - - return explanations - - def search_student_history( - self, - student_id: str, - query: str, - limit: int = 10, - ) -> List[Dict[str, Any]]: - """Search all teaching memories for a student (grounding context).""" - results = self.memory.search( - query=query, - user_id=student_id, - limit=limit, - ) - - return [ - { - "memory_id": mem.get("id"), - "content": mem.get("content", ""), - "memory_type": self._parse_metadata(mem).get("memory_type", ""), - "strength": mem.get("strength", 1.0), - } - for mem in results - ] - - def get_student_misconceptions( - self, - student_id: str, - concept_id: str | None = None, - limit: int = 20, - ) -> List[Dict[str, Any]]: - """Retrieve recorded misconceptions for a student.""" - if not hasattr(self.memory, "db"): - return [] - - memories = self.memory.db.get_all_memories( - user_id=student_id, - memory_type="misconception", - limit=limit * 2, - ) - - results = [] - for mem in memories: - md = self._parse_metadata(mem) - if md.get("memory_type") != "misconception": - continue - if concept_id and md.get("concept_id") != concept_id: - continue - results.append({ - "concept_id": md.get("concept_id", ""), - "misconception": md.get("misconception", ""), - "correction": md.get("correction", ""), - "recorded_at": md.get("recorded_at", ""), - }) - if len(results) >= limit: - break - - return results - - # ------------------------------------------------------------------ - # Helpers - # ------------------------------------------------------------------ - - @staticmethod - def _extract_id(result: Dict[str, Any]) -> Optional[str]: - if isinstance(result, dict): - results = result.get("results", []) - if results and isinstance(results, list): - first = results[0] - return first.get("id") or first.get("memory_id") - return result.get("id") or result.get("memory_id") - return None - - @staticmethod - def _parse_metadata(mem: Dict[str, Any]) -> Dict[str, Any]: - md = mem.get("metadata", {}) - if isinstance(md, str): - try: - md = json.loads(md) - except (json.JSONDecodeError, TypeError): - md = {} - return md diff --git a/dhee/utils/factory.py b/dhee/utils/factory.py index 24102db..89fa196 100644 --- a/dhee/utils/factory.py +++ b/dhee/utils/factory.py @@ -7,6 +7,22 @@ logger = logging.getLogger(__name__) +def _normalize_sqlite_vec_config(config: Dict[str, Any]) -> Dict[str, Any]: + """Coerce directory-style vector paths into sqlite-vec DB files.""" + normalized = dict(config or {}) + path = normalized.get("path") + if not path: + return normalized + + path = str(path) + root, ext = os.path.splitext(path) + if ext: + return normalized + + normalized["path"] = os.path.join(path, "sqlite_vec.db") + return normalized + + def _dhee_model_available() -> bool: """Check if local DheeModel GGUF is available.""" try: @@ -152,7 +168,7 @@ def create(cls, provider: str, config: Dict[str, Any]): if provider == "sqlite_vec": from dhee.vector_stores.sqlite_vec import SqliteVecStore - return SqliteVecStore(config) + return SqliteVecStore(_normalize_sqlite_vec_config(config)) if provider == "zvec": try: from dhee.vector_stores.zvec_store import ZvecStore @@ -161,7 +177,9 @@ def create(cls, provider: str, config: Dict[str, Any]): logger.warning("zvec not installed, falling back to sqlite_vec") try: from dhee.vector_stores.sqlite_vec import SqliteVecStore - return SqliteVecStore(config) + return SqliteVecStore( + _normalize_sqlite_vec_config(config) + ) except ImportError: logger.warning("sqlite_vec not installed, falling back to in-memory") from dhee.vector_stores.memory import InMemoryVectorStore diff --git a/pyproject.toml b/pyproject.toml index 09b62b6..0c301ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dhee" -version = "3.0.1" +version = "3.1.0" description = "Cognition layer for AI agents — persistent memory, performance tracking, and insight synthesis" readme = "README.md" requires-python = ">=3.9" @@ -44,6 +44,7 @@ local = ["llama-cpp-python>=0.3", "sentence-transformers>=3.0"] mcp = ["mcp>=1.0.0"] api = ["fastapi>=0.100.0", "uvicorn>=0.20.0"] bus = ["engram-bus>=0.1.0"] +benchmarks = ["huggingface_hub>=0.24.0"] # Edge/hardware deployment (offline, ONNX embedder) edge = ["onnxruntime>=1.16"] # Training (QLoRA fine-tuning) @@ -56,6 +57,7 @@ all = [ "fastapi>=0.100.0", "uvicorn>=0.20.0", "engram-bus>=0.1.0", + "huggingface_hub>=0.24.0", "llama-cpp-python>=0.3", "sentence-transformers>=3.0", ] @@ -63,6 +65,7 @@ dev = [ "pytest>=7.0.0", "pytest-asyncio>=0.23.0", "openai>=1.0.0", + "huggingface_hub>=0.24.0", "build>=1.0.0", "twine>=5.0.0", ] diff --git a/tests/test_auto_lifecycle.py b/tests/test_auto_lifecycle.py index 7411e6f..14f1be2 100644 --- a/tests/test_auto_lifecycle.py +++ b/tests/test_auto_lifecycle.py @@ -8,7 +8,7 @@ import pytest from dhee.memory.main import FullMemory -from dhee.adapters.base import DheePlugin +from dhee.plugin import DheePlugin from dhee.simple import Dhee, Engram diff --git a/tests/test_cognition_v3.py b/tests/test_cognition_v3.py index 1173400..bbca8bf 100644 --- a/tests/test_cognition_v3.py +++ b/tests/test_cognition_v3.py @@ -277,47 +277,7 @@ def test_rollback_on_poor_performance(self, tmpdir): # ═══════════════════════════════════════════════════════════════════════════ -# 5. Progressive Training — data flow -# ═══════════════════════════════════════════════════════════════════════════ - -class TestProgressiveTraining: - def test_training_cycle_with_data(self, tmpdir): - from dhee.mini.progressive_trainer import ProgressiveTrainer - - trainer = ProgressiveTrainer(data_dir=os.path.join(tmpdir, "training")) - - # Generate enough SFT data - sft_data = [ - {"input": f"[MEMORY_OP] Query {i}", "output": "store", "type": "memory_op"} - for i in range(25) - ] - dpo_data = [ - {"prompt": f"Task {i}", "chosen": "good approach", "rejected": "bad approach"} - for i in range(15) - ] - - result = trainer.run_cycle( - sft_data=sft_data, - dpo_data=dpo_data, - samskara_data={}, - ) - assert result.cycle_id - stage_names = [s.stage for s in result.stages] - assert "sft" in stage_names - assert "dpo" in stage_names - - def test_skips_with_insufficient_data(self, tmpdir): - from dhee.mini.progressive_trainer import ProgressiveTrainer - - trainer = ProgressiveTrainer(data_dir=os.path.join(tmpdir, "training")) - - result = trainer.run_cycle(sft_data=[], dpo_data=[], samskara_data={}) - for stage_result in result.stages: - assert stage_result.status in ("skipped", "completed") - - -# ═══════════════════════════════════════════════════════════════════════════ -# 6. Episode — lifecycle + selective forgetting +# 5. Episode — lifecycle + selective forgetting # ═══════════════════════════════════════════════════════════════════════════ class TestEpisode: diff --git a/tests/test_hippocamp_benchmark.py b/tests/test_hippocamp_benchmark.py new file mode 100644 index 0000000..3320f61 --- /dev/null +++ b/tests/test_hippocamp_benchmark.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import sqlite3 +import zipfile +import json +from pathlib import Path + +from dhee.benchmarks.hippocamp import ( + CONFIG_SPECS, + _emit_progress, + _make_gold_repo_path, + _preview_text, + _relative_path_from_environment, + exact_match, + file_retrieval_metrics, + gold_document_to_items, + token_f1, +) +from dhee.benchmarks.raw_extractors import raw_file_to_items + + +def test_relative_path_and_gold_path_mapping() -> None: + spec = CONFIG_SPECS["adam_subset"] + repo_path = "Adam/Subset/Adam_Subset/contractnli/Tazza-CAFFE-Confidentiality-Agreement.pdf" + + relative = _relative_path_from_environment(spec, repo_path) + assert relative == "contractnli/Tazza-CAFFE-Confidentiality-Agreement.pdf" + + gold_path = _make_gold_repo_path(spec.profile, relative) + assert gold_path == "HippoCamp_Gold/Adam/contractnli/Tazza-CAFFE-Confidentiality-Agreement.json" + + +def test_gold_document_chunking_keeps_metadata() -> None: + document = { + "file_info": { + "file_name": "Diary Entry.pdf", + "file_type": "pdf", + "file_modality": "document", + "creation_date": "2025-10-20 09:00:00", + "location": "Home", + }, + "summary": "This file summarizes a weekly routine.", + "segments": [ + {"page": 1, "content": "First page content " * 40}, + {"page": 2, "content": "Second page content " * 40}, + ], + } + + items = gold_document_to_items( + doc=document, + profile="Adam", + config_name="adam_subset", + relative_path="docs/Diary Entry.pdf", + chunk_chars=500, + ) + + assert len(items) >= 2 + assert items[0]["metadata"]["file_path"] == "docs/Diary Entry.pdf" + assert items[0]["metadata"]["profile"] == "Adam" + assert "Summary: This file summarizes a weekly routine." in items[0]["content"] + assert "Relative Path: docs/Diary Entry.pdf" in items[0]["content"] + + +def test_file_retrieval_metrics_deduplicate_predictions() -> None: + metrics = file_retrieval_metrics( + predicted_paths=["a.txt", "a.txt", "b.txt"], + gold_paths=["a.txt", "c.txt"], + ) + assert round(metrics["precision"], 4) == 0.5 + assert round(metrics["recall"], 4) == 0.5 + assert round(metrics["f1"], 4) == 0.5 + + +def test_answer_matching_helpers() -> None: + assert exact_match("Cursor", "cursor") + assert not exact_match("Cursor editor", "VS Code") + assert token_f1("Cursor editor", "Cursor") > 0.0 + assert token_f1("", "") == 1.0 + + +def test_progress_jsonl_event_emission(tmp_path: Path) -> None: + path = tmp_path / "adam_subset.progress.jsonl" + _emit_progress( + path, + "judge_done", + config="adam_subset", + question_index=3, + judge_correct=True, + judge_score_0_to_5=4.0, + judge_rationale_preview=_preview_text("This answer matches the gold evidence and resolves the question cleanly."), + ) + + lines = path.read_text(encoding="utf-8").splitlines() + assert len(lines) == 1 + payload = json.loads(lines[0]) + assert payload["event"] == "judge_done" + assert payload["config"] == "adam_subset" + assert payload["question_index"] == 3 + assert payload["judge_correct"] is True + assert payload["judge_score_0_to_5"] == 4.0 + assert "gold evidence" in payload["judge_rationale_preview"] + + +def test_raw_text_file_extraction(tmp_path: Path) -> None: + path = tmp_path / "note.txt" + path.write_text("alpha\nbeta\ngamma", encoding="utf-8") + + result = raw_file_to_items( + local_path=path, + relative_path="docs/note.txt", + profile="Adam", + config_name="adam_subset", + chunk_chars=1000, + ) + + assert result.mode == "text" + assert len(result.items) == 1 + assert "alpha" in result.items[0]["content"] + assert result.items[0]["metadata"]["exposure_mode"] == "raw_files_only" + + +def test_raw_sqlite_extraction(tmp_path: Path) -> None: + db_path = tmp_path / "sample.sqlite" + conn = sqlite3.connect(db_path) + try: + conn.execute("CREATE TABLE visits (url TEXT, title TEXT)") + conn.execute("INSERT INTO visits VALUES (?, ?)", ("https://example.com", "Example")) + conn.commit() + finally: + conn.close() + + result = raw_file_to_items( + local_path=db_path, + relative_path="Browser_History.sqlite", + profile="Adam", + config_name="adam_subset", + chunk_chars=2000, + ) + + assert result.mode == "text" + assert "table=visits" in result.items[0]["content"] + assert "https://example.com" in result.items[0]["content"] + + +def test_raw_xlsx_extraction_from_minimal_ooxml(tmp_path: Path) -> None: + xlsx_path = tmp_path / "plan.xlsx" + with zipfile.ZipFile(xlsx_path, "w") as archive: + archive.writestr( + "xl/workbook.xml", + """ + + + + + """, + ) + archive.writestr( + "xl/_rels/workbook.xml.rels", + """ + + + """, + ) + archive.writestr( + "xl/sharedStrings.xml", + """ + + Run + Tempo + """, + ) + archive.writestr( + "xl/worksheets/sheet1.xml", + """ + + + + 0 + 1 + + + 5 + 42 + + + """, + ) + + result = raw_file_to_items( + local_path=xlsx_path, + relative_path="plan.xlsx", + profile="Adam", + config_name="adam_subset", + chunk_chars=2000, + ) + + assert result.mode == "text" + content = result.items[0]["content"] + assert "sheet=Schedule" in content + assert "Run\tTempo" in content + assert "5\t42" in content diff --git a/tests/test_orchestration_core.py b/tests/test_orchestration_core.py index b1afbce..bd28f5f 100644 --- a/tests/test_orchestration_core.py +++ b/tests/test_orchestration_core.py @@ -79,201 +79,3 @@ def memory_instance(tmp_path): mem.close() -def test_search_orchestrated_skips_map_reduce_when_coverage_sufficient(memory_instance, monkeypatch) -> None: - mem = memory_instance - - def fake_search(**kwargs: Any) -> Dict[str, Any]: - return {"results": [{"id": "m1", "memory": "foo", "composite_score": 0.9}]} - - def fake_search_episodes(**kwargs: Any) -> Dict[str, Any]: - return { - "results": [{"memory_id": "m1", "value_text": "foo", "event_type": "utterance"}], - "coverage": {"sufficient": True, "coverage_ratio": 1.0, "event_hit_count": 1, "unique_canonical_keys": 1}, - } - - monkeypatch.setattr(mem, "search", fake_search) - monkeypatch.setattr(mem, "search_episodes", fake_search_episodes) - - def fail_extract_atomic_facts(**kwargs: Any) -> List[Dict[str, Any]]: - raise AssertionError("map stage should be skipped when episodic coverage is sufficient") - - monkeypatch.setattr("dhee.memory.main.extract_atomic_facts", fail_extract_atomic_facts) - - payload = mem.search_orchestrated( - query="How many projects?", - user_id="u1", - question_type="multi-session", - orchestration_mode="hybrid", - orchestrator_llm=object(), - rerank=False, - ) - - assert payload["orchestration"]["map_reduce_used"] is False - assert payload["orchestration"]["reflection_hops"] == 0 - - -def test_search_orchestrated_reflection_hard_caps_to_one_hop(memory_instance, monkeypatch) -> None: - mem = memory_instance - search_calls: List[int] = [] - reduce_calls: List[int] = [] - - def fake_search(**kwargs: Any) -> Dict[str, Any]: - search_calls.append(1) - if len(search_calls) == 1: - return {"results": [{"id": "m1", "memory": "first", "composite_score": 0.8}]} - return {"results": [{"id": "m1", "memory": "first", "composite_score": 0.8}, {"id": "m2", "memory": "second", "composite_score": 0.7}]} - - def fake_search_episodes(**kwargs: Any) -> Dict[str, Any]: - return { - "results": [{"memory_id": "m1", "value_text": "first", "event_type": "utterance"}], - "coverage": {"sufficient": False, "coverage_ratio": 0.2, "event_hit_count": 1, "unique_canonical_keys": 1}, - } - - def fake_extract_atomic_facts(**kwargs: Any) -> List[Dict[str, Any]]: - return [{"value": "4", "relevant": True, "canonical_key": "k1"}] - - def fake_reduce_atomic_facts(**kwargs: Any): - reduce_calls.append(1) - if len(reduce_calls) == 1: - return "I don't know", {} - return "4", {} - - monkeypatch.setattr(mem, "search", fake_search) - monkeypatch.setattr(mem, "search_episodes", fake_search_episodes) - monkeypatch.setattr("dhee.memory.main.extract_atomic_facts", fake_extract_atomic_facts) - monkeypatch.setattr("dhee.memory.main.reduce_atomic_facts", fake_reduce_atomic_facts) - - payload = mem.search_orchestrated( - query="How many projects?", - user_id="u1", - question_type="multi-session", - orchestration_mode="hybrid", - orchestrator_llm=object(), - reflection_max_hops=1, - base_search_limit=4, - search_cap=20, - rerank=False, - ) - - assert payload["orchestration"]["map_reduce_used"] is True - assert payload["orchestration"]["reflection_hops"] == 1 - assert payload["reduced_answer"] == "4" - # Initial retrieval + exactly one reflection retrieval. - assert len(search_calls) == 2 - - -def test_search_orchestrated_inconsistency_can_trigger_map_reduce(memory_instance, monkeypatch) -> None: - mem = memory_instance - - def fake_search(**kwargs: Any) -> Dict[str, Any]: - return { - "results": [ - {"id": "m1", "memory": "I led 2 projects.", "evidence_text": "I led 2 projects.", "composite_score": 0.9}, - {"id": "m2", "memory": "I led 5 projects.", "evidence_text": "I led 5 projects.", "composite_score": 0.8}, - ] - } - - def fake_search_episodes(**kwargs: Any) -> Dict[str, Any]: - return { - "results": [ - {"memory_id": "m1", "value_text": "I led 2 projects", "event_type": "utterance"}, - {"memory_id": "m2", "value_text": "I led 5 projects", "event_type": "utterance"}, - ], - "coverage": { - "sufficient": True, - "coverage_ratio": 1.0, - "intent_coverage": 1.0, - "event_hit_count": 2, - "unique_canonical_keys": 2, - }, - } - - monkeypatch.setattr(mem, "search", fake_search) - monkeypatch.setattr(mem, "search_episodes", fake_search_episodes) - monkeypatch.setattr("dhee.memory.main.extract_atomic_facts", lambda **kwargs: [{"value": "5", "relevant": True}]) - monkeypatch.setattr("dhee.memory.main.reduce_atomic_facts", lambda **kwargs: ("5", {})) - - payload = mem.search_orchestrated( - query="How many projects have I led?", - user_id="u1", - question_type="multi-session", - orchestration_mode="hybrid", - orchestrator_llm=object(), - rerank=False, - ) - - assert payload["orchestration"]["map_reduce_used"] is True - assert "count_numeric_conflict" in payload["orchestration"]["reason_codes"] - - -def test_search_orchestrated_respects_query_llm_budget(memory_instance, monkeypatch) -> None: - mem = memory_instance - mem.config.orchestration.max_query_llm_calls = 0 - - def fake_search(**kwargs: Any) -> Dict[str, Any]: - return {"results": [{"id": "m1", "memory": "I led 4 projects.", "composite_score": 0.9}]} - - def fake_search_episodes(**kwargs: Any) -> Dict[str, Any]: - return { - "results": [{"memory_id": "m1", "value_text": "I led 4 projects", "event_type": "utterance"}], - "coverage": {"sufficient": False, "coverage_ratio": 0.2, "intent_coverage": 0.2, "event_hit_count": 1, "unique_canonical_keys": 1}, - } - - monkeypatch.setattr(mem, "search", fake_search) - monkeypatch.setattr(mem, "search_episodes", fake_search_episodes) - - payload = mem.search_orchestrated( - query="How many projects have I led?", - user_id="u1", - question_type="multi-session", - orchestration_mode="hybrid", - orchestrator_llm=object(), - rerank=False, - ) - - assert payload["orchestration"]["map_reduce_used"] is False - assert "query_llm_budget_exhausted" in payload["orchestration"]["reason_codes"] - - -def test_search_orchestrated_reducer_cache_hit(memory_instance, monkeypatch) -> None: - mem = memory_instance - extract_calls = {"count": 0} - - def fake_search(**kwargs: Any) -> Dict[str, Any]: - return {"results": [{"id": "m1", "memory": "I led 4 projects.", "evidence_text": "I led 4 projects.", "composite_score": 0.9}]} - - def fake_search_episodes(**kwargs: Any) -> Dict[str, Any]: - return { - "results": [{"memory_id": "m1", "value_text": "I led 4 projects", "event_type": "utterance"}], - "coverage": {"sufficient": False, "coverage_ratio": 0.2, "intent_coverage": 0.2, "event_hit_count": 1, "unique_canonical_keys": 1}, - } - - def fake_extract_atomic_facts(**kwargs: Any) -> List[Dict[str, Any]]: - extract_calls["count"] += 1 - return [{"value": "4", "relevant": True, "canonical_key": "projects"}] - - monkeypatch.setattr(mem, "search", fake_search) - monkeypatch.setattr(mem, "search_episodes", fake_search_episodes) - monkeypatch.setattr("dhee.memory.main.extract_atomic_facts", fake_extract_atomic_facts) - monkeypatch.setattr("dhee.memory.main.reduce_atomic_facts", lambda **kwargs: ("4", {})) - - first = mem.search_orchestrated( - query="How many projects have I led?", - user_id="u1", - question_type="multi-session", - orchestration_mode="hybrid", - orchestrator_llm=object(), - rerank=False, - ) - second = mem.search_orchestrated( - query="How many projects have I led?", - user_id="u1", - question_type="multi-session", - orchestration_mode="hybrid", - orchestrator_llm=object(), - rerank=False, - ) - - assert first["orchestration"]["cache_hit"] is False - assert second["orchestration"]["cache_hit"] is True - assert extract_calls["count"] == 1 diff --git a/tests/test_simple_zero_config.py b/tests/test_simple_zero_config.py new file mode 100644 index 0000000..39e137d --- /dev/null +++ b/tests/test_simple_zero_config.py @@ -0,0 +1,41 @@ +import os + +from dhee.simple import _detect_provider, _get_embedding_dims + + +def test_detect_provider_defaults_to_mock_without_keys(monkeypatch): + for key in ( + "OPENAI_API_KEY", + "GOOGLE_API_KEY", + "GEMINI_API_KEY", + "NVIDIA_API_KEY", + "NVIDIA_QWEN_API_KEY", + "NVIDIA_EMBEDDING_API_KEY", + "NVIDIA_EMBED_API_KEY", + "NVIDIA_LLAMA_4_MAV_API_KEY", + ): + monkeypatch.delenv(key, raising=False) + + assert _detect_provider() == "mock" + + +def test_embedding_dims_cover_mock_and_nvidia(): + assert _get_embedding_dims("mock") == 384 + assert _get_embedding_dims("nvidia") == 2048 + + +def test_detect_provider_accepts_nvidia_alias_keys(monkeypatch): + for key in ( + "OPENAI_API_KEY", + "GOOGLE_API_KEY", + "GEMINI_API_KEY", + "NVIDIA_API_KEY", + "NVIDIA_QWEN_API_KEY", + "NVIDIA_EMBEDDING_API_KEY", + "NVIDIA_EMBED_API_KEY", + "NVIDIA_LLAMA_4_MAV_API_KEY", + ): + monkeypatch.delenv(key, raising=False) + + monkeypatch.setenv("NVIDIA_EMBED_API_KEY", "test-key") + assert _detect_provider() == "nvidia" diff --git a/tests/test_v3_all_phases.py b/tests/test_v3_all_phases.py deleted file mode 100644 index 1ac808a..0000000 --- a/tests/test_v3_all_phases.py +++ /dev/null @@ -1,820 +0,0 @@ -"""Tests for Dhee v3 Phases 2-10. - -Phase 2: Anchor resolver (per-field candidates, re-anchoring) -Phase 4: Distillation + Promotion pipeline -Phase 5: Lease manager + Job registry -Phase 7: RRF Fusion (5-stage pipeline) -Phase 8: Three-tier invalidation + Conflicts -Phase 6: Read model + delta overlay -Phase 9: Observability (v3_health) -Phase 10: Migration bridge (dual-write, backfill) -Phase 3: Sparse to_dict on UniversalEngram -""" - -import json -import os -import sqlite3 -import threading -import time - -import pytest - -from dhee.core.storage import initialize_schema -from dhee.core.events import RawEventStore, EventStatus -from dhee.core.derived_store import ( - BeliefStore, PolicyStore, AnchorStore, InsightStore, - HeuristicStore, DerivedLineageStore, CognitionStore, -) - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - -@pytest.fixture -def db_path(tmp_path): - return str(tmp_path / "test_v3_phases.db") - - -@pytest.fixture -def store(db_path): - s = CognitionStore(db_path=db_path) - yield s - s.close() - - -@pytest.fixture -def conn_lock(db_path): - """Shared connection + lock for lower-level store tests.""" - conn = sqlite3.connect(db_path, check_same_thread=False) - conn.execute("PRAGMA journal_mode=WAL") - conn.row_factory = sqlite3.Row - initialize_schema(conn) - lock = threading.RLock() - yield conn, lock - conn.close() - - -# ========================================================================= -# Phase 2: Anchor Resolver -# ========================================================================= - -class TestAnchorResolver: - - def test_submit_and_resolve(self, store): - from dhee.core.anchor_resolver import AnchorCandidateStore, AnchorResolver - - anchor_id = store.anchors.add(user_id="u1") - cand_store = AnchorCandidateStore(store.events._conn, store.events._lock) - resolver = AnchorResolver(cand_store, store.anchors) - - # Submit competing candidates for 'place' - cand_store.submit(anchor_id, "place", "Ghazipur", confidence=0.6) - cand_store.submit(anchor_id, "place", "Bengaluru", confidence=0.9) - - result = resolver.resolve(anchor_id) - assert result["resolved_fields"]["place"] == "Bengaluru" - assert result["details"]["place"]["confidence"] == 0.9 - - # Verify anchor was updated - anchor = store.anchors.get(anchor_id) - assert anchor["place"] == "Bengaluru" - - def test_re_anchor_correction(self, store): - from dhee.core.anchor_resolver import AnchorCandidateStore, AnchorResolver - - anchor_id = store.anchors.add(user_id="u1", place="Ghazipur") - cand_store = AnchorCandidateStore(store.events._conn, store.events._lock) - resolver = AnchorResolver(cand_store, store.anchors) - - # Initial candidate - cand_store.submit(anchor_id, "place", "Ghazipur", confidence=0.6) - resolver.resolve(anchor_id) - assert store.anchors.get(anchor_id)["place"] == "Ghazipur" - - # User corrects - result = resolver.re_anchor( - anchor_id, "place", "Delhi", confidence=0.95 - ) - assert result["resolved_fields"]["place"] == "Delhi" - assert store.anchors.get(anchor_id)["place"] == "Delhi" - - def test_extract_and_submit(self, store): - from dhee.core.anchor_resolver import AnchorCandidateStore, AnchorResolver - - anchor_id = store.anchors.add(user_id="u1") - cand_store = AnchorCandidateStore(store.events._conn, store.events._lock) - resolver = AnchorResolver(cand_store, store.anchors) - - cids = resolver.extract_and_submit( - anchor_id, "I was coding at the office today" - ) - assert len(cids) >= 1 # should detect 'coding' activity and 'office' place_type - - candidates = cand_store.get_candidates(anchor_id) - field_names = {c["field_name"] for c in candidates} - assert "activity" in field_names - - def test_invalid_field_rejected(self, store): - from dhee.core.anchor_resolver import AnchorCandidateStore - - cand_store = AnchorCandidateStore(store.events._conn, store.events._lock) - with pytest.raises(ValueError, match="Invalid anchor field"): - cand_store.submit("a1", "invalid_field", "value") - - -# ========================================================================= -# Phase 4: Distillation + Promotion -# ========================================================================= - -class TestDistillationPromotion: - - def test_submit_and_promote_belief(self, store): - from dhee.core.distillation import ( - DistillationStore, DistillationCandidate, distill_belief_from_events, - ) - from dhee.core.promotion import PromotionEngine - - conn = store.events._conn - lock = store.events._lock - - # Create source events - e1 = store.events.add(content="Python uses GIL", user_id="u1") - e2 = store.events.add(content="Python GIL limits threading", user_id="u1") - - # Distill a belief candidate - candidate = distill_belief_from_events( - [e1.to_dict(), e2.to_dict()], - user_id="u1", domain="programming", - ) - assert candidate is not None - - # Submit to distillation store - dist_store = DistillationStore(conn, lock) - cid = dist_store.submit(candidate) - assert cid is not None - - # Promote - engine = PromotionEngine( - distillation=dist_store, - beliefs=store.beliefs, - policies=store.policies, - insights=store.insights, - heuristics=store.heuristics, - lineage=store.lineage, - ) - result = engine.promote_pending(target_type="belief") - assert result.promoted # at least one promoted - - # Verify lineage was created - promoted_id = result.promoted[0] - sources = store.lineage.get_sources("belief", promoted_id) - assert len(sources) >= 1 - - def test_idempotent_dedup(self, store): - from dhee.core.distillation import DistillationStore, DistillationCandidate - - conn = store.events._conn - lock = store.events._lock - dist_store = DistillationStore(conn, lock) - - candidate = DistillationCandidate( - candidate_id="c1", - source_event_ids=["e1", "e2"], - target_type="belief", - canonical_key="test_dedup", - payload={"user_id": "u1", "claim": "test"}, - ) - - cid1 = dist_store.submit(candidate) - assert cid1 == "c1" - - # Same idempotency key — should be deduped - candidate2 = DistillationCandidate( - candidate_id="c2", - source_event_ids=["e1", "e2"], - target_type="belief", - canonical_key="test_dedup", - payload={"user_id": "u1", "claim": "test"}, - ) - cid2 = dist_store.submit(candidate2) - assert cid2 is None # deduped - - def test_low_confidence_rejected(self, store): - from dhee.core.distillation import DistillationStore, DistillationCandidate - from dhee.core.promotion import PromotionEngine - - conn = store.events._conn - lock = store.events._lock - dist_store = DistillationStore(conn, lock) - - candidate = DistillationCandidate( - candidate_id="c-low", - source_event_ids=["e1"], - target_type="belief", - canonical_key="low_conf", - confidence=0.1, # below MIN_PROMOTION_CONFIDENCE (0.3) - payload={"user_id": "u1", "claim": "uncertain thing"}, - ) - dist_store.submit(candidate) - - engine = PromotionEngine( - distillation=dist_store, - beliefs=store.beliefs, - policies=store.policies, - insights=store.insights, - heuristics=store.heuristics, - lineage=store.lineage, - ) - result = engine.promote_pending() - assert "c-low" in result.rejected - - -# ========================================================================= -# Phase 5: Lease Manager + Job Registry -# ========================================================================= - -class TestLeaseManager: - - def test_acquire_release(self, conn_lock): - from dhee.core.lease_manager import LeaseManager - - conn, lock = conn_lock - lm = LeaseManager(conn, lock) - - assert lm.acquire("job-1", "worker-a") is True - assert lm.is_held("job-1") is True - assert lm.get_holder("job-1") == "worker-a" - - # Different worker can't acquire - assert lm.acquire("job-1", "worker-b") is False - - # Release - assert lm.release("job-1", "worker-a") is True - assert lm.is_held("job-1") is False - - def test_same_owner_renew(self, conn_lock): - from dhee.core.lease_manager import LeaseManager - - conn, lock = conn_lock - lm = LeaseManager(conn, lock) - - lm.acquire("job-1", "worker-a") - assert lm.renew("job-1", "worker-a") is True - - def test_wrong_owner_cant_release(self, conn_lock): - from dhee.core.lease_manager import LeaseManager - - conn, lock = conn_lock - lm = LeaseManager(conn, lock) - - lm.acquire("job-1", "worker-a") - assert lm.release("job-1", "worker-b") is False - - -class TestJobRegistry: - - def test_register_and_run(self, conn_lock): - from dhee.core.lease_manager import LeaseManager - from dhee.core.jobs import JobRegistry, Job - - conn, lock = conn_lock - lm = LeaseManager(conn, lock) - registry = JobRegistry(conn, lock, lm) - - class TestJob(Job): - name = "test_job" - def execute(self, payload): - return {"sum": payload.get("a", 0) + payload.get("b", 0)} - - registry.register(TestJob) - result = registry.run("test_job", payload={"a": 3, "b": 7}) - assert result["status"] == "completed" - assert result["result"]["sum"] == 10 - - def test_run_unknown_job(self, conn_lock): - from dhee.core.lease_manager import LeaseManager - from dhee.core.jobs import JobRegistry - - conn, lock = conn_lock - lm = LeaseManager(conn, lock) - registry = JobRegistry(conn, lock, lm) - - result = registry.run("nonexistent") - assert result["status"] == "error" - - def test_health_check(self, conn_lock): - from dhee.core.lease_manager import LeaseManager - from dhee.core.jobs import JobRegistry, Job - - conn, lock = conn_lock - lm = LeaseManager(conn, lock) - registry = JobRegistry(conn, lock, lm) - - class NopJob(Job): - name = "nop" - def execute(self, payload): - return {} - - registry.register(NopJob) - registry.run("nop") - health = registry.get_health() - assert "nop" in health["job_status"] - assert health["job_status"]["nop"]["last_status"] == "completed" - - -# ========================================================================= -# Phase 7: RRF Fusion -# ========================================================================= - -class TestRRFFusion: - - def test_basic_fusion(self): - from dhee.core.fusion_v3 import RRFFusion, FusionCandidate, FusionConfig - - raw = [ - FusionCandidate( - row_id="r1", source_kind="raw", source_type="event", - source_id="e1", retrieval_text="raw fact", raw_score=0.9, - ), - FusionCandidate( - row_id="r2", source_kind="raw", source_type="event", - source_id="e2", retrieval_text="raw fact 2", raw_score=0.7, - ), - ] - distilled = [ - FusionCandidate( - row_id="d1", source_kind="distilled", source_type="belief", - source_id="b1", retrieval_text="distilled belief", raw_score=0.85, - confidence=0.9, - ), - ] - - fusion = RRFFusion(FusionConfig(final_top_n=5)) - results, breakdown = fusion.fuse(raw, distilled, query="test") - - assert len(results) >= 1 - assert breakdown.final_count >= 1 - # Distilled should rank high due to higher weight - top = results[0] - assert top.source_kind == "distilled" - - def test_staleness_penalty(self): - from dhee.core.fusion_v3 import RRFFusion, FusionCandidate - - fresh = FusionCandidate( - row_id="f1", source_kind="distilled", source_type="belief", - source_id="b1", retrieval_text="fresh", raw_score=0.8, - status="active", - ) - stale = FusionCandidate( - row_id="s1", source_kind="distilled", source_type="belief", - source_id="b2", retrieval_text="stale", raw_score=0.85, - status="stale", - ) - - fusion = RRFFusion() - results, _ = fusion.fuse([], [fresh, stale]) - - # Fresh should beat stale despite lower raw score - assert results[0].row_id == "f1" - - def test_invalidated_excluded(self): - from dhee.core.fusion_v3 import RRFFusion, FusionCandidate - - valid = FusionCandidate( - row_id="v1", source_kind="distilled", source_type="belief", - source_id="b1", retrieval_text="valid", raw_score=0.5, - ) - invalid = FusionCandidate( - row_id="i1", source_kind="distilled", source_type="belief", - source_id="b2", retrieval_text="invalidated", raw_score=0.9, - status="invalidated", - ) - - fusion = RRFFusion() - results, _ = fusion.fuse([], [valid, invalid]) - - # Invalidated should have score=0 and sort last - ids = [r.row_id for r in results if r.adjusted_score > 0] - assert "i1" not in ids - - def test_contradiction_penalty(self): - from dhee.core.fusion_v3 import RRFFusion, FusionCandidate - - clean = FusionCandidate( - row_id="c1", source_kind="distilled", source_type="belief", - source_id="b1", retrieval_text="clean", raw_score=0.8, - ) - conflicted = FusionCandidate( - row_id="c2", source_kind="distilled", source_type="belief", - source_id="b2", retrieval_text="conflicted", raw_score=0.85, - ) - - def checker(t, i): - return i == "b2" # b2 has conflicts - - fusion = RRFFusion() - results, _ = fusion.fuse([], [clean, conflicted], conflict_checker=checker) - assert results[0].row_id == "c1" # clean beats conflicted - - def test_breakdown_logged(self): - from dhee.core.fusion_v3 import RRFFusion, FusionCandidate - - raw = [FusionCandidate( - row_id="r1", source_kind="raw", source_type="event", - source_id="e1", retrieval_text="t", raw_score=0.5, - )] - dist = [FusionCandidate( - row_id="d1", source_kind="distilled", source_type="belief", - source_id="b1", retrieval_text="t", raw_score=0.5, - )] - - fusion = RRFFusion() - _, breakdown = fusion.fuse(raw, dist, query="test query") - - d = breakdown.to_dict() - assert "per_index_counts" in d - assert d["per_index_counts"]["raw"] == 1 - assert d["per_index_counts"]["distilled"] == 1 - - -# ========================================================================= -# Phase 8: Three-Tier Invalidation -# ========================================================================= - -class TestInvalidation: - - def test_hard_invalidation_sole_source(self, store): - from dhee.core.invalidation import InvalidationEngine - - e = store.events.add(content="false memory", user_id="u1") - bid = store.beliefs.add(user_id="u1", claim="false claim") - store.lineage.add("belief", bid, e.event_id, contribution_weight=1.0) - - engine = InvalidationEngine( - lineage=store.lineage, - stores={"belief": store.beliefs}, - conn=store.events._conn, - lock=store.events._lock, - ) - - # Delete the source → hard invalidation - store.events.delete(e.event_id) - result = engine.on_event_deleted(e.event_id) - - assert len(result["hard_invalidated"]) == 1 - belief = store.beliefs.get(bid) - assert belief["status"] == "invalidated" - - def test_soft_invalidation_sole_source(self, store): - from dhee.core.invalidation import InvalidationEngine - - e = store.events.add(content="old fact", user_id="u1") - bid = store.beliefs.add(user_id="u1", claim="old fact", confidence=0.8) - store.lineage.add("belief", bid, e.event_id, contribution_weight=1.0) - - engine = InvalidationEngine( - lineage=store.lineage, - stores={"belief": store.beliefs}, - conn=store.events._conn, - lock=store.events._lock, - ) - - # Correct the source → soft invalidation - store.events.correct(e.event_id, "new fact") - result = engine.on_event_corrected(e.event_id) - - assert len(result["soft_invalidated"]) == 1 - assert len(result["jobs_enqueued"]) >= 1 - belief = store.beliefs.get(bid) - assert belief["status"] == "stale" - - def test_partial_invalidation_minor_source(self, store): - from dhee.core.invalidation import InvalidationEngine - - e1 = store.events.add(content="main fact", user_id="u1") - e2 = store.events.add(content="supporting detail", user_id="u1") - - bid = store.beliefs.add(user_id="u1", claim="combined claim", confidence=0.8) - store.lineage.add("belief", bid, e1.event_id, contribution_weight=0.8) - store.lineage.add("belief", bid, e2.event_id, contribution_weight=0.2) - - engine = InvalidationEngine( - lineage=store.lineage, - stores={"belief": store.beliefs}, - conn=store.events._conn, - lock=store.events._lock, - ) - - # Correct the minor source (weight=0.2 < 0.3 threshold) - store.events.correct(e2.event_id, "updated detail") - result = engine.on_event_corrected(e2.event_id) - - assert len(result["partial_invalidated"]) == 1 - belief = store.beliefs.get(bid) - assert belief["status"] == "suspect" - - def test_partial_escalates_on_high_weight(self, store): - from dhee.core.invalidation import InvalidationEngine - - e1 = store.events.add(content="main", user_id="u1") - e2 = store.events.add(content="secondary", user_id="u1") - - bid = store.beliefs.add(user_id="u1", claim="test", confidence=0.8) - store.lineage.add("belief", bid, e1.event_id, contribution_weight=0.6) - store.lineage.add("belief", bid, e2.event_id, contribution_weight=0.4) - - engine = InvalidationEngine( - lineage=store.lineage, - stores={"belief": store.beliefs}, - conn=store.events._conn, - lock=store.events._lock, - ) - - # Correct the secondary source (weight=0.4 >= 0.3 threshold) → soft - store.events.correct(e2.event_id, "updated secondary") - result = engine.on_event_corrected(e2.event_id) - - assert len(result["soft_invalidated"]) == 1 # escalated to soft - assert len(result["partial_invalidated"]) == 0 - - -# ========================================================================= -# Phase 8: Conflicts -# ========================================================================= - -class TestConflicts: - - def test_create_and_auto_resolve(self, conn_lock): - from dhee.core.conflicts import ConflictStore - - conn, lock = conn_lock - cs = ConflictStore(conn, lock) - - # Clear confidence gap → auto-resolve - result = cs.create( - "belief_contradiction", - "belief", "b1", "belief", "b2", - side_a_confidence=0.95, - side_b_confidence=0.1, - ) - assert result["resolution_status"] == "auto_resolved" - assert result["auto_resolution"]["winner"] == "side_a" - - def test_no_auto_resolve_when_close(self, conn_lock): - from dhee.core.conflicts import ConflictStore - - conn, lock = conn_lock - cs = ConflictStore(conn, lock) - - result = cs.create( - "belief_contradiction", - "belief", "b1", "belief", "b2", - side_a_confidence=0.6, - side_b_confidence=0.5, - ) - assert result["resolution_status"] == "open" - - def test_manual_resolve(self, conn_lock): - from dhee.core.conflicts import ConflictStore - - conn, lock = conn_lock - cs = ConflictStore(conn, lock) - - result = cs.create( - "anchor_disagreement", - "anchor", "a1", "anchor", "a2", - ) - cid = result["conflict_id"] - assert cs.resolve(cid, {"winner": "a1", "reason": "user chose"}) - - conflict = cs.get(cid) - assert conflict["resolution_status"] == "user_resolved" - - def test_has_open_conflicts(self, conn_lock): - from dhee.core.conflicts import ConflictStore - - conn, lock = conn_lock - cs = ConflictStore(conn, lock) - - cs.create("belief_contradiction", "belief", "b1", "belief", "b2") - assert cs.has_open_conflicts("belief", "b1") is True - assert cs.has_open_conflicts("belief", "b999") is False - - def test_count_open(self, conn_lock): - from dhee.core.conflicts import ConflictStore - - conn, lock = conn_lock - cs = ConflictStore(conn, lock) - - cs.create("belief_contradiction", "belief", "x1", "belief", "x2") - cs.create("distillation_conflict", "insight", "i1", "insight", "i2") - assert cs.count_open() == 2 - - -# ========================================================================= -# Phase 6: Read Model -# ========================================================================= - -class TestReadModel: - - def test_refresh_and_query(self, store): - from dhee.core.read_model import ReadModel - - conn = store.events._conn - lock = store.events._lock - rm = ReadModel(conn, lock) - - # Populate - store.events.add(content="raw fact 1", user_id="u1") - store.events.add(content="raw fact 2", user_id="u1") - store.beliefs.add(user_id="u1", claim="belief 1", confidence=0.8) - - counts = rm.refresh( - "u1", - events_store=store.events, - beliefs_store=store.beliefs, - ) - assert counts["raw_events"] == 2 - assert counts["beliefs"] == 1 - - results = rm.query("u1") - assert len(results) == 3 - - # Filter by kind - raw_only = rm.query("u1", source_kind="raw") - assert len(raw_only) == 2 - - def test_delta_overlay(self, store): - from dhee.core.read_model import ReadModel - - conn = store.events._conn - lock = store.events._lock - rm = ReadModel(conn, lock) - - # Add event before refresh - store.events.add(content="before refresh", user_id="u1") - rm.refresh("u1", events_store=store.events) - - # Add event after refresh - since = rm.last_refresh - store.events.add(content="after refresh", user_id="u1") - - delta = rm.get_delta("u1", since, events_store=store.events) - assert len(delta) == 1 - assert delta[0]["retrieval_text"] == "after refresh" - - def test_invalidated_excluded(self, store): - from dhee.core.read_model import ReadModel - - conn = store.events._conn - lock = store.events._lock - rm = ReadModel(conn, lock) - - bid = store.beliefs.add(user_id="u1", claim="will invalidate") - store.beliefs.set_status(bid, "invalidated") - - rm.refresh("u1", beliefs_store=store.beliefs) - results = rm.query("u1") - # Invalidated beliefs are skipped during refresh - assert all(r["source_id"] != bid for r in results) - - -# ========================================================================= -# Phase 9: Observability -# ========================================================================= - -class TestV3Health: - - def test_health_metrics(self, store): - from dhee.core.v3_health import v3_health - - conn = store.events._conn - lock = store.events._lock - - store.events.add(content="fact", user_id="u1") - bid = store.beliefs.add(user_id="u1", claim="test") - store.beliefs.set_status(bid, "stale") - - health = v3_health(conn, lock, user_id="u1") - - assert health["raw_events_active"] == 1 - assert health["derived_invalidation"]["beliefs"]["stale"] == 1 - assert "v3_warnings" in health - - def test_health_no_user_filter(self, store): - from dhee.core.v3_health import v3_health - - conn = store.events._conn - lock = store.events._lock - - store.events.add(content="a", user_id="u1") - store.events.add(content="b", user_id="u2") - - health = v3_health(conn, lock) - assert health["raw_events_active"] == 2 - - -# ========================================================================= -# Phase 10: Migration -# ========================================================================= - -class TestMigration: - - def test_dual_write(self, db_path): - from dhee.core.v3_migration import V3MigrationBridge - - cs = CognitionStore(db_path=db_path) - bridge = V3MigrationBridge(v3_store=cs) - - eid = bridge.on_remember("test fact", "u1", v2_memory_id="v2-123") - assert eid is not None - - event = cs.events.get(eid) - assert event.content == "test fact" - assert event.metadata.get("v2_memory_id") == "v2-123" - cs.close() - - def test_backfill(self, db_path): - from dhee.core.v3_migration import V3MigrationBridge - - cs = CognitionStore(db_path=db_path) - bridge = V3MigrationBridge(v3_store=cs) - - v2_memories = [ - {"memory": "fact 1", "id": "m1", "layer": "sml"}, - {"memory": "fact 2", "id": "m2", "layer": "lml"}, - {"memory": "fact 1", "id": "m3"}, # duplicate content - ] - - stats = bridge.backfill_from_v2(v2_memories, user_id="u1") - assert stats["created"] == 2 - assert stats["skipped_dedup"] == 1 - assert stats["total"] == 3 - - # Idempotent — running again should skip all - stats2 = bridge.backfill_from_v2(v2_memories, user_id="u1") - assert stats2["created"] == 0 - # All 3 are deduped: 2 unique already exist + 1 duplicate content - assert stats2["skipped_dedup"] == 3 - cs.close() - - def test_correction_bridge(self, db_path): - from dhee.core.v3_migration import V3MigrationBridge - - cs = CognitionStore(db_path=db_path) - bridge = V3MigrationBridge(v3_store=cs) - - # First add the original - bridge.on_remember("I live in Ghazipur", "u1") - - # Then correct - eid = bridge.on_correction("I live in Ghazipur", "I live in Bengaluru", "u1") - assert eid is not None - - event = cs.events.get(eid) - assert event.content == "I live in Bengaluru" - assert event.supersedes_event_id is not None - cs.close() - - def test_disabled_bridge(self): - from dhee.core.v3_migration import V3MigrationBridge - bridge = V3MigrationBridge(v3_store=None) - assert bridge.on_remember("test", "u1") is None - assert bridge.should_use_v3_read() is False - - -# ========================================================================= -# Phase 3: Sparse to_dict -# ========================================================================= - -class TestSparseDict: - - def test_sparse_omits_empty(self): - from dhee.core.engram import UniversalEngram - - e = UniversalEngram( - id="test-1", - raw_content="hello", - strength=1.0, - user_id="u1", - ) - full = e.to_dict() - sparse = e.to_dict(sparse=True) - - # Sparse should be smaller - assert len(sparse) < len(full) - # Should keep non-empty values - assert sparse["id"] == "test-1" - assert sparse["raw_content"] == "hello" - # Should omit empty lists, None, empty strings - assert "echo" not in sparse or sparse.get("echo") != [] - - def test_full_preserves_all(self): - from dhee.core.engram import UniversalEngram - - e = UniversalEngram(id="test-2", raw_content="x") - full = e.to_dict() - assert "echo" in full # even empty values - assert "entities" in full diff --git a/tests/test_v3_jobs.py b/tests/test_v3_jobs.py deleted file mode 100644 index 6eff31a..0000000 --- a/tests/test_v3_jobs.py +++ /dev/null @@ -1,706 +0,0 @@ -"""Tests for Dhee v3 Sprint 2: Lease Manager, Job Registry, Distillation, Promotion. - -Covers: -- LeaseManager: acquire, release, renew, expiry steal, cleanup -- JobRegistry: registration, execution, idempotency, history, health -- DistillationStore: submit, dedup, status transitions -- PromotionEngine: validation, type-specific promotion, lineage, batch -- ConsolidationEngine: feedback loop prevention -""" - -import json -import sqlite3 -import threading -import time - -import pytest - -from dhee.core.storage import initialize_schema -from dhee.core.lease_manager import LeaseManager -from dhee.core.jobs import Job, JobRegistry, ApplyForgettingJob -from dhee.core.distillation import ( - DistillationCandidate, - DistillationStore, - compute_idempotency_key, - distill_belief_from_events, - DERIVATION_VERSION, -) -from dhee.core.promotion import PromotionEngine, PromotionResult -from dhee.core.derived_store import ( - BeliefStore, - PolicyStore, - InsightStore, - HeuristicStore, - DerivedLineageStore, - CognitionStore, -) - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - -@pytest.fixture -def db_conn(tmp_path): - """Shared connection + lock for all Sprint 2 tests.""" - db_path = str(tmp_path / "test_v3_sprint2.db") - conn = sqlite3.connect(db_path, check_same_thread=False) - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA busy_timeout=5000") - conn.row_factory = sqlite3.Row - initialize_schema(conn) - lock = threading.RLock() - yield conn, lock - conn.close() - - -@pytest.fixture -def lease_manager(db_conn): - conn, lock = db_conn - return LeaseManager(conn, lock, default_duration_seconds=10) - - -@pytest.fixture -def stores(db_conn): - """All derived stores sharing one connection.""" - conn, lock = db_conn - return { - "beliefs": BeliefStore(conn, lock), - "policies": PolicyStore(conn, lock), - "insights": InsightStore(conn, lock), - "heuristics": HeuristicStore(conn, lock), - "lineage": DerivedLineageStore(conn, lock), - "distillation": DistillationStore(conn, lock), - } - - -@pytest.fixture -def promotion_engine(stores): - return PromotionEngine( - distillation=stores["distillation"], - beliefs=stores["beliefs"], - policies=stores["policies"], - insights=stores["insights"], - heuristics=stores["heuristics"], - lineage=stores["lineage"], - min_confidence=0.3, - ) - - -# ========================================================================= -# LeaseManager Tests -# ========================================================================= - -class TestLeaseManager: - - def test_acquire_new(self, lease_manager): - assert lease_manager.acquire("job:decay", "worker-1") is True - assert lease_manager.is_held("job:decay") is True - assert lease_manager.get_holder("job:decay") == "worker-1" - - def test_acquire_same_owner_renews(self, lease_manager): - assert lease_manager.acquire("job:decay", "worker-1") is True - assert lease_manager.acquire("job:decay", "worker-1") is True # renew - assert lease_manager.get_holder("job:decay") == "worker-1" - - def test_acquire_different_owner_blocked(self, lease_manager): - assert lease_manager.acquire("job:decay", "worker-1") is True - assert lease_manager.acquire("job:decay", "worker-2") is False - assert lease_manager.get_holder("job:decay") == "worker-1" - - def test_release(self, lease_manager): - lease_manager.acquire("job:decay", "worker-1") - assert lease_manager.release("job:decay", "worker-1") is True - assert lease_manager.is_held("job:decay") is False - - def test_release_wrong_owner(self, lease_manager): - lease_manager.acquire("job:decay", "worker-1") - assert lease_manager.release("job:decay", "worker-2") is False - assert lease_manager.is_held("job:decay") is True - - def test_release_nonexistent(self, lease_manager): - assert lease_manager.release("nonexistent", "w1") is False - - def test_renew(self, lease_manager): - lease_manager.acquire("job:decay", "worker-1") - assert lease_manager.renew("job:decay", "worker-1") is True - - def test_renew_wrong_owner(self, lease_manager): - lease_manager.acquire("job:decay", "worker-1") - assert lease_manager.renew("job:decay", "worker-2") is False - - def test_acquire_expired_lease(self, lease_manager): - """Expired lease can be stolen by another worker.""" - # Acquire with very short duration - assert lease_manager.acquire("job:decay", "worker-1", duration_seconds=1) is True - # Wait for expiry - time.sleep(1.1) - # Another worker can steal it - assert lease_manager.acquire("job:decay", "worker-2") is True - assert lease_manager.get_holder("job:decay") == "worker-2" - - def test_is_held_expired(self, lease_manager): - lease_manager.acquire("job:x", "w1", duration_seconds=1) - time.sleep(1.1) - assert lease_manager.is_held("job:x") is False - assert lease_manager.get_holder("job:x") is None - - def test_cleanup_expired(self, lease_manager): - lease_manager.acquire("job:a", "w1", duration_seconds=1) - lease_manager.acquire("job:b", "w1", duration_seconds=1) - lease_manager.acquire("job:c", "w1", duration_seconds=300) - time.sleep(1.1) - cleaned = lease_manager.cleanup_expired() - assert cleaned == 2 # a and b expired, c still held - - def test_multiple_locks_independent(self, lease_manager): - lease_manager.acquire("job:a", "w1") - lease_manager.acquire("job:b", "w2") - assert lease_manager.get_holder("job:a") == "w1" - assert lease_manager.get_holder("job:b") == "w2" - - -# ========================================================================= -# JobRegistry Tests -# ========================================================================= - -class _TestJob(Job): - name = "test_job" - - def execute(self, payload): - return {"echo": payload.get("value", "none")} - - -class _FailingJob(Job): - name = "failing_job" - - def execute(self, payload): - raise RuntimeError("intentional failure") - - -class _IdempotentJob(Job): - name = "idempotent_job" - - def execute(self, payload): - return {"processed": True} - - def make_idempotency_key(self, payload): - return f"idem:{payload.get('batch_id', '')}" - - -class TestJobRegistry: - - def test_register_and_list(self, db_conn, lease_manager): - conn, lock = db_conn - registry = JobRegistry(conn, lock, lease_manager) - registry.register(_TestJob) - assert "test_job" in registry.list_registered() - - def test_run_success(self, db_conn, lease_manager): - conn, lock = db_conn - registry = JobRegistry(conn, lock, lease_manager) - registry.register(_TestJob) - - result = registry.run("test_job", payload={"value": "hello"}) - assert result["status"] == "completed" - assert result["result"]["echo"] == "hello" - assert result["job_id"] - - def test_run_failure(self, db_conn, lease_manager): - conn, lock = db_conn - registry = JobRegistry(conn, lock, lease_manager) - registry.register(_FailingJob) - - result = registry.run("failing_job") - assert result["status"] == "failed" - assert "intentional failure" in result["error"] - - def test_run_unknown_job(self, db_conn, lease_manager): - conn, lock = db_conn - registry = JobRegistry(conn, lock, lease_manager) - - result = registry.run("nonexistent") - assert result["status"] == "error" - - def test_idempotency(self, db_conn, lease_manager): - conn, lock = db_conn - registry = JobRegistry(conn, lock, lease_manager) - registry.register(_IdempotentJob) - - r1 = registry.run("idempotent_job", payload={"batch_id": "b1"}) - assert r1["status"] == "completed" - - r2 = registry.run("idempotent_job", payload={"batch_id": "b1"}) - assert r2["status"] == "skipped_idempotent" - - # Different batch_id should run - r3 = registry.run("idempotent_job", payload={"batch_id": "b2"}) - assert r3["status"] == "completed" - - def test_job_history(self, db_conn, lease_manager): - conn, lock = db_conn - registry = JobRegistry(conn, lock, lease_manager) - registry.register(_TestJob) - - registry.run("test_job", payload={"n": 1}) - registry.run("test_job", payload={"n": 2}) - - history = registry.get_job_history("test_job", limit=5) - assert len(history) == 2 - assert history[0]["status"] == "completed" - - def test_health(self, db_conn, lease_manager): - conn, lock = db_conn - registry = JobRegistry(conn, lock, lease_manager) - registry.register(_TestJob) - - health = registry.get_health() - assert health["total_registered"] == 1 - assert "test_job" in health["job_status"] - assert health["job_status"]["test_job"]["last_status"] == "never_run" - - registry.run("test_job") - health = registry.get_health() - assert health["job_status"]["test_job"]["last_status"] == "completed" - - def test_lease_prevents_concurrent(self, db_conn, lease_manager): - """Two sequential runs of same job: first completes, second runs too - (lease released). But if lease held, second is blocked.""" - conn, lock = db_conn - registry = JobRegistry(conn, lock, lease_manager) - registry.register(_TestJob) - - # First run — should succeed - r1 = registry.run("test_job", owner_id="w1") - assert r1["status"] == "completed" - - # Second run — lease was released, should succeed - r2 = registry.run("test_job", owner_id="w2") - assert r2["status"] == "completed" - - def test_run_all(self, db_conn, lease_manager): - conn, lock = db_conn - registry = JobRegistry(conn, lock, lease_manager) - registry.register(_TestJob) - registry.register(_IdempotentJob) - - results = registry.run_all() - assert len(results) == 2 - completed = [r for r in results if r["status"] == "completed"] - assert len(completed) == 2 - - -# ========================================================================= -# Distillation Tests -# ========================================================================= - -class TestDistillation: - - def test_idempotency_key(self): - k1 = compute_idempotency_key(["e1", "e2"], 1, "belief:u1:test") - k2 = compute_idempotency_key(["e2", "e1"], 1, "belief:u1:test") # sorted - assert k1 == k2 # Same regardless of order - - k3 = compute_idempotency_key(["e1", "e2"], 2, "belief:u1:test") # diff version - assert k1 != k3 - - def test_candidate_auto_key(self): - c = DistillationCandidate( - candidate_id="c1", - source_event_ids=["e1", "e2"], - target_type="belief", - canonical_key="belief:u1:test", - payload={"claim": "test"}, - ) - assert c.idempotency_key - assert len(c.idempotency_key) == 24 - - def test_submit_and_get(self, stores): - ds = stores["distillation"] - candidate = DistillationCandidate( - candidate_id="c1", - source_event_ids=["e1"], - target_type="belief", - canonical_key="belief:u1:test", - payload={"user_id": "u1", "claim": "test claim"}, - confidence=0.6, - ) - result = ds.submit(candidate) - assert result == "c1" - - fetched = ds.get("c1") - assert fetched is not None - assert fetched["target_type"] == "belief" - assert fetched["status"] == "pending_validation" - - def test_submit_dedup(self, stores): - ds = stores["distillation"] - - c1 = DistillationCandidate( - candidate_id="c1", - source_event_ids=["e1"], - target_type="belief", - canonical_key="belief:u1:test", - payload={"claim": "test"}, - ) - c2 = DistillationCandidate( - candidate_id="c2", - source_event_ids=["e1"], # same source + same key - target_type="belief", - canonical_key="belief:u1:test", - payload={"claim": "test"}, - ) - - assert ds.submit(c1) == "c1" - assert ds.submit(c2) is None # dedup - - def test_submit_after_reject(self, stores): - ds = stores["distillation"] - - c1 = DistillationCandidate( - candidate_id="c1", - source_event_ids=["e1"], - target_type="belief", - canonical_key="belief:u1:test", - payload={"claim": "test"}, - ) - ds.submit(c1) - ds.set_status("c1", "rejected") - - # Same idempotency key, but rejected — should allow resubmit - c2 = DistillationCandidate( - candidate_id="c2", - source_event_ids=["e1"], - target_type="belief", - canonical_key="belief:u1:test", - payload={"claim": "test v2"}, - ) - assert ds.submit(c2) == "c2" - - def test_get_pending(self, stores): - ds = stores["distillation"] - - ds.submit(DistillationCandidate( - candidate_id="c1", source_event_ids=["e1"], - target_type="belief", canonical_key="k1", - payload={"claim": "a"}, confidence=0.8, - )) - ds.submit(DistillationCandidate( - candidate_id="c2", source_event_ids=["e2"], - target_type="policy", canonical_key="k2", - payload={"name": "b"}, confidence=0.5, - )) - - pending = ds.get_pending() - assert len(pending) == 2 - - beliefs_only = ds.get_pending("belief") - assert len(beliefs_only) == 1 - - def test_distill_belief_from_events(self): - events = [ - {"event_id": "e1", "content": "User prefers dark mode"}, - {"event_id": "e2", "content": "User prefers dark mode"}, - ] - candidate = distill_belief_from_events(events, user_id="u1") - assert candidate is not None - assert candidate.target_type == "belief" - assert candidate.confidence == 0.5 # 0.3 + 0.1 * 2 - assert len(candidate.source_event_ids) == 2 - - def test_distill_empty_events(self): - assert distill_belief_from_events([], user_id="u1") is None - - -# ========================================================================= -# Promotion Tests -# ========================================================================= - -class TestPromotion: - - def test_promote_belief(self, stores, promotion_engine): - ds = stores["distillation"] - - ds.submit(DistillationCandidate( - candidate_id="c1", - source_event_ids=["e1", "e2"], - target_type="belief", - canonical_key="belief:u1:test", - payload={"user_id": "u1", "claim": "Python is great", "domain": "tech"}, - confidence=0.7, - )) - - result = promotion_engine.promote_pending("belief") - assert result.to_dict()["promoted"] == 1 - - # Verify belief was created - beliefs = stores["beliefs"].list_by_user("u1") - assert len(beliefs) == 1 - assert beliefs[0]["claim"] == "Python is great" - - # Verify lineage was written - lineage = stores["lineage"].get_sources("belief", beliefs[0]["belief_id"]) - assert len(lineage) == 2 # from e1 and e2 - - # Verify candidate was marked promoted - candidate = ds.get("c1") - assert candidate["status"] == "promoted" - assert candidate["promoted_id"] == beliefs[0]["belief_id"] - - def test_promote_policy(self, stores, promotion_engine): - ds = stores["distillation"] - - ds.submit(DistillationCandidate( - candidate_id="c1", - source_event_ids=["e1"], - target_type="policy", - canonical_key="policy:u1:blame", - payload={ - "user_id": "u1", - "name": "git blame first", - "condition": {"task_types": ["bug_fix"]}, - "action": {"approach": "Run git blame"}, - }, - confidence=0.6, - )) - - result = promotion_engine.promote_pending("policy") - assert result.to_dict()["promoted"] == 1 - - policies = stores["policies"].list_by_user("u1") - assert len(policies) == 1 - assert policies[0]["name"] == "git blame first" - - def test_promote_insight(self, stores, promotion_engine): - ds = stores["distillation"] - - ds.submit(DistillationCandidate( - candidate_id="c1", - source_event_ids=["e1"], - target_type="insight", - canonical_key="insight:u1:tokens", - payload={ - "user_id": "u1", - "content": "Token expiry causes outages in production", - "insight_type": "causal", - }, - confidence=0.5, - )) - - result = promotion_engine.promote_pending("insight") - assert result.to_dict()["promoted"] == 1 - - def test_promote_heuristic(self, stores, promotion_engine): - ds = stores["distillation"] - - ds.submit(DistillationCandidate( - candidate_id="c1", - source_event_ids=["e1"], - target_type="heuristic", - canonical_key="heuristic:u1:constrained", - payload={ - "user_id": "u1", - "content": "Start with the most constrained component first", - "abstraction_level": "universal", - }, - confidence=0.6, - )) - - result = promotion_engine.promote_pending("heuristic") - assert result.to_dict()["promoted"] == 1 - - def test_reject_low_confidence(self, stores, promotion_engine): - ds = stores["distillation"] - - ds.submit(DistillationCandidate( - candidate_id="c1", - source_event_ids=["e1"], - target_type="belief", - canonical_key="belief:u1:weak", - payload={"user_id": "u1", "claim": "maybe"}, - confidence=0.1, # Below min_confidence of 0.3 - )) - - result = promotion_engine.promote_pending("belief") - assert result.to_dict()["rejected"] == 1 - assert result.to_dict()["promoted"] == 0 - - def test_reject_empty_payload(self, stores, promotion_engine): - ds = stores["distillation"] - - ds.submit(DistillationCandidate( - candidate_id="c1", - source_event_ids=["e1"], - target_type="belief", - canonical_key="belief:u1:empty", - payload={}, - confidence=0.8, - )) - - result = promotion_engine.promote_pending("belief") - assert result.to_dict()["rejected"] == 1 - - def test_reject_short_claim(self, stores, promotion_engine): - ds = stores["distillation"] - - ds.submit(DistillationCandidate( - candidate_id="c1", - source_event_ids=["e1"], - target_type="belief", - canonical_key="belief:u1:short", - payload={"user_id": "u1", "claim": "hi"}, # < 5 chars - confidence=0.8, - )) - - result = promotion_engine.promote_pending("belief") - assert result.to_dict()["rejected"] == 1 - - def test_promote_single(self, stores, promotion_engine): - ds = stores["distillation"] - - ds.submit(DistillationCandidate( - candidate_id="c1", - source_event_ids=["e1"], - target_type="belief", - canonical_key="belief:u1:single", - payload={"user_id": "u1", "claim": "Test single promotion"}, - confidence=0.7, - )) - - result = promotion_engine.promote_single("c1") - assert result["status"] == "promoted" - assert result["promoted_id"] - - def test_promote_single_nonexistent(self, promotion_engine): - result = promotion_engine.promote_single("nonexistent") - assert result["status"] == "error" - - def test_promote_single_already_promoted(self, stores, promotion_engine): - ds = stores["distillation"] - - ds.submit(DistillationCandidate( - candidate_id="c1", - source_event_ids=["e1"], - target_type="belief", - canonical_key="belief:u1:double", - payload={"user_id": "u1", "claim": "Test double promotion"}, - confidence=0.7, - )) - - r1 = promotion_engine.promote_single("c1") - assert r1["status"] == "promoted" - - r2 = promotion_engine.promote_single("c1") - assert r2["status"] == "skipped" - - def test_batch_promotion(self, stores, promotion_engine): - ds = stores["distillation"] - - for i in range(5): - ds.submit(DistillationCandidate( - candidate_id=f"c{i}", - source_event_ids=[f"e{i}"], - target_type="belief", - canonical_key=f"belief:u1:batch{i}", - payload={"user_id": "u1", "claim": f"Batch claim number {i}"}, - confidence=0.6, - )) - - result = promotion_engine.promote_pending("belief", limit=10) - assert result.to_dict()["promoted"] == 5 - - beliefs = stores["beliefs"].list_by_user("u1") - assert len(beliefs) == 5 - - -# ========================================================================= -# Consolidation Feedback Loop Tests -# ========================================================================= - -class TestConsolidationSafety: - - def test_should_promote_rejects_consolidated(self): - """Signals with source='consolidated' must be rejected.""" - from dhee.core.consolidation import ConsolidationEngine - - # We can't easily instantiate ConsolidationEngine without full deps, - # so test the logic directly by checking the source code contract. - import inspect - src = inspect.getsource(ConsolidationEngine._should_promote) - - # Verify the feedback loop guard is present - assert "consolidated" in src, ( - "ConsolidationEngine._should_promote must check for " - "'consolidated' source to prevent feedback loops" - ) - assert "consolidated_from" in src, ( - "ConsolidationEngine._should_promote must check for " - "'consolidated_from' metadata to prevent re-consolidation" - ) - - def test_promote_uses_infer_false(self): - """_promote_to_passive must use infer=False to skip enrichment.""" - from dhee.core.consolidation import ConsolidationEngine - import inspect - src = inspect.getsource(ConsolidationEngine._promote_to_passive) - - assert "infer=False" in src, ( - "ConsolidationEngine._promote_to_passive must use infer=False " - "to skip the LLM enrichment pipeline" - ) - - def test_promote_tags_provenance(self): - """_promote_to_passive must tag consolidated provenance.""" - from dhee.core.consolidation import ConsolidationEngine - import inspect - src = inspect.getsource(ConsolidationEngine._promote_to_passive) - - assert '"source": "consolidated"' in src or "'source': 'consolidated'" in src, ( - "ConsolidationEngine._promote_to_passive must tag " - "promoted memories with source='consolidated'" - ) - - -# ========================================================================= -# AGI Loop Cleanup Tests -# ========================================================================= - -class TestAgiLoopCleanup: - - def test_no_phantom_imports(self): - """agi_loop.py must not import non-existent engram_* packages.""" - import inspect - from dhee.core import agi_loop - src = inspect.getsource(agi_loop) - - phantom_packages = [ - "engram_reconsolidation", - "engram_procedural", - "engram_metamemory", - "engram_prospective", - "engram_working", - "engram_failure", - "engram_router", - "engram_identity", - "engram_heartbeat", - "engram_policy", - "engram_skills", - "engram_spawn", - "engram_resilience", - ] - - for pkg in phantom_packages: - assert pkg not in src, ( - f"agi_loop.py still references phantom package '{pkg}'. " - f"All engram_* phantom imports must be removed." - ) - - def test_run_agi_cycle_api_preserved(self): - """run_agi_cycle function must still exist (backward compat).""" - from dhee.core.agi_loop import run_agi_cycle - assert callable(run_agi_cycle) - - def test_get_system_health_api_preserved(self): - """get_system_health function must still exist (backward compat).""" - from dhee.core.agi_loop import get_system_health - assert callable(get_system_health) diff --git a/tests/test_v3_storage.py b/tests/test_v3_storage.py deleted file mode 100644 index 344a4a8..0000000 --- a/tests/test_v3_storage.py +++ /dev/null @@ -1,738 +0,0 @@ -"""Tests for Dhee v3 event-sourced storage layer. - -Covers: -- RawEventStore: add, dedup, correct, delete, supersedes chain -- BeliefStore: CRUD, confidence updates, contradiction tracking -- PolicyStore: CRUD, outcome recording, status transitions -- AnchorStore: CRUD, field updates, filtering -- InsightStore: CRUD, outcome recording -- HeuristicStore: CRUD, outcome recording -- DerivedLineageStore: source/dependent queries, contribution weights -- CognitionStore: coordinator integration -""" - -import os -import sqlite3 -import tempfile - -import pytest - -from dhee.core.events import RawEventStore, RawMemoryEvent, EventStatus -from dhee.core.derived_store import ( - BeliefStore, - PolicyStore, - AnchorStore, - InsightStore, - HeuristicStore, - DerivedLineageStore, - CognitionStore, -) -from dhee.core.storage import initialize_schema - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - -@pytest.fixture -def db_path(tmp_path): - return str(tmp_path / "test_v3.db") - - -@pytest.fixture -def event_store(db_path): - store = RawEventStore(db_path=db_path) - yield store - store.close() - - -@pytest.fixture -def cognition_store(db_path): - store = CognitionStore(db_path=db_path) - yield store - store.close() - - -@pytest.fixture -def shared_conn(db_path): - """Shared connection + lock for derived store tests.""" - import threading - conn = sqlite3.connect(db_path, check_same_thread=False) - conn.execute("PRAGMA journal_mode=WAL") - conn.row_factory = sqlite3.Row - initialize_schema(conn) - lock = threading.RLock() - yield conn, lock - conn.close() - - -# ========================================================================= -# RawEventStore Tests -# ========================================================================= - -class TestRawEventStore: - - def test_add_basic(self, event_store): - event = event_store.add(content="User prefers dark mode", user_id="u1") - assert event.event_id - assert event.content == "User prefers dark mode" - assert event.user_id == "u1" - assert event.status == EventStatus.ACTIVE - assert event.content_hash == RawMemoryEvent.compute_hash("User prefers dark mode") - - def test_add_with_metadata(self, event_store): - event = event_store.add( - content="test", - user_id="u1", - session_id="s1", - source="mcp", - metadata={"key": "value"}, - ) - assert event.session_id == "s1" - assert event.source == "mcp" - assert event.metadata == {"key": "value"} - - def test_dedup_same_content(self, event_store): - e1 = event_store.add(content="same fact", user_id="u1") - e2 = event_store.add(content="same fact", user_id="u1") - assert e1.event_id == e2.event_id # dedup returns existing - - def test_dedup_different_users(self, event_store): - e1 = event_store.add(content="shared fact", user_id="u1") - e2 = event_store.add(content="shared fact", user_id="u2") - assert e1.event_id != e2.event_id # different users = no dedup - - def test_get(self, event_store): - e = event_store.add(content="test get", user_id="u1") - fetched = event_store.get(e.event_id) - assert fetched is not None - assert fetched.content == "test get" - - def test_get_nonexistent(self, event_store): - assert event_store.get("nonexistent") is None - - def test_get_by_hash(self, event_store): - e = event_store.add(content="hash lookup", user_id="u1") - found = event_store.get_by_hash(e.content_hash, "u1") - assert found is not None - assert found.event_id == e.event_id - - def test_correct(self, event_store): - original = event_store.add(content="I live in Ghazipur", user_id="u1") - correction = event_store.correct( - original.event_id, "I live in Bengaluru" - ) - - assert correction.supersedes_event_id == original.event_id - assert correction.content == "I live in Bengaluru" - assert correction.status == EventStatus.ACTIVE - - # Original should now be 'corrected' - old = event_store.get(original.event_id) - assert old.status == EventStatus.CORRECTED - - def test_correct_nonexistent(self, event_store): - with pytest.raises(ValueError, match="not found"): - event_store.correct("nonexistent", "new content") - - def test_correct_already_corrected(self, event_store): - e = event_store.add(content="old", user_id="u1") - event_store.correct(e.event_id, "new") - with pytest.raises(ValueError, match="Cannot correct"): - event_store.correct(e.event_id, "newer") - - def test_delete(self, event_store): - e = event_store.add(content="to delete", user_id="u1") - assert event_store.delete(e.event_id) is True - deleted = event_store.get(e.event_id) - assert deleted.status == EventStatus.DELETED - - def test_delete_idempotent(self, event_store): - e = event_store.add(content="to delete", user_id="u1") - assert event_store.delete(e.event_id) is True - assert event_store.delete(e.event_id) is False # already deleted - - def test_delete_nonexistent(self, event_store): - with pytest.raises(ValueError, match="not found"): - event_store.delete("nonexistent") - - def test_list_by_user(self, event_store): - event_store.add(content="fact1", user_id="u1") - event_store.add(content="fact2", user_id="u1") - event_store.add(content="fact3", user_id="u2") - - u1_events = event_store.list_by_user("u1") - assert len(u1_events) == 2 - - u2_events = event_store.list_by_user("u2") - assert len(u2_events) == 1 - - def test_list_by_user_with_status(self, event_store): - e1 = event_store.add(content="active", user_id="u1") - e2 = event_store.add(content="will delete", user_id="u1") - event_store.delete(e2.event_id) - - active = event_store.list_by_user("u1", status=EventStatus.ACTIVE) - assert len(active) == 1 - assert active[0].content == "active" - - def test_supersedes_chain(self, event_store): - e1 = event_store.add(content="v1", user_id="u1") - e2 = event_store.correct(e1.event_id, "v2") - # Can't correct e1 again (it's already corrected), but we can - # correct e2 to get a 3-step chain - e3 = event_store.correct(e2.event_id, "v3") - - chain = event_store.get_supersedes_chain(e3.event_id) - assert len(chain) == 3 - assert chain[0].content == "v3" - assert chain[1].content == "v2" - assert chain[2].content == "v1" - - def test_count(self, event_store): - event_store.add(content="a", user_id="u1") - event_store.add(content="b", user_id="u1") - event_store.add(content="c", user_id="u1") - - assert event_store.count("u1") == 3 - assert event_store.count("u1", status=EventStatus.ACTIVE) == 3 - assert event_store.count("u2") == 0 - - def test_dedup_after_delete(self, event_store): - """Deleted content should not block new addition of same content.""" - e1 = event_store.add(content="ephemeral", user_id="u1") - event_store.delete(e1.event_id) - # Adding same content again should create new event (old is deleted, not active) - e2 = event_store.add(content="ephemeral", user_id="u1") - assert e2.event_id != e1.event_id - - -# ========================================================================= -# BeliefStore Tests -# ========================================================================= - -class TestBeliefStore: - - def test_add_and_get(self, shared_conn): - conn, lock = shared_conn - store = BeliefStore(conn, lock) - - bid = store.add( - user_id="u1", - claim="Python is dynamically typed", - domain="programming", - confidence=0.8, - ) - belief = store.get(bid) - assert belief is not None - assert belief["claim"] == "Python is dynamically typed" - assert belief["domain"] == "programming" - assert belief["confidence"] == 0.8 - assert belief["status"] == "proposed" - - def test_update_confidence_auto_status(self, shared_conn): - conn, lock = shared_conn - store = BeliefStore(conn, lock) - - bid = store.add(user_id="u1", claim="test", confidence=0.5) - - # High confidence → held - store.update_confidence(bid, 0.9) - b = store.get(bid) - assert b["status"] == "held" - assert b["confidence"] == 0.9 - assert len(b["revisions"]) == 1 - - # Low confidence → retracted - store.update_confidence(bid, 0.05) - b = store.get(bid) - assert b["status"] == "retracted" - - def test_update_confidence_with_evidence(self, shared_conn): - conn, lock = shared_conn - store = BeliefStore(conn, lock) - - bid = store.add(user_id="u1", claim="test", confidence=0.5) - store.update_confidence( - bid, 0.7, - evidence={"content": "saw it in docs", "supports": True}, - revision_reason="documentation found", - ) - - b = store.get(bid) - assert len(b["evidence"]) == 1 - assert b["evidence"][0]["content"] == "saw it in docs" - - def test_contradiction(self, shared_conn): - conn, lock = shared_conn - store = BeliefStore(conn, lock) - - b1 = store.add(user_id="u1", claim="Earth is round", confidence=0.9) - b2 = store.add(user_id="u1", claim="Earth is flat", confidence=0.3) - - store.add_contradiction(b1, b2) - - belief1 = store.get(b1) - belief2 = store.get(b2) - assert b2 in belief1["contradicts_ids"] - assert b1 in belief2["contradicts_ids"] - assert belief1["status"] == "challenged" - assert belief2["status"] == "challenged" - - def test_list_by_user(self, shared_conn): - conn, lock = shared_conn - store = BeliefStore(conn, lock) - - store.add(user_id="u1", claim="a", domain="sci", confidence=0.9) - store.add(user_id="u1", claim="b", domain="sci", confidence=0.3) - store.add(user_id="u1", claim="c", domain="eng", confidence=0.7) - - sci = store.list_by_user("u1", domain="sci") - assert len(sci) == 2 - - high = store.list_by_user("u1", min_confidence=0.5) - assert len(high) == 2 - - def test_set_status(self, shared_conn): - conn, lock = shared_conn - store = BeliefStore(conn, lock) - - bid = store.add(user_id="u1", claim="test") - assert store.set_status(bid, "stale") is True - assert store.get(bid)["status"] == "stale" - - def test_get_by_invalidation_status(self, shared_conn): - conn, lock = shared_conn - store = BeliefStore(conn, lock) - - b1 = store.add(user_id="u1", claim="stale one") - store.set_status(b1, "stale") - b2 = store.add(user_id="u1", claim="active one") - - stale = store.get_by_invalidation_status("stale") - assert len(stale) == 1 - assert stale[0]["belief_id"] == b1 - - -# ========================================================================= -# PolicyStore Tests -# ========================================================================= - -class TestPolicyStore: - - def test_add_and_get(self, shared_conn): - conn, lock = shared_conn - store = PolicyStore(conn, lock) - - pid = store.add( - user_id="u1", - name="Use git blame first", - condition={"task_types": ["bug_fix"]}, - action={"approach": "Run git blame on failing file"}, - granularity="task", - ) - policy = store.get(pid) - assert policy is not None - assert policy["name"] == "Use git blame first" - assert policy["granularity"] == "task" - assert policy["status"] == "proposed" - assert policy["condition"]["task_types"] == ["bug_fix"] - - def test_record_outcome_success(self, shared_conn): - conn, lock = shared_conn - store = PolicyStore(conn, lock) - - pid = store.add( - user_id="u1", name="test", - condition={}, action={}, - ) - - for _ in range(5): - store.record_outcome(pid, success=True, baseline_score=0.5, actual_score=0.8) - - p = store.get(pid) - assert p["apply_count"] == 5 - assert p["success_count"] == 5 - assert p["failure_count"] == 0 - assert p["status"] == "validated" # win_rate >= 0.6 after 5+ - assert p["utility"] > 0 - - def test_record_outcome_deprecated(self, shared_conn): - conn, lock = shared_conn - store = PolicyStore(conn, lock) - - pid = store.add( - user_id="u1", name="bad policy", - condition={}, action={}, - ) - - for _ in range(6): - store.record_outcome(pid, success=False) - - p = store.get(pid) - assert p["status"] == "deprecated" - - def test_record_outcome_utility_ema(self, shared_conn): - conn, lock = shared_conn - store = PolicyStore(conn, lock) - - pid = store.add( - user_id="u1", name="ema test", - condition={}, action={}, - ) - - store.record_outcome(pid, success=True, baseline_score=0.5, actual_score=0.9) - p = store.get(pid) - # First delta = 0.4, utility = 0.3 * 0.4 + 0.7 * 0.0 = 0.12 - assert abs(p["utility"] - 0.12) < 0.01 - - def test_list_by_user(self, shared_conn): - conn, lock = shared_conn - store = PolicyStore(conn, lock) - - store.add(user_id="u1", name="p1", condition={}, action={}, granularity="task") - store.add(user_id="u1", name="p2", condition={}, action={}, granularity="step") - - task_policies = store.list_by_user("u1", granularity="task") - assert len(task_policies) == 1 - - -# ========================================================================= -# AnchorStore Tests -# ========================================================================= - -class TestAnchorStore: - - def test_add_and_get(self, shared_conn): - conn, lock = shared_conn - store = AnchorStore(conn, lock) - - aid = store.add( - user_id="u1", - era="bengaluru_work", - place="Bengaluru", - place_type="city", - activity="coding", - ) - anchor = store.get(aid) - assert anchor is not None - assert anchor["era"] == "bengaluru_work" - assert anchor["place"] == "Bengaluru" - assert anchor["activity"] == "coding" - - def test_get_by_event(self, shared_conn): - conn, lock = shared_conn - store = AnchorStore(conn, lock) - - aid = store.add(user_id="u1", memory_event_id="evt-123") - found = store.get_by_event("evt-123") - assert found is not None - assert found["anchor_id"] == aid - - def test_list_filtered(self, shared_conn): - conn, lock = shared_conn - store = AnchorStore(conn, lock) - - store.add(user_id="u1", era="school", place="Ghazipur") - store.add(user_id="u1", era="work", place="Bengaluru") - store.add(user_id="u1", era="school", place="Delhi") - - school = store.list_by_user("u1", era="school") - assert len(school) == 2 - - blr = store.list_by_user("u1", place="Bengaluru") - assert len(blr) == 1 - - def test_update_fields(self, shared_conn): - conn, lock = shared_conn - store = AnchorStore(conn, lock) - - aid = store.add(user_id="u1", era="old_era") - assert store.update_fields(aid, era="new_era") is True - assert store.get(aid)["era"] == "new_era" - - def test_update_fields_rejects_invalid(self, shared_conn): - conn, lock = shared_conn - store = AnchorStore(conn, lock) - - aid = store.add(user_id="u1") - # user_id is not in the allowed update set - assert store.update_fields(aid, user_id="hacker") is False - - -# ========================================================================= -# InsightStore Tests -# ========================================================================= - -class TestInsightStore: - - def test_add_and_get(self, shared_conn): - conn, lock = shared_conn - store = InsightStore(conn, lock) - - iid = store.add( - user_id="u1", - content="Strict criteria + balanced scoring works best", - insight_type="strategy", - confidence=0.7, - ) - insight = store.get(iid) - assert insight["content"] == "Strict criteria + balanced scoring works best" - assert insight["insight_type"] == "strategy" - - def test_record_outcome_success(self, shared_conn): - conn, lock = shared_conn - store = InsightStore(conn, lock) - - iid = store.add(user_id="u1", content="test", confidence=0.5) - store.record_outcome(iid, success=True) - i = store.get(iid) - assert i["validation_count"] == 1 - assert i["confidence"] == 0.55 # 0.5 + 0.05 - - def test_record_outcome_failure(self, shared_conn): - conn, lock = shared_conn - store = InsightStore(conn, lock) - - iid = store.add(user_id="u1", content="test", confidence=0.5) - store.record_outcome(iid, success=False) - i = store.get(iid) - assert i["invalidation_count"] == 1 - assert i["confidence"] == 0.4 # 0.5 - 0.1 - - def test_record_outcome_with_scores(self, shared_conn): - conn, lock = shared_conn - store = InsightStore(conn, lock) - - iid = store.add(user_id="u1", content="test", confidence=0.5) - store.record_outcome(iid, success=True, baseline_score=0.5, actual_score=0.8) - i = store.get(iid) - # utility = 0.3 * 0.3 + 0.7 * 0.0 = 0.09 - assert abs(i["utility"] - 0.09) < 0.01 - - def test_list_by_type(self, shared_conn): - conn, lock = shared_conn - store = InsightStore(conn, lock) - - store.add(user_id="u1", content="a", insight_type="warning") - store.add(user_id="u1", content="b", insight_type="strategy") - store.add(user_id="u1", content="c", insight_type="warning") - - warnings = store.list_by_user("u1", insight_type="warning") - assert len(warnings) == 2 - - -# ========================================================================= -# HeuristicStore Tests -# ========================================================================= - -class TestHeuristicStore: - - def test_add_and_get(self, shared_conn): - conn, lock = shared_conn - store = HeuristicStore(conn, lock) - - hid = store.add( - user_id="u1", - content="Start with the most constrained component", - abstraction_level="universal", - ) - h = store.get(hid) - assert h["content"] == "Start with the most constrained component" - assert h["abstraction_level"] == "universal" - - def test_record_outcome(self, shared_conn): - conn, lock = shared_conn - store = HeuristicStore(conn, lock) - - hid = store.add(user_id="u1", content="test", confidence=0.5) - store.record_outcome(hid, success=True, baseline_score=0.4, actual_score=0.7) - - h = store.get(hid) - assert h["validation_count"] == 1 - assert h["confidence"] == 0.55 - assert abs(h["utility"] - 0.09) < 0.01 # 0.3 * 0.3 + 0.7 * 0 - - def test_list_by_level(self, shared_conn): - conn, lock = shared_conn - store = HeuristicStore(conn, lock) - - store.add(user_id="u1", content="a", abstraction_level="specific") - store.add(user_id="u1", content="b", abstraction_level="domain") - store.add(user_id="u1", content="c", abstraction_level="universal") - - universal = store.list_by_user("u1", abstraction_level="universal") - assert len(universal) == 1 - - -# ========================================================================= -# DerivedLineageStore Tests -# ========================================================================= - -class TestDerivedLineageStore: - - def test_add_and_get_sources(self, shared_conn): - conn, lock = shared_conn - store = DerivedLineageStore(conn, lock) - - store.add("belief", "b1", "evt1", contribution_weight=0.6) - store.add("belief", "b1", "evt2", contribution_weight=0.4) - - sources = store.get_sources("belief", "b1") - assert len(sources) == 2 - weights = {s["source_event_id"]: s["contribution_weight"] for s in sources} - assert weights["evt1"] == 0.6 - assert weights["evt2"] == 0.4 - - def test_get_dependents(self, shared_conn): - conn, lock = shared_conn - store = DerivedLineageStore(conn, lock) - - store.add("belief", "b1", "evt1") - store.add("policy", "p1", "evt1") - store.add("belief", "b2", "evt2") - - deps = store.get_dependents("evt1") - assert len(deps) == 2 - types = {d["derived_type"] for d in deps} - assert types == {"belief", "policy"} - - def test_add_batch(self, shared_conn): - conn, lock = shared_conn - store = DerivedLineageStore(conn, lock) - - ids = store.add_batch( - "insight", "i1", - ["evt1", "evt2", "evt3"], - weights=[0.5, 0.3, 0.2], - ) - assert len(ids) == 3 - assert store.get_source_count("insight", "i1") == 3 - - def test_contribution_weight(self, shared_conn): - conn, lock = shared_conn - store = DerivedLineageStore(conn, lock) - - store.add("belief", "b1", "evt1", contribution_weight=0.7) - w = store.get_contribution_weight("belief", "b1", "evt1") - assert w == 0.7 - - # Nonexistent - assert store.get_contribution_weight("belief", "b1", "evt999") is None - - def test_delete_for_derived(self, shared_conn): - conn, lock = shared_conn - store = DerivedLineageStore(conn, lock) - - store.add("belief", "b1", "evt1") - store.add("belief", "b1", "evt2") - assert store.get_source_count("belief", "b1") == 2 - - deleted = store.delete_for_derived("belief", "b1") - assert deleted == 2 - assert store.get_source_count("belief", "b1") == 0 - - -# ========================================================================= -# CognitionStore Integration Tests -# ========================================================================= - -class TestCognitionStore: - - def test_full_lifecycle(self, cognition_store): - """End-to-end: raw event → derived belief → lineage → invalidation status.""" - cs = cognition_store - - # 1. Store raw event - event = cs.events.add(content="User prefers dark mode", user_id="u1") - assert event.status == EventStatus.ACTIVE - - # 2. Derive a belief from it - bid = cs.beliefs.add( - user_id="u1", - claim="User prefers dark mode", - domain="preferences", - confidence=0.8, - source_memory_ids=[event.event_id], - ) - - # 3. Record lineage - cs.lineage.add("belief", bid, event.event_id, contribution_weight=1.0) - - # 4. Verify lineage - sources = cs.lineage.get_sources("belief", bid) - assert len(sources) == 1 - assert sources[0]["source_event_id"] == event.event_id - - deps = cs.lineage.get_dependents(event.event_id) - assert len(deps) == 1 - assert deps[0]["derived_id"] == bid - - # 5. Correct the raw event - correction = cs.events.correct( - event.event_id, "User prefers light mode" - ) - assert correction.supersedes_event_id == event.event_id - - # 6. Mark belief as stale (soft invalidation would do this) - cs.beliefs.set_status(bid, "stale") - stale = cs.beliefs.get_by_invalidation_status("stale") - assert len(stale) == 1 - - def test_policy_lifecycle(self, cognition_store): - cs = cognition_store - - # Create event + policy + lineage - e = cs.events.add(content="git blame helped find bug", user_id="u1") - pid = cs.policies.add( - user_id="u1", - name="git blame first", - condition={"task_types": ["bug_fix"]}, - action={"approach": "Run git blame"}, - ) - cs.lineage.add("policy", pid, e.event_id) - - # Record outcomes → validated - for _ in range(6): - cs.policies.record_outcome(pid, success=True, baseline_score=0.4, actual_score=0.8) - - p = cs.policies.get(pid) - assert p["status"] == "validated" - assert p["utility"] > 0 - - def test_multi_type_lineage(self, cognition_store): - """One raw event feeds multiple derived types.""" - cs = cognition_store - - e = cs.events.add(content="auth tokens expire in prod", user_id="u1") - - bid = cs.beliefs.add(user_id="u1", claim="Tokens expire in prod") - pid = cs.policies.add( - user_id="u1", name="check token expiry", - condition={}, action={}, - ) - iid = cs.insights.add(user_id="u1", content="Token expiry causes outages") - - cs.lineage.add("belief", bid, e.event_id) - cs.lineage.add("policy", pid, e.event_id) - cs.lineage.add("insight", iid, e.event_id) - - deps = cs.lineage.get_dependents(e.event_id) - assert len(deps) == 3 - types = {d["derived_type"] for d in deps} - assert types == {"belief", "policy", "insight"} - - def test_shared_connection(self, cognition_store): - """All stores share the same connection — verify cross-store visibility.""" - cs = cognition_store - - # Write through events store - e = cs.events.add(content="test visibility", user_id="u1") - - # Write through beliefs store - bid = cs.beliefs.add(user_id="u1", claim="test") - - # Lineage can see both (same connection) - cs.lineage.add("belief", bid, e.event_id) - sources = cs.lineage.get_sources("belief", bid) - assert len(sources) == 1 diff --git a/tests/test_vector_store_factory.py b/tests/test_vector_store_factory.py new file mode 100644 index 0000000..2c2238d --- /dev/null +++ b/tests/test_vector_store_factory.py @@ -0,0 +1,14 @@ +from dhee.utils.factory import _normalize_sqlite_vec_config + + +def test_normalize_sqlite_vec_config_keeps_file_paths(): + cfg = {"path": "/tmp/dhee/sqlite_vec.db", "collection_name": "x"} + normalized = _normalize_sqlite_vec_config(cfg) + assert normalized["path"] == "/tmp/dhee/sqlite_vec.db" + assert normalized["collection_name"] == "x" + + +def test_normalize_sqlite_vec_config_converts_directory_paths(): + cfg = {"path": "/tmp/dhee/zvec"} + normalized = _normalize_sqlite_vec_config(cfg) + assert normalized["path"] == "/tmp/dhee/zvec/sqlite_vec.db"