diff --git a/README.md b/README.md index 52b484f..9d26cac 100644 --- a/README.md +++ b/README.md @@ -276,8 +276,20 @@ pip install dhee[ollama,mcp] # Ollama (local inference, no API costs) ```bash git clone https://github.com/Sankhya-AI/Dhee.git cd Dhee -pip install -e ".[dev]" + +./scripts/bootstrap_dev_env.sh +source .venv-dhee/bin/activate + +# optional if you prefer manual bootstrap: +# python3 -m venv .venv-dhee +# .venv-dhee/bin/python -m pip install -e ./dhee-accel -e ./engram-bus -e ".[dev]" + pytest + +# live vendor-backed suites are explicit opt-in: +# DHEE_RUN_LIVE_TESTS=1 pytest -q tests/test_e2e_all_features.py tests/test_power_packages.py + +# manual smoke scripts live under scripts/manual/ ``` --- diff --git a/dhee/__init__.py b/dhee/__init__.py index 43196ad..3172da0 100644 --- a/dhee/__init__.py +++ b/dhee/__init__.py @@ -13,6 +13,7 @@ d.checkpoint("Fixed it", what_worked="git blame first") Memory Classes: + Engram — batteries-included memory interface with sensible defaults CoreMemory — lightweight: add/search/delete + decay (no LLM) SmartMemory — + echo encoding, categories, knowledge graph (needs LLM) FullMemory — + scenes, profiles, orchestration, cognition (everything) @@ -22,7 +23,7 @@ from dhee.memory.core import CoreMemory from dhee.memory.smart import SmartMemory from dhee.memory.main import FullMemory -from dhee.simple import Dhee +from dhee.simple import Dhee, Engram from dhee.adapters.base import DheePlugin from dhee.core.category import CategoryProcessor, Category, CategoryType, CategoryMatch from dhee.core.echo import EchoProcessor, EchoDepth, EchoResult @@ -31,9 +32,10 @@ # Default: CoreMemory (lightest, zero-config) Memory = CoreMemory -__version__ = "3.0.0" +__version__ = "3.0.1" __all__ = [ # Memory classes + "Engram", "CoreMemory", "SmartMemory", "FullMemory", diff --git a/dhee/adapters/base.py b/dhee/adapters/base.py index d7cf824..37912ac 100644 --- a/dhee/adapters/base.py +++ b/dhee/adapters/base.py @@ -36,6 +36,8 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union +from dhee.checkpoint_runtime import run_checkpoint_common + logger = logging.getLogger(__name__) @@ -118,6 +120,11 @@ def kernel(self): """Access the CognitionKernel for direct state manipulation.""" return self._kernel + @property + def memory(self): + """Expose the configured runtime memory engine for advanced integrations.""" + return self._engram.memory + # ------------------------------------------------------------------ # Hook registry # ------------------------------------------------------------------ @@ -153,6 +160,13 @@ def _fire_hooks(self, event: str, data: Any) -> None: except Exception: logger.debug("Hook %s failed", event) + @staticmethod + def _health_error(component: str, exc: Exception) -> Dict[str, str]: + return { + "component": component, + "error": f"{type(exc).__name__}: {exc}", + } + # ------------------------------------------------------------------ # Tool 1: remember # ------------------------------------------------------------------ @@ -257,7 +271,7 @@ def context( hyper_ctx = self._buddhi.get_hyper_context( user_id=uid, task_description=task_description, - memory=self._engram._memory, + memory=self.memory, ) if operational: result = hyper_ctx.to_operational_dict() @@ -325,62 +339,29 @@ def checkpoint( if not what_worked: what_worked = outcome.get("what_worked") - result: Dict[str, Any] = {} - score = max(0.0, min(1.0, float(outcome_score))) if outcome_score is not None else None - - # 1. Session digest - try: - from dhee.core.kernel import save_session_digest - digest = save_session_digest( - task_summary=summary, agent_id=agent_id, repo=repo, - status=status, decisions_made=decisions, - files_touched=files_touched, todos_remaining=todos, - ) - result["session_saved"] = True - if isinstance(digest, dict): - result["session_id"] = digest.get("session_id") - except Exception: - result["session_saved"] = False - - # 2. Batch enrichment - memory = self._engram._memory - if hasattr(memory, "enrich_pending"): - try: - enrich_result = memory.enrich_pending( - user_id=uid, batch_size=10, max_batches=5, - ) - enriched = enrich_result.get("enriched_count", 0) - if enriched > 0: - result["memories_enriched"] = enriched - except Exception: - pass - - # 3. Outcome recording - if task_type and score is not None: - insight = self._buddhi.record_outcome( - user_id=uid, task_type=task_type, score=score, - ) - result["outcome_recorded"] = True - if insight: - result["auto_insight"] = insight.to_dict() - - # 4. Insight synthesis - if any([what_worked, what_failed, key_decision]): - insights = self._buddhi.reflect( - user_id=uid, task_type=task_type or "general", - what_worked=what_worked, what_failed=what_failed, - key_decision=key_decision, - outcome_score=score if score is not None else None, - ) - result["insights_created"] = len(insights) - - # 5. Intention storage - if remember_to: - intention = self._buddhi.store_intention( - user_id=uid, description=remember_to, - trigger_keywords=trigger_keywords, - ) - result["intention_stored"] = intention.to_dict() + result = run_checkpoint_common( + logger=logger, + log_prefix="Plugin checkpoint", + user_id=uid, + summary=summary, + status=status, + agent_id=agent_id, + repo=repo, + decisions=decisions, + files_touched=files_touched, + todos=todos, + task_type=task_type, + outcome_score=outcome_score, + what_worked=what_worked, + what_failed=what_failed, + key_decision=key_decision, + remember_to=remember_to, + trigger_keywords=trigger_keywords, + enrich_pending_fn=self._engram.enrich_pending, + record_outcome_fn=self._buddhi.record_outcome, + reflect_fn=self._buddhi.reflect, + store_intention_fn=self._buddhi.store_intention, + ) # 6. Episode closure (via kernel) ep_result = self._kernel.record_checkpoint_event( @@ -441,8 +422,8 @@ def session_start( task_description=task_description or "session", task_type=task_type or "general", ) - except Exception: - pass + except Exception as exc: + logger.warning("Session start episode initialization failed: %s", exc, exc_info=True) ctx = self.context(task_description=task_description, user_id=uid) return self._render_system_prompt(ctx, task_description) @@ -479,17 +460,19 @@ def _handle_tracker_signals(self, signals: Dict[str, Any], user_id: str) -> None if signals.get("needs_auto_checkpoint"): args = signals.get("auto_checkpoint_args", {}) try: - self.checkpoint(user_id=user_id, **args) - except Exception: - pass + checkpoint_result = self.checkpoint(user_id=user_id, **args) + for warning in checkpoint_result.get("warnings", []): + logger.warning("Plugin auto-checkpoint warning: %s", warning) + except Exception as exc: + logger.warning("Plugin auto-checkpoint failed: %s", exc, exc_info=True) # Auto-context for new session if signals.get("needs_auto_context"): task = signals.get("inferred_task") try: self.context(task_description=task, user_id=user_id) - except Exception: - pass + except Exception as exc: + logger.warning("Plugin auto-context failed: %s", exc, exc_info=True) # ------------------------------------------------------------------ # Cognition health (harness monitoring) @@ -503,6 +486,7 @@ def cognition_health(self, user_id: Optional[str] = None) -> Dict[str, Any]: """ uid = user_id or self._user_id health: Dict[str, Any] = {} + errors: List[Dict[str, str]] = [] health["kernel"] = self._kernel.get_stats() health["buddhi"] = self._buddhi.get_stats() @@ -513,16 +497,45 @@ def cognition_health(self, user_id: Optional[str] = None) -> Dict[str, Any]: low_util = [p for p in policies if p.utility < -0.2 and p.apply_count >= 3] if low_util: warnings.append(f"{len(low_util)} policies with negative utility") + except Exception as exc: + logger.warning( + "Cognition health derivation failed for policies: %s", + exc, + exc_info=True, + ) + errors.append(self._health_error("policies.get_user_policies", exc)) + + try: active_intentions = self._kernel.intentions.get_active(uid) if len(active_intentions) > 20: - warnings.append(f"{len(active_intentions)} active intentions (consider cleanup)") + warnings.append( + f"{len(active_intentions)} active intentions (consider cleanup)" + ) + except Exception as exc: + logger.warning( + "Cognition health derivation failed for intentions: %s", + exc, + exc_info=True, + ) + errors.append(self._health_error("intentions.get_active", exc)) + + try: contradictions = self._kernel.beliefs.get_contradictions(uid) if len(contradictions) > 5: - warnings.append(f"{len(contradictions)} unresolved belief contradictions") - except Exception: - pass + warnings.append( + f"{len(contradictions)} unresolved belief contradictions" + ) + except Exception as exc: + logger.warning( + "Cognition health derivation failed for contradictions: %s", + exc, + exc_info=True, + ) + errors.append(self._health_error("beliefs.get_contradictions", exc)) health["warnings"] = warnings + if errors: + health["errors"] = errors return health # ------------------------------------------------------------------ @@ -640,14 +653,37 @@ def end_trajectory( # Store trajectory as memory for skill mining try: from dhee.skills.trajectory import TrajectoryStore - store = TrajectoryStore(memory=self._engram._memory) + store = TrajectoryStore(memory=self.memory) store.save(trajectory) result["stored"] = True - except Exception: + except Exception as exc: + logger.warning("Trajectory persistence failed: %s", exc, exc_info=True) result["stored"] = False + result["storage_error"] = str(exc) return result + def close(self) -> None: + """Flush cognition state and release runtime resources.""" + errors: List[str] = [] + + try: + self._buddhi.flush() + except Exception as exc: + logger.exception("DheePlugin close failed for buddhi.flush") + errors.append(f"buddhi.flush: {type(exc).__name__}: {exc}") + + try: + self._engram.close() + except Exception as exc: + logger.exception("DheePlugin close failed for engram.close") + errors.append(f"engram.close: {type(exc).__name__}: {exc}") + + if errors: + raise RuntimeError( + "Failed to close DheePlugin resources: " + "; ".join(errors) + ) + # ------------------------------------------------------------------ # Framework export: OpenAI function calling # ------------------------------------------------------------------ diff --git a/dhee/api/app.py b/dhee/api/app.py index d5fede4..f49f382 100644 --- a/dhee/api/app.py +++ b/dhee/api/app.py @@ -95,7 +95,10 @@ async def handoff_checkpoint(request: CheckpointRequest): bus = _get_bus() # Find or create a session for this agent - session = bus.get_session(agent_id=request.agent_id) + session = bus.get_session( + agent_id=request.agent_id, + repo=request.repo_path, + ) if session is None: sid = bus.save_session( agent_id=request.agent_id, diff --git a/dhee/checkpoint_runtime.py b/dhee/checkpoint_runtime.py new file mode 100644 index 0000000..ab0c237 --- /dev/null +++ b/dhee/checkpoint_runtime.py @@ -0,0 +1,121 @@ +"""Shared checkpoint runtime helpers for supported Dhee entrypoints.""" + +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, List, Optional + + +def _serialize_optional(value: Any) -> Any: + """Convert common model objects to plain dictionaries when available.""" + if value is None: + return None + to_dict = getattr(value, "to_dict", None) + if callable(to_dict): + return to_dict() + return value + + +def run_checkpoint_common( + *, + logger: logging.Logger, + log_prefix: str, + user_id: str, + summary: str, + status: str, + agent_id: str, + repo: Optional[str], + decisions: Optional[List[str]], + files_touched: Optional[List[str]], + todos: Optional[List[str]], + task_type: Optional[str], + outcome_score: Optional[float], + what_worked: Optional[str], + what_failed: Optional[str], + key_decision: Optional[str], + remember_to: Optional[str], + trigger_keywords: Optional[List[str]], + enrich_pending_fn: Callable[..., Dict[str, Any]], + record_outcome_fn: Callable[..., Any], + reflect_fn: Callable[..., List[Any]], + store_intention_fn: Callable[..., Any], +) -> Dict[str, Any]: + """Execute the shared checkpoint side-effects for Dhee entrypoints.""" + result: Dict[str, Any] = {} + warnings: List[str] = [] + + clamped_score = None + if outcome_score is not None: + clamped_score = max(0.0, min(1.0, float(outcome_score))) + + try: + from dhee.core.kernel import save_session_digest + + digest = save_session_digest( + task_summary=summary, + agent_id=agent_id, + repo=repo, + status=status, + decisions_made=decisions, + files_touched=files_touched, + todos_remaining=todos, + ) + result["session_saved"] = True + if isinstance(digest, dict): + result["session_id"] = digest.get("session_id") + except Exception as exc: + logger.warning("%s session digest save failed: %s", log_prefix, exc, exc_info=True) + result["session_saved"] = False + result["session_save_error"] = str(exc) + warnings.append(f"session_save_failed: {exc}") + + try: + enrich_result = enrich_pending_fn( + user_id=user_id, + batch_size=10, + max_batches=5, + ) + enriched = enrich_result.get("enriched_count", 0) + if enriched > 0: + result["memories_enriched"] = enriched + except Exception as exc: + logger.warning("%s deferred enrichment failed: %s", log_prefix, exc, exc_info=True) + result["enrichment_error"] = str(exc) + warnings.append(f"deferred_enrichment_failed: {exc}") + + if task_type and clamped_score is not None: + insight = record_outcome_fn( + user_id=user_id, + task_type=task_type, + score=clamped_score, + ) + result["outcome_recorded"] = True + insight_payload = _serialize_optional(insight) + if insight_payload: + result["auto_insight"] = insight_payload + + if any([what_worked, what_failed, key_decision]): + insights = reflect_fn( + user_id=user_id, + task_type=task_type or "general", + what_worked=what_worked, + what_failed=what_failed, + key_decision=key_decision, + outcome_score=clamped_score, + ) + result["insights_created"] = len(insights) + + if remember_to: + intention = store_intention_fn( + user_id=user_id, + description=remember_to, + trigger_keywords=trigger_keywords, + ) + intention_payload = _serialize_optional(intention) + if intention_payload: + result["intention_stored"] = intention_payload + + if warnings: + result["warnings"] = warnings + + return result diff --git a/dhee/core/buddhi.py b/dhee/core/buddhi.py index d38740a..006ac6a 100644 --- a/dhee/core/buddhi.py +++ b/dhee/core/buddhi.py @@ -194,6 +194,7 @@ class HyperContext: critical_blockers: List[str] = field(default_factory=list) contradictions: List[Dict[str, Any]] = field(default_factory=list) action_items: List[str] = field(default_factory=list) + state_errors: List[str] = field(default_factory=list) def to_dict(self) -> Dict[str, Any]: return { @@ -221,6 +222,7 @@ def to_dict(self) -> Dict[str, Any]: "critical_blockers": self.critical_blockers[:5], "contradictions": self.contradictions[:5], "action_items": self.action_items[:10], + "state_errors": self.state_errors[:10], "meta": { "n_insights": len(self.insights), "n_active_intentions": len(self.intentions), @@ -235,6 +237,7 @@ def to_dict(self) -> Dict[str, Any]: "n_action_items": len(self.action_items), "n_critical_blockers": len(self.critical_blockers), "performance_tracked": len(self.performance) > 0, + "n_state_errors": len(self.state_errors), }, } @@ -257,6 +260,8 @@ def to_operational_dict(self) -> Dict[str, Any]: result["warnings"] = self.warnings[:3] if self.contradictions: result["contradictions"] = self.contradictions[:3] + if self.state_errors: + result["state_errors"] = self.state_errors[:3] return result @@ -330,6 +335,32 @@ def _get_meta_buddhi(self): ) return self._meta_buddhi + @staticmethod + def _degradation_message(component: str, exc: Exception) -> str: + return f"{component}: {type(exc).__name__}: {exc}" + + def _record_context_degradation( + self, + state_errors: List[str], + warnings: List[str], + *, + component: str, + exc: Exception, + warning_prefix: str = "Context assembly degraded", + include_traceback: bool = True, + ) -> None: + message = self._degradation_message(component, exc) + logger.warning("%s", message, exc_info=include_traceback) + state_errors.append(message) + warnings.append(f"{warning_prefix}: {message}") + + @staticmethod + def _stats_error(component: str, exc: Exception) -> Dict[str, str]: + return { + "component": component, + "error": f"{type(exc).__name__}: {exc}", + } + # Deprecated forwarders — use self._kernel.* directly. # Kept for backward compat with test_cognition_v3.py. @@ -360,13 +391,20 @@ def get_hyper_context( Called at session start or when context is needed. Returns everything: performance, insights, skills, intentions, warnings. """ + context_errors: List[str] = [] + # 1. Last session (via kernel handoff, not memory object) last_session = None try: from dhee.core.kernel import get_last_session last_session = get_last_session() - except Exception: - pass + except Exception as exc: + self._record_context_degradation( + context_errors, + [], + component="handoff.get_last_session", + exc=exc, + ) # 2. Performance snapshots for relevant task types performance = self._get_performance_snapshots(user_id, task_description) @@ -390,14 +428,19 @@ def get_hyper_context( } for r in (results if isinstance(results, list) else []) ] - except Exception: - pass - - # 5. Check pending intentions (via kernel) - triggered = self._kernel.intentions.check_triggers(user_id, task_description) + except Exception as exc: + self._record_context_degradation( + context_errors, + [], + component="skills.search", + exc=exc, + ) - # 6. Generate proactive warnings + # 5. Generate proactive warnings warnings = self._generate_warnings(performance, insights) + warnings.extend( + [f"Context assembly degraded: {message}" for message in context_errors] + ) # 7. Top memories memories = [] @@ -411,8 +454,13 @@ def get_hyper_context( else: result = memory.get_all(user_id=user_id, limit=10) memories = result.get("results", []) - except Exception: - pass + except Exception as exc: + self._record_context_degradation( + context_errors, + warnings, + component="memory.search", + exc=exc, + ) # 8. Track query sequence (for future pattern prediction) if task_description: @@ -429,8 +477,13 @@ def get_hyper_context( task_description or "", user_id=user_id, limit=5, ) contrasts = [p.to_compact() for p in pairs] - except Exception: - pass + except Exception as exc: + self._record_context_degradation( + context_errors, + warnings, + component="contrastive.retrieve_contrasts", + exc=exc, + ) # 10. Heuristics (Phase 2: ERL pattern) heuristics = [] @@ -440,18 +493,51 @@ def get_hyper_context( task_description or "", user_id=user_id, limit=5, ) heuristics = [h.to_compact() for h in relevant] - except Exception: - pass + except Exception as exc: + self._record_context_degradation( + context_errors, + warnings, + component="heuristics.retrieve_relevant", + exc=exc, + ) - # 11-14. Cognitive state from kernel (episodes, tasks, policies, beliefs) - cog_state = self._kernel.get_cognitive_state(user_id, task_description) + # 6-10. Cognitive state from kernel (episodes, tasks, policies, beliefs) + try: + cog_state = self._kernel.get_cognitive_state(user_id, task_description) + except Exception as exc: + self._record_context_degradation( + context_errors, + warnings, + component="kernel.get_cognitive_state", + exc=exc, + ) + cog_state = { + "episodes": [], + "task_states": [], + "policies": [], + "beliefs": [], + "triggered_intentions": [], + "belief_warnings": [], + "step_policies": [], + "state_errors": [], + "active_step": None, + } episodes = cog_state.get("episodes", []) task_states = cog_state.get("task_states", []) policies = cog_state.get("policies", []) beliefs = cog_state.get("beliefs", []) + triggered = cog_state.get("triggered_intentions", []) warnings.extend(cog_state.get("belief_warnings", [])) + state_error_messages = [ + f"{entry.get('component', 'unknown')}: {entry.get('error', 'unknown error')}" + for entry in cog_state.get("state_errors", []) + if isinstance(entry, dict) + ] + for message in state_error_messages: + warnings.append(f"Cognitive state degraded: {message}") + state_error_messages = context_errors + state_error_messages - # 15. Operational cognition packet (Phase 4) + # 11. Operational cognition packet (Phase 4) active_step_desc = cog_state.get("active_step") active_step = {"description": active_step_desc} if active_step_desc else None step_policies_list = cog_state.get("step_policies", []) @@ -475,8 +561,18 @@ def get_hyper_context( "confidence_b": round(b2.confidence, 2), "severity": round(1.0 - severity, 2), }) - except Exception: - pass + except Exception as exc: + self._record_context_degradation( + context_errors, + warnings, + component="beliefs.get_contradictions", + exc=exc, + ) + state_error_messages = context_errors + [ + message + for message in state_error_messages + if message not in context_errors + ] # Build prioritized action items action_items = [] @@ -515,6 +611,7 @@ def get_hyper_context( critical_blockers=critical_blockers, contradictions=contradictions_list, action_items=action_items, + state_errors=state_error_messages, ) # ------------------------------------------------------------------ @@ -831,14 +928,22 @@ def on_memory_stored( content=content[:500], memory_id=memory_id, ) - except Exception: - pass + except Exception as exc: + logger.warning( + "Buddhi on_memory_stored episode event failed: %s", + exc, + exc_info=True, + ) # 3. Belief creation for factual statements (via kernel) try: self._maybe_create_belief(content, user_id, memory_id) - except Exception: - pass + except Exception as exc: + logger.warning( + "Buddhi on_memory_stored belief creation failed: %s", + exc, + exc_info=True, + ) return intention @@ -988,8 +1093,12 @@ def reflect( task_type=task_type, user_id=user_id, ) - except Exception: - pass + except Exception as exc: + logger.warning( + "Buddhi reflect contrastive update failed: %s", + exc, + exc_info=True, + ) # Compute baseline from moving average for utility scoring (D2Skill) # Moved before heuristic/policy blocks so all can use it @@ -1001,8 +1110,12 @@ def reflect( if len(records) >= 2: recent = records[-min(10, len(records)):] baseline_score = sum(r["score"] for r in recent) / len(recent) - except Exception: - pass + except Exception as exc: + logger.warning( + "Buddhi reflect baseline computation failed: %s", + exc, + exc_info=True, + ) # Phase 2: Distill heuristic from what_worked if what_worked: @@ -1021,8 +1134,12 @@ def reflect( baseline_score=baseline_score, actual_score=outcome_score, ) - except Exception: - pass + except Exception as exc: + logger.warning( + "Buddhi reflect heuristic distillation failed: %s", + exc, + exc_info=True, + ) # Delegate cross-structure learning to kernel # Kernel handles: policy outcomes, step extraction, belief-policy decay, @@ -1049,8 +1166,12 @@ def reflect( self._kernel.beliefs.reinforce_belief( belief.id, what_worked, source="outcome", ) - except Exception: - pass + except Exception as exc: + logger.warning( + "Buddhi reflect belief reinforcement failed: %s", + exc, + exc_info=True, + ) if what_failed: try: @@ -1061,8 +1182,12 @@ def reflect( self._kernel.beliefs.challenge_belief( belief.id, what_failed, source="outcome", ) - except Exception: - pass + except Exception as exc: + logger.warning( + "Buddhi reflect belief challenge failed: %s", + exc, + exc_info=True, + ) # Buddhi-owned: insight utility tracking (buddhi owns insights) try: @@ -1076,8 +1201,12 @@ def reflect( actual_score=outcome_score, ) self._save_insights() - except Exception: - pass + except Exception as exc: + logger.warning( + "Buddhi reflect insight utility tracking failed: %s", + exc, + exc_info=True, + ) # Buddhi-owned: contrastive pair utility (buddhi owns contrastive store) try: @@ -1092,8 +1221,12 @@ def reflect( actual_score=outcome_score, ) store._save_all() - except Exception: - pass + except Exception as exc: + logger.warning( + "Buddhi reflect contrastive utility tracking failed: %s", + exc, + exc_info=True, + ) # Buddhi-owned: heuristic reinforcement from positive policy deltas if what_worked: @@ -1114,8 +1247,12 @@ def reflect( actual_score=outcome_score, ) distiller._save_all() - except Exception: - pass + except Exception as exc: + logger.warning( + "Buddhi reflect heuristic reinforcement failed: %s", + exc, + exc_info=True, + ) return new_insights @@ -1143,14 +1280,18 @@ def _validate_used_heuristics( actual_score=actual_score, ) distiller._save_all() - except Exception: - pass + except Exception as exc: + logger.warning( + "Buddhi heuristic validation failed: %s", + exc, + exc_info=True, + ) # ------------------------------------------------------------------ # Persistence # ------------------------------------------------------------------ - def _save_insights(self) -> None: + def _save_insights(self, *, strict: bool = False) -> None: path = os.path.join(self._data_dir, "insights.jsonl") try: with open(path, "w", encoding="utf-8") as f: @@ -1172,18 +1313,22 @@ def _save_insights(self) -> None: } f.write(json.dumps(row, ensure_ascii=False) + "\n") except OSError as e: + if strict: + raise logger.debug("Failed to save insights: %s", e) def _save_intentions(self) -> None: """Deprecated: intentions now managed by kernel IntentionStore.""" self._kernel.intentions.flush() - def _save_performance(self) -> None: + def _save_performance(self, *, strict: bool = False) -> None: path = os.path.join(self._data_dir, "performance.json") try: with open(path, "w", encoding="utf-8") as f: json.dump(self._performance, f, ensure_ascii=False) except OSError as e: + if strict: + raise logger.debug("Failed to save performance: %s", e) def _load_state(self) -> None: @@ -1230,50 +1375,81 @@ def _load_state(self) -> None: def flush(self) -> None: """Persist all state. Call on shutdown.""" - self._save_insights() - self._save_performance() + errors: List[Dict[str, str]] = [] - # Flush kernel (all state stores) - self._kernel.flush() + try: + self._save_insights(strict=True) + except Exception as exc: + logger.exception("Buddhi flush failed for insights") + errors.append(self._stats_error("insights.save", exc)) + + try: + self._save_performance(strict=True) + except Exception as exc: + logger.exception("Buddhi flush failed for performance") + errors.append(self._stats_error("performance.save", exc)) - # Flush Phase 2 subsystems if initialized - for store in [ - self._contrastive, self._heuristic_distiller, + try: + self._kernel.flush() + except Exception as exc: + logger.exception("Buddhi flush failed for kernel") + errors.append(self._stats_error("kernel.flush", exc)) + + for name, store in [ + ("contrastive.flush", self._contrastive), + ("heuristics.flush", self._heuristic_distiller), ]: if store and hasattr(store, "flush"): try: store.flush() - except Exception: - pass + except Exception as exc: + logger.exception("Buddhi flush failed for %s", name) + errors.append(self._stats_error(name, exc)) + + if errors: + detail = "; ".join( + f"{entry['component']}: {entry['error']}" for entry in errors + ) + raise RuntimeError(f"Failed to flush Buddhi state: {detail}") def get_stats(self) -> Dict[str, Any]: """Get buddhi status for health checks.""" - intention_stats = self._kernel.intentions.get_stats() - stats = { + stats: Dict[str, Any] = { "insights": len(self._insights), - "active_intentions": intention_stats.get("active", 0), - "triggered_intentions": intention_stats.get("triggered", 0), + "active_intentions": 0, + "triggered_intentions": 0, "task_types_tracked": len(self._performance), "total_performance_records": sum( len(v) for v in self._performance.values() ), } + errors: List[Dict[str, str]] = [] # Kernel state store stats kernel_stats = self._kernel.get_stats() + intention_stats = kernel_stats.get("intentions", {}) + if isinstance(intention_stats, dict) and "error" not in intention_stats: + stats["active_intentions"] = intention_stats.get("active", 0) + stats["triggered_intentions"] = intention_stats.get("triggered", 0) + for entry in kernel_stats.get("errors", []): + if isinstance(entry, dict): + errors.append(entry) stats.update(kernel_stats) # Phase 2 stats (only if initialized) for name, store in [ ("contrastive", self._contrastive), ("heuristics", self._heuristic_distiller), - ("contrastive", self._contrastive), - ("heuristics", self._heuristic_distiller), ]: if store and hasattr(store, "get_stats"): try: stats[name] = store.get_stats() - except Exception: - pass + except Exception as exc: + logger.exception("Buddhi get_stats failed for %s", name) + errors.append(self._stats_error(f"{name}.get_stats", exc)) + stats[name] = {"error": f"{type(exc).__name__}: {exc}"} + + if errors: + stats["errors"] = errors return stats diff --git a/dhee/core/cognition_kernel.py b/dhee/core/cognition_kernel.py index 30cecc4..ac80010 100644 --- a/dhee/core/cognition_kernel.py +++ b/dhee/core/cognition_kernel.py @@ -65,6 +65,69 @@ def __init__(self, data_dir: Optional[str] = None): # Cognitive state snapshot (for HyperContext assembly) # ------------------------------------------------------------------ + @staticmethod + def _state_error(component: str, exc: Exception) -> Dict[str, str]: + return { + "component": component, + "error": f"{type(exc).__name__}: {exc}", + } + + @staticmethod + def _operation_error( + operation: str, + component: str, + exc: Exception, + *, + target: Optional[str] = None, + ) -> Dict[str, str]: + entry = { + "operation": operation, + "component": component, + "error": f"{type(exc).__name__}: {exc}", + } + if target is not None: + entry["target"] = target + return entry + + def _record_operation_error( + self, + result: Dict[str, Any], + *, + operation: str, + component: str, + exc: Exception, + target: Optional[str] = None, + ) -> None: + if target is None: + logger.exception( + "CognitionKernel %s failed for %s", operation, component + ) + else: + logger.exception( + "CognitionKernel %s failed for %s (%s)", + operation, + component, + target, + ) + result.setdefault("errors", []).append( + self._operation_error( + operation, + component, + exc, + target=target, + ) + ) + + @staticmethod + def _merge_operation_errors( + result: Dict[str, Any], nested_result: Optional[Dict[str, Any]] + ) -> None: + if not nested_result: + return + nested_errors = nested_result.get("errors", []) + if nested_errors: + result.setdefault("errors", []).extend(nested_errors) + def get_cognitive_state( self, user_id: str, @@ -75,7 +138,8 @@ def get_cognitive_state( Returns a dict with episodes, task_states, policies, beliefs, triggered_intentions, and belief_warnings. """ - result: Dict[str, Any] = {} + result: Dict[str, Any] = {"state_errors": []} + state_errors: List[Dict[str, str]] = result["state_errors"] # Episodes try: @@ -85,7 +149,9 @@ def get_cognitive_state( limit=5, ) result["episodes"] = [ep.to_compact() for ep in recent_eps] - except Exception: + except Exception as exc: + logger.exception("Failed to load cognitive state component 'episodes'") + state_errors.append(self._state_error("episodes", exc)) result["episodes"] = [] # Task states + active step context @@ -106,7 +172,9 @@ def get_cognitive_state( if c not in task_states: task_states.append(c) result["task_states"] = task_states - except Exception: + except Exception as exc: + logger.exception("Failed to load cognitive state component 'task_states'") + state_errors.append(self._state_error("task_states", exc)) result["task_states"] = [] result["active_step"] = step_context if step_context else None @@ -121,7 +189,9 @@ def get_cognitive_state( limit=3, ) result["policies"] = [p.to_compact() for p in matched] - except Exception: + except Exception as exc: + logger.exception("Failed to load cognitive state component 'policies'") + state_errors.append(self._state_error("policies", exc)) result["policies"] = [] # Step policies (separate, for operational context) @@ -137,7 +207,9 @@ def get_cognitive_state( result["step_policies"] = [p.to_compact() for p in step_matched] else: result["step_policies"] = [] - except Exception: + except Exception as exc: + logger.exception("Failed to load cognitive state component 'step_policies'") + state_errors.append(self._state_error("step_policies", exc)) result["step_policies"] = [] # Beliefs @@ -157,7 +229,9 @@ def get_cognitive_state( f"Contradicting beliefs: '{b1.claim[:80]}' vs '{b2.claim[:80]}' " f"(confidence: {b1.confidence:.2f} vs {b2.confidence:.2f})" ) - except Exception: + except Exception as exc: + logger.exception("Failed to load cognitive state component 'beliefs'") + state_errors.append(self._state_error("beliefs", exc)) result["beliefs"] = [] result["belief_warnings"] = belief_warnings @@ -166,7 +240,9 @@ def get_cognitive_state( try: triggered = self.intentions.check_triggers(user_id, task_description) result["triggered_intentions"] = triggered - except Exception: + except Exception as exc: + logger.exception("Failed to load cognitive state component 'triggered_intentions'") + state_errors.append(self._state_error("triggered_intentions", exc)) result["triggered_intentions"] = [] return result @@ -194,30 +270,66 @@ def record_checkpoint_event( content=summary[:500], metadata={"status": status, "outcome_score": outcome_score}, ) + except Exception as exc: + self._record_operation_error( + result, + operation="record_checkpoint_event", + component="episodes.record_event", + exc=exc, + ) - # Wire episode.connection_count for cross-primitive links + # Wire episode.connection_count for cross-primitive links + connections = 0 + try: + active_task = self.tasks.get_active_task(user_id) + if active_task: + connections += 1 + except Exception as exc: + self._record_operation_error( + result, + operation="record_checkpoint_event", + component="tasks.get_active_task", + exc=exc, + ) + + try: + matched_policies = self.policies.match_policies( + user_id, summary[:50], summary[:200], limit=3, + ) + connections += len(matched_policies) + except Exception as exc: + self._record_operation_error( + result, + operation="record_checkpoint_event", + component="policies.match_policies", + exc=exc, + ) + + if connections > 0: try: - connections = 0 - active_task = self.tasks.get_active_task(user_id) - if active_task: - connections += 1 - matched_policies = self.policies.match_policies( - user_id, summary[:50], summary[:200], limit=3, + self.episodes.increment_connections(user_id, connections) + except Exception as exc: + self._record_operation_error( + result, + operation="record_checkpoint_event", + component="episodes.increment_connections", + exc=exc, ) - connections += len(matched_policies) - if connections > 0: - self.episodes.increment_connections(user_id, connections) - except Exception: - pass - if status == "completed": + if status == "completed": + try: episode = self.episodes.end_episode( user_id, outcome_score, summary ) if episode: result["episode_closed"] = episode.id - except Exception: - pass + except Exception as exc: + self._record_operation_error( + result, + operation="record_checkpoint_event", + component="episodes.end_episode", + exc=exc, + ) return result def update_task_on_checkpoint( @@ -238,11 +350,20 @@ def update_task_on_checkpoint( Replaces the scattered task logic in DheePlugin.checkpoint(). """ result: Dict[str, Any] = {} + active_task = None try: active_task = self.tasks.get_active_task(user_id) + except Exception as exc: + self._record_operation_error( + result, + operation="update_task_on_checkpoint", + component="tasks.get_active_task", + exc=exc, + ) - if goal or plan: - if not active_task or active_task.goal != (goal or active_task.goal): + if goal or plan: + if not active_task or active_task.goal != (goal or active_task.goal): + try: active_task = self.tasks.create_task( user_id=user_id, goal=goal or summary, @@ -252,15 +373,40 @@ def update_task_on_checkpoint( ) active_task.start() result["task_created"] = active_task.id - elif plan: + except Exception as exc: + self._record_operation_error( + result, + operation="update_task_on_checkpoint", + component="tasks.create_task", + exc=exc, + ) + elif plan: + try: active_task.set_plan(plan, plan_rationale) + except Exception as exc: + self._record_operation_error( + result, + operation="update_task_on_checkpoint", + component="tasks.set_plan", + exc=exc, + ) - if active_task: - if blockers: - for b in blockers: - active_task.add_blocker(b, severity="soft") + if active_task: + if blockers: + for blocker in blockers: + try: + active_task.add_blocker(blocker, severity="soft") + except Exception as exc: + self._record_operation_error( + result, + operation="update_task_on_checkpoint", + component="tasks.add_blocker", + exc=exc, + target=blocker, + ) - if status == "completed" and outcome_score is not None: + if status == "completed" and outcome_score is not None: + try: if outcome_score >= 0.5: active_task.complete( score=outcome_score, @@ -270,24 +416,52 @@ def update_task_on_checkpoint( else: active_task.fail(summary, evidence=outcome_evidence) result["task_completed"] = active_task.id + except Exception as exc: + self._record_operation_error( + result, + operation="update_task_on_checkpoint", + component="tasks.complete_or_fail", + exc=exc, + ) + try: self.tasks.update_task(active_task) + except Exception as exc: + self._record_operation_error( + result, + operation="update_task_on_checkpoint", + component="tasks.update_task", + exc=exc, + ) - # Record outcomes on STEP policies for completed/failed steps - if status == "completed" and active_task.plan: - for step in active_task.plan: - if step.status.value == "completed": - self.record_step_outcome( - user_id, task_type, step.description, - success=True, actual_score=outcome_score, - ) - elif step.status.value == "failed": - self.record_step_outcome( - user_id, task_type, step.description, - success=False, actual_score=outcome_score, - ) - except Exception: - pass + # Record outcomes on STEP policies for completed/failed steps + if status == "completed" and active_task.plan: + step_updates = 0 + for step in active_task.plan: + if step.status.value == "completed": + step_result = self.record_step_outcome( + user_id, + task_type, + step.description, + success=True, + actual_score=outcome_score, + ) + elif step.status.value == "failed": + step_result = self.record_step_outcome( + user_id, + task_type, + step.description, + success=False, + actual_score=outcome_score, + ) + else: + continue + + step_updates += step_result.get("policies_updated", 0) + self._merge_operation_errors(result, step_result) + + if step_updates: + result["step_policies_updated"] = step_updates return result def record_step_outcome( @@ -298,12 +472,13 @@ def record_step_outcome( success: bool, baseline_score: Optional[float] = None, actual_score: Optional[float] = None, - ) -> None: + ) -> Dict[str, Any]: """Record outcome on STEP policies matching a completed/failed step. Finds matching STEP policies and records their outcomes. Zero LLM calls. """ + result: Dict[str, Any] = {"policies_updated": 0} try: matched = self.policies.match_step_policies( user_id=user_id, @@ -312,15 +487,33 @@ def record_step_outcome( step_context=step_description, limit=5, ) - for policy in matched: + except Exception as exc: + self._record_operation_error( + result, + operation="record_step_outcome", + component="policies.match_step_policies", + exc=exc, + ) + return result + + for policy in matched: + try: self.policies.record_outcome( policy.id, success=success, baseline_score=baseline_score, actual_score=actual_score, ) - except Exception: - pass + result["policies_updated"] += 1 + except Exception as exc: + self._record_operation_error( + result, + operation="record_step_outcome", + component="policies.record_outcome", + exc=exc, + target=policy.id, + ) + return result def record_learning_outcomes( self, @@ -357,33 +550,70 @@ def record_learning_outcomes( matched = self.policies.match_policies( user_id, task_type, task_desc, ) + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="policies.match_policies", + exc=exc, + ) + else: for policy in matched: - self.policies.record_outcome( - policy.id, - success=success, - baseline_score=baseline_score, - actual_score=actual_score, - ) - result["policies_updated"] += 1 - except Exception: - pass + try: + self.policies.record_outcome( + policy.id, + success=success, + baseline_score=baseline_score, + actual_score=actual_score, + ) + result["policies_updated"] += 1 + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="policies.record_outcome", + exc=exc, + target=policy.id, + ) # 2. Extract TASK + STEP policies from completed tasks try: completed = self.tasks.get_tasks_by_type( user_id, task_type, limit=10, ) + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="tasks.get_tasks_by_type", + exc=exc, + ) + else: if len(completed) >= 3: task_dicts = [t.to_dict() for t in completed] - self.policies.extract_from_tasks( - user_id, task_dicts, task_type, - ) - step_policies = self.policies.extract_step_policies( - user_id, task_dicts, task_type, - ) - result["step_policies_created"] = len(step_policies) - except Exception: - pass + try: + self.policies.extract_from_tasks( + user_id, task_dicts, task_type, + ) + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="policies.extract_from_tasks", + exc=exc, + ) + try: + step_policies = self.policies.extract_step_policies( + user_id, task_dicts, task_type, + ) + result["step_policies_created"] = len(step_policies) + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="policies.extract_step_policies", + exc=exc, + ) # 3. Belief-policy interaction: challenged beliefs degrade dependent policies if not success: @@ -391,75 +621,177 @@ def record_learning_outcomes( relevant_beliefs = self.beliefs.get_relevant_beliefs( user_id, task_desc, limit=3, ) + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="beliefs.get_relevant_beliefs", + exc=exc, + ) + else: + try: + user_policies = list(self.policies.get_user_policies(user_id)) + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="policies.get_user_policies", + exc=exc, + ) + user_policies = [] + for belief in relevant_beliefs: if belief.confidence < 0.3: claim_words = set(belief.claim.lower().split()[:5]) - for policy in self.policies.get_user_policies(user_id): + for policy in user_policies: approach_words = set(policy.action.approach.lower().split()) if len(claim_words & approach_words) >= 2: - self.policies.decay_utility(policy.id, factor=0.8) - result["beliefs_policy_decays"] += 1 - except Exception: - pass + try: + self.policies.decay_utility(policy.id, factor=0.8) + result["beliefs_policy_decays"] += 1 + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="policies.decay_utility", + exc=exc, + target=policy.id, + ) # 4. Intention outcome recording try: triggered = self.intentions.get_triggered_pending_feedback(user_id) + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="intentions.get_triggered_pending_feedback", + exc=exc, + ) + else: for intention in triggered: - self.intentions.record_outcome( - intention.id, - useful=success, - outcome_score=actual_score, - ) - result["intentions_updated"] += 1 - except Exception: - pass + try: + self.intentions.record_outcome( + intention.id, + useful=success, + outcome_score=actual_score, + ) + result["intentions_updated"] += 1 + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="intentions.record_outcome", + exc=exc, + target=intention.id, + ) # 5. Episode connection wiring + connections = 0 try: active_task = self.tasks.get_active_task(user_id) - connections = 0 if active_task: connections += 1 + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="tasks.get_active_task", + exc=exc, + ) + try: matched_policies = self.policies.match_policies( user_id, task_type, task_desc, limit=3, ) connections += len(matched_policies) - if connections > 0: + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="policies.match_policies", + exc=exc, + ) + if connections > 0: + try: self.episodes.increment_connections(user_id, connections) - except Exception: - pass + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="episodes.increment_connections", + exc=exc, + ) # 6. Temporal failure pattern detection (decision stumps) try: from dhee.core.pattern_detector import ( FailurePatternDetector, extract_features, ) - recent = self.tasks.get_recent_tasks( - user_id, limit=100, include_terminal=True, + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="pattern_detector.import", + exc=exc, ) - terminal = [t for t in recent if t.is_terminal] - if len(terminal) >= FailurePatternDetector.MIN_SAMPLES: - # Build episode lookup via public API - episode_map = {} - for t in terminal: - if t.episode_id: - ep = self.episodes.get_episode(t.episode_id) - if ep: - episode_map[ep.id] = ep - - features = extract_features(terminal, episode_map) - detector = FailurePatternDetector() - patterns = detector.detect_and_describe(features) - - for pattern in patterns[:3]: - stored = self._store_pattern_as_policy( - user_id, task_type, pattern, - ) - if stored: - result["patterns_detected"] += 1 - except Exception: - pass + else: + try: + recent = self.tasks.get_recent_tasks( + user_id, limit=100, include_terminal=True, + ) + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="tasks.get_recent_tasks", + exc=exc, + ) + else: + terminal = [t for t in recent if t.is_terminal] + if len(terminal) >= FailurePatternDetector.MIN_SAMPLES: + episode_map = {} + for task in terminal: + if not task.episode_id: + continue + try: + episode = self.episodes.get_episode(task.episode_id) + if episode: + episode_map[episode.id] = episode + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="episodes.get_episode", + exc=exc, + target=task.episode_id, + ) + + try: + features = extract_features(terminal, episode_map) + detector = FailurePatternDetector() + patterns = detector.detect_and_describe(features) + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="pattern_detector.detect_and_describe", + exc=exc, + ) + else: + for pattern in patterns[:3]: + try: + stored = self._store_pattern_as_policy( + user_id, task_type, pattern, + ) + if stored: + result["patterns_detected"] += 1 + except Exception as exc: + self._record_operation_error( + result, + operation="record_learning_outcomes", + component="policies.store_temporal_pattern", + exc=exc, + ) return result @@ -476,12 +808,22 @@ def selective_forget( ) if archived > 0: result["episodes_archived"] = archived - except Exception: - pass + except Exception as exc: + self._record_operation_error( + result, + operation="selective_forget", + component="episodes.selective_forget", + exc=exc, + ) try: self.beliefs.prune_retracted(user_id) - except Exception: - pass + except Exception as exc: + self._record_operation_error( + result, + operation="selective_forget", + component="beliefs.prune_retracted", + exc=exc, + ) return result # ------------------------------------------------------------------ @@ -537,19 +879,39 @@ def _store_pattern_as_policy( def flush(self) -> None: """Persist all store state to disk.""" + errors: List[Dict[str, str]] = [] for store in [ - self.episodes, self.tasks, self.beliefs, - self.policies, self.intentions, + ("episodes", self.episodes), + ("tasks", self.tasks), + ("beliefs", self.beliefs), + ("policies", self.policies), + ("intentions", self.intentions), ]: - if hasattr(store, "flush"): + name, store_instance = store + if hasattr(store_instance, "flush"): try: - store.flush() - except Exception: - pass + store_instance.flush() + except Exception as exc: + logger.exception( + "CognitionKernel flush failed for %s", name + ) + errors.append( + self._operation_error( + "flush", + f"{name}.flush", + exc, + ) + ) + if errors: + detail = "; ".join( + f"{entry['component']}: {entry['error']}" for entry in errors + ) + raise RuntimeError(f"Failed to flush cognition stores: {detail}") def get_stats(self) -> Dict[str, Any]: """Aggregated stats from all stores.""" stats: Dict[str, Any] = {} + errors: List[Dict[str, str]] = [] for name, store in [ ("episodes", self.episodes), ("tasks", self.tasks), @@ -559,8 +921,18 @@ def get_stats(self) -> Dict[str, Any]: ]: try: stats[name] = store.get_stats() - except Exception: - stats[name] = {} + except Exception as exc: + logger.exception("CognitionKernel get_stats failed for %s", name) + errors.append( + self._operation_error( + "get_stats", + f"{name}.get_stats", + exc, + ) + ) + stats[name] = {"error": f"{type(exc).__name__}: {exc}"} + if errors: + stats["errors"] = errors return stats def __repr__(self) -> str: diff --git a/dhee/core/consolidation.py b/dhee/core/consolidation.py index 764aa70..7c07add 100644 --- a/dhee/core/consolidation.py +++ b/dhee/core/consolidation.py @@ -18,10 +18,9 @@ """ import logging -from typing import Any, Dict, TYPE_CHECKING +from typing import Any, Dict, Protocol, TYPE_CHECKING from dhee.configs.active import ActiveMemoryConfig -from dhee.core.active_memory import ActiveMemoryStore if TYPE_CHECKING: from dhee.memory.main import FullMemory @@ -29,6 +28,21 @@ logger = logging.getLogger(__name__) +class ActiveMemoryStore(Protocol): + """Structural contract for consolidation's active-memory dependency.""" + + def get_consolidation_candidates( + self, + *, + min_age_seconds: int, + min_reads: int, + ) -> list[Dict[str, Any]]: + ... + + def mark_consolidated(self, signal_ids: list[str]) -> None: + ... + + class ConsolidationEngine: """Promotes qualifying active signals into passive (Engram) memory.""" diff --git a/dhee/core/decay.py b/dhee/core/decay.py index 322a1db..a1a54ee 100644 --- a/dhee/core/decay.py +++ b/dhee/core/decay.py @@ -1,6 +1,6 @@ """FadeMem decay calculations. -Requires dhee-accel (Rust) for the core decay math. +Uses dhee-accel (Rust) when available, pure-Python fallback otherwise. """ import math @@ -10,7 +10,26 @@ if TYPE_CHECKING: from dhee.configs.base import FadeMemConfig -from dhee_accel import calculate_decayed_strength as _rs_decay +try: + from dhee_accel import calculate_decayed_strength as _rs_decay + _ACCEL = True +except ImportError: + _ACCEL = False + + +def _py_decay( + strength: float, + elapsed_days: float, + decay_rate: float, + access_count: int, + dampening_factor: float, +) -> float: + """Pure-Python decay: strength * exp(-rate * days / dampening).""" + if math.isnan(strength): + return 0.0 + dampening = 1.0 + dampening_factor * math.log(1.0 + access_count) + decayed = strength * math.exp(-decay_rate * elapsed_days / dampening) + return max(0.0, min(1.0, decayed)) def calculate_decayed_strength( @@ -31,7 +50,15 @@ def calculate_decayed_strength( time_elapsed_days = (datetime.now(timezone.utc) - last_accessed).total_seconds() / 86400.0 decay_rate = config.sml_decay_rate if layer == "sml" else config.lml_decay_rate - return _rs_decay( + if _ACCEL: + return _rs_decay( + current_strength, + time_elapsed_days, + decay_rate, + access_count, + config.access_dampening_factor, + ) + return _py_decay( current_strength, time_elapsed_days, decay_rate, diff --git a/dhee/core/kernel.py b/dhee/core/kernel.py index 9694f81..5bc0f95 100644 --- a/dhee/core/kernel.py +++ b/dhee/core/kernel.py @@ -36,10 +36,13 @@ def get_last_session( repo: Optional[str] = None, fallback_log_recovery: bool = True, db_path: Optional[str] = None, + user_id: Optional[str] = None, + requester_agent_id: Optional[str] = None, ) -> Optional[Dict]: """Get the last session for *agent_id*, falling back to JSONL logs. - 1. Try ``bus.get_session(agent_id=agent_id)`` + 1. Try ``bus.get_session(agent_id=agent_id, repo=repo)`` when a repo is + provided, otherwise fall back to agent-only lookup. 2. If found, attach latest checkpoint journal entries. 3. If not found **and** *fallback_log_recovery* is ``True`` **and** *repo* is provided, parse the most recent Claude Code conversation log for @@ -50,11 +53,18 @@ def get_last_session( agent_id: The source agent whose session to load (default ``"mcp-server"``). repo: - Absolute path to the repository root, used for log-based fallback. + Absolute path to the repository root. When provided, bus lookup is + repo-scoped before log fallback is attempted. fallback_log_recovery: Whether to fall back to JSONL log parsing if no bus session exists. db_path: Override path for the handoff SQLite database. + user_id: + Reserved compatibility parameter for higher-level continuity callers. + Currently unused because handoff records are agent/repo scoped. + requester_agent_id: + Reserved compatibility parameter for higher-level continuity callers. + Currently unused because handoff lookup is driven by *agent_id*. Returns ------- @@ -63,7 +73,7 @@ def get_last_session( bus = None try: bus = _get_bus(db_path) - session = bus.get_session(agent_id=agent_id) + session = bus.get_session(agent_id=agent_id, repo=repo) if session is not None: # Attach latest checkpoints @@ -113,6 +123,8 @@ def save_session_digest( key_commands: Optional[List[str]] = None, test_results: Optional[str] = None, db_path: Optional[str] = None, + user_id: Optional[str] = None, + requester_agent_id: Optional[str] = None, ) -> Dict: """Save a session digest to the dhee-bus handoff store. diff --git a/dhee/core/retrieval.py b/dhee/core/retrieval.py index 68453ef..d6e0886 100644 --- a/dhee/core/retrieval.py +++ b/dhee/core/retrieval.py @@ -1,12 +1,19 @@ """Retrieval scoring functions for Engram memory search. -Requires dhee-accel (Rust) for tokenize and BM25 operations. +Uses dhee-accel (Rust) when available, pure-Python fallback otherwise. """ import math +import re from typing import Dict, List, Any, Optional, Set -from dhee_accel import tokenize as _rs_tokenize, bm25_score_batch as _rs_bm25_batch +try: + from dhee_accel import tokenize as _rs_tokenize, bm25_score_batch as _rs_bm25_batch + _ACCEL = True +except ImportError: + _ACCEL = False + +_TOKEN_RE = re.compile(r'[a-z0-9_]+') def composite_score(similarity: float, strength: float) -> float: @@ -14,9 +21,16 @@ def composite_score(similarity: float, strength: float) -> float: return similarity * strength +def _py_tokenize(text: str) -> List[str]: + """Pure-Python tokenize: lowercase, split on non-alphanumeric boundaries.""" + return _TOKEN_RE.findall(text.lower()) + + def tokenize(text: str) -> List[str]: - """Tokenize text for BM25 scoring (Rust-accelerated).""" - return _rs_tokenize(text) + """Tokenize text for BM25 scoring.""" + if _ACCEL: + return _rs_tokenize(text) + return _py_tokenize(text) def calculate_bm25_score( @@ -56,6 +70,53 @@ def calculate_bm25_score( return score +def _py_bm25_batch( + query_terms: List[str], + documents: List[List[str]], + total_docs: int, + avg_doc_len: float, + k1: float = 1.5, + b: float = 0.75, +) -> List[float]: + """Pure-Python batch BM25 scoring.""" + if not query_terms or not documents: + return [0.0] * len(documents) + + total_docs_f = float(total_docs) + if avg_doc_len == 0.0: + avg_doc_len = 1.0 + + # Document frequency for query terms + doc_freq: Dict[str, int] = {} + for term in query_terms: + count = sum(1 for doc in documents if term in doc) + doc_freq[term] = count + + scores = [] + for doc in documents: + if not doc: + scores.append(0.0) + continue + + tf: Dict[str, int] = {} + for t in doc: + tf[t] = tf.get(t, 0) + 1 + + doc_len = float(len(doc)) + score = 0.0 + for term in query_terms: + if term not in tf: + continue + term_f = float(tf[term]) + df = float(doc_freq.get(term, 1)) + idf = math.log((total_docs_f - df + 0.5) / (df + 0.5) + 1.0) + tf_component = (term_f * (k1 + 1.0)) / (term_f + k1 * (1.0 - b + b * doc_len / avg_doc_len)) + score += idf * tf_component + scores.append(score) + + return scores + + def bm25_score_batch( query_terms: List[str], documents: List[List[str]], @@ -64,8 +125,10 @@ def bm25_score_batch( k1: float = 1.5, b: float = 0.75, ) -> List[float]: - """Batch BM25 scoring for N documents (Rust-accelerated).""" - return _rs_bm25_batch(query_terms, documents, total_docs, avg_doc_len, k1, b) + """Batch BM25 scoring for N documents.""" + if _ACCEL: + return _rs_bm25_batch(query_terms, documents, total_docs, avg_doc_len, k1, b) + return _py_bm25_batch(query_terms, documents, total_docs, avg_doc_len, k1, b) def calculate_keyword_score( @@ -99,32 +162,22 @@ def calculate_keyword_score( def build_sparse_vector(text: str, dim: int = 30000) -> Dict[int, float]: - """Build a sparse BM25-like weight vector from text. - - Tokenizes via Rust, hashes tokens to sparse indices, and returns - a dict mapping index → weight. Useful for hybrid dense+sparse search - if the vector store supports sparse fields. - """ + """Build a sparse BM25-like weight vector from text.""" import hashlib as _hashlib tokens = tokenize(text) if not tokens: return {} - # Term frequency tf: Dict[str, int] = {} for token in tokens: tf[token] = tf.get(token, 0) + 1 sparse: Dict[int, float] = {} - doc_len = len(tokens) for token, count in tf.items(): - # Hash token to a sparse index h = int(_hashlib.md5(token.encode("utf-8")).hexdigest(), 16) idx = h % dim - # BM25-like weight: tf / (tf + 1) weight = count / (count + 1.0) - # Accumulate in case of hash collision sparse[idx] = sparse.get(idx, 0.0) + weight return sparse @@ -165,7 +218,6 @@ def score_memory( hybrid = hybrid_score(semantic_similarity, keyword_score, self.alpha) - # Apply contrastive boost: results aligned with past successes score higher if self.contrastive_boost > 0 and contrastive_signal > 0: hybrid += self.contrastive_boost * contrastive_signal diff --git a/dhee/core/traces.py b/dhee/core/traces.py index 44516bd..258114c 100644 --- a/dhee/core/traces.py +++ b/dhee/core/traces.py @@ -3,7 +3,7 @@ Each memory has three traces (fast, mid, slow) that decay at different rates and cascade information from fast -> mid -> slow during sleep cycles. -Requires dhee-accel (Rust) for batch decay operations. +Uses dhee-accel (Rust) when available, pure-Python fallback otherwise. """ from __future__ import annotations @@ -15,7 +15,11 @@ if TYPE_CHECKING: from dhee.configs.base import DistillationConfig -from dhee_accel import decay_traces_batch as _rs_decay_traces_batch +try: + from dhee_accel import decay_traces_batch as _rs_decay_traces_batch + _ACCEL = True +except ImportError: + _ACCEL = False def initialize_traces( @@ -72,14 +76,44 @@ def decay_traces( ) +def _py_decay_traces_batch( + traces: List[Tuple[float, float, float]], + elapsed_days: List[float], + access_counts: List[int], + fast_rate: float, + mid_rate: float, + slow_rate: float, +) -> List[Tuple[float, float, float]]: + """Pure-Python batch trace decay.""" + results = [] + for i, (s_fast, s_mid, s_slow) in enumerate(traces): + days = elapsed_days[i] if i < len(elapsed_days) else 0.0 + access = access_counts[i] if i < len(access_counts) else 0 + dampening = 1.0 + 0.5 * math.log(1.0 + access) + new_fast = max(0.0, min(1.0, s_fast * math.exp(-fast_rate * days / dampening))) + new_mid = max(0.0, min(1.0, s_mid * math.exp(-mid_rate * days / dampening))) + new_slow = max(0.0, min(1.0, s_slow * math.exp(-slow_rate * days / dampening))) + results.append((new_fast, new_mid, new_slow)) + return results + + def decay_traces_batch( traces: List[Tuple[float, float, float]], elapsed_days: List[float], access_counts: List[int], config: "DistillationConfig", ) -> List[Tuple[float, float, float]]: - """Batch version of decay_traces (Rust-accelerated).""" - return _rs_decay_traces_batch( + """Batch version of decay_traces.""" + if _ACCEL: + return _rs_decay_traces_batch( + traces, + elapsed_days, + [int(a) for a in access_counts], + config.s_fast_decay_rate, + config.s_mid_decay_rate, + config.s_slow_decay_rate, + ) + return _py_decay_traces_batch( traces, elapsed_days, [int(a) for a in access_counts], diff --git a/dhee/db/sqlite.py b/dhee/db/sqlite.py index 297b312..e1cb60e 100644 --- a/dhee/db/sqlite.py +++ b/dhee/db/sqlite.py @@ -1,4 +1,3 @@ -import hashlib import json import logging import os @@ -6,44 +5,18 @@ import threading import uuid from contextlib import contextmanager -from datetime import datetime, timezone from typing import Any, Dict, List, Optional -logger = logging.getLogger(__name__) - -# Phase 5: Allowed column names for dynamic UPDATE queries to prevent SQL injection. -VALID_MEMORY_COLUMNS = frozenset({ - "memory", "metadata", "categories", "embedding", "strength", - "layer", "tombstone", "updated_at", "related_memories", "source_memories", - "confidentiality_scope", "source_type", "source_app", "source_event_id", - "decay_lambda", "status", "importance", "sensitivity", "namespace", - "access_count", "last_accessed", "immutable", "expiration_date", - "scene_id", "user_id", "agent_id", "run_id", "app_id", - "memory_type", "s_fast", "s_mid", "s_slow", "content_hash", - "conversation_context", "enrichment_status", -}) - -VALID_SCENE_COLUMNS = frozenset({ - "title", "summary", "topic", "location", "participants", "memory_ids", - "start_time", "end_time", "embedding", "strength", "access_count", - "tombstone", "layer", "scene_strength", "topic_embedding_ref", "namespace", -}) - -VALID_PROFILE_COLUMNS = frozenset({ - "name", "profile_type", "narrative", "facts", "preferences", - "relationships", "sentiment", "theory_of_mind", "aliases", - "embedding", "strength", "updated_at", "role_bias", "profile_summary", -}) +from .sqlite_analytics import SQLiteAnalyticsMixin +from .sqlite_common import ( + VALID_MEMORY_COLUMNS, + VALID_PROFILE_COLUMNS, + VALID_SCENE_COLUMNS, + _utcnow_iso, +) +from .sqlite_domains import SQLiteDomainMixin - -def _utcnow() -> datetime: - """Return current UTC datetime (timezone-aware).""" - return datetime.now(timezone.utc) - - -def _utcnow_iso() -> str: - """Return current UTC time as ISO string.""" - return _utcnow().isoformat() +logger = logging.getLogger(__name__) class _SQLiteBase: @@ -544,33 +517,7 @@ def purge_tombstoned(self) -> int: return count -class FullSQLiteManager(CoreSQLiteManager): - def __init__(self, db_path: str): - self.db_path = db_path - db_dir = os.path.dirname(db_path) - if db_dir: - os.makedirs(db_dir, exist_ok=True) - # Phase 1: Persistent connection with WAL mode. - self._conn = sqlite3.connect(db_path, check_same_thread=False) - self._conn.execute("PRAGMA journal_mode=WAL") - self._conn.execute("PRAGMA busy_timeout=5000") - self._conn.execute("PRAGMA synchronous=FULL") - self._conn.execute("PRAGMA cache_size=-8000") # 8MB cache - self._conn.execute("PRAGMA temp_store=MEMORY") - self._conn.row_factory = sqlite3.Row - self._lock = threading.RLock() - self._init_db() - - def close(self) -> None: - """Close the persistent connection for clean shutdown.""" - with self._lock: - if self._conn: - try: - self._conn.close() - except Exception: - pass - self._conn = None # type: ignore[assignment] - +class FullSQLiteManager(SQLiteAnalyticsMixin, SQLiteDomainMixin, CoreSQLiteManager): def __repr__(self) -> str: return f"SQLiteManager(db_path={self.db_path!r})" @@ -724,17 +671,6 @@ def _init_db(self) -> None: # v2 schema + idempotent migrations. self._ensure_v2_schema(conn) - @contextmanager - def _get_connection(self): - """Yield the persistent connection under the thread lock.""" - with self._lock: - try: - yield self._conn - self._conn.commit() - except Exception: - self._conn.rollback() - raise - def _ensure_v2_schema(self, conn: sqlite3.Connection) -> None: """Create and migrate Engram v2 schema in-place (idempotent).""" conn.execute( @@ -1199,18 +1135,6 @@ def _ensure_v3_universal_engram(self, conn: sqlite3.Connection) -> None: "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v3_universal_engram')" ) - def _ensure_content_hash_column(self, conn: sqlite3.Connection) -> None: - """Add content_hash column + index for SHA-256 dedup.""" - if self._is_migration_applied(conn, "v2_content_hash"): - return - self._migrate_add_column_conn(conn, "memories", "content_hash", "TEXT") - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_content_hash ON memories(content_hash, user_id)" - ) - conn.execute( - "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v2_content_hash')" - ) - def _ensure_cls_columns(self, conn: sqlite3.Connection) -> None: """Add CLS Distillation Memory columns to memories table (idempotent).""" if self._is_migration_applied(conn, "v2_cls_columns_complete"): @@ -1234,36 +1158,6 @@ def _ensure_cls_columns(self, conn: sqlite3.Connection) -> None: "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v2_cls_columns_complete')" ) - def _is_migration_applied(self, conn: sqlite3.Connection, version: str) -> bool: - row = conn.execute( - "SELECT 1 FROM schema_migrations WHERE version = ?", - (version,), - ).fetchone() - return row is not None - - # Phase 5: Allowed table names for ALTER TABLE to prevent SQL injection. - _ALLOWED_TABLES = frozenset({ - "memories", "scenes", "profiles", "categories", - }) - - def _migrate_add_column_conn( - self, - conn: sqlite3.Connection, - table: str, - column: str, - col_type: str, - ) -> None: - """Add a column using an existing connection, if missing.""" - if table not in self._ALLOWED_TABLES: - raise ValueError(f"Invalid table for migration: {table!r}") - # Validate column name: must be alphanumeric/underscore only. - if not column.replace("_", "").isalnum(): - raise ValueError(f"Invalid column name: {column!r}") - try: - conn.execute(f"ALTER TABLE {table} ADD COLUMN {column} {col_type}") - except sqlite3.OperationalError: - pass - def add_memory(self, memory_data: Dict[str, Any]) -> str: memory_id = memory_data.get("id", str(uuid.uuid4())) now = _utcnow_iso() @@ -1842,1001 +1736,7 @@ def purge_tombstoned(self) -> int: conn.execute("DELETE FROM memories WHERE tombstone = 1") return count - # CLS Distillation Memory helpers - - def get_episodic_memories( - self, - user_id: str, - *, - scene_id: Optional[str] = None, - created_after: Optional[str] = None, - created_before: Optional[str] = None, - limit: int = 100, - namespace: Optional[str] = None, - ) -> List[Dict[str, Any]]: - """Fetch episodic-type memories for a user, optionally filtered by scene/time.""" - query = "SELECT * FROM memories WHERE user_id = ? AND memory_type = 'episodic' AND tombstone = 0" - params: List[Any] = [user_id] - if scene_id: - query += " AND scene_id = ?" - params.append(scene_id) - if created_after: - query += " AND created_at >= ?" - params.append(created_after) - if created_before: - query += " AND created_at <= ?" - params.append(created_before) - if namespace: - query += " AND namespace = ?" - params.append(namespace) - query += " ORDER BY created_at DESC LIMIT ?" - params.append(limit) - with self._get_connection() as conn: - rows = conn.execute(query, params).fetchall() - return [self._row_to_dict(row) for row in rows] - - def add_distillation_provenance( - self, - semantic_memory_id: str, - episodic_memory_ids: List[str], - run_id: str, - ) -> None: - """Record which episodic memories contributed to a distilled semantic memory.""" - with self._get_connection() as conn: - for ep_id in episodic_memory_ids: - conn.execute( - """ - INSERT INTO distillation_provenance (id, semantic_memory_id, episodic_memory_id, distillation_run_id) - VALUES (?, ?, ?, ?) - """, - (str(uuid.uuid4()), semantic_memory_id, ep_id, run_id), - ) - - def log_distillation_run( - self, - user_id: str, - episodes_sampled: int, - semantic_created: int, - semantic_deduplicated: int = 0, - errors: int = 0, - ) -> str: - """Log a distillation run and return the run ID.""" - run_id = str(uuid.uuid4()) - with self._get_connection() as conn: - conn.execute( - """ - INSERT INTO distillation_log (id, user_id, episodes_sampled, semantic_created, semantic_deduplicated, errors) - VALUES (?, ?, ?, ?, ?, ?) - """, - (run_id, user_id, episodes_sampled, semantic_created, semantic_deduplicated, errors), - ) - return run_id - - def get_memory_count_by_namespace(self, user_id: str) -> Dict[str, int]: - """Return {namespace: count} for active memories of a user.""" - with self._get_connection() as conn: - rows = conn.execute( - """ - SELECT COALESCE(namespace, 'default') AS ns, COUNT(*) AS cnt - FROM memories - WHERE user_id = ? AND tombstone = 0 - GROUP BY ns - """, - (user_id,), - ).fetchall() - return {row["ns"]: row["cnt"] for row in rows} - - def update_multi_trace( - self, - memory_id: str, - s_fast: float, - s_mid: float, - s_slow: float, - effective_strength: float, - ) -> bool: - """Update multi-trace columns and effective strength for a memory.""" - return self.update_memory(memory_id, { - "s_fast": s_fast, - "s_mid": s_mid, - "s_slow": s_slow, - "strength": effective_strength, - }) - - # CategoryMem methods - def save_category(self, category_data: Dict[str, Any]) -> str: - """Save or update a category.""" - category_id = category_data.get("id") - if not category_id: - return "" - - with self._get_connection() as conn: - conn.execute( - """ - INSERT OR REPLACE INTO categories ( - id, name, description, category_type, parent_id, - children_ids, memory_count, total_strength, access_count, - last_accessed, created_at, embedding, keywords, - summary, summary_updated_at, related_ids, strength - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - category_id, - category_data.get("name", ""), - category_data.get("description", ""), - category_data.get("category_type", "dynamic"), - category_data.get("parent_id"), - json.dumps(category_data.get("children_ids", [])), - category_data.get("memory_count", 0), - category_data.get("total_strength", 0.0), - category_data.get("access_count", 0), - category_data.get("last_accessed"), - category_data.get("created_at"), - json.dumps(category_data.get("embedding")) if category_data.get("embedding") else None, - json.dumps(category_data.get("keywords", [])), - category_data.get("summary"), - category_data.get("summary_updated_at"), - json.dumps(category_data.get("related_ids", [])), - category_data.get("strength", 1.0), - ), - ) - return category_id - - def get_category(self, category_id: str) -> Optional[Dict[str, Any]]: - """Get a category by ID.""" - with self._get_connection() as conn: - row = conn.execute( - "SELECT * FROM categories WHERE id = ?", - (category_id,) - ).fetchone() - if row: - return self._category_row_to_dict(row) - return None - - def get_all_categories(self) -> List[Dict[str, Any]]: - """Get all categories.""" - with self._get_connection() as conn: - rows = conn.execute( - "SELECT * FROM categories ORDER BY strength DESC" - ).fetchall() - return [self._category_row_to_dict(row) for row in rows] - - def delete_category(self, category_id: str) -> bool: - """Delete a category.""" - with self._get_connection() as conn: - conn.execute("DELETE FROM categories WHERE id = ?", (category_id,)) - return True - - def save_all_categories(self, categories: List[Dict[str, Any]]) -> int: - """Save multiple categories in a single transaction for performance.""" - if not categories: - return 0 - rows = [] - for cat in categories: - cat_id = cat.get("id") - if not cat_id: - continue - rows.append(( - cat_id, - cat.get("name", ""), - cat.get("description", ""), - cat.get("category_type", "dynamic"), - cat.get("parent_id"), - json.dumps(cat.get("children_ids", [])), - cat.get("memory_count", 0), - cat.get("total_strength", 0.0), - cat.get("access_count", 0), - cat.get("last_accessed"), - cat.get("created_at"), - json.dumps(cat.get("embedding")) if cat.get("embedding") else None, - json.dumps(cat.get("keywords", [])), - cat.get("summary"), - cat.get("summary_updated_at"), - json.dumps(cat.get("related_ids", [])), - cat.get("strength", 1.0), - )) - if not rows: - return 0 - with self._get_connection() as conn: - conn.executemany( - """ - INSERT OR REPLACE INTO categories ( - id, name, description, category_type, parent_id, - children_ids, memory_count, total_strength, access_count, - last_accessed, created_at, embedding, keywords, - summary, summary_updated_at, related_ids, strength - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - rows, - ) - return len(rows) - - def _category_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: - """Convert a category row to dict.""" - data = dict(row) - for key in ["children_ids", "keywords", "related_ids"]: - if key in data and data[key]: - data[key] = json.loads(data[key]) - else: - data[key] = [] - if data.get("embedding"): - data["embedding"] = json.loads(data["embedding"]) - return data - - def _migrate_add_column(self, table: str, column: str, col_type: str) -> None: - """Add a column to an existing table if it doesn't already exist.""" - with self._get_connection() as conn: - self._migrate_add_column_conn(conn, table, column, col_type) - - # ========================================================================= - # Scene methods - # ========================================================================= - - def add_scene(self, scene_data: Dict[str, Any]) -> str: - scene_id = scene_data.get("id", str(uuid.uuid4())) - with self._get_connection() as conn: - conn.execute( - """ - INSERT INTO scenes ( - id, user_id, title, summary, topic, location, - participants, memory_ids, start_time, end_time, - embedding, strength, access_count, tombstone, - layer, scene_strength, topic_embedding_ref, namespace - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - scene_id, - scene_data.get("user_id"), - scene_data.get("title"), - scene_data.get("summary"), - scene_data.get("topic"), - scene_data.get("location"), - json.dumps(scene_data.get("participants", [])), - json.dumps(scene_data.get("memory_ids", [])), - scene_data.get("start_time"), - scene_data.get("end_time"), - json.dumps(scene_data.get("embedding")) if scene_data.get("embedding") else None, - scene_data.get("strength", 1.0), - scene_data.get("access_count", 0), - 1 if scene_data.get("tombstone", False) else 0, - scene_data.get("layer", "sml"), - scene_data.get("scene_strength", scene_data.get("strength", 1.0)), - scene_data.get("topic_embedding_ref"), - scene_data.get("namespace", "default"), - ), - ) - return scene_id - - def get_scene(self, scene_id: str) -> Optional[Dict[str, Any]]: - with self._get_connection() as conn: - row = conn.execute( - "SELECT * FROM scenes WHERE id = ? AND tombstone = 0", (scene_id,) - ).fetchone() - if row: - return self._scene_row_to_dict(row) - return None - - def update_scene(self, scene_id: str, updates: Dict[str, Any]) -> bool: - set_clauses = [] - params: List[Any] = [] - for key, value in updates.items(): - if key not in VALID_SCENE_COLUMNS: - raise ValueError(f"Invalid scene column: {key!r}") - if key in {"participants", "memory_ids", "embedding"}: - value = json.dumps(value) - set_clauses.append(f"{key} = ?") - params.append(value) - if not set_clauses: - return False - params.append(scene_id) - with self._get_connection() as conn: - conn.execute( - f"UPDATE scenes SET {', '.join(set_clauses)} WHERE id = ?", - params, - ) - return True - - def get_open_scene(self, user_id: str) -> Optional[Dict[str, Any]]: - """Get the most recent scene without an end_time for a user.""" - with self._get_connection() as conn: - row = conn.execute( - """ - SELECT * FROM scenes - WHERE user_id = ? AND end_time IS NULL AND tombstone = 0 - ORDER BY start_time DESC LIMIT 1 - """, - (user_id,), - ).fetchone() - if row: - return self._scene_row_to_dict(row) - return None - - def get_scenes( - self, - user_id: Optional[str] = None, - topic: Optional[str] = None, - start_after: Optional[str] = None, - start_before: Optional[str] = None, - namespace: Optional[str] = None, - limit: int = 50, - ) -> List[Dict[str, Any]]: - query = "SELECT * FROM scenes WHERE tombstone = 0" - params: List[Any] = [] - if user_id: - query += " AND user_id = ?" - params.append(user_id) - if topic: - query += " AND topic LIKE ?" - params.append(f"%{topic}%") - if start_after: - query += " AND start_time >= ?" - params.append(start_after) - if start_before: - query += " AND start_time <= ?" - params.append(start_before) - if namespace: - query += " AND namespace = ?" - params.append(namespace) - query += " ORDER BY start_time DESC LIMIT ?" - params.append(limit) - with self._get_connection() as conn: - rows = conn.execute(query, params).fetchall() - return [self._scene_row_to_dict(row) for row in rows] - - def add_scene_memory(self, scene_id: str, memory_id: str, position: int = 0) -> None: - with self._get_connection() as conn: - conn.execute( - "INSERT OR IGNORE INTO scene_memories (scene_id, memory_id, position) VALUES (?, ?, ?)", - (scene_id, memory_id, position), - ) - - def get_scene_memories(self, scene_id: str) -> List[Dict[str, Any]]: - with self._get_connection() as conn: - rows = conn.execute( - """ - SELECT m.* FROM memories m - JOIN scene_memories sm ON m.id = sm.memory_id - WHERE sm.scene_id = ? AND m.tombstone = 0 - ORDER BY sm.position - """, - (scene_id,), - ).fetchall() - return [self._row_to_dict(row) for row in rows] - - def _scene_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: - data = dict(row) - for key in ["participants", "memory_ids"]: - if key in data and data[key]: - data[key] = json.loads(data[key]) - else: - data[key] = [] - if data.get("embedding"): - data["embedding"] = json.loads(data["embedding"]) - data["tombstone"] = bool(data.get("tombstone", 0)) - return data - - # ========================================================================= - # Profile methods - # ========================================================================= - - def add_profile(self, profile_data: Dict[str, Any]) -> str: - profile_id = profile_data.get("id", str(uuid.uuid4())) - now = _utcnow_iso() - with self._get_connection() as conn: - conn.execute( - """ - INSERT INTO profiles ( - id, user_id, name, profile_type, narrative, - facts, preferences, relationships, sentiment, - theory_of_mind, aliases, embedding, strength, - created_at, updated_at, role_bias, profile_summary - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - profile_id, - profile_data.get("user_id"), - profile_data.get("name", ""), - profile_data.get("profile_type", "contact"), - profile_data.get("narrative"), - json.dumps(profile_data.get("facts", [])), - json.dumps(profile_data.get("preferences", [])), - json.dumps(profile_data.get("relationships", [])), - profile_data.get("sentiment"), - json.dumps(profile_data.get("theory_of_mind", {})), - json.dumps(profile_data.get("aliases", [])), - json.dumps(profile_data.get("embedding")) if profile_data.get("embedding") else None, - profile_data.get("strength", 1.0), - profile_data.get("created_at", now), - profile_data.get("updated_at", now), - profile_data.get("role_bias"), - profile_data.get("profile_summary"), - ), - ) - return profile_id - - def get_profile(self, profile_id: str) -> Optional[Dict[str, Any]]: - with self._get_connection() as conn: - row = conn.execute( - "SELECT * FROM profiles WHERE id = ?", (profile_id,) - ).fetchone() - if row: - return self._profile_row_to_dict(row) - return None - - def update_profile(self, profile_id: str, updates: Dict[str, Any]) -> bool: - set_clauses = [] - params: List[Any] = [] - for key, value in updates.items(): - if key not in VALID_PROFILE_COLUMNS: - raise ValueError(f"Invalid profile column: {key!r}") - if key in {"facts", "preferences", "relationships", "aliases", "theory_of_mind", "embedding"}: - value = json.dumps(value) - set_clauses.append(f"{key} = ?") - params.append(value) - set_clauses.append("updated_at = ?") - params.append(_utcnow_iso()) - params.append(profile_id) - with self._get_connection() as conn: - conn.execute( - f"UPDATE profiles SET {', '.join(set_clauses)} WHERE id = ?", - params, - ) - return True - - def get_all_profiles(self, user_id: Optional[str] = None) -> List[Dict[str, Any]]: - query = "SELECT * FROM profiles" - params: List[Any] = [] - if user_id: - query += " WHERE user_id = ?" - params.append(user_id) - query += " ORDER BY strength DESC" - with self._get_connection() as conn: - rows = conn.execute(query, params).fetchall() - return [self._profile_row_to_dict(row) for row in rows] - - def get_profile_by_name(self, name: str, user_id: Optional[str] = None) -> Optional[Dict[str, Any]]: - """Find a profile by exact name match, then fall back to alias scan.""" - # Fast path: exact name match via indexed column. - query = "SELECT * FROM profiles WHERE lower(name) = ?" - params: List[Any] = [name.lower()] - if user_id: - query += " AND user_id = ?" - params.append(user_id) - query += " LIMIT 1" - with self._get_connection() as conn: - row = conn.execute(query, params).fetchone() - if row: - return self._profile_row_to_dict(row) - # Slow path: alias scan (aliases stored as JSON, can't index). - alias_query = "SELECT * FROM profiles WHERE aliases LIKE ?" - alias_params: List[Any] = [f'%"{name}"%'] - if user_id: - alias_query += " AND user_id = ?" - alias_params.append(user_id) - alias_query += " LIMIT 1" - row = conn.execute(alias_query, alias_params).fetchone() - if row: - result = self._profile_row_to_dict(row) - # Verify case-insensitive alias match. - if name.lower() in [a.lower() for a in result.get("aliases", [])]: - return result - return None - - def find_profile_by_substring(self, name: str, user_id: Optional[str] = None) -> Optional[Dict[str, Any]]: - """Find a profile where the name contains the query as a substring (case-insensitive).""" - query = "SELECT * FROM profiles WHERE lower(name) LIKE ?" - params: List[Any] = [f"%{name.lower()}%"] - if user_id: - query += " AND user_id = ?" - params.append(user_id) - query += " ORDER BY strength DESC LIMIT 1" - with self._get_connection() as conn: - row = conn.execute(query, params).fetchone() - if row: - return self._profile_row_to_dict(row) - return None - - def add_profile_memory(self, profile_id: str, memory_id: str, role: str = "mentioned") -> None: - with self._get_connection() as conn: - conn.execute( - "INSERT OR IGNORE INTO profile_memories (profile_id, memory_id, role) VALUES (?, ?, ?)", - (profile_id, memory_id, role), - ) - - def get_profile_memories(self, profile_id: str) -> List[Dict[str, Any]]: - with self._get_connection() as conn: - rows = conn.execute( - """ - SELECT m.*, pm.role AS profile_role FROM memories m - JOIN profile_memories pm ON m.id = pm.memory_id - WHERE pm.profile_id = ? AND m.tombstone = 0 - ORDER BY m.created_at DESC - """, - (profile_id,), - ).fetchall() - return [self._row_to_dict(row) for row in rows] - - def _profile_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: - data = dict(row) - for key in ["facts", "preferences", "relationships", "aliases"]: - if key in data and data[key]: - data[key] = json.loads(data[key]) - else: - data[key] = [] - if data.get("theory_of_mind"): - data["theory_of_mind"] = json.loads(data["theory_of_mind"]) - else: - data["theory_of_mind"] = {} - if data.get("embedding"): - data["embedding"] = json.loads(data["embedding"]) - return data - - def get_memories_by_category( - self, - category_id: str, - limit: int = 100, - min_strength: float = 0.0, - ) -> List[Dict[str, Any]]: - """Get memories belonging to a specific category.""" - with self._get_connection() as conn: - rows = conn.execute( - """ - SELECT * FROM memories - WHERE categories LIKE ? AND strength >= ? AND tombstone = 0 - ORDER BY strength DESC - LIMIT ? - """, - (f'%"{category_id}"%', min_strength, limit), - ).fetchall() - return [self._row_to_dict(row) for row in rows] - - # ========================================================================= - # User ID listing - # ========================================================================= - - def list_user_ids(self) -> List[str]: - with self._get_connection() as conn: - rows = conn.execute( - """ - SELECT DISTINCT user_id FROM memories - WHERE user_id IS NOT NULL AND user_id != '' - ORDER BY user_id - """ - ).fetchall() - return [str(row["user_id"]) for row in rows if row["user_id"]] - - # ========================================================================= - # Episodic event index + cost counters - # ========================================================================= - - def delete_episodic_events_for_memory(self, memory_id: str) -> int: - with self._get_connection() as conn: - cur = conn.execute( - "DELETE FROM episodic_events WHERE memory_id = ?", - (memory_id,), - ) - return int(cur.rowcount or 0) - - def add_episodic_events(self, events: List[Dict[str, Any]]) -> int: - if not events: - return 0 - rows = [] - for event in events: - rows.append( - ( - event.get("id"), - event.get("memory_id"), - event.get("user_id"), - event.get("conversation_id"), - event.get("session_id"), - int(event.get("turn_id", 0) or 0), - event.get("actor_id"), - event.get("actor_role"), - event.get("event_time"), - event.get("event_type"), - event.get("canonical_key"), - event.get("value_text"), - event.get("value_num"), - event.get("value_unit"), - event.get("currency"), - event.get("normalized_time_start"), - event.get("normalized_time_end"), - event.get("time_granularity"), - event.get("entity_key"), - event.get("value_norm"), - event.get("confidence", 0.0), - event.get("superseded_by"), - ) - ) - with self._get_connection() as conn: - conn.executemany( - """ - INSERT OR REPLACE INTO episodic_events ( - id, memory_id, user_id, conversation_id, session_id, turn_id, - actor_id, actor_role, event_time, event_type, canonical_key, - value_text, value_num, value_unit, currency, - normalized_time_start, normalized_time_end, time_granularity, entity_key, value_norm, - confidence, superseded_by - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - rows, - ) - return len(rows) - - def get_episodic_events( - self, - *, - user_id: str, - actor_id: Optional[str] = None, - event_types: Optional[List[str]] = None, - time_anchor: Optional[str] = None, - entity_hints: Optional[List[str]] = None, - include_superseded: bool = False, - limit: int = 300, - ) -> List[Dict[str, Any]]: - query = "SELECT * FROM episodic_events WHERE user_id = ?" - params: List[Any] = [user_id] - if actor_id: - query += " AND actor_id = ?" - params.append(actor_id) - if event_types: - placeholders = ",".join("?" for _ in event_types) - query += f" AND event_type IN ({placeholders})" - params.extend(event_types) - if time_anchor: - query += " AND COALESCE(normalized_time_start, event_time) <= ?" - params.append(str(time_anchor)) - normalized_hints = [str(h).strip().lower() for h in (entity_hints or []) if str(h).strip()] - if normalized_hints: - clauses = [] - for hint in normalized_hints: - wildcard = f"%{hint}%" - clauses.append( - "(" - "LOWER(COALESCE(entity_key, '')) LIKE ? " - "OR LOWER(COALESCE(actor_id, '')) LIKE ? " - "OR LOWER(COALESCE(actor_role, '')) LIKE ?" - ")" - ) - params.extend([wildcard, wildcard, wildcard]) - query += " AND (" + " OR ".join(clauses) + ")" - if not include_superseded: - query += " AND (superseded_by IS NULL OR superseded_by = '')" - query += " ORDER BY COALESCE(normalized_time_start, event_time) DESC, turn_id DESC LIMIT ?" - params.append(max(1, int(limit))) - with self._get_connection() as conn: - rows = conn.execute(query, params).fetchall() - return [dict(row) for row in rows] - - def record_cost_counter( - self, - *, - phase: str, - user_id: Optional[str] = None, - llm_calls: float = 0.0, - input_tokens: float = 0.0, - output_tokens: float = 0.0, - embed_calls: float = 0.0, - ) -> None: - with self._get_connection() as conn: - conn.execute( - """ - INSERT INTO cost_counters ( - user_id, phase, llm_calls, input_tokens, output_tokens, embed_calls - ) VALUES (?, ?, ?, ?, ?, ?) - """, - ( - user_id, - str(phase), - float(llm_calls or 0.0), - float(input_tokens or 0.0), - float(output_tokens or 0.0), - float(embed_calls or 0.0), - ), - ) - - def aggregate_cost_counters( - self, - *, - phase: str, - user_id: Optional[str] = None, - ) -> Dict[str, Any]: - query = """ - SELECT - COUNT(*) AS samples, - COALESCE(SUM(llm_calls), 0) AS llm_calls, - COALESCE(SUM(input_tokens), 0) AS input_tokens, - COALESCE(SUM(output_tokens), 0) AS output_tokens, - COALESCE(SUM(embed_calls), 0) AS embed_calls - FROM cost_counters - WHERE phase = ? - """ - params: List[Any] = [str(phase)] - if user_id: - query += " AND user_id = ?" - params.append(user_id) - with self._get_connection() as conn: - row = conn.execute(query, params).fetchone() - if not row: - return { - "phase": phase, - "samples": 0, - "llm_calls": 0.0, - "input_tokens": 0.0, - "output_tokens": 0.0, - "embed_calls": 0.0, - } - return { - "phase": phase, - "samples": int(row["samples"] or 0), - "llm_calls": float(row["llm_calls"] or 0.0), - "input_tokens": float(row["input_tokens"] or 0.0), - "output_tokens": float(row["output_tokens"] or 0.0), - "embed_calls": float(row["embed_calls"] or 0.0), - } - - # ========================================================================= - # Dashboard / Visualization methods - # ========================================================================= - - def get_constellation_data(self, user_id: Optional[str] = None, limit: int = 200) -> Dict[str, Any]: - """Build graph data for the constellation visualizer.""" - with self._get_connection() as conn: - # Nodes: memories - mem_query = "SELECT id, memory, strength, layer, categories, created_at FROM memories WHERE tombstone = 0" - params: List[Any] = [] - if user_id: - mem_query += " AND user_id = ?" - params.append(user_id) - mem_query += " ORDER BY strength DESC LIMIT ?" - params.append(limit) - mem_rows = conn.execute(mem_query, params).fetchall() - - nodes = [] - node_ids = set() - for row in mem_rows: - cats = row["categories"] - if cats: - try: - cats = json.loads(cats) - except Exception: - cats = [] - else: - cats = [] - nodes.append({ - "id": row["id"], - "memory": (row["memory"] or "")[:120], - "strength": row["strength"], - "layer": row["layer"], - "categories": cats, - "created_at": row["created_at"], - }) - node_ids.add(row["id"]) - - # Edges from scene_memories (memories sharing a scene) - edges: List[Dict[str, Any]] = [] - if node_ids: - placeholders = ",".join("?" for _ in node_ids) - scene_rows = conn.execute( - f""" - SELECT a.memory_id AS source, b.memory_id AS target, a.scene_id - FROM scene_memories a - JOIN scene_memories b ON a.scene_id = b.scene_id AND a.memory_id < b.memory_id - WHERE a.memory_id IN ({placeholders}) AND b.memory_id IN ({placeholders}) - """, - list(node_ids) + list(node_ids), - ).fetchall() - for row in scene_rows: - edges.append({"source": row["source"], "target": row["target"], "type": "scene"}) - - # Edges from profile_memories (memories sharing a profile) - profile_rows = conn.execute( - f""" - SELECT a.memory_id AS source, b.memory_id AS target, a.profile_id - FROM profile_memories a - JOIN profile_memories b ON a.profile_id = b.profile_id AND a.memory_id < b.memory_id - WHERE a.memory_id IN ({placeholders}) AND b.memory_id IN ({placeholders}) - """, - list(node_ids) + list(node_ids), - ).fetchall() - for row in profile_rows: - edges.append({"source": row["source"], "target": row["target"], "type": "profile"}) - - return {"nodes": nodes, "edges": edges} - - def get_decay_log_entries(self, limit: int = 20) -> List[Dict[str, Any]]: - """Return recent decay log entries for the dashboard sparkline.""" - with self._get_connection() as conn: - rows = conn.execute( - "SELECT * FROM decay_log ORDER BY run_at DESC LIMIT ?", - (limit,), - ).fetchall() - return [dict(row) for row in rows] - - # ========================================================================= - # Entity Aggregates - # ========================================================================= - - def _ensure_entity_table(self, conn: sqlite3.Connection) -> None: - """Lazily ensure entity_aggregates table exists.""" - conn.execute(""" - CREATE TABLE IF NOT EXISTS entity_aggregates ( - id TEXT PRIMARY KEY, - user_id TEXT NOT NULL, - entity_key TEXT NOT NULL, - agg_type TEXT NOT NULL, - value_num REAL DEFAULT 0.0, - value_unit TEXT, - item_set TEXT, - contributing_sessions TEXT, - contributing_memory_ids TEXT, - last_updated TEXT, - created_at TEXT - ) - """) - conn.execute(""" - CREATE INDEX IF NOT EXISTS idx_entity_agg_lookup - ON entity_aggregates(user_id, agg_type, entity_key) - """) - - def upsert_entity_aggregate( - self, - user_id: str, - entity_key: str, - agg_type: str, - value_delta: float, - value_unit: Optional[str] = None, - session_id: Optional[str] = None, - memory_id: Optional[str] = None, - ) -> None: - """Atomically increment an entity aggregate and append contributing session/memory.""" - agg_id = hashlib.sha256(f"{user_id}|{agg_type}|{entity_key}".encode()).hexdigest() - now = datetime.now(timezone.utc).isoformat() - with self._get_connection() as conn: - self._ensure_entity_table(conn) - existing = conn.execute( - "SELECT value_num, contributing_sessions, contributing_memory_ids FROM entity_aggregates WHERE id = ?", - (agg_id,), - ).fetchone() - if existing: - cur_val = float(existing["value_num"] or 0) - sessions = json.loads(existing["contributing_sessions"] or "[]") - memories = json.loads(existing["contributing_memory_ids"] or "[]") - if session_id and session_id not in sessions: - sessions.append(session_id) - if memory_id and memory_id not in memories: - memories.append(memory_id) - conn.execute( - """UPDATE entity_aggregates - SET value_num = ?, value_unit = COALESCE(?, value_unit), - contributing_sessions = ?, contributing_memory_ids = ?, - last_updated = ? - WHERE id = ?""", - (cur_val + value_delta, value_unit, - json.dumps(sessions), json.dumps(memories), now, agg_id), - ) - else: - sessions = [session_id] if session_id else [] - memories = [memory_id] if memory_id else [] - conn.execute( - """INSERT INTO entity_aggregates - (id, user_id, entity_key, agg_type, value_num, value_unit, - item_set, contributing_sessions, contributing_memory_ids, - last_updated, created_at) - VALUES (?, ?, ?, ?, ?, ?, '[]', ?, ?, ?, ?)""", - (agg_id, user_id, entity_key, agg_type, value_delta, value_unit, - json.dumps(sessions), json.dumps(memories), now, now), - ) - - def upsert_entity_set_member( - self, - user_id: str, - entity_key: str, - item_value: str, - session_id: Optional[str] = None, - memory_id: Optional[str] = None, - ) -> None: - """Add a unique item to an item_set aggregate and increment count.""" - agg_id = hashlib.sha256(f"{user_id}|item_set|{entity_key}".encode()).hexdigest() - now = datetime.now(timezone.utc).isoformat() - with self._get_connection() as conn: - self._ensure_entity_table(conn) - existing = conn.execute( - "SELECT value_num, item_set, contributing_sessions, contributing_memory_ids " - "FROM entity_aggregates WHERE id = ?", - (agg_id,), - ).fetchone() - if existing: - items = json.loads(existing["item_set"] or "[]") - sessions = json.loads(existing["contributing_sessions"] or "[]") - memories = json.loads(existing["contributing_memory_ids"] or "[]") - if item_value not in items: - items.append(item_value) - if session_id and session_id not in sessions: - sessions.append(session_id) - if memory_id and memory_id not in memories: - memories.append(memory_id) - conn.execute( - """UPDATE entity_aggregates - SET value_num = ?, item_set = ?, - contributing_sessions = ?, contributing_memory_ids = ?, - last_updated = ? - WHERE id = ?""", - (len(items), json.dumps(items), - json.dumps(sessions), json.dumps(memories), now, agg_id), - ) - else: - sessions = [session_id] if session_id else [] - memories = [memory_id] if memory_id else [] - conn.execute( - """INSERT INTO entity_aggregates - (id, user_id, entity_key, agg_type, value_num, value_unit, - item_set, contributing_sessions, contributing_memory_ids, - last_updated, created_at) - VALUES (?, ?, ?, 'item_set', 1, NULL, ?, ?, ?, ?, ?)""", - (agg_id, user_id, entity_key, - json.dumps([item_value]), - json.dumps(sessions), json.dumps(memories), now, now), - ) - - def get_entity_aggregates( - self, - user_id: str, - agg_type: Optional[str] = None, - entity_hints: Optional[List[str]] = None, - ) -> List[Dict[str, Any]]: - """Query entity aggregates with optional fuzzy match on entity_key.""" - with self._get_connection() as conn: - self._ensure_entity_table(conn) - if agg_type and entity_hints: - # Build fuzzy LIKE conditions for each hint - conditions = " OR ".join(["entity_key LIKE ?" for _ in entity_hints]) - params: list = [user_id, agg_type] + [f"%{h}%" for h in entity_hints] - rows = conn.execute( - f"SELECT * FROM entity_aggregates WHERE user_id = ? AND agg_type = ? AND ({conditions})", - params, - ).fetchall() - elif agg_type: - rows = conn.execute( - "SELECT * FROM entity_aggregates WHERE user_id = ? AND agg_type = ?", - (user_id, agg_type), - ).fetchall() - elif entity_hints: - conditions = " OR ".join(["entity_key LIKE ?" for _ in entity_hints]) - params = [user_id] + [f"%{h}%" for h in entity_hints] - rows = conn.execute( - f"SELECT * FROM entity_aggregates WHERE user_id = ? AND ({conditions})", - params, - ).fetchall() - else: - rows = conn.execute( - "SELECT * FROM entity_aggregates WHERE user_id = ?", - (user_id,), - ).fetchall() - return [dict(row) for row in rows] - - def delete_entity_aggregates_for_user(self, user_id: str) -> int: - """Delete all entity aggregates for a user (benchmark isolation).""" - with self._get_connection() as conn: - self._ensure_entity_table(conn) - cursor = conn.execute( - "DELETE FROM entity_aggregates WHERE user_id = ?", (user_id,) - ) - return cursor.rowcount - - # ========================================================================= - # Utilities - # ========================================================================= - - @staticmethod - def _parse_json_value(value: Any, default: Any) -> Any: - if value is None: - return default - if isinstance(value, (dict, list)): - return value - try: - return json.loads(value) - except Exception: - return default + # Domain- and analytics-specific APIs live in focused mixins. # Backward compatibility alias diff --git a/dhee/db/sqlite_analytics.py b/dhee/db/sqlite_analytics.py new file mode 100644 index 0000000..0323baf --- /dev/null +++ b/dhee/db/sqlite_analytics.py @@ -0,0 +1,634 @@ +import hashlib +import json +import sqlite3 +import uuid +from typing import Any, Dict, List, Optional + +from .sqlite_common import _utcnow_iso + + +class SQLiteAnalyticsMixin: + """Distillation, episodic indexing, counters, and aggregate APIs.""" + + def get_episodic_memories( + self, + user_id: str, + *, + scene_id: Optional[str] = None, + created_after: Optional[str] = None, + created_before: Optional[str] = None, + limit: int = 100, + namespace: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """Fetch episodic-type memories for a user, optionally filtered.""" + query = ( + "SELECT * FROM memories WHERE user_id = ? " + "AND memory_type = 'episodic' AND tombstone = 0" + ) + params: List[Any] = [user_id] + if scene_id: + query += " AND scene_id = ?" + params.append(scene_id) + if created_after: + query += " AND created_at >= ?" + params.append(created_after) + if created_before: + query += " AND created_at <= ?" + params.append(created_before) + if namespace: + query += " AND namespace = ?" + params.append(namespace) + query += " ORDER BY created_at DESC LIMIT ?" + params.append(limit) + with self._get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [self._row_to_dict(row) for row in rows] + + def add_distillation_provenance( + self, + semantic_memory_id: str, + episodic_memory_ids: List[str], + run_id: str, + ) -> None: + """Record which episodic memories contributed to a semantic memory.""" + with self._get_connection() as conn: + for ep_id in episodic_memory_ids: + conn.execute( + """ + INSERT INTO distillation_provenance ( + id, semantic_memory_id, episodic_memory_id, + distillation_run_id + ) VALUES (?, ?, ?, ?) + """, + (str(uuid.uuid4()), semantic_memory_id, ep_id, run_id), + ) + + def log_distillation_run( + self, + user_id: str, + episodes_sampled: int, + semantic_created: int, + semantic_deduplicated: int = 0, + errors: int = 0, + ) -> str: + """Log a distillation run and return the run ID.""" + run_id = str(uuid.uuid4()) + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO distillation_log ( + id, user_id, episodes_sampled, semantic_created, + semantic_deduplicated, errors + ) VALUES (?, ?, ?, ?, ?, ?) + """, + ( + run_id, + user_id, + episodes_sampled, + semantic_created, + semantic_deduplicated, + errors, + ), + ) + return run_id + + def get_memory_count_by_namespace(self, user_id: str) -> Dict[str, int]: + """Return {namespace: count} for active memories of a user.""" + with self._get_connection() as conn: + rows = conn.execute( + """ + SELECT COALESCE(namespace, 'default') AS ns, COUNT(*) AS cnt + FROM memories + WHERE user_id = ? AND tombstone = 0 + GROUP BY ns + """, + (user_id,), + ).fetchall() + return {row["ns"]: row["cnt"] for row in rows} + + def update_multi_trace( + self, + memory_id: str, + s_fast: float, + s_mid: float, + s_slow: float, + effective_strength: float, + ) -> bool: + """Update multi-trace columns and effective strength for a memory.""" + return self.update_memory( + memory_id, + { + "s_fast": s_fast, + "s_mid": s_mid, + "s_slow": s_slow, + "strength": effective_strength, + }, + ) + + def delete_episodic_events_for_memory(self, memory_id: str) -> int: + with self._get_connection() as conn: + cur = conn.execute( + "DELETE FROM episodic_events WHERE memory_id = ?", + (memory_id,), + ) + return int(cur.rowcount or 0) + + def add_episodic_events(self, events: List[Dict[str, Any]]) -> int: + if not events: + return 0 + rows = [] + for event in events: + rows.append( + ( + event.get("id"), + event.get("memory_id"), + event.get("user_id"), + event.get("conversation_id"), + event.get("session_id"), + int(event.get("turn_id", 0) or 0), + event.get("actor_id"), + event.get("actor_role"), + event.get("event_time"), + event.get("event_type"), + event.get("canonical_key"), + event.get("value_text"), + event.get("value_num"), + event.get("value_unit"), + event.get("currency"), + event.get("normalized_time_start"), + event.get("normalized_time_end"), + event.get("time_granularity"), + event.get("entity_key"), + event.get("value_norm"), + event.get("confidence", 0.0), + event.get("superseded_by"), + ) + ) + with self._get_connection() as conn: + conn.executemany( + """ + INSERT OR REPLACE INTO episodic_events ( + id, memory_id, user_id, conversation_id, session_id, turn_id, + actor_id, actor_role, event_time, event_type, canonical_key, + value_text, value_num, value_unit, currency, + normalized_time_start, normalized_time_end, time_granularity, + entity_key, value_norm, confidence, superseded_by + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + rows, + ) + return len(rows) + + def get_episodic_events( + self, + *, + user_id: str, + actor_id: Optional[str] = None, + event_types: Optional[List[str]] = None, + time_anchor: Optional[str] = None, + entity_hints: Optional[List[str]] = None, + include_superseded: bool = False, + limit: int = 300, + ) -> List[Dict[str, Any]]: + query = "SELECT * FROM episodic_events WHERE user_id = ?" + params: List[Any] = [user_id] + if actor_id: + query += " AND actor_id = ?" + params.append(actor_id) + if event_types: + placeholders = ",".join("?" for _ in event_types) + query += f" AND event_type IN ({placeholders})" + params.extend(event_types) + if time_anchor: + query += " AND COALESCE(normalized_time_start, event_time) <= ?" + params.append(str(time_anchor)) + normalized_hints = [ + str(h).strip().lower() + for h in (entity_hints or []) + if str(h).strip() + ] + if normalized_hints: + clauses = [] + for hint in normalized_hints: + wildcard = f"%{hint}%" + clauses.append( + "(" + "LOWER(COALESCE(entity_key, '')) LIKE ? " + "OR LOWER(COALESCE(actor_id, '')) LIKE ? " + "OR LOWER(COALESCE(actor_role, '')) LIKE ?" + ")" + ) + params.extend([wildcard, wildcard, wildcard]) + query += " AND (" + " OR ".join(clauses) + ")" + if not include_superseded: + query += " AND (superseded_by IS NULL OR superseded_by = '')" + query += ( + " ORDER BY COALESCE(normalized_time_start, event_time) DESC, " + "turn_id DESC LIMIT ?" + ) + params.append(max(1, int(limit))) + with self._get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [dict(row) for row in rows] + + def record_cost_counter( + self, + *, + phase: str, + user_id: Optional[str] = None, + llm_calls: float = 0.0, + input_tokens: float = 0.0, + output_tokens: float = 0.0, + embed_calls: float = 0.0, + ) -> None: + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO cost_counters ( + user_id, phase, llm_calls, input_tokens, output_tokens, + embed_calls + ) VALUES (?, ?, ?, ?, ?, ?) + """, + ( + user_id, + str(phase), + float(llm_calls or 0.0), + float(input_tokens or 0.0), + float(output_tokens or 0.0), + float(embed_calls or 0.0), + ), + ) + + def aggregate_cost_counters( + self, + *, + phase: str, + user_id: Optional[str] = None, + ) -> Dict[str, Any]: + query = """ + SELECT + COUNT(*) AS samples, + COALESCE(SUM(llm_calls), 0) AS llm_calls, + COALESCE(SUM(input_tokens), 0) AS input_tokens, + COALESCE(SUM(output_tokens), 0) AS output_tokens, + COALESCE(SUM(embed_calls), 0) AS embed_calls + FROM cost_counters + WHERE phase = ? + """ + params: List[Any] = [str(phase)] + if user_id: + query += " AND user_id = ?" + params.append(user_id) + with self._get_connection() as conn: + row = conn.execute(query, params).fetchone() + if not row: + return { + "phase": phase, + "samples": 0, + "llm_calls": 0.0, + "input_tokens": 0.0, + "output_tokens": 0.0, + "embed_calls": 0.0, + } + return { + "phase": phase, + "samples": int(row["samples"] or 0), + "llm_calls": float(row["llm_calls"] or 0.0), + "input_tokens": float(row["input_tokens"] or 0.0), + "output_tokens": float(row["output_tokens"] or 0.0), + "embed_calls": float(row["embed_calls"] or 0.0), + } + + def get_constellation_data( + self, + user_id: Optional[str] = None, + limit: int = 200, + ) -> Dict[str, Any]: + """Build graph data for the constellation visualizer.""" + with self._get_connection() as conn: + mem_query = ( + "SELECT id, memory, strength, layer, categories, created_at " + "FROM memories WHERE tombstone = 0" + ) + params: List[Any] = [] + if user_id: + mem_query += " AND user_id = ?" + params.append(user_id) + mem_query += " ORDER BY strength DESC LIMIT ?" + params.append(limit) + mem_rows = conn.execute(mem_query, params).fetchall() + + nodes = [] + node_ids = set() + for row in mem_rows: + nodes.append( + { + "id": row["id"], + "memory": (row["memory"] or "")[:120], + "strength": row["strength"], + "layer": row["layer"], + "categories": self._parse_json_value( + row["categories"], [] + ), + "created_at": row["created_at"], + } + ) + node_ids.add(row["id"]) + + edges: List[Dict[str, Any]] = [] + if node_ids: + placeholders = ",".join("?" for _ in node_ids) + scene_rows = conn.execute( + f""" + SELECT a.memory_id AS source, b.memory_id AS target, a.scene_id + FROM scene_memories a + JOIN scene_memories b + ON a.scene_id = b.scene_id + AND a.memory_id < b.memory_id + WHERE a.memory_id IN ({placeholders}) + AND b.memory_id IN ({placeholders}) + """, + list(node_ids) + list(node_ids), + ).fetchall() + for row in scene_rows: + edges.append( + { + "source": row["source"], + "target": row["target"], + "type": "scene", + } + ) + + profile_rows = conn.execute( + f""" + SELECT a.memory_id AS source, b.memory_id AS target, a.profile_id + FROM profile_memories a + JOIN profile_memories b + ON a.profile_id = b.profile_id + AND a.memory_id < b.memory_id + WHERE a.memory_id IN ({placeholders}) + AND b.memory_id IN ({placeholders}) + """, + list(node_ids) + list(node_ids), + ).fetchall() + for row in profile_rows: + edges.append( + { + "source": row["source"], + "target": row["target"], + "type": "profile", + } + ) + + return {"nodes": nodes, "edges": edges} + + def get_decay_log_entries(self, limit: int = 20) -> List[Dict[str, Any]]: + """Return recent decay log entries for the dashboard sparkline.""" + with self._get_connection() as conn: + rows = conn.execute( + "SELECT * FROM decay_log ORDER BY run_at DESC LIMIT ?", + (limit,), + ).fetchall() + return [dict(row) for row in rows] + + def _ensure_entity_table(self, conn: sqlite3.Connection) -> None: + """Lazily ensure entity_aggregates table exists.""" + conn.execute( + """ + CREATE TABLE IF NOT EXISTS entity_aggregates ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + entity_key TEXT NOT NULL, + agg_type TEXT NOT NULL, + value_num REAL DEFAULT 0.0, + value_unit TEXT, + item_set TEXT, + contributing_sessions TEXT, + contributing_memory_ids TEXT, + last_updated TEXT, + created_at TEXT + ) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_entity_agg_lookup + ON entity_aggregates(user_id, agg_type, entity_key) + """ + ) + + def upsert_entity_aggregate( + self, + user_id: str, + entity_key: str, + agg_type: str, + value_delta: float, + value_unit: Optional[str] = None, + session_id: Optional[str] = None, + memory_id: Optional[str] = None, + ) -> None: + """Increment an entity aggregate and append provenance metadata.""" + agg_id = hashlib.sha256( + f"{user_id}|{agg_type}|{entity_key}".encode() + ).hexdigest() + now = _utcnow_iso() + with self._get_connection() as conn: + self._ensure_entity_table(conn) + existing = conn.execute( + """ + SELECT value_num, contributing_sessions, contributing_memory_ids + FROM entity_aggregates WHERE id = ? + """, + (agg_id,), + ).fetchone() + if existing: + cur_val = float(existing["value_num"] or 0) + sessions = self._parse_json_value( + existing["contributing_sessions"], [] + ) + memories = self._parse_json_value( + existing["contributing_memory_ids"], [] + ) + if session_id and session_id not in sessions: + sessions.append(session_id) + if memory_id and memory_id not in memories: + memories.append(memory_id) + conn.execute( + """ + UPDATE entity_aggregates + SET value_num = ?, value_unit = COALESCE(?, value_unit), + contributing_sessions = ?, contributing_memory_ids = ?, + last_updated = ? + WHERE id = ? + """, + ( + cur_val + value_delta, + value_unit, + json.dumps(sessions), + json.dumps(memories), + now, + agg_id, + ), + ) + else: + sessions = [session_id] if session_id else [] + memories = [memory_id] if memory_id else [] + conn.execute( + """ + INSERT INTO entity_aggregates ( + id, user_id, entity_key, agg_type, value_num, value_unit, + item_set, contributing_sessions, contributing_memory_ids, + last_updated, created_at + ) VALUES (?, ?, ?, ?, ?, ?, '[]', ?, ?, ?, ?) + """, + ( + agg_id, + user_id, + entity_key, + agg_type, + value_delta, + value_unit, + json.dumps(sessions), + json.dumps(memories), + now, + now, + ), + ) + + def upsert_entity_set_member( + self, + user_id: str, + entity_key: str, + item_value: str, + session_id: Optional[str] = None, + memory_id: Optional[str] = None, + ) -> None: + """Add a unique item to an item_set aggregate and increment count.""" + agg_id = hashlib.sha256( + f"{user_id}|item_set|{entity_key}".encode() + ).hexdigest() + now = _utcnow_iso() + with self._get_connection() as conn: + self._ensure_entity_table(conn) + existing = conn.execute( + """ + SELECT value_num, item_set, contributing_sessions, + contributing_memory_ids + FROM entity_aggregates WHERE id = ? + """, + (agg_id,), + ).fetchone() + if existing: + items = self._parse_json_value(existing["item_set"], []) + sessions = self._parse_json_value( + existing["contributing_sessions"], [] + ) + memories = self._parse_json_value( + existing["contributing_memory_ids"], [] + ) + if item_value not in items: + items.append(item_value) + if session_id and session_id not in sessions: + sessions.append(session_id) + if memory_id and memory_id not in memories: + memories.append(memory_id) + conn.execute( + """ + UPDATE entity_aggregates + SET value_num = ?, item_set = ?, + contributing_sessions = ?, contributing_memory_ids = ?, + last_updated = ? + WHERE id = ? + """, + ( + len(items), + json.dumps(items), + json.dumps(sessions), + json.dumps(memories), + now, + agg_id, + ), + ) + else: + sessions = [session_id] if session_id else [] + memories = [memory_id] if memory_id else [] + conn.execute( + """ + INSERT INTO entity_aggregates ( + id, user_id, entity_key, agg_type, value_num, + value_unit, item_set, contributing_sessions, + contributing_memory_ids, last_updated, created_at + ) VALUES (?, ?, ?, 'item_set', 1, NULL, ?, ?, ?, ?, ?) + """, + ( + agg_id, + user_id, + entity_key, + json.dumps([item_value]), + json.dumps(sessions), + json.dumps(memories), + now, + now, + ), + ) + + def get_entity_aggregates( + self, + user_id: str, + agg_type: Optional[str] = None, + entity_hints: Optional[List[str]] = None, + ) -> List[Dict[str, Any]]: + """Query entity aggregates with optional fuzzy match on entity_key.""" + with self._get_connection() as conn: + self._ensure_entity_table(conn) + if agg_type and entity_hints: + conditions = " OR ".join( + ["entity_key LIKE ?" for _ in entity_hints] + ) + params: List[Any] = [user_id, agg_type] + [ + f"%{hint}%" + for hint in entity_hints + ] + rows = conn.execute( + f""" + SELECT * FROM entity_aggregates + WHERE user_id = ? AND agg_type = ? AND ({conditions}) + """, + params, + ).fetchall() + elif agg_type: + rows = conn.execute( + """ + SELECT * FROM entity_aggregates + WHERE user_id = ? AND agg_type = ? + """, + (user_id, agg_type), + ).fetchall() + elif entity_hints: + conditions = " OR ".join( + ["entity_key LIKE ?" for _ in entity_hints] + ) + params = [user_id] + [f"%{hint}%" for hint in entity_hints] + rows = conn.execute( + f""" + SELECT * FROM entity_aggregates + WHERE user_id = ? AND ({conditions}) + """, + params, + ).fetchall() + else: + rows = conn.execute( + "SELECT * FROM entity_aggregates WHERE user_id = ?", + (user_id,), + ).fetchall() + return [dict(row) for row in rows] + + def delete_entity_aggregates_for_user(self, user_id: str) -> int: + """Delete all entity aggregates for a user.""" + with self._get_connection() as conn: + self._ensure_entity_table(conn) + cursor = conn.execute( + "DELETE FROM entity_aggregates WHERE user_id = ?", + (user_id,), + ) + return int(cursor.rowcount or 0) diff --git a/dhee/db/sqlite_backup.py b/dhee/db/sqlite_backup.py deleted file mode 100644 index 27224d8..0000000 --- a/dhee/db/sqlite_backup.py +++ /dev/null @@ -1,2070 +0,0 @@ -import json -import logging -import os -import sqlite3 -import threading -import uuid -from contextlib import contextmanager -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional - -logger = logging.getLogger(__name__) - -# Phase 5: Allowed column names for dynamic UPDATE queries to prevent SQL injection. -VALID_MEMORY_COLUMNS = frozenset({ - "memory", "metadata", "categories", "embedding", "strength", - "layer", "tombstone", "updated_at", "related_memories", "source_memories", - "confidentiality_scope", "source_type", "source_app", "source_event_id", - "decay_lambda", "status", "importance", "sensitivity", "namespace", - "access_count", "last_accessed", "immutable", "expiration_date", - "scene_id", "user_id", "agent_id", "run_id", "app_id", - "memory_type", "s_fast", "s_mid", "s_slow", -}) - -VALID_SCENE_COLUMNS = frozenset({ - "title", "summary", "topic", "location", "participants", "memory_ids", - "start_time", "end_time", "embedding", "strength", "access_count", - "tombstone", "layer", "scene_strength", "topic_embedding_ref", "namespace", -}) - -VALID_PROFILE_COLUMNS = frozenset({ - "name", "profile_type", "narrative", "facts", "preferences", - "relationships", "sentiment", "theory_of_mind", "aliases", - "embedding", "strength", "updated_at", "role_bias", "profile_summary", -}) - - -def _utcnow() -> datetime: - """Return current UTC datetime (timezone-aware).""" - return datetime.now(timezone.utc) - - -def _utcnow_iso() -> str: - """Return current UTC time as ISO string.""" - return _utcnow().isoformat() - - -class _SQLiteBase: - """Base class for SQLite managers with common functionality.""" - - def __init__(self, db_path: str): - self.db_path = db_path - db_dir = os.path.dirname(db_path) - if db_dir: - os.makedirs(db_dir, exist_ok=True) - # Phase 1: Persistent connection with WAL mode. - self._conn = sqlite3.connect(db_path, check_same_thread=False) - self._conn.execute("PRAGMA journal_mode=WAL") - self._conn.execute("PRAGMA busy_timeout=5000") - self._conn.execute("PRAGMA synchronous=FULL") - self._conn.execute("PRAGMA cache_size=-8000") # 8MB cache - self._conn.execute("PRAGMA temp_store=MEMORY") - self._conn.row_factory = sqlite3.Row - self._lock = threading.RLock() - - def close(self) -> None: - """Close the persistent connection for clean shutdown.""" - with self._lock: - if self._conn: - try: - self._conn.close() - except Exception: - pass - self._conn = None # type: ignore[assignment] - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(db_path={self.db_path!r})" - - @contextmanager - def _get_connection(self): - """Yield the persistent connection under the thread lock.""" - with self._lock: - try: - yield self._conn - self._conn.commit() - except Exception: - self._conn.rollback() - raise - - def _is_migration_applied(self, conn: sqlite3.Connection, version: str) -> bool: - row = conn.execute( - "SELECT 1 FROM schema_migrations WHERE version = ?", - (version,), - ).fetchone() - return row is not None - - # Phase 5: Allowed table names for ALTER TABLE to prevent SQL injection. - _ALLOWED_TABLES = frozenset({ - "memories", "scenes", "profiles", "categories", - }) - - def _migrate_add_column_conn( - self, - conn: sqlite3.Connection, - table: str, - column: str, - col_type: str, - ) -> None: - """Add a column using an existing connection, if missing.""" - if table not in self._ALLOWED_TABLES: - raise ValueError(f"Invalid table for migration: {table!r}") - # Validate column name: must be alphanumeric/underscore only. - if not column.replace("_", "").isalnum(): - raise ValueError(f"Invalid column name: {column!r}") - try: - conn.execute(f"ALTER TABLE {table} ADD COLUMN {column} {col_type}") - except sqlite3.OperationalError: - pass - - @staticmethod - def _parse_json_value(value: Any, default: Any) -> Any: - if value is None: - return default - if isinstance(value, (dict, list)): - return value - try: - return json.loads(value) - except Exception: - return default - - -class CoreSQLiteManager(_SQLiteBase): - """Minimal SQLite manager for CoreMemory - only essential tables. - - Tables created: - - memories: core memory storage with content_hash for deduplication - - memory_history: audit trail for memory operations - - decay_log: decay cycle metrics - - schema_migrations: migration tracking - """ - - def __init__(self, db_path: str): - super().__init__(db_path) - self._init_db() - - def _init_db(self) -> None: - """Initialize minimal schema for CoreMemory.""" - with self._get_connection() as conn: - conn.executescript( - """ - CREATE TABLE IF NOT EXISTS memories ( - id TEXT PRIMARY KEY, - memory TEXT NOT NULL, - user_id TEXT, - agent_id TEXT, - run_id TEXT, - app_id TEXT, - metadata TEXT DEFAULT '{}', - categories TEXT DEFAULT '[]', - immutable INTEGER DEFAULT 0, - expiration_date TEXT, - created_at TEXT DEFAULT CURRENT_TIMESTAMP, - updated_at TEXT DEFAULT CURRENT_TIMESTAMP, - layer TEXT DEFAULT 'sml' CHECK (layer IN ('sml', 'lml')), - strength REAL DEFAULT 1.0, - access_count INTEGER DEFAULT 0, - last_accessed TEXT DEFAULT CURRENT_TIMESTAMP, - embedding TEXT, - related_memories TEXT DEFAULT '[]', - source_memories TEXT DEFAULT '[]', - tombstone INTEGER DEFAULT 0, - content_hash TEXT - ); - - CREATE INDEX IF NOT EXISTS idx_user_layer ON memories(user_id, layer); - CREATE INDEX IF NOT EXISTS idx_strength ON memories(strength DESC); - CREATE INDEX IF NOT EXISTS idx_tombstone ON memories(tombstone); - CREATE INDEX IF NOT EXISTS idx_content_hash ON memories(content_hash, user_id); - - CREATE TABLE IF NOT EXISTS memory_history ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - memory_id TEXT NOT NULL, - event TEXT NOT NULL, - old_value TEXT, - new_value TEXT, - old_strength REAL, - new_strength REAL, - old_layer TEXT, - new_layer TEXT, - timestamp TEXT DEFAULT CURRENT_TIMESTAMP - ); - - CREATE TABLE IF NOT EXISTS decay_log ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - run_at TEXT DEFAULT CURRENT_TIMESTAMP, - memories_decayed INTEGER, - memories_forgotten INTEGER, - memories_promoted INTEGER, - storage_before_mb REAL, - storage_after_mb REAL - ); - - CREATE TABLE IF NOT EXISTS schema_migrations ( - version TEXT PRIMARY KEY, - applied_at TEXT DEFAULT CURRENT_TIMESTAMP - ); - """ - ) - # Apply content_hash column migration if needed - self._ensure_content_hash_column(conn) - - def _ensure_content_hash_column(self, conn: sqlite3.Connection) -> None: - """Add content_hash column + index for SHA-256 dedup (idempotent).""" - if self._is_migration_applied(conn, "v2_content_hash"): - return - self._migrate_add_column_conn(conn, "memories", "content_hash", "TEXT") - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_content_hash ON memories(content_hash, user_id)" - ) - conn.execute( - "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v2_content_hash')" - ) - - # Core memory operations - def add_memory(self, memory_data: Dict[str, Any]) -> str: - memory_id = memory_data.get("id", str(uuid.uuid4())) - now = _utcnow_iso() - metadata = memory_data.get("metadata", {}) or {} - - with self._get_connection() as conn: - conn.execute( - """ - INSERT INTO memories ( - id, memory, user_id, agent_id, run_id, app_id, - metadata, categories, immutable, expiration_date, - created_at, updated_at, layer, strength, access_count, - last_accessed, embedding, related_memories, source_memories, tombstone, - content_hash - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - memory_id, - memory_data.get("memory", ""), - memory_data.get("user_id"), - memory_data.get("agent_id"), - memory_data.get("run_id"), - memory_data.get("app_id"), - json.dumps(memory_data.get("metadata", {})), - json.dumps(memory_data.get("categories", [])), - 1 if memory_data.get("immutable", False) else 0, - memory_data.get("expiration_date"), - memory_data.get("created_at", now), - memory_data.get("updated_at", now), - memory_data.get("layer", "sml"), - memory_data.get("strength", 1.0), - memory_data.get("access_count", 0), - memory_data.get("last_accessed", now), - json.dumps(memory_data.get("embedding", [])), - json.dumps(memory_data.get("related_memories", [])), - json.dumps(memory_data.get("source_memories", [])), - 1 if memory_data.get("tombstone", False) else 0, - memory_data.get("content_hash"), - ), - ) - # Log the add event - conn.execute( - """ - INSERT INTO memory_history ( - memory_id, event, old_value, new_value, - old_strength, new_strength, old_layer, new_layer - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - (memory_id, "ADD", None, memory_data.get("memory"), None, None, None, None), - ) - return memory_id - - def get_memory(self, memory_id: str, include_tombstoned: bool = False) -> Optional[Dict[str, Any]]: - query = "SELECT * FROM memories WHERE id = ?" - params = [memory_id] - if not include_tombstoned: - query += " AND tombstone = 0" - - with self._get_connection() as conn: - row = conn.execute(query, params).fetchone() - if row: - return self._row_to_dict(row) - return None - - def get_memory_by_content_hash( - self, content_hash: str, user_id: str = "default" - ) -> Optional[Dict[str, Any]]: - """Find an existing memory by content hash (for deduplication).""" - with self._get_connection() as conn: - row = conn.execute( - "SELECT * FROM memories WHERE content_hash = ? AND user_id = ? AND tombstone = 0 LIMIT 1", - (content_hash, user_id), - ).fetchone() - if row: - return self._row_to_dict(row) - return None - - def get_all_memories( - self, - *, - user_id: Optional[str] = None, - agent_id: Optional[str] = None, - run_id: Optional[str] = None, - app_id: Optional[str] = None, - layer: Optional[str] = None, - namespace: Optional[str] = None, - min_strength: float = 0.0, - include_tombstoned: bool = False, - limit: Optional[int] = None, - ) -> List[Dict[str, Any]]: - query = "SELECT * FROM memories WHERE strength >= ?" - params: List[Any] = [min_strength] - - if not include_tombstoned: - query += " AND tombstone = 0" - if user_id: - query += " AND user_id = ?" - params.append(user_id) - if agent_id: - query += " AND agent_id = ?" - params.append(agent_id) - if run_id: - query += " AND run_id = ?" - params.append(run_id) - if app_id: - query += " AND app_id = ?" - params.append(app_id) - if layer: - query += " AND layer = ?" - params.append(layer) - - query += " ORDER BY strength DESC" - - if limit is not None and limit > 0: - query += " LIMIT ?" - params.append(limit) - - with self._get_connection() as conn: - rows = conn.execute(query, params).fetchall() - return [self._row_to_dict(row) for row in rows] - - def update_memory(self, memory_id: str, updates: Dict[str, Any]) -> bool: - set_clauses = [] - params: List[Any] = [] - for key, value in updates.items(): - if key not in VALID_MEMORY_COLUMNS: - raise ValueError(f"Invalid memory column: {key!r}") - if key in {"metadata", "categories", "embedding", "related_memories", "source_memories"}: - value = json.dumps(value) - set_clauses.append(f"{key} = ?") - params.append(value) - - set_clauses.append("updated_at = ?") - params.append(_utcnow_iso()) - params.append(memory_id) - - with self._get_connection() as conn: - old_row = conn.execute( - "SELECT memory, strength, layer FROM memories WHERE id = ?", - (memory_id,), - ).fetchone() - if not old_row: - return False - - conn.execute( - f"UPDATE memories SET {', '.join(set_clauses)} WHERE id = ?", - params, - ) - - # Log the update event - conn.execute( - """ - INSERT INTO memory_history ( - memory_id, event, old_value, new_value, - old_strength, new_strength, old_layer, new_layer - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - memory_id, - "UPDATE", - old_row["memory"], - updates.get("memory"), - old_row["strength"], - updates.get("strength"), - old_row["layer"], - updates.get("layer"), - ), - ) - return True - - def delete_memory(self, memory_id: str, use_tombstone: bool = True) -> bool: - if use_tombstone: - return self.update_memory(memory_id, {"tombstone": 1}) - with self._get_connection() as conn: - conn.execute("DELETE FROM memories WHERE id = ?", (memory_id,)) - self._log_event(memory_id, "DELETE") - return True - - def increment_access(self, memory_id: str) -> None: - now = _utcnow_iso() - with self._get_connection() as conn: - conn.execute( - """ - UPDATE memories - SET access_count = access_count + 1, last_accessed = ? - WHERE id = ? - """, - (now, memory_id), - ) - - def increment_access_bulk(self, memory_ids: List[str]) -> None: - """Increment access count for multiple memories in a single transaction.""" - if not memory_ids: - return - now = _utcnow_iso() - with self._get_connection() as conn: - placeholders = ",".join("?" for _ in memory_ids) - conn.execute( - f""" - UPDATE memories - SET access_count = access_count + 1, last_accessed = ? - WHERE id IN ({placeholders}) - """, - [now] + list(memory_ids), - ) - - def get_memories_bulk( - self, memory_ids: List[str], include_tombstoned: bool = False - ) -> Dict[str, Dict[str, Any]]: - """Fetch multiple memories by ID in a single query.""" - if not memory_ids: - return {} - with self._get_connection() as conn: - placeholders = ",".join("?" for _ in memory_ids) - query = f"SELECT * FROM memories WHERE id IN ({placeholders})" - if not include_tombstoned: - query += " AND tombstone = 0" - rows = conn.execute(query, memory_ids).fetchall() - return {row["id"]: self._row_to_dict(row) for row in rows} - - def update_strength_bulk(self, updates: Dict[str, float]) -> None: - """Batch-update strength for multiple memories.""" - if not updates: - return - now = _utcnow_iso() - with self._get_connection() as conn: - conn.executemany( - "UPDATE memories SET strength = ?, updated_at = ? WHERE id = ?", - [(strength, now, memory_id) for memory_id, strength in updates.items()], - ) - - _MEMORY_JSON_FIELDS = ("metadata", "categories", "related_memories", "source_memories") - - def _row_to_dict(self, row: sqlite3.Row, *, skip_embedding: bool = False) -> Dict[str, Any]: - data = dict(row) - for key in self._MEMORY_JSON_FIELDS: - if key in data and data[key]: - data[key] = json.loads(data[key]) - # Embedding is the largest JSON field (~30-50KB for 3072-dim vectors). - if skip_embedding: - data.pop("embedding", None) - elif "embedding" in data and data["embedding"]: - data["embedding"] = json.loads(data["embedding"]) - data["immutable"] = bool(data.get("immutable", 0)) - data["tombstone"] = bool(data.get("tombstone", 0)) - return data - - def _log_event(self, memory_id: str, event: str, **kwargs: Any) -> None: - with self._get_connection() as conn: - conn.execute( - """ - INSERT INTO memory_history ( - memory_id, event, old_value, new_value, - old_strength, new_strength, old_layer, new_layer - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - memory_id, - event, - kwargs.get("old_value"), - kwargs.get("new_value"), - kwargs.get("old_strength"), - kwargs.get("new_strength"), - kwargs.get("old_layer"), - kwargs.get("new_layer"), - ), - ) - - def log_event(self, memory_id: str, event: str, **kwargs: Any) -> None: - """Public wrapper for logging custom events like DECAY or FUSE.""" - self._log_event(memory_id, event, **kwargs) - - def get_history(self, memory_id: str) -> List[Dict[str, Any]]: - with self._get_connection() as conn: - rows = conn.execute( - "SELECT * FROM memory_history WHERE memory_id = ? ORDER BY timestamp DESC", - (memory_id,), - ).fetchall() - return [dict(row) for row in rows] - - def log_decay( - self, - decayed: int, - forgotten: int, - promoted: int, - storage_before_mb: Optional[float] = None, - storage_after_mb: Optional[float] = None, - ) -> None: - with self._get_connection() as conn: - conn.execute( - """ - INSERT INTO decay_log (memories_decayed, memories_forgotten, memories_promoted, storage_before_mb, storage_after_mb) - VALUES (?, ?, ?, ?, ?) - """, - (decayed, forgotten, promoted, storage_before_mb, storage_after_mb), - ) - - def purge_tombstoned(self) -> int: - """Permanently delete all tombstoned memories.""" - with self._get_connection() as conn: - rows = conn.execute( - "SELECT id, user_id, memory FROM memories WHERE tombstone = 1" - ).fetchall() - count = len(rows) - if count > 0: - for row in rows: - self._log_event(row["id"], "PURGE", old_value=row["memory"]) - conn.execute("DELETE FROM memories WHERE tombstone = 1") - return count - - -# Backward compatibility alias -SQLiteManager = CoreSQLiteManager - - -class FullSQLiteManager(CoreSQLiteManager): - def __init__(self, db_path: str): - self.db_path = db_path - db_dir = os.path.dirname(db_path) - if db_dir: - os.makedirs(db_dir, exist_ok=True) - # Phase 1: Persistent connection with WAL mode. - self._conn = sqlite3.connect(db_path, check_same_thread=False) - self._conn.execute("PRAGMA journal_mode=WAL") - self._conn.execute("PRAGMA busy_timeout=5000") - self._conn.execute("PRAGMA synchronous=FULL") - self._conn.execute("PRAGMA cache_size=-8000") # 8MB cache - self._conn.execute("PRAGMA temp_store=MEMORY") - self._conn.row_factory = sqlite3.Row - self._lock = threading.RLock() - self._init_db() - - def close(self) -> None: - """Close the persistent connection for clean shutdown.""" - with self._lock: - if self._conn: - try: - self._conn.close() - except Exception: - pass - self._conn = None # type: ignore[assignment] - - def __repr__(self) -> str: - return f"SQLiteManager(db_path={self.db_path!r})" - - def _init_db(self) -> None: - with self._get_connection() as conn: - conn.executescript( - """ - CREATE TABLE IF NOT EXISTS memories ( - id TEXT PRIMARY KEY, - memory TEXT NOT NULL, - user_id TEXT, - agent_id TEXT, - run_id TEXT, - app_id TEXT, - metadata TEXT DEFAULT '{}', - categories TEXT DEFAULT '[]', - immutable INTEGER DEFAULT 0, - expiration_date TEXT, - created_at TEXT DEFAULT CURRENT_TIMESTAMP, - updated_at TEXT DEFAULT CURRENT_TIMESTAMP, - layer TEXT DEFAULT 'sml' CHECK (layer IN ('sml', 'lml')), - strength REAL DEFAULT 1.0, - access_count INTEGER DEFAULT 0, - last_accessed TEXT DEFAULT CURRENT_TIMESTAMP, - embedding TEXT, - related_memories TEXT DEFAULT '[]', - source_memories TEXT DEFAULT '[]', - tombstone INTEGER DEFAULT 0 - ); - - CREATE INDEX IF NOT EXISTS idx_user_layer ON memories(user_id, layer); - CREATE INDEX IF NOT EXISTS idx_strength ON memories(strength DESC); - CREATE INDEX IF NOT EXISTS idx_tombstone ON memories(tombstone); - - CREATE TABLE IF NOT EXISTS memory_history ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - memory_id TEXT NOT NULL, - event TEXT NOT NULL, - old_value TEXT, - new_value TEXT, - old_strength REAL, - new_strength REAL, - old_layer TEXT, - new_layer TEXT, - timestamp TEXT DEFAULT CURRENT_TIMESTAMP - ); - - CREATE TABLE IF NOT EXISTS decay_log ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - run_at TEXT DEFAULT CURRENT_TIMESTAMP, - memories_decayed INTEGER, - memories_forgotten INTEGER, - memories_promoted INTEGER, - storage_before_mb REAL, - storage_after_mb REAL - ); - - -- CategoryMem tables - CREATE TABLE IF NOT EXISTS categories ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL, - description TEXT, - category_type TEXT DEFAULT 'dynamic', - parent_id TEXT, - children_ids TEXT DEFAULT '[]', - memory_count INTEGER DEFAULT 0, - total_strength REAL DEFAULT 0.0, - access_count INTEGER DEFAULT 0, - last_accessed TEXT, - created_at TEXT DEFAULT CURRENT_TIMESTAMP, - embedding TEXT, - keywords TEXT DEFAULT '[]', - summary TEXT, - summary_updated_at TEXT, - related_ids TEXT DEFAULT '[]', - strength REAL DEFAULT 1.0, - FOREIGN KEY (parent_id) REFERENCES categories(id) - ); - - CREATE INDEX IF NOT EXISTS idx_category_type ON categories(category_type); - CREATE INDEX IF NOT EXISTS idx_category_parent ON categories(parent_id); - CREATE INDEX IF NOT EXISTS idx_category_strength ON categories(strength DESC); - - -- Episodic scenes - CREATE TABLE IF NOT EXISTS scenes ( - id TEXT PRIMARY KEY, - user_id TEXT, - title TEXT, - summary TEXT, - topic TEXT, - location TEXT, - participants TEXT DEFAULT '[]', - memory_ids TEXT DEFAULT '[]', - start_time TEXT, - end_time TEXT, - embedding TEXT, - strength REAL DEFAULT 1.0, - access_count INTEGER DEFAULT 0, - tombstone INTEGER DEFAULT 0 - ); - - CREATE INDEX IF NOT EXISTS idx_scene_user ON scenes(user_id); - CREATE INDEX IF NOT EXISTS idx_scene_start ON scenes(start_time DESC); - - -- Scene-Memory junction - CREATE TABLE IF NOT EXISTS scene_memories ( - scene_id TEXT NOT NULL, - memory_id TEXT NOT NULL, - position INTEGER DEFAULT 0, - PRIMARY KEY (scene_id, memory_id), - FOREIGN KEY (scene_id) REFERENCES scenes(id), - FOREIGN KEY (memory_id) REFERENCES memories(id) - ); - - -- Character profiles - CREATE TABLE IF NOT EXISTS profiles ( - id TEXT PRIMARY KEY, - user_id TEXT, - name TEXT NOT NULL, - profile_type TEXT DEFAULT 'contact' CHECK (profile_type IN ('self', 'contact', 'entity')), - narrative TEXT, - facts TEXT DEFAULT '[]', - preferences TEXT DEFAULT '[]', - relationships TEXT DEFAULT '[]', - sentiment TEXT, - theory_of_mind TEXT DEFAULT '{}', - aliases TEXT DEFAULT '[]', - embedding TEXT, - strength REAL DEFAULT 1.0, - created_at TEXT DEFAULT CURRENT_TIMESTAMP, - updated_at TEXT DEFAULT CURRENT_TIMESTAMP - ); - - CREATE INDEX IF NOT EXISTS idx_profile_user ON profiles(user_id); - CREATE INDEX IF NOT EXISTS idx_profile_name ON profiles(name); - CREATE INDEX IF NOT EXISTS idx_profile_type ON profiles(profile_type); - - -- Profile-Memory junction - CREATE TABLE IF NOT EXISTS profile_memories ( - profile_id TEXT NOT NULL, - memory_id TEXT NOT NULL, - role TEXT DEFAULT 'mentioned' CHECK (role IN ('subject', 'mentioned', 'about')), - PRIMARY KEY (profile_id, memory_id), - FOREIGN KEY (profile_id) REFERENCES profiles(id), - FOREIGN KEY (memory_id) REFERENCES memories(id) - ); - """ - ) - # Legacy migration: add scene_id column to memories if missing. - self._migrate_add_column_conn(conn, "memories", "scene_id", "TEXT") - # v2 schema + idempotent migrations. - self._ensure_v2_schema(conn) - - @contextmanager - def _get_connection(self): - """Yield the persistent connection under the thread lock.""" - with self._lock: - try: - yield self._conn - self._conn.commit() - except Exception: - self._conn.rollback() - raise - - def _ensure_v2_schema(self, conn: sqlite3.Connection) -> None: - """Create and migrate Engram v2 schema in-place (idempotent).""" - conn.execute( - """ - CREATE TABLE IF NOT EXISTS schema_migrations ( - version TEXT PRIMARY KEY, - applied_at TEXT DEFAULT CURRENT_TIMESTAMP - ) - """ - ) - - migrations: Dict[str, str] = { - "v2_013": """ - CREATE TABLE IF NOT EXISTS distillation_provenance ( - id TEXT PRIMARY KEY, - semantic_memory_id TEXT NOT NULL, - episodic_memory_id TEXT NOT NULL, - distillation_run_id TEXT, - created_at TEXT DEFAULT CURRENT_TIMESTAMP - ); - CREATE INDEX IF NOT EXISTS idx_distill_prov_semantic ON distillation_provenance(semantic_memory_id); - CREATE INDEX IF NOT EXISTS idx_distill_prov_episodic ON distillation_provenance(episodic_memory_id); - CREATE INDEX IF NOT EXISTS idx_distill_prov_run ON distillation_provenance(distillation_run_id); - - CREATE TABLE IF NOT EXISTS distillation_log ( - id TEXT PRIMARY KEY, - run_at TEXT DEFAULT CURRENT_TIMESTAMP, - user_id TEXT, - episodes_sampled INTEGER DEFAULT 0, - semantic_created INTEGER DEFAULT 0, - semantic_deduplicated INTEGER DEFAULT 0, - errors INTEGER DEFAULT 0 - ); - CREATE INDEX IF NOT EXISTS idx_distill_log_user ON distillation_log(user_id, run_at DESC); - """, - } - - for version, ddl in migrations.items(): - if not self._is_migration_applied(conn, version): - conn.executescript(ddl) - conn.execute( - "INSERT OR IGNORE INTO schema_migrations (version) VALUES (?)", - (version,), - ) - - # Phase 3: Skip column migrations + backfills if already complete. - if self._is_migration_applied(conn, "v2_columns_complete"): - # CLS Distillation Memory columns (idempotent). - self._ensure_cls_columns(conn) - return - - # v2 columns on existing canonical tables. - self._migrate_add_column_conn(conn, "memories", "confidentiality_scope", "TEXT DEFAULT 'work'") - self._migrate_add_column_conn(conn, "memories", "source_type", "TEXT") - self._migrate_add_column_conn(conn, "memories", "source_app", "TEXT") - self._migrate_add_column_conn(conn, "memories", "source_event_id", "TEXT") - self._migrate_add_column_conn(conn, "memories", "decay_lambda", "REAL DEFAULT 0.12") - self._migrate_add_column_conn(conn, "memories", "status", "TEXT DEFAULT 'active'") - self._migrate_add_column_conn(conn, "memories", "importance", "REAL DEFAULT 0.5") - self._migrate_add_column_conn(conn, "memories", "sensitivity", "TEXT DEFAULT 'normal'") - self._migrate_add_column_conn(conn, "memories", "namespace", "TEXT DEFAULT 'default'") - - self._migrate_add_column_conn(conn, "scenes", "layer", "TEXT DEFAULT 'sml'") - self._migrate_add_column_conn(conn, "scenes", "scene_strength", "REAL DEFAULT 1.0") - self._migrate_add_column_conn(conn, "scenes", "topic_embedding_ref", "TEXT") - self._migrate_add_column_conn(conn, "scenes", "namespace", "TEXT DEFAULT 'default'") - - self._migrate_add_column_conn(conn, "profiles", "role_bias", "TEXT") - self._migrate_add_column_conn(conn, "profiles", "profile_summary", "TEXT") - - conn.execute( - """ - CREATE INDEX IF NOT EXISTS idx_memories_user_source_event - ON memories(user_id, source_event_id, namespace, created_at DESC) - """ - ) - - # Backfills. - conn.execute( - """ - UPDATE memories - SET confidentiality_scope = 'work' - WHERE confidentiality_scope IS NULL OR confidentiality_scope = '' - """ - ) - conn.execute( - """ - UPDATE memories - SET status = 'active' - WHERE status IS NULL OR status = '' - """ - ) - conn.execute( - """ - UPDATE memories - SET namespace = 'default' - WHERE namespace IS NULL OR namespace = '' - """ - ) - conn.execute( - """ - UPDATE scenes - SET namespace = 'default' - WHERE namespace IS NULL OR namespace = '' - """ - ) - conn.execute( - """ - UPDATE memories - SET decay_lambda = 0.12 - WHERE decay_lambda IS NULL - """ - ) - conn.execute( - """ - UPDATE memories - SET importance = COALESCE( - CASE - WHEN json_extract(metadata, '$.importance') IS NOT NULL - THEN json_extract(metadata, '$.importance') - ELSE importance - END, - 0.5 - ) - """ - ) - conn.execute( - """ - UPDATE memories - SET sensitivity = CASE - WHEN lower(memory) LIKE '%password%' OR lower(memory) LIKE '%api key%' OR lower(memory) LIKE '%token%' - THEN 'secret' - WHEN lower(memory) LIKE '%health%' OR lower(memory) LIKE '%medical%' - THEN 'sensitive' - WHEN lower(memory) LIKE '%bank%' OR lower(memory) LIKE '%salary%' OR lower(memory) LIKE '%credit card%' - THEN 'sensitive' - ELSE COALESCE(NULLIF(sensitivity, ''), 'normal') - END - """ - ) - - # Phase 3: Mark column migrations + backfills as complete. - conn.execute( - "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v2_columns_complete')" - ) - - # CLS Distillation Memory columns (idempotent). - self._ensure_cls_columns(conn) - - # Content-hash dedup column (idempotent). - self._ensure_content_hash_column(conn) - - def _ensure_content_hash_column(self, conn: sqlite3.Connection) -> None: - """Add content_hash column + index for SHA-256 dedup.""" - if self._is_migration_applied(conn, "v2_content_hash"): - return - self._migrate_add_column_conn(conn, "memories", "content_hash", "TEXT") - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_content_hash ON memories(content_hash, user_id)" - ) - conn.execute( - "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v2_content_hash')" - ) - - def _ensure_cls_columns(self, conn: sqlite3.Connection) -> None: - """Add CLS Distillation Memory columns to memories table (idempotent).""" - if self._is_migration_applied(conn, "v2_cls_columns_complete"): - return - - self._migrate_add_column_conn(conn, "memories", "memory_type", "TEXT DEFAULT 'semantic'") - self._migrate_add_column_conn(conn, "memories", "s_fast", "REAL DEFAULT NULL") - self._migrate_add_column_conn(conn, "memories", "s_mid", "REAL DEFAULT NULL") - self._migrate_add_column_conn(conn, "memories", "s_slow", "REAL DEFAULT NULL") - - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_memories_memory_type ON memories(memory_type, user_id)" - ) - - # Backfill: set memory_type to 'semantic' for existing memories. - conn.execute( - "UPDATE memories SET memory_type = 'semantic' WHERE memory_type IS NULL" - ) - - conn.execute( - "INSERT OR IGNORE INTO schema_migrations (version) VALUES ('v2_cls_columns_complete')" - ) - - def _is_migration_applied(self, conn: sqlite3.Connection, version: str) -> bool: - row = conn.execute( - "SELECT 1 FROM schema_migrations WHERE version = ?", - (version,), - ).fetchone() - return row is not None - - # Phase 5: Allowed table names for ALTER TABLE to prevent SQL injection. - _ALLOWED_TABLES = frozenset({ - "memories", "scenes", "profiles", "categories", - }) - - def _migrate_add_column_conn( - self, - conn: sqlite3.Connection, - table: str, - column: str, - col_type: str, - ) -> None: - """Add a column using an existing connection, if missing.""" - if table not in self._ALLOWED_TABLES: - raise ValueError(f"Invalid table for migration: {table!r}") - # Validate column name: must be alphanumeric/underscore only. - if not column.replace("_", "").isalnum(): - raise ValueError(f"Invalid column name: {column!r}") - try: - conn.execute(f"ALTER TABLE {table} ADD COLUMN {column} {col_type}") - except sqlite3.OperationalError: - pass - - def add_memory(self, memory_data: Dict[str, Any]) -> str: - memory_id = memory_data.get("id", str(uuid.uuid4())) - now = _utcnow_iso() - metadata = memory_data.get("metadata", {}) or {} - source_app = memory_data.get("source_app") or memory_data.get("app_id") or metadata.get("source_app") - - with self._get_connection() as conn: - conn.execute( - """ - INSERT INTO memories ( - id, memory, user_id, agent_id, run_id, app_id, - metadata, categories, immutable, expiration_date, - created_at, updated_at, layer, strength, access_count, - last_accessed, embedding, related_memories, source_memories, tombstone, - confidentiality_scope, namespace, source_type, source_app, source_event_id, decay_lambda, - status, importance, sensitivity, - memory_type, s_fast, s_mid, s_slow, content_hash - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - memory_id, - memory_data.get("memory", ""), - memory_data.get("user_id"), - memory_data.get("agent_id"), - memory_data.get("run_id"), - memory_data.get("app_id"), - json.dumps(memory_data.get("metadata", {})), - json.dumps(memory_data.get("categories", [])), - 1 if memory_data.get("immutable", False) else 0, - memory_data.get("expiration_date"), - memory_data.get("created_at", now), - memory_data.get("updated_at", now), - memory_data.get("layer", "sml"), - memory_data.get("strength", 1.0), - memory_data.get("access_count", 0), - memory_data.get("last_accessed", now), - json.dumps(memory_data.get("embedding", [])), - json.dumps(memory_data.get("related_memories", [])), - json.dumps(memory_data.get("source_memories", [])), - 1 if memory_data.get("tombstone", False) else 0, - memory_data.get("confidentiality_scope", "work"), - memory_data.get("namespace", metadata.get("namespace", "default")), - memory_data.get("source_type") or metadata.get("source_type") or "mcp", - source_app, - memory_data.get("source_event_id") or metadata.get("source_event_id"), - memory_data.get("decay_lambda", 0.12), - memory_data.get("status", "active"), - memory_data.get("importance", metadata.get("importance", 0.5)), - memory_data.get("sensitivity", metadata.get("sensitivity", "normal")), - memory_data.get("memory_type", "semantic"), - memory_data.get("s_fast"), - memory_data.get("s_mid"), - memory_data.get("s_slow"), - memory_data.get("content_hash"), - ), - ) - - # Log within the same transaction -- atomic with the insert. - conn.execute( - """ - INSERT INTO memory_history ( - memory_id, event, old_value, new_value, - old_strength, new_strength, old_layer, new_layer - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - (memory_id, "ADD", None, memory_data.get("memory"), None, None, None, None), - ) - - return memory_id - - def add_memories_batch(self, memories: List[Dict[str, Any]]) -> List[str]: - """Insert multiple memories in a single transaction (atomic). - - Returns list of memory IDs in the same order as input. - """ - if not memories: - return [] - now = _utcnow_iso() - ids: List[str] = [] - insert_rows = [] - history_rows = [] - - for memory_data in memories: - memory_id = memory_data.get("id", str(uuid.uuid4())) - ids.append(memory_id) - metadata = memory_data.get("metadata", {}) or {} - source_app = memory_data.get("source_app") or memory_data.get("app_id") or metadata.get("source_app") - - insert_rows.append(( - memory_id, - memory_data.get("memory", ""), - memory_data.get("user_id"), - memory_data.get("agent_id"), - memory_data.get("run_id"), - memory_data.get("app_id"), - json.dumps(memory_data.get("metadata", {})), - json.dumps(memory_data.get("categories", [])), - 1 if memory_data.get("immutable", False) else 0, - memory_data.get("expiration_date"), - memory_data.get("created_at", now), - memory_data.get("updated_at", now), - memory_data.get("layer", "sml"), - memory_data.get("strength", 1.0), - memory_data.get("access_count", 0), - memory_data.get("last_accessed", now), - json.dumps(memory_data.get("embedding", [])), - json.dumps(memory_data.get("related_memories", [])), - json.dumps(memory_data.get("source_memories", [])), - 1 if memory_data.get("tombstone", False) else 0, - memory_data.get("confidentiality_scope", "work"), - memory_data.get("namespace", metadata.get("namespace", "default")), - memory_data.get("source_type") or metadata.get("source_type") or "mcp", - source_app, - memory_data.get("source_event_id") or metadata.get("source_event_id"), - memory_data.get("decay_lambda", 0.12), - memory_data.get("status", "active"), - memory_data.get("importance", metadata.get("importance", 0.5)), - memory_data.get("sensitivity", metadata.get("sensitivity", "normal")), - memory_data.get("memory_type", "semantic"), - memory_data.get("s_fast"), - memory_data.get("s_mid"), - memory_data.get("s_slow"), - )) - history_rows.append(( - memory_id, "ADD", None, memory_data.get("memory"), None, None, None, None, - )) - - with self._get_connection() as conn: - conn.executemany( - """ - INSERT INTO memories ( - id, memory, user_id, agent_id, run_id, app_id, - metadata, categories, immutable, expiration_date, - created_at, updated_at, layer, strength, access_count, - last_accessed, embedding, related_memories, source_memories, tombstone, - confidentiality_scope, namespace, source_type, source_app, source_event_id, decay_lambda, - status, importance, sensitivity, - memory_type, s_fast, s_mid, s_slow - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - insert_rows, - ) - conn.executemany( - """ - INSERT INTO memory_history ( - memory_id, event, old_value, new_value, - old_strength, new_strength, old_layer, new_layer - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - history_rows, - ) - - return ids - - def get_memory(self, memory_id: str, include_tombstoned: bool = False) -> Optional[Dict[str, Any]]: - query = "SELECT * FROM memories WHERE id = ?" - params = [memory_id] - if not include_tombstoned: - query += " AND tombstone = 0" - - with self._get_connection() as conn: - row = conn.execute(query, params).fetchone() - if row: - return self._row_to_dict(row) - return None - - def get_memory_by_content_hash( - self, content_hash: str, user_id: str = "default" - ) -> Optional[Dict[str, Any]]: - """Find an existing memory by content hash (for deduplication).""" - with self._get_connection() as conn: - row = conn.execute( - "SELECT * FROM memories WHERE content_hash = ? AND user_id = ? AND tombstone = 0 LIMIT 1", - (content_hash, user_id), - ).fetchone() - if row: - return self._row_to_dict(row) - return None - - def get_memory_by_source_event( - self, - *, - user_id: str, - source_event_id: str, - namespace: Optional[str] = None, - source_app: Optional[str] = None, - include_tombstoned: bool = False, - ) -> Optional[Dict[str, Any]]: - normalized_event = str(source_event_id or "").strip() - if not normalized_event: - return None - query = """ - SELECT * - FROM memories - WHERE user_id = ? - AND source_event_id = ? - """ - params: List[Any] = [user_id, normalized_event] - if namespace: - query += " AND namespace = ?" - params.append(namespace) - if source_app: - query += " AND source_app = ?" - params.append(source_app) - if not include_tombstoned: - query += " AND tombstone = 0" - query += " ORDER BY created_at DESC LIMIT 1" - - with self._get_connection() as conn: - row = conn.execute(query, params).fetchone() - if row: - return self._row_to_dict(row) - return None - - def get_all_memories( - self, - *, - user_id: Optional[str] = None, - agent_id: Optional[str] = None, - run_id: Optional[str] = None, - app_id: Optional[str] = None, - layer: Optional[str] = None, - namespace: Optional[str] = None, - memory_type: Optional[str] = None, - min_strength: float = 0.0, - include_tombstoned: bool = False, - created_after: Optional[str] = None, - created_before: Optional[str] = None, - limit: Optional[int] = None, - ) -> List[Dict[str, Any]]: - query = "SELECT * FROM memories WHERE strength >= ?" - params: List[Any] = [min_strength] - - if not include_tombstoned: - query += " AND tombstone = 0" - if memory_type: - query += " AND memory_type = ?" - params.append(memory_type) - if user_id: - query += " AND user_id = ?" - params.append(user_id) - if agent_id: - query += " AND agent_id = ?" - params.append(agent_id) - if run_id: - query += " AND run_id = ?" - params.append(run_id) - if app_id: - query += " AND app_id = ?" - params.append(app_id) - if layer: - query += " AND layer = ?" - params.append(layer) - if namespace: - query += " AND namespace = ?" - params.append(namespace) - if created_after: - query += " AND created_at >= ?" - params.append(created_after) - if created_before: - query += " AND created_at <= ?" - params.append(created_before) - - query += " ORDER BY strength DESC" - - # Apply SQL-level LIMIT to avoid fetching unbounded rows into memory. - if limit is not None and limit > 0: - query += " LIMIT ?" - params.append(limit) - - with self._get_connection() as conn: - rows = conn.execute(query, params).fetchall() - return [self._row_to_dict(row) for row in rows] - - def update_memory(self, memory_id: str, updates: Dict[str, Any]) -> bool: - set_clauses = [] - params: List[Any] = [] - for key, value in updates.items(): - if key not in VALID_MEMORY_COLUMNS: - raise ValueError(f"Invalid memory column: {key!r}") - if key in {"metadata", "categories", "embedding", "related_memories", "source_memories"}: - value = json.dumps(value) - set_clauses.append(f"{key} = ?") - params.append(value) - - set_clauses.append("updated_at = ?") - params.append(_utcnow_iso()) - params.append(memory_id) - - with self._get_connection() as conn: - # Read old values and update in a single transaction. - old_row = conn.execute( - "SELECT memory, strength, layer FROM memories WHERE id = ?", - (memory_id,), - ).fetchone() - if not old_row: - return False - - conn.execute( - f"UPDATE memories SET {', '.join(set_clauses)} WHERE id = ?", - params, - ) - - # Log within the same transaction. - conn.execute( - """ - INSERT INTO memory_history ( - memory_id, event, old_value, new_value, - old_strength, new_strength, old_layer, new_layer - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - memory_id, - "UPDATE", - old_row["memory"], - updates.get("memory"), - old_row["strength"], - updates.get("strength"), - old_row["layer"], - updates.get("layer"), - ), - ) - return True - - def delete_memory(self, memory_id: str, use_tombstone: bool = True) -> bool: - if use_tombstone: - return self.update_memory(memory_id, {"tombstone": 1}) - with self._get_connection() as conn: - conn.execute("DELETE FROM memories WHERE id = ?", (memory_id,)) - self._log_event(memory_id, "DELETE") - return True - - def increment_access(self, memory_id: str) -> None: - now = _utcnow_iso() - with self._get_connection() as conn: - conn.execute( - """ - UPDATE memories - SET access_count = access_count + 1, last_accessed = ? - WHERE id = ? - """, - (now, memory_id), - ) - - # Phase 2: Batch operations to eliminate N+1 queries in search. - - def get_memories_bulk(self, memory_ids: List[str], include_tombstoned: bool = False) -> Dict[str, Dict[str, Any]]: - """Fetch multiple memories by ID in a single query. Returns {id: memory_dict}.""" - if not memory_ids: - return {} - with self._get_connection() as conn: - placeholders = ",".join("?" for _ in memory_ids) - query = f"SELECT * FROM memories WHERE id IN ({placeholders})" - if not include_tombstoned: - query += " AND tombstone = 0" - rows = conn.execute(query, memory_ids).fetchall() - return {row["id"]: self._row_to_dict(row) for row in rows} - - def increment_access_bulk(self, memory_ids: List[str]) -> None: - """Increment access count for multiple memories in a single transaction.""" - if not memory_ids: - return - now = _utcnow_iso() - with self._get_connection() as conn: - placeholders = ",".join("?" for _ in memory_ids) - conn.execute( - f""" - UPDATE memories - SET access_count = access_count + 1, last_accessed = ? - WHERE id IN ({placeholders}) - """, - [now] + list(memory_ids), - ) - - def update_strength_bulk(self, updates: Dict[str, float]) -> None: - """Batch-update strength for multiple memories. updates = {memory_id: new_strength}.""" - if not updates: - return - now = _utcnow_iso() - with self._get_connection() as conn: - conn.executemany( - "UPDATE memories SET strength = ?, updated_at = ? WHERE id = ?", - [(strength, now, memory_id) for memory_id, strength in updates.items()], - ) - - _MEMORY_JSON_FIELDS = ("metadata", "categories", "related_memories", "source_memories") - - def _row_to_dict(self, row: sqlite3.Row, *, skip_embedding: bool = False) -> Dict[str, Any]: - data = dict(row) - for key in self._MEMORY_JSON_FIELDS: - if key in data and data[key]: - data[key] = json.loads(data[key]) - # Embedding is the largest JSON field (~30-50KB for 3072-dim vectors). - # Skip deserialization when the caller doesn't need it. - if skip_embedding: - data.pop("embedding", None) - elif "embedding" in data and data["embedding"]: - data["embedding"] = json.loads(data["embedding"]) - data["immutable"] = bool(data.get("immutable", 0)) - data["tombstone"] = bool(data.get("tombstone", 0)) - return data - - def _log_event(self, memory_id: str, event: str, **kwargs: Any) -> None: - with self._get_connection() as conn: - conn.execute( - """ - INSERT INTO memory_history ( - memory_id, event, old_value, new_value, - old_strength, new_strength, old_layer, new_layer - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - memory_id, - event, - kwargs.get("old_value"), - kwargs.get("new_value"), - kwargs.get("old_strength"), - kwargs.get("new_strength"), - kwargs.get("old_layer"), - kwargs.get("new_layer"), - ), - ) - - def log_event(self, memory_id: str, event: str, **kwargs: Any) -> None: - """Public wrapper for logging custom events like DECAY or FUSE.""" - self._log_event(memory_id, event, **kwargs) - - def get_history(self, memory_id: str) -> List[Dict[str, Any]]: - with self._get_connection() as conn: - rows = conn.execute( - "SELECT * FROM memory_history WHERE memory_id = ? ORDER BY timestamp DESC", - (memory_id,), - ).fetchall() - return [dict(row) for row in rows] - - def log_decay(self, decayed: int, forgotten: int, promoted: int, storage_before_mb: Optional[float] = None, storage_after_mb: Optional[float] = None) -> None: - with self._get_connection() as conn: - conn.execute( - """ - INSERT INTO decay_log (memories_decayed, memories_forgotten, memories_promoted, storage_before_mb, storage_after_mb) - VALUES (?, ?, ?, ?, ?) - """, - (decayed, forgotten, promoted, storage_before_mb, storage_after_mb), - ) - - def purge_tombstoned(self) -> int: - """Permanently delete all tombstoned memories. This is IRREVERSIBLE.""" - with self._get_connection() as conn: - # Log what will be purged before deletion for audit trail. - rows = conn.execute( - "SELECT id, user_id, memory FROM memories WHERE tombstone = 1" - ).fetchall() - count = len(rows) - if count > 0: - ids = [row["id"] for row in rows] - logger.warning( - "purge_tombstoned: permanently deleting %d memories: %s", - count, - ids, - ) - for row in rows: - conn.execute( - """INSERT INTO memory_history (memory_id, event, old_value, new_value, - old_strength, new_strength, old_layer, new_layer) - VALUES (?, ?, ?, NULL, NULL, NULL, NULL, NULL)""", - (row["id"], "PURGE", row["memory"]), - ) - conn.execute("DELETE FROM memories WHERE tombstone = 1") - return count - - # CLS Distillation Memory helpers - - def get_episodic_memories( - self, - user_id: str, - *, - scene_id: Optional[str] = None, - created_after: Optional[str] = None, - created_before: Optional[str] = None, - limit: int = 100, - namespace: Optional[str] = None, - ) -> List[Dict[str, Any]]: - """Fetch episodic-type memories for a user, optionally filtered by scene/time.""" - query = "SELECT * FROM memories WHERE user_id = ? AND memory_type = 'episodic' AND tombstone = 0" - params: List[Any] = [user_id] - if scene_id: - query += " AND scene_id = ?" - params.append(scene_id) - if created_after: - query += " AND created_at >= ?" - params.append(created_after) - if created_before: - query += " AND created_at <= ?" - params.append(created_before) - if namespace: - query += " AND namespace = ?" - params.append(namespace) - query += " ORDER BY created_at DESC LIMIT ?" - params.append(limit) - with self._get_connection() as conn: - rows = conn.execute(query, params).fetchall() - return [self._row_to_dict(row) for row in rows] - - def add_distillation_provenance( - self, - semantic_memory_id: str, - episodic_memory_ids: List[str], - run_id: str, - ) -> None: - """Record which episodic memories contributed to a distilled semantic memory.""" - with self._get_connection() as conn: - for ep_id in episodic_memory_ids: - conn.execute( - """ - INSERT INTO distillation_provenance (id, semantic_memory_id, episodic_memory_id, distillation_run_id) - VALUES (?, ?, ?, ?) - """, - (str(uuid.uuid4()), semantic_memory_id, ep_id, run_id), - ) - - def log_distillation_run( - self, - user_id: str, - episodes_sampled: int, - semantic_created: int, - semantic_deduplicated: int = 0, - errors: int = 0, - ) -> str: - """Log a distillation run and return the run ID.""" - run_id = str(uuid.uuid4()) - with self._get_connection() as conn: - conn.execute( - """ - INSERT INTO distillation_log (id, user_id, episodes_sampled, semantic_created, semantic_deduplicated, errors) - VALUES (?, ?, ?, ?, ?, ?) - """, - (run_id, user_id, episodes_sampled, semantic_created, semantic_deduplicated, errors), - ) - return run_id - - def get_memory_count_by_namespace(self, user_id: str) -> Dict[str, int]: - """Return {namespace: count} for active memories of a user.""" - with self._get_connection() as conn: - rows = conn.execute( - """ - SELECT COALESCE(namespace, 'default') AS ns, COUNT(*) AS cnt - FROM memories - WHERE user_id = ? AND tombstone = 0 - GROUP BY ns - """, - (user_id,), - ).fetchall() - return {row["ns"]: row["cnt"] for row in rows} - - def update_multi_trace( - self, - memory_id: str, - s_fast: float, - s_mid: float, - s_slow: float, - effective_strength: float, - ) -> bool: - """Update multi-trace columns and effective strength for a memory.""" - return self.update_memory(memory_id, { - "s_fast": s_fast, - "s_mid": s_mid, - "s_slow": s_slow, - "strength": effective_strength, - }) - - # CategoryMem methods - def save_category(self, category_data: Dict[str, Any]) -> str: - """Save or update a category.""" - category_id = category_data.get("id") - if not category_id: - return "" - - with self._get_connection() as conn: - conn.execute( - """ - INSERT OR REPLACE INTO categories ( - id, name, description, category_type, parent_id, - children_ids, memory_count, total_strength, access_count, - last_accessed, created_at, embedding, keywords, - summary, summary_updated_at, related_ids, strength - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - category_id, - category_data.get("name", ""), - category_data.get("description", ""), - category_data.get("category_type", "dynamic"), - category_data.get("parent_id"), - json.dumps(category_data.get("children_ids", [])), - category_data.get("memory_count", 0), - category_data.get("total_strength", 0.0), - category_data.get("access_count", 0), - category_data.get("last_accessed"), - category_data.get("created_at"), - json.dumps(category_data.get("embedding")) if category_data.get("embedding") else None, - json.dumps(category_data.get("keywords", [])), - category_data.get("summary"), - category_data.get("summary_updated_at"), - json.dumps(category_data.get("related_ids", [])), - category_data.get("strength", 1.0), - ), - ) - return category_id - - def get_category(self, category_id: str) -> Optional[Dict[str, Any]]: - """Get a category by ID.""" - with self._get_connection() as conn: - row = conn.execute( - "SELECT * FROM categories WHERE id = ?", - (category_id,) - ).fetchone() - if row: - return self._category_row_to_dict(row) - return None - - def get_all_categories(self) -> List[Dict[str, Any]]: - """Get all categories.""" - with self._get_connection() as conn: - rows = conn.execute( - "SELECT * FROM categories ORDER BY strength DESC" - ).fetchall() - return [self._category_row_to_dict(row) for row in rows] - - def delete_category(self, category_id: str) -> bool: - """Delete a category.""" - with self._get_connection() as conn: - conn.execute("DELETE FROM categories WHERE id = ?", (category_id,)) - return True - - def save_all_categories(self, categories: List[Dict[str, Any]]) -> int: - """Save multiple categories in a single transaction for performance.""" - if not categories: - return 0 - rows = [] - for cat in categories: - cat_id = cat.get("id") - if not cat_id: - continue - rows.append(( - cat_id, - cat.get("name", ""), - cat.get("description", ""), - cat.get("category_type", "dynamic"), - cat.get("parent_id"), - json.dumps(cat.get("children_ids", [])), - cat.get("memory_count", 0), - cat.get("total_strength", 0.0), - cat.get("access_count", 0), - cat.get("last_accessed"), - cat.get("created_at"), - json.dumps(cat.get("embedding")) if cat.get("embedding") else None, - json.dumps(cat.get("keywords", [])), - cat.get("summary"), - cat.get("summary_updated_at"), - json.dumps(cat.get("related_ids", [])), - cat.get("strength", 1.0), - )) - if not rows: - return 0 - with self._get_connection() as conn: - conn.executemany( - """ - INSERT OR REPLACE INTO categories ( - id, name, description, category_type, parent_id, - children_ids, memory_count, total_strength, access_count, - last_accessed, created_at, embedding, keywords, - summary, summary_updated_at, related_ids, strength - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - rows, - ) - return len(rows) - - def _category_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: - """Convert a category row to dict.""" - data = dict(row) - for key in ["children_ids", "keywords", "related_ids"]: - if key in data and data[key]: - data[key] = json.loads(data[key]) - else: - data[key] = [] - if data.get("embedding"): - data["embedding"] = json.loads(data["embedding"]) - return data - - def _migrate_add_column(self, table: str, column: str, col_type: str) -> None: - """Add a column to an existing table if it doesn't already exist.""" - with self._get_connection() as conn: - self._migrate_add_column_conn(conn, table, column, col_type) - - # ========================================================================= - # Scene methods - # ========================================================================= - - def add_scene(self, scene_data: Dict[str, Any]) -> str: - scene_id = scene_data.get("id", str(uuid.uuid4())) - with self._get_connection() as conn: - conn.execute( - """ - INSERT INTO scenes ( - id, user_id, title, summary, topic, location, - participants, memory_ids, start_time, end_time, - embedding, strength, access_count, tombstone, - layer, scene_strength, topic_embedding_ref, namespace - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - scene_id, - scene_data.get("user_id"), - scene_data.get("title"), - scene_data.get("summary"), - scene_data.get("topic"), - scene_data.get("location"), - json.dumps(scene_data.get("participants", [])), - json.dumps(scene_data.get("memory_ids", [])), - scene_data.get("start_time"), - scene_data.get("end_time"), - json.dumps(scene_data.get("embedding")) if scene_data.get("embedding") else None, - scene_data.get("strength", 1.0), - scene_data.get("access_count", 0), - 1 if scene_data.get("tombstone", False) else 0, - scene_data.get("layer", "sml"), - scene_data.get("scene_strength", scene_data.get("strength", 1.0)), - scene_data.get("topic_embedding_ref"), - scene_data.get("namespace", "default"), - ), - ) - return scene_id - - def get_scene(self, scene_id: str) -> Optional[Dict[str, Any]]: - with self._get_connection() as conn: - row = conn.execute( - "SELECT * FROM scenes WHERE id = ? AND tombstone = 0", (scene_id,) - ).fetchone() - if row: - return self._scene_row_to_dict(row) - return None - - def update_scene(self, scene_id: str, updates: Dict[str, Any]) -> bool: - set_clauses = [] - params: List[Any] = [] - for key, value in updates.items(): - if key not in VALID_SCENE_COLUMNS: - raise ValueError(f"Invalid scene column: {key!r}") - if key in {"participants", "memory_ids", "embedding"}: - value = json.dumps(value) - set_clauses.append(f"{key} = ?") - params.append(value) - if not set_clauses: - return False - params.append(scene_id) - with self._get_connection() as conn: - conn.execute( - f"UPDATE scenes SET {', '.join(set_clauses)} WHERE id = ?", - params, - ) - return True - - def get_open_scene(self, user_id: str) -> Optional[Dict[str, Any]]: - """Get the most recent scene without an end_time for a user.""" - with self._get_connection() as conn: - row = conn.execute( - """ - SELECT * FROM scenes - WHERE user_id = ? AND end_time IS NULL AND tombstone = 0 - ORDER BY start_time DESC LIMIT 1 - """, - (user_id,), - ).fetchone() - if row: - return self._scene_row_to_dict(row) - return None - - def get_scenes( - self, - user_id: Optional[str] = None, - topic: Optional[str] = None, - start_after: Optional[str] = None, - start_before: Optional[str] = None, - namespace: Optional[str] = None, - limit: int = 50, - ) -> List[Dict[str, Any]]: - query = "SELECT * FROM scenes WHERE tombstone = 0" - params: List[Any] = [] - if user_id: - query += " AND user_id = ?" - params.append(user_id) - if topic: - query += " AND topic LIKE ?" - params.append(f"%{topic}%") - if start_after: - query += " AND start_time >= ?" - params.append(start_after) - if start_before: - query += " AND start_time <= ?" - params.append(start_before) - if namespace: - query += " AND namespace = ?" - params.append(namespace) - query += " ORDER BY start_time DESC LIMIT ?" - params.append(limit) - with self._get_connection() as conn: - rows = conn.execute(query, params).fetchall() - return [self._scene_row_to_dict(row) for row in rows] - - def add_scene_memory(self, scene_id: str, memory_id: str, position: int = 0) -> None: - with self._get_connection() as conn: - conn.execute( - "INSERT OR IGNORE INTO scene_memories (scene_id, memory_id, position) VALUES (?, ?, ?)", - (scene_id, memory_id, position), - ) - - def get_scene_memories(self, scene_id: str) -> List[Dict[str, Any]]: - with self._get_connection() as conn: - rows = conn.execute( - """ - SELECT m.* FROM memories m - JOIN scene_memories sm ON m.id = sm.memory_id - WHERE sm.scene_id = ? AND m.tombstone = 0 - ORDER BY sm.position - """, - (scene_id,), - ).fetchall() - return [self._row_to_dict(row) for row in rows] - - def _scene_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: - data = dict(row) - for key in ["participants", "memory_ids"]: - if key in data and data[key]: - data[key] = json.loads(data[key]) - else: - data[key] = [] - if data.get("embedding"): - data["embedding"] = json.loads(data["embedding"]) - data["tombstone"] = bool(data.get("tombstone", 0)) - return data - - # ========================================================================= - # Profile methods - # ========================================================================= - - def add_profile(self, profile_data: Dict[str, Any]) -> str: - profile_id = profile_data.get("id", str(uuid.uuid4())) - now = _utcnow_iso() - with self._get_connection() as conn: - conn.execute( - """ - INSERT INTO profiles ( - id, user_id, name, profile_type, narrative, - facts, preferences, relationships, sentiment, - theory_of_mind, aliases, embedding, strength, - created_at, updated_at, role_bias, profile_summary - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - profile_id, - profile_data.get("user_id"), - profile_data.get("name", ""), - profile_data.get("profile_type", "contact"), - profile_data.get("narrative"), - json.dumps(profile_data.get("facts", [])), - json.dumps(profile_data.get("preferences", [])), - json.dumps(profile_data.get("relationships", [])), - profile_data.get("sentiment"), - json.dumps(profile_data.get("theory_of_mind", {})), - json.dumps(profile_data.get("aliases", [])), - json.dumps(profile_data.get("embedding")) if profile_data.get("embedding") else None, - profile_data.get("strength", 1.0), - profile_data.get("created_at", now), - profile_data.get("updated_at", now), - profile_data.get("role_bias"), - profile_data.get("profile_summary"), - ), - ) - return profile_id - - def get_profile(self, profile_id: str) -> Optional[Dict[str, Any]]: - with self._get_connection() as conn: - row = conn.execute( - "SELECT * FROM profiles WHERE id = ?", (profile_id,) - ).fetchone() - if row: - return self._profile_row_to_dict(row) - return None - - def update_profile(self, profile_id: str, updates: Dict[str, Any]) -> bool: - set_clauses = [] - params: List[Any] = [] - for key, value in updates.items(): - if key not in VALID_PROFILE_COLUMNS: - raise ValueError(f"Invalid profile column: {key!r}") - if key in {"facts", "preferences", "relationships", "aliases", "theory_of_mind", "embedding"}: - value = json.dumps(value) - set_clauses.append(f"{key} = ?") - params.append(value) - set_clauses.append("updated_at = ?") - params.append(_utcnow_iso()) - params.append(profile_id) - with self._get_connection() as conn: - conn.execute( - f"UPDATE profiles SET {', '.join(set_clauses)} WHERE id = ?", - params, - ) - return True - - def get_all_profiles(self, user_id: Optional[str] = None) -> List[Dict[str, Any]]: - query = "SELECT * FROM profiles" - params: List[Any] = [] - if user_id: - query += " WHERE user_id = ?" - params.append(user_id) - query += " ORDER BY strength DESC" - with self._get_connection() as conn: - rows = conn.execute(query, params).fetchall() - return [self._profile_row_to_dict(row) for row in rows] - - def get_profile_by_name(self, name: str, user_id: Optional[str] = None) -> Optional[Dict[str, Any]]: - """Find a profile by exact name match, then fall back to alias scan.""" - # Fast path: exact name match via indexed column. - query = "SELECT * FROM profiles WHERE lower(name) = ?" - params: List[Any] = [name.lower()] - if user_id: - query += " AND user_id = ?" - params.append(user_id) - query += " LIMIT 1" - with self._get_connection() as conn: - row = conn.execute(query, params).fetchone() - if row: - return self._profile_row_to_dict(row) - # Slow path: alias scan (aliases stored as JSON, can't index). - alias_query = "SELECT * FROM profiles WHERE aliases LIKE ?" - alias_params: List[Any] = [f'%"{name}"%'] - if user_id: - alias_query += " AND user_id = ?" - alias_params.append(user_id) - alias_query += " LIMIT 1" - row = conn.execute(alias_query, alias_params).fetchone() - if row: - result = self._profile_row_to_dict(row) - # Verify case-insensitive alias match. - if name.lower() in [a.lower() for a in result.get("aliases", [])]: - return result - return None - - def find_profile_by_substring(self, name: str, user_id: Optional[str] = None) -> Optional[Dict[str, Any]]: - """Find a profile where the name contains the query as a substring (case-insensitive).""" - query = "SELECT * FROM profiles WHERE lower(name) LIKE ?" - params: List[Any] = [f"%{name.lower()}%"] - if user_id: - query += " AND user_id = ?" - params.append(user_id) - query += " ORDER BY strength DESC LIMIT 1" - with self._get_connection() as conn: - row = conn.execute(query, params).fetchone() - if row: - return self._profile_row_to_dict(row) - return None - - def add_profile_memory(self, profile_id: str, memory_id: str, role: str = "mentioned") -> None: - with self._get_connection() as conn: - conn.execute( - "INSERT OR IGNORE INTO profile_memories (profile_id, memory_id, role) VALUES (?, ?, ?)", - (profile_id, memory_id, role), - ) - - def get_profile_memories(self, profile_id: str) -> List[Dict[str, Any]]: - with self._get_connection() as conn: - rows = conn.execute( - """ - SELECT m.*, pm.role AS profile_role FROM memories m - JOIN profile_memories pm ON m.id = pm.memory_id - WHERE pm.profile_id = ? AND m.tombstone = 0 - ORDER BY m.created_at DESC - """, - (profile_id,), - ).fetchall() - return [self._row_to_dict(row) for row in rows] - - def _profile_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: - data = dict(row) - for key in ["facts", "preferences", "relationships", "aliases"]: - if key in data and data[key]: - data[key] = json.loads(data[key]) - else: - data[key] = [] - if data.get("theory_of_mind"): - data["theory_of_mind"] = json.loads(data["theory_of_mind"]) - else: - data["theory_of_mind"] = {} - if data.get("embedding"): - data["embedding"] = json.loads(data["embedding"]) - return data - - def get_memories_by_category( - self, - category_id: str, - limit: int = 100, - min_strength: float = 0.0, - ) -> List[Dict[str, Any]]: - """Get memories belonging to a specific category.""" - with self._get_connection() as conn: - rows = conn.execute( - """ - SELECT * FROM memories - WHERE categories LIKE ? AND strength >= ? AND tombstone = 0 - ORDER BY strength DESC - LIMIT ? - """, - (f'%"{category_id}"%', min_strength, limit), - ).fetchall() - return [self._row_to_dict(row) for row in rows] - - # ========================================================================= - # User ID listing - # ========================================================================= - - def list_user_ids(self) -> List[str]: - with self._get_connection() as conn: - rows = conn.execute( - """ - SELECT DISTINCT user_id FROM memories - WHERE user_id IS NOT NULL AND user_id != '' - ORDER BY user_id - """ - ).fetchall() - return [str(row["user_id"]) for row in rows if row["user_id"]] - - # ========================================================================= - # Dashboard / Visualization methods - # ========================================================================= - - def get_constellation_data(self, user_id: Optional[str] = None, limit: int = 200) -> Dict[str, Any]: - """Build graph data for the constellation visualizer.""" - with self._get_connection() as conn: - # Nodes: memories - mem_query = "SELECT id, memory, strength, layer, categories, created_at FROM memories WHERE tombstone = 0" - params: List[Any] = [] - if user_id: - mem_query += " AND user_id = ?" - params.append(user_id) - mem_query += " ORDER BY strength DESC LIMIT ?" - params.append(limit) - mem_rows = conn.execute(mem_query, params).fetchall() - - nodes = [] - node_ids = set() - for row in mem_rows: - cats = row["categories"] - if cats: - try: - cats = json.loads(cats) - except Exception: - cats = [] - else: - cats = [] - nodes.append({ - "id": row["id"], - "memory": (row["memory"] or "")[:120], - "strength": row["strength"], - "layer": row["layer"], - "categories": cats, - "created_at": row["created_at"], - }) - node_ids.add(row["id"]) - - # Edges from scene_memories (memories sharing a scene) - edges: List[Dict[str, Any]] = [] - if node_ids: - placeholders = ",".join("?" for _ in node_ids) - scene_rows = conn.execute( - f""" - SELECT a.memory_id AS source, b.memory_id AS target, a.scene_id - FROM scene_memories a - JOIN scene_memories b ON a.scene_id = b.scene_id AND a.memory_id < b.memory_id - WHERE a.memory_id IN ({placeholders}) AND b.memory_id IN ({placeholders}) - """, - list(node_ids) + list(node_ids), - ).fetchall() - for row in scene_rows: - edges.append({"source": row["source"], "target": row["target"], "type": "scene"}) - - # Edges from profile_memories (memories sharing a profile) - profile_rows = conn.execute( - f""" - SELECT a.memory_id AS source, b.memory_id AS target, a.profile_id - FROM profile_memories a - JOIN profile_memories b ON a.profile_id = b.profile_id AND a.memory_id < b.memory_id - WHERE a.memory_id IN ({placeholders}) AND b.memory_id IN ({placeholders}) - """, - list(node_ids) + list(node_ids), - ).fetchall() - for row in profile_rows: - edges.append({"source": row["source"], "target": row["target"], "type": "profile"}) - - return {"nodes": nodes, "edges": edges} - - def get_decay_log_entries(self, limit: int = 20) -> List[Dict[str, Any]]: - """Return recent decay log entries for the dashboard sparkline.""" - with self._get_connection() as conn: - rows = conn.execute( - "SELECT * FROM decay_log ORDER BY run_at DESC LIMIT ?", - (limit,), - ).fetchall() - return [dict(row) for row in rows] - - # ========================================================================= - # Utilities - # ========================================================================= - - @staticmethod - def _parse_json_value(value: Any, default: Any) -> Any: - if value is None: - return default - if isinstance(value, (dict, list)): - return value - try: - return json.loads(value) - except Exception: - return default diff --git a/dhee/db/sqlite_common.py b/dhee/db/sqlite_common.py new file mode 100644 index 0000000..ef42811 --- /dev/null +++ b/dhee/db/sqlite_common.py @@ -0,0 +1,35 @@ +from datetime import datetime, timezone + + +VALID_MEMORY_COLUMNS = frozenset({ + "memory", "metadata", "categories", "embedding", "strength", + "layer", "tombstone", "updated_at", "related_memories", "source_memories", + "confidentiality_scope", "source_type", "source_app", "source_event_id", + "decay_lambda", "status", "importance", "sensitivity", "namespace", + "access_count", "last_accessed", "immutable", "expiration_date", + "scene_id", "user_id", "agent_id", "run_id", "app_id", + "memory_type", "s_fast", "s_mid", "s_slow", "content_hash", + "conversation_context", "enrichment_status", +}) + +VALID_SCENE_COLUMNS = frozenset({ + "title", "summary", "topic", "location", "participants", "memory_ids", + "start_time", "end_time", "embedding", "strength", "access_count", + "tombstone", "layer", "scene_strength", "topic_embedding_ref", "namespace", +}) + +VALID_PROFILE_COLUMNS = frozenset({ + "name", "profile_type", "narrative", "facts", "preferences", + "relationships", "sentiment", "theory_of_mind", "aliases", + "embedding", "strength", "updated_at", "role_bias", "profile_summary", +}) + + +def _utcnow() -> datetime: + """Return current UTC datetime (timezone-aware).""" + return datetime.now(timezone.utc) + + +def _utcnow_iso() -> str: + """Return current UTC time as ISO string.""" + return _utcnow().isoformat() diff --git a/dhee/db/sqlite_domains.py b/dhee/db/sqlite_domains.py new file mode 100644 index 0000000..93beb3b --- /dev/null +++ b/dhee/db/sqlite_domains.py @@ -0,0 +1,498 @@ +import json +import sqlite3 +import uuid +from typing import Any, Dict, List, Optional + +from .sqlite_common import VALID_PROFILE_COLUMNS, VALID_SCENE_COLUMNS, _utcnow_iso + + +class SQLiteDomainMixin: + """Category, scene, and profile storage APIs for FullSQLiteManager.""" + + def save_category(self, category_data: Dict[str, Any]) -> str: + """Save or update a category.""" + category_id = category_data.get("id") + if not category_id: + return "" + + with self._get_connection() as conn: + conn.execute( + """ + INSERT OR REPLACE INTO categories ( + id, name, description, category_type, parent_id, + children_ids, memory_count, total_strength, access_count, + last_accessed, created_at, embedding, keywords, + summary, summary_updated_at, related_ids, strength + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + category_id, + category_data.get("name", ""), + category_data.get("description", ""), + category_data.get("category_type", "dynamic"), + category_data.get("parent_id"), + json.dumps(category_data.get("children_ids", [])), + category_data.get("memory_count", 0), + category_data.get("total_strength", 0.0), + category_data.get("access_count", 0), + category_data.get("last_accessed"), + category_data.get("created_at"), + json.dumps(category_data.get("embedding")) + if category_data.get("embedding") + else None, + json.dumps(category_data.get("keywords", [])), + category_data.get("summary"), + category_data.get("summary_updated_at"), + json.dumps(category_data.get("related_ids", [])), + category_data.get("strength", 1.0), + ), + ) + return category_id + + def get_category(self, category_id: str) -> Optional[Dict[str, Any]]: + """Get a category by ID.""" + with self._get_connection() as conn: + row = conn.execute( + "SELECT * FROM categories WHERE id = ?", + (category_id,), + ).fetchone() + if row: + return self._category_row_to_dict(row) + return None + + def get_all_categories(self) -> List[Dict[str, Any]]: + """Get all categories.""" + with self._get_connection() as conn: + rows = conn.execute( + "SELECT * FROM categories ORDER BY strength DESC" + ).fetchall() + return [self._category_row_to_dict(row) for row in rows] + + def delete_category(self, category_id: str) -> bool: + """Delete a category.""" + with self._get_connection() as conn: + conn.execute("DELETE FROM categories WHERE id = ?", (category_id,)) + return True + + def save_all_categories(self, categories: List[Dict[str, Any]]) -> int: + """Save multiple categories in a single transaction for performance.""" + if not categories: + return 0 + rows = [] + for cat in categories: + cat_id = cat.get("id") + if not cat_id: + continue + rows.append( + ( + cat_id, + cat.get("name", ""), + cat.get("description", ""), + cat.get("category_type", "dynamic"), + cat.get("parent_id"), + json.dumps(cat.get("children_ids", [])), + cat.get("memory_count", 0), + cat.get("total_strength", 0.0), + cat.get("access_count", 0), + cat.get("last_accessed"), + cat.get("created_at"), + json.dumps(cat.get("embedding")) + if cat.get("embedding") + else None, + json.dumps(cat.get("keywords", [])), + cat.get("summary"), + cat.get("summary_updated_at"), + json.dumps(cat.get("related_ids", [])), + cat.get("strength", 1.0), + ) + ) + if not rows: + return 0 + with self._get_connection() as conn: + conn.executemany( + """ + INSERT OR REPLACE INTO categories ( + id, name, description, category_type, parent_id, + children_ids, memory_count, total_strength, access_count, + last_accessed, created_at, embedding, keywords, + summary, summary_updated_at, related_ids, strength + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + rows, + ) + return len(rows) + + def _category_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: + """Convert a category row to dict.""" + data = dict(row) + data["children_ids"] = self._parse_json_value( + data.get("children_ids"), [] + ) + data["keywords"] = self._parse_json_value(data.get("keywords"), []) + data["related_ids"] = self._parse_json_value( + data.get("related_ids"), [] + ) + data["embedding"] = self._parse_json_value(data.get("embedding"), None) + return data + + def _migrate_add_column(self, table: str, column: str, col_type: str) -> None: + """Add a column to an existing table if it doesn't already exist.""" + with self._get_connection() as conn: + self._migrate_add_column_conn(conn, table, column, col_type) + + def add_scene(self, scene_data: Dict[str, Any]) -> str: + scene_id = scene_data.get("id", str(uuid.uuid4())) + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO scenes ( + id, user_id, title, summary, topic, location, + participants, memory_ids, start_time, end_time, + embedding, strength, access_count, tombstone, + layer, scene_strength, topic_embedding_ref, namespace + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + scene_id, + scene_data.get("user_id"), + scene_data.get("title"), + scene_data.get("summary"), + scene_data.get("topic"), + scene_data.get("location"), + json.dumps(scene_data.get("participants", [])), + json.dumps(scene_data.get("memory_ids", [])), + scene_data.get("start_time"), + scene_data.get("end_time"), + json.dumps(scene_data.get("embedding")) + if scene_data.get("embedding") + else None, + scene_data.get("strength", 1.0), + scene_data.get("access_count", 0), + 1 if scene_data.get("tombstone", False) else 0, + scene_data.get("layer", "sml"), + scene_data.get( + "scene_strength", scene_data.get("strength", 1.0) + ), + scene_data.get("topic_embedding_ref"), + scene_data.get("namespace", "default"), + ), + ) + return scene_id + + def get_scene(self, scene_id: str) -> Optional[Dict[str, Any]]: + with self._get_connection() as conn: + row = conn.execute( + "SELECT * FROM scenes WHERE id = ? AND tombstone = 0", + (scene_id,), + ).fetchone() + if row: + return self._scene_row_to_dict(row) + return None + + def update_scene(self, scene_id: str, updates: Dict[str, Any]) -> bool: + set_clauses = [] + params: List[Any] = [] + for key, value in updates.items(): + if key not in VALID_SCENE_COLUMNS: + raise ValueError(f"Invalid scene column: {key!r}") + if key in {"participants", "memory_ids", "embedding"}: + value = json.dumps(value) + set_clauses.append(f"{key} = ?") + params.append(value) + if not set_clauses: + return False + params.append(scene_id) + with self._get_connection() as conn: + conn.execute( + f"UPDATE scenes SET {', '.join(set_clauses)} WHERE id = ?", + params, + ) + return True + + def get_open_scene(self, user_id: str) -> Optional[Dict[str, Any]]: + """Get the most recent scene without an end_time for a user.""" + with self._get_connection() as conn: + row = conn.execute( + """ + SELECT * FROM scenes + WHERE user_id = ? AND end_time IS NULL AND tombstone = 0 + ORDER BY start_time DESC LIMIT 1 + """, + (user_id,), + ).fetchone() + if row: + return self._scene_row_to_dict(row) + return None + + def get_scenes( + self, + user_id: Optional[str] = None, + topic: Optional[str] = None, + start_after: Optional[str] = None, + start_before: Optional[str] = None, + namespace: Optional[str] = None, + limit: int = 50, + ) -> List[Dict[str, Any]]: + query = "SELECT * FROM scenes WHERE tombstone = 0" + params: List[Any] = [] + if user_id: + query += " AND user_id = ?" + params.append(user_id) + if topic: + query += " AND topic LIKE ?" + params.append(f"%{topic}%") + if start_after: + query += " AND start_time >= ?" + params.append(start_after) + if start_before: + query += " AND start_time <= ?" + params.append(start_before) + if namespace: + query += " AND namespace = ?" + params.append(namespace) + query += " ORDER BY start_time DESC LIMIT ?" + params.append(limit) + with self._get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [self._scene_row_to_dict(row) for row in rows] + + def add_scene_memory( + self, + scene_id: str, + memory_id: str, + position: int = 0, + ) -> None: + with self._get_connection() as conn: + conn.execute( + "INSERT OR IGNORE INTO scene_memories (scene_id, memory_id, position) VALUES (?, ?, ?)", + (scene_id, memory_id, position), + ) + + def get_scene_memories(self, scene_id: str) -> List[Dict[str, Any]]: + with self._get_connection() as conn: + rows = conn.execute( + """ + SELECT m.* FROM memories m + JOIN scene_memories sm ON m.id = sm.memory_id + WHERE sm.scene_id = ? AND m.tombstone = 0 + ORDER BY sm.position + """, + (scene_id,), + ).fetchall() + return [self._row_to_dict(row) for row in rows] + + def _scene_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: + data = dict(row) + data["participants"] = self._parse_json_value( + data.get("participants"), [] + ) + data["memory_ids"] = self._parse_json_value(data.get("memory_ids"), []) + data["embedding"] = self._parse_json_value(data.get("embedding"), None) + data["tombstone"] = bool(data.get("tombstone", 0)) + return data + + def add_profile(self, profile_data: Dict[str, Any]) -> str: + profile_id = profile_data.get("id", str(uuid.uuid4())) + now = _utcnow_iso() + with self._get_connection() as conn: + conn.execute( + """ + INSERT INTO profiles ( + id, user_id, name, profile_type, narrative, + facts, preferences, relationships, sentiment, + theory_of_mind, aliases, embedding, strength, + created_at, updated_at, role_bias, profile_summary + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + profile_id, + profile_data.get("user_id"), + profile_data.get("name", ""), + profile_data.get("profile_type", "contact"), + profile_data.get("narrative"), + json.dumps(profile_data.get("facts", [])), + json.dumps(profile_data.get("preferences", [])), + json.dumps(profile_data.get("relationships", [])), + profile_data.get("sentiment"), + json.dumps(profile_data.get("theory_of_mind", {})), + json.dumps(profile_data.get("aliases", [])), + json.dumps(profile_data.get("embedding")) + if profile_data.get("embedding") + else None, + profile_data.get("strength", 1.0), + profile_data.get("created_at", now), + profile_data.get("updated_at", now), + profile_data.get("role_bias"), + profile_data.get("profile_summary"), + ), + ) + return profile_id + + def get_profile(self, profile_id: str) -> Optional[Dict[str, Any]]: + with self._get_connection() as conn: + row = conn.execute( + "SELECT * FROM profiles WHERE id = ?", + (profile_id,), + ).fetchone() + if row: + return self._profile_row_to_dict(row) + return None + + def update_profile(self, profile_id: str, updates: Dict[str, Any]) -> bool: + set_clauses = [] + params: List[Any] = [] + for key, value in updates.items(): + if key not in VALID_PROFILE_COLUMNS: + raise ValueError(f"Invalid profile column: {key!r}") + if key in { + "facts", + "preferences", + "relationships", + "aliases", + "theory_of_mind", + "embedding", + }: + value = json.dumps(value) + set_clauses.append(f"{key} = ?") + params.append(value) + set_clauses.append("updated_at = ?") + params.append(_utcnow_iso()) + params.append(profile_id) + with self._get_connection() as conn: + conn.execute( + f"UPDATE profiles SET {', '.join(set_clauses)} WHERE id = ?", + params, + ) + return True + + def get_all_profiles( + self, + user_id: Optional[str] = None, + ) -> List[Dict[str, Any]]: + query = "SELECT * FROM profiles" + params: List[Any] = [] + if user_id: + query += " WHERE user_id = ?" + params.append(user_id) + query += " ORDER BY strength DESC" + with self._get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [self._profile_row_to_dict(row) for row in rows] + + def get_profile_by_name( + self, + name: str, + user_id: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: + """Find a profile by exact name match, then fall back to alias scan.""" + query = "SELECT * FROM profiles WHERE lower(name) = ?" + params: List[Any] = [name.lower()] + if user_id: + query += " AND user_id = ?" + params.append(user_id) + query += " LIMIT 1" + with self._get_connection() as conn: + row = conn.execute(query, params).fetchone() + if row: + return self._profile_row_to_dict(row) + alias_query = "SELECT * FROM profiles WHERE aliases LIKE ?" + alias_params: List[Any] = [f'%"{name}"%'] + if user_id: + alias_query += " AND user_id = ?" + alias_params.append(user_id) + alias_query += " LIMIT 1" + row = conn.execute(alias_query, alias_params).fetchone() + if row: + result = self._profile_row_to_dict(row) + if name.lower() in [a.lower() for a in result.get("aliases", [])]: + return result + return None + + def find_profile_by_substring( + self, + name: str, + user_id: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: + """Find a profile where the name contains the query as a substring.""" + query = "SELECT * FROM profiles WHERE lower(name) LIKE ?" + params: List[Any] = [f"%{name.lower()}%"] + if user_id: + query += " AND user_id = ?" + params.append(user_id) + query += " ORDER BY strength DESC LIMIT 1" + with self._get_connection() as conn: + row = conn.execute(query, params).fetchone() + if row: + return self._profile_row_to_dict(row) + return None + + def add_profile_memory( + self, + profile_id: str, + memory_id: str, + role: str = "mentioned", + ) -> None: + with self._get_connection() as conn: + conn.execute( + "INSERT OR IGNORE INTO profile_memories (profile_id, memory_id, role) VALUES (?, ?, ?)", + (profile_id, memory_id, role), + ) + + def get_profile_memories(self, profile_id: str) -> List[Dict[str, Any]]: + with self._get_connection() as conn: + rows = conn.execute( + """ + SELECT m.*, pm.role AS profile_role FROM memories m + JOIN profile_memories pm ON m.id = pm.memory_id + WHERE pm.profile_id = ? AND m.tombstone = 0 + ORDER BY m.created_at DESC + """, + (profile_id,), + ).fetchall() + return [self._row_to_dict(row) for row in rows] + + def _profile_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: + data = dict(row) + data["facts"] = self._parse_json_value(data.get("facts"), []) + data["preferences"] = self._parse_json_value( + data.get("preferences"), [] + ) + data["relationships"] = self._parse_json_value( + data.get("relationships"), [] + ) + data["aliases"] = self._parse_json_value(data.get("aliases"), []) + data["theory_of_mind"] = self._parse_json_value( + data.get("theory_of_mind"), {} + ) + data["embedding"] = self._parse_json_value(data.get("embedding"), None) + return data + + def get_memories_by_category( + self, + category_id: str, + limit: int = 100, + min_strength: float = 0.0, + ) -> List[Dict[str, Any]]: + """Get memories belonging to a specific category.""" + with self._get_connection() as conn: + rows = conn.execute( + """ + SELECT * FROM memories + WHERE categories LIKE ? AND strength >= ? AND tombstone = 0 + ORDER BY strength DESC + LIMIT ? + """, + (f'%"{category_id}"%', min_strength, limit), + ).fetchall() + return [self._row_to_dict(row) for row in rows] + + def list_user_ids(self) -> List[str]: + with self._get_connection() as conn: + rows = conn.execute( + """ + SELECT DISTINCT user_id FROM memories + WHERE user_id IS NOT NULL AND user_id != '' + ORDER BY user_id + """ + ).fetchall() + return [str(row["user_id"]) for row in rows if row["user_id"]] diff --git a/dhee/mcp_slim.py b/dhee/mcp_slim.py index ec6f79d..1013f99 100644 --- a/dhee/mcp_slim.py +++ b/dhee/mcp_slim.py @@ -41,7 +41,7 @@ def _get_plugin(): from dhee.adapters.base import DheePlugin _plugin = DheePlugin() # Enable deferred enrichment on the underlying memory - memory = _plugin._engram._memory + memory = _plugin.memory if hasattr(memory, "config") and hasattr(memory.config, "enrichment"): memory.config.enrichment.defer_enrichment = True memory.config.enrichment.enable_unified = True @@ -52,9 +52,11 @@ def _auto_checkpoint_on_exit(): try: args = _plugin._tracker.finalize() if args: - _plugin.checkpoint(**args) - except Exception: - pass + result = _plugin.checkpoint(**args) + for warning in result.get("warnings", []): + logger.warning("MCP auto-checkpoint warning: %s", warning) + except Exception as exc: + logger.warning("MCP auto-checkpoint on exit failed: %s", exc, exc_info=True) atexit.register(_auto_checkpoint_on_exit) return _plugin diff --git a/dhee/memory/main.py b/dhee/memory/main.py index a2a8dad..69089b3 100644 --- a/dhee/memory/main.py +++ b/dhee/memory/main.py @@ -33,6 +33,7 @@ from dhee.core.graph import KnowledgeGraph from dhee.core.scene import SceneProcessor from dhee.core.profile import ProfileProcessor +from dhee.core.answer_orchestration import extract_atomic_facts, reduce_atomic_facts from dhee.db.sqlite import SQLiteManager from dhee.exceptions import FadeMemValidationError from dhee.memory.base import MemoryBase @@ -293,6 +294,8 @@ def __init__(self, config: Optional[MemoryConfig] = None, preset: Optional[str] config = MemoryConfig.full() # Initialize parent SmartMemory (handles db, llm, embedder, etc.) super().__init__(config=config, preset=preset) + self._runtime_root_dir = self._resolve_runtime_root_dir() + self._buddhi_data_dir = os.path.join(self._runtime_root_dir, "buddhi") # Only FullMemory-specific lazy init self._scene_processor: Optional[SceneProcessor] = None self._profile_processor: Optional[ProfileProcessor] = None @@ -325,6 +328,33 @@ def __init__(self, config: Optional[MemoryConfig] = None, preset: Optional[str] # Search pipeline (lazy — created on first use) self.__search_pipeline: Optional[SearchPipeline] = None + def _resolve_runtime_root_dir(self) -> str: + """Pick a stable runtime root for cognition sidecars. + + FullMemory may be backed by a temporary or custom data dir even when the + vector store itself is in-memory. Use the configured on-disk paths first + so helper layers do not silently spill into ~/.dhee. + """ + candidate_paths: List[object] = [getattr(self.config, "history_db_path", None)] + vector_config = getattr(getattr(self.config, "vector_store", None), "config", {}) + if isinstance(vector_config, dict): + candidate_paths.append(vector_config.get("path")) + + for raw_path in candidate_paths: + if raw_path is None: + continue + path = str(raw_path).strip() + if not path or path == ":memory:": + continue + root = os.path.dirname(os.path.abspath(os.path.expanduser(path))) + if root: + os.makedirs(root, exist_ok=True) + return root + + fallback = os.path.join(os.path.expanduser("~"), ".dhee") + os.makedirs(fallback, exist_ok=True) + return fallback + @property def _write_pipeline(self) -> MemoryWritePipeline: """Lazy-initialized write pipeline that delegates heavy write-path logic.""" @@ -354,6 +384,7 @@ def _write_pipeline(self) -> MemoryWritePipeline: assign_to_scene_fn=self._assign_to_scene, update_profiles_fn=self._update_profiles, store_prospective_scenes_fn=self._store_prospective_scenes, + persist_categories_fn=self._persist_categories, ) return self.__write_pipeline @@ -399,6 +430,8 @@ def _orchestration_engine(self) -> OrchestrationEngine: profile_processor_fn=lambda: self.profile_processor, evolution_layer_fn=lambda: self.evolution_layer, llm_fn=lambda: self.llm, + extract_atomic_facts_fn=extract_atomic_facts, + reduce_atomic_facts_fn=reduce_atomic_facts, ) return self.__orchestration_engine @@ -492,7 +525,9 @@ def evolution_layer(self): if self._evolution_layer is None: try: from dhee.core.evolution import EvolutionLayer - self._evolution_layer = EvolutionLayer() + self._evolution_layer = EvolutionLayer( + data_dir=self._runtime_root_dir, + ) except Exception as e: logger.debug("Evolution layer init skipped: %s", e) return self._evolution_layer @@ -503,7 +538,9 @@ def buddhi_layer(self): if self._buddhi_layer is None: try: from dhee.core.buddhi import Buddhi - self._buddhi_layer = Buddhi() + self._buddhi_layer = Buddhi( + data_dir=self._buddhi_data_dir, + ) except Exception as e: logger.debug("Buddhi layer init skipped: %s", e) return self._buddhi_layer @@ -691,22 +728,56 @@ def get_skill_stats(self) -> Dict[str, Any]: def close(self) -> None: """Release all resources held by the Memory instance.""" + errors = [] + # Flush self-evolution state before shutdown if self._evolution_layer is not None: try: self._evolution_layer.flush() - except Exception: - pass + except Exception as exc: + logger.exception("FullMemory close failed for evolution.flush") + errors.append( + f"evolution.flush: {type(exc).__name__}: {exc}" + ) + # Shutdown parallel executor if it was created if self._executor is not None: - self._executor.shutdown() - self._executor = None + try: + self._executor.shutdown() + except Exception as exc: + logger.exception("FullMemory close failed for executor.shutdown") + errors.append( + f"executor.shutdown: {type(exc).__name__}: {exc}" + ) + finally: + self._executor = None + # Release vector store if self.vector_store is not None: - self.vector_store.close() + try: + self.vector_store.close() + except Exception as exc: + logger.exception("FullMemory close failed for vector_store.close") + errors.append( + f"vector_store.close: {type(exc).__name__}: {exc}" + ) + finally: + self.vector_store = None + # Release database if self.db is not None: - self.db.close() + try: + self.db.close() + except Exception as exc: + logger.exception("FullMemory close failed for db.close") + errors.append(f"db.close: {type(exc).__name__}: {exc}") + finally: + self.db = None + + if errors: + raise RuntimeError( + "Failed to close FullMemory resources: " + "; ".join(errors) + ) def __repr__(self) -> str: return f"FullMemory(db={self.db!r}, echo={self.config.echo.enable_echo}, scenes={self.config.scene.enable_scenes})" @@ -959,7 +1030,6 @@ def add_batch( use_batch = batch_config and batch_config.enable_batch if not use_batch or not items: - # Fallback: sequential add per item all_results = [] for item in items: content = item.get("content") or item.get("messages", "") @@ -983,27 +1053,25 @@ def add_batch( return {"results": all_results} max_batch = batch_config.max_batch_size - - # Split into sub-batches if needed all_results: List[Dict[str, Any]] = [] for start in range(0, len(items), max_batch): - chunk = items[start : start + max_batch] - chunk_results = self._process_memory_batch( - chunk, - user_id=user_id, - agent_id=agent_id, - run_id=run_id, - app_id=app_id, - metadata=metadata, - filters=filters, - initial_strength=initial_strength, - echo_depth=echo_depth, - batch_config=batch_config, - **common_kwargs, + chunk = items[start:start + max_batch] + all_results.extend( + self._process_memory_batch( + chunk, + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + app_id=app_id, + metadata=metadata, + filters=filters, + initial_strength=initial_strength, + echo_depth=echo_depth, + batch_config=batch_config, + **common_kwargs, + ) ) - all_results.extend(chunk_results) - # Persist categories after full batch if self.category_processor: self._persist_categories() @@ -1025,459 +1093,19 @@ def _process_memory_batch( **common_kwargs: Any, ) -> List[Dict[str, Any]]: """Process a batch of memory items with batched echo/embed/DB.""" - - # Extract contents - contents = [] - item_metadata_list = [] - for item in items: - content = item.get("content") or item.get("messages", "") - if isinstance(content, list): - # Flatten message list to string - content = " ".join( - m.get("content", "") for m in content if isinstance(m, dict) - ) - contents.append(str(content).strip()) - item_meta = dict(metadata or {}) - item_meta.update(item.get("metadata") or {}) - item_metadata_list.append(item_meta) - - # Write-path telemetry aggregates for this batch (later normalized per memory). - batch_llm_calls_total = 0.0 - batch_embed_calls_total = 0.0 - batch_input_tokens_total = 0.0 - batch_output_tokens_total = 0.0 - - # 0. Try unified enrichment (single LLM call for echo+category+entities+profiles) - echo_results = [None] * len(contents) - category_results = [None] * len(contents) - enrichment_results = [None] * len(contents) # stash for post-store hooks - - enrichment_config = getattr(self.config, "enrichment", None) - _use_unified = ( - self.unified_enrichment is not None - and self.echo_config.enable_echo - and batch_config.batch_echo - ) - - if _use_unified: - try: - depth_override = EchoDepth(echo_depth) if echo_depth else EchoDepth(self.echo_config.default_depth) - existing_cats = None - if self.category_processor: - cats = self.category_processor.get_all_categories() - if cats: - existing_cats = "\n".join( - f"- {c['id']}: {c['name']} — {c.get('description', '')}" - for c in cats[:30] - ) - - # Process in sub-batches of enrichment_config.max_batch_size - enrich_batch_size = enrichment_config.max_batch_size if enrichment_config else 10 - for start in range(0, len(contents), enrich_batch_size): - end = min(start + enrich_batch_size, len(contents)) - sub_contents = contents[start:end] - sub_results = self.unified_enrichment.enrich_batch( - sub_contents, - depth=depth_override, - existing_categories=existing_cats, - include_entities=enrichment_config.include_entities if enrichment_config else True, - include_profiles=enrichment_config.include_profiles if enrichment_config else True, - ) - sub_input_tokens = sum(self._estimate_token_count(c) for c in sub_contents) - sub_input_tokens += self._estimate_token_count(existing_cats) - batch_llm_calls_total += 1.0 - batch_input_tokens_total += sub_input_tokens - batch_output_tokens_total += self._estimate_output_tokens(sub_input_tokens) - for j, enrichment in enumerate(sub_results): - idx = start + j - if enrichment.echo_result: - echo_results[idx] = enrichment.echo_result - if enrichment.category_match: - category_results[idx] = enrichment.category_match - enrichment_results[idx] = enrichment - - logger.info("Unified batch enrichment completed for %d memories", len(contents)) - except Exception as e: - logger.warning("Unified batch enrichment failed, falling back to separate: %s", e) - # Reset — let the fallback below handle it - echo_results = [None] * len(contents) - category_results = [None] * len(contents) - enrichment_results = [None] * len(contents) - _use_unified = False - - # 1. Batch echo encoding (fallback if unified was not used or failed) - if not _use_unified: - if self.echo_processor and self.echo_config.enable_echo and batch_config.batch_echo: - depth_override = EchoDepth(echo_depth) if echo_depth else EchoDepth(self.echo_config.default_depth) - if depth_override != EchoDepth.SHALLOW: - echo_input_tokens = sum(self._estimate_token_count(c) for c in contents if c) - non_empty_count = sum(1 for c in contents if c) - batch_llm_calls_total += float(non_empty_count) - batch_input_tokens_total += echo_input_tokens - batch_output_tokens_total += self._estimate_output_tokens(echo_input_tokens) - try: - echo_results = self.echo_processor.process_batch( - contents, depth=depth_override - ) - except Exception as e: - logger.warning("Batch echo failed, processing individually: %s", e) - for i, c in enumerate(contents): - if c: - try: - depth_override = EchoDepth(echo_depth) if echo_depth else None - echo_results[i] = self.echo_processor.process(c, depth=depth_override) - except Exception: - pass - - # 2. Batch category detection - if ( - self.category_processor - and self.category_config.auto_categorize - and batch_config.batch_category - ): - if self.category_config.use_llm_categorization: - cat_input_tokens = sum(self._estimate_token_count(c) for c in contents if c) - non_empty_count = sum(1 for c in contents if c) - batch_llm_calls_total += float(non_empty_count) - batch_input_tokens_total += cat_input_tokens - batch_output_tokens_total += self._estimate_output_tokens(cat_input_tokens) - try: - category_results = self.category_processor.detect_categories_batch( - contents, - use_llm=self.category_config.use_llm_categorization, - ) - except Exception as e: - logger.warning("Batch category failed: %s", e) - - # 3. Batch embeddings - primary_texts = [] - for i, content in enumerate(contents): - echo_result = echo_results[i] - primary_texts.append(self._select_primary_text(content, echo_result)) - - if batch_config.batch_embed: - try: - # Sub-batch to stay within API limits (~50 per call) - embeddings: List[List[float]] = [] - for start in range(0, len(primary_texts), 50): - sub = primary_texts[start:start + 50] - embeddings.extend(self.embedder.embed_batch(sub, memory_action="add")) - batch_embed_calls_total += 1.0 - except Exception as e: - logger.warning("Batch embed failed, falling back to sequential: %s", e) - embeddings = [ - self.embedder.embed(t, memory_action="add") for t in primary_texts - ] - batch_embed_calls_total += float(len(primary_texts)) - else: - embeddings = [ - self.embedder.embed(t, memory_action="add") for t in primary_texts - ] - batch_embed_calls_total += float(len(primary_texts)) - - # 3b. Pre-embed all echo node texts (paraphrases, questions, content variants) - # so _build_index_vectors can use the cache instead of individual embed() calls. - echo_node_texts = [] - for i, content in enumerate(contents): - echo_result = echo_results[i] - pt = primary_texts[i] - if pt != content: - cleaned = content.strip() - if cleaned: - echo_node_texts.append(cleaned) - if echo_result: - for p in echo_result.paraphrases: - cleaned = str(p).strip() - if cleaned: - echo_node_texts.append(cleaned) - for q in echo_result.questions: - cleaned = str(q).strip() - if cleaned: - echo_node_texts.append(cleaned) - - embedding_cache: Dict[str, List[float]] = {} - if echo_node_texts: - # Deduplicate while preserving order for batch embedding - unique_texts = list(dict.fromkeys(echo_node_texts)) - try: - # Sub-batch to stay within NVIDIA API limits (~50 per call) - all_echo_embeddings: List[List[float]] = [] - for start in range(0, len(unique_texts), 50): - sub = unique_texts[start:start + 50] - sub_embs = self.embedder.embed_batch(sub, memory_action="add") - all_echo_embeddings.extend(sub_embs) - batch_embed_calls_total += 1.0 - for text, emb in zip(unique_texts, all_echo_embeddings): - embedding_cache[text] = emb - logger.info("Batch-embedded %d echo node texts in %d API calls", - len(unique_texts), (len(unique_texts) + 49) // 50) - except Exception as e: - logger.warning("Batch echo node embedding failed, will embed individually: %s", e) - - # 4. Build memory records and batch-insert into DB - processed_metadata_base, effective_filters = build_filters_and_metadata( + return self._write_pipeline.process_memory_batch( + items, user_id=user_id, agent_id=agent_id, run_id=run_id, - input_metadata=metadata, - input_filters=filters, + app_id=app_id, + metadata=metadata, + filters=filters, + initial_strength=initial_strength, + echo_depth=echo_depth, + batch_config=batch_config, + **common_kwargs, ) - if app_id: - processed_metadata_base["app_id"] = app_id - - now = datetime.now(timezone.utc).isoformat() - memory_records = [] - episodic_rows: List[Tuple[str, Optional[str], str, Dict[str, Any]]] = [] - vector_batch = [] # (vectors, payloads, ids) - results = [] - - for i, content in enumerate(contents): - if not content: - continue - - memory_id = str(uuid.uuid4()) - mem_metadata = dict(processed_metadata_base) - mem_metadata.update(item_metadata_list[i]) - mem_metadata = self._attach_bitemporal_metadata(mem_metadata, observed_time=now) - - echo_result = echo_results[i] - effective_strength = initial_strength - mem_categories = list(items[i].get("categories") or []) - - if echo_result: - effective_strength = initial_strength * echo_result.strength_multiplier - mem_metadata.update(echo_result.to_metadata()) - if not mem_categories and echo_result.category: - mem_categories = [echo_result.category] - - cat_match = category_results[i] - if cat_match and not mem_categories: - mem_categories = [cat_match.category_id] - mem_metadata["category_confidence"] = cat_match.confidence - mem_metadata["category_auto"] = True - - embedding = embeddings[i] - namespace_value = str(mem_metadata.get("namespace", "default") or "default").strip() or "default" - - memory_type = self._classify_memory_type(mem_metadata, mem_metadata.get("role", "user")) - - s_fast_val = s_mid_val = s_slow_val = None - if self.distillation_config and self.distillation_config.enable_multi_trace: - s_fast_val, s_mid_val, s_slow_val = initialize_traces(effective_strength, is_new=True) - - memory_data = { - "id": memory_id, - "memory": content, - "user_id": items[i].get("user_id") or user_id, - "agent_id": agent_id, - "run_id": run_id, - "app_id": app_id, - "metadata": mem_metadata, - "categories": mem_categories, - "immutable": items[i].get("immutable", False), - "expiration_date": items[i].get("expiration_date"), - "created_at": now, - "updated_at": now, - "layer": "sml", - "strength": effective_strength, - "access_count": 0, - "last_accessed": now, - "embedding": embedding, - "confidentiality_scope": "work", - "source_type": "mcp", - "source_app": items[i].get("source_app"), - "source_event_id": mem_metadata.get("source_event_id"), - "decay_lambda": self.fade_config.sml_decay_rate, - "status": "active", - "importance": mem_metadata.get("importance", 0.5), - "sensitivity": mem_metadata.get("sensitivity", "normal"), - "namespace": namespace_value, - "memory_type": memory_type, - "s_fast": s_fast_val, - "s_mid": s_mid_val, - "s_slow": s_slow_val, - } - memory_records.append(memory_data) - episodic_rows.append( - ( - memory_id, - items[i].get("user_id") or user_id, - content, - mem_metadata, - ) - ) - - # Build vector index entries - vectors, payloads, vector_ids = self._build_index_vectors( - memory_id=memory_id, - content=content, - primary_text=primary_texts[i], - embedding=embedding, - echo_result=echo_result, - metadata=mem_metadata, - categories=mem_categories, - user_id=items[i].get("user_id") or user_id, - agent_id=agent_id, - run_id=run_id, - app_id=app_id, - embedding_cache=embedding_cache if embedding_cache else None, - ) - if vectors: - vector_batch.append((vectors, payloads, vector_ids)) - - results.append({ - "id": memory_id, - "memory": content, - "event": "ADD", - "layer": "sml", - "strength": effective_strength, - "echo_depth": echo_result.echo_depth.value if echo_result else None, - "categories": mem_categories, - "namespace": namespace_value, - "memory_type": memory_type, - }) - - # 4a. Batch DB insert - if memory_records: - try: - self.db.add_memories_batch(memory_records) - except Exception as e: - logger.error("Batch DB insert failed, falling back to sequential: %s", e) - for record in memory_records: - self.db.add_memory(record) - - # 4b. Batch vector insert - for vectors, payloads, vector_ids in vector_batch: - try: - self.vector_store.insert(vectors=vectors, payloads=payloads, ids=vector_ids) - except Exception as e: - logger.error("Vector insert failed in batch: %s", e) - - # Deterministic episodic index. - for memory_id, owner_user_id, content, mem_metadata in episodic_rows: - self._index_episodic_events_for_memory( - memory_id=memory_id, - user_id=owner_user_id, - content=content, - metadata=mem_metadata, - ) - - # Post-store hooks: category stats - for i, record in enumerate(memory_records): - if self.category_processor and record.get("categories"): - for cat_id in record["categories"]: - self.category_processor.update_category_stats( - cat_id, record["strength"], is_addition=True - ) - - # Post-store hooks: fact decomposition (batch embed + insert) - all_fact_texts = [] - all_fact_meta = [] # (memory_id, fact_index) - for i, record in enumerate(memory_records): - enrichment = enrichment_results[i] if i < len(enrichment_results) else None - if enrichment and enrichment.facts: - for fi, fact_text in enumerate(enrichment.facts[:8]): - fact_text = fact_text.strip() - if fact_text and len(fact_text) >= 10: - all_fact_texts.append(fact_text) - all_fact_meta.append((record["id"], fi)) - - if all_fact_texts: - try: - # Sub-batch fact embeddings to stay within API limits - fact_embeddings: List[List[float]] = [] - for fs in range(0, len(all_fact_texts), 50): - sub = all_fact_texts[fs:fs + 50] - fact_embeddings.extend(self.embedder.embed_batch(sub, memory_action="add")) - batch_embed_calls_total += 1.0 - fact_vectors = [] - fact_payloads = [] - fact_ids = [] - for (memory_id, fi), fact_text, fact_emb in zip(all_fact_meta, all_fact_texts, fact_embeddings): - fact_id = f"{memory_id}__fact_{fi}" - fact_vectors.append(fact_emb) - fact_payloads.append({ - "memory_id": memory_id, - "is_fact": True, - "fact_index": fi, - "fact_text": fact_text, - "user_id": user_id, - "agent_id": agent_id, - }) - fact_ids.append(fact_id) - if fact_vectors: - self.vector_store.insert(vectors=fact_vectors, payloads=fact_payloads, ids=fact_ids) - except Exception as e: - logger.warning("Batch fact embedding/insert failed: %s", e) - - # Post-store hooks: entity linking and profile updates - for i, record in enumerate(memory_records): - enrichment = enrichment_results[i] if i < len(enrichment_results) else None - if not enrichment: - continue - memory_id = record["id"] - content = record.get("memory", "") - - if self.knowledge_graph and enrichment.entities: - try: - for entity in enrichment.entities: - existing_ent = self.knowledge_graph._get_or_create_entity( - entity.name, entity.entity_type, - ) - existing_ent.memory_ids.add(memory_id) - self.knowledge_graph.memory_entities[memory_id] = { - e.name for e in enrichment.entities - } - if self.graph_config.auto_link_entities: - self.knowledge_graph.link_by_shared_entities(memory_id) - except Exception as e: - logger.warning("Entity linking failed for %s: %s", memory_id, e) - - if self.profile_processor and enrichment.profile_updates: - try: - for profile_update in enrichment.profile_updates: - self.profile_processor.apply_update( - profile_update=profile_update, - memory_id=memory_id, - user_id=record.get("user_id") or user_id or "default", - ) - except Exception as e: - logger.warning("Profile update failed for %s: %s", memory_id, e) - - # Post-store hooks: Universal Engram extraction (structured facts + context anchors) - if self.engram_extractor: - for i, record in enumerate(memory_records): - memory_id = record["id"] - content = record.get("memory", "") - try: - engram = self.engram_extractor.extract( - content=content, - session_context=None, - existing_metadata=record.get("metadata"), - user_id=record.get("user_id") or user_id or "default", - ) - if self.context_resolver and engram: - self.context_resolver.store_engram(engram, memory_id) - except Exception as e: - logger.warning("Engram extraction failed for %s: %s", memory_id, e) - - if episodic_rows: - sample_count = float(len(episodic_rows)) - llm_calls_per_memory = batch_llm_calls_total / sample_count - input_tokens_per_memory = batch_input_tokens_total / sample_count - output_tokens_per_memory = batch_output_tokens_total / sample_count - embed_calls_per_memory = batch_embed_calls_total / sample_count - for _, owner_user_id, _, _ in episodic_rows: - self._record_cost_counter( - phase="write", - user_id=owner_user_id, - llm_calls=llm_calls_per_memory, - input_tokens=input_tokens_per_memory, - output_tokens=output_tokens_per_memory, - embed_calls=embed_calls_per_memory, - ) - - return results def _resolve_memory_metadata(self, **kwargs) -> tuple: return self._write_pipeline.resolve_memory_metadata(**kwargs) @@ -1497,160 +1125,12 @@ def enrich_pending( batch_size: int = 10, max_batches: int = 5, ) -> Dict[str, Any]: - """Batch-enrich memories that were stored with deferred enrichment. - - Uses unified enrichment: 1 LLM call per batch_size memories. - Returns {enriched_count, batches, remaining}. - """ - limit = batch_size * max_batches - pending = self.db.get_pending_enrichment(user_id=user_id, limit=limit) - if not pending: - return {"enriched_count": 0, "batches": 0, "remaining": 0} - - enriched_count = 0 - batches_processed = 0 - - for start in range(0, len(pending), batch_size): - batch = pending[start:start + batch_size] - contents = [m.get("memory", "") for m in batch] - - # Try unified enrichment (single LLM call for the batch) - enrichment_results = None - if self.unified_enrichment is not None: - try: - existing_cats = None - if self.category_processor: - cats = self.category_processor.get_all_categories() - if cats: - existing_cats = "\n".join( - f"- {c['id']}: {c['name']} — {c.get('description', '')}" - for c in cats[:30] - ) - - enrichment_results = self.unified_enrichment.enrich_batch( - contents, - depth=EchoDepth.MEDIUM, - existing_categories=existing_cats, - include_entities=True, - include_profiles=True, - ) - except Exception as e: - logger.warning("Unified batch enrichment failed in enrich_pending: %s", e) - enrichment_results = None - - # Fallback: individual enrichment per memory - if enrichment_results is None: - enrichment_results = [] - for c in contents: - if self.unified_enrichment is not None: - try: - enrichment_results.append( - self.unified_enrichment.enrich(c, depth=EchoDepth.MEDIUM) - ) - except Exception: - enrichment_results.append(None) - else: - enrichment_results.append(None) - - # Apply enrichment results and update DB - db_updates: List[Dict[str, Any]] = [] - for mem, enrichment in zip(batch, enrichment_results): - mem_id = mem["id"] - mem_meta = mem.get("metadata", {}) or {} - mem_cats = mem.get("categories", []) or [] - - if enrichment: - # Apply echo result - if enrichment.echo_result: - mem_meta.update(enrichment.echo_result.to_metadata()) - if not mem_cats and enrichment.echo_result.category: - mem_cats = [enrichment.echo_result.category] - - # Apply category result - if enrichment.category_match and not mem_cats: - mem_cats = [enrichment.category_match.category_id] - mem_meta["category_confidence"] = enrichment.category_match.confidence - mem_meta["category_auto"] = True - - # Apply extracted facts to metadata - if enrichment.facts: - mem_meta["enrichment_facts"] = enrichment.facts[:8] - - # Post-store hooks: entities - if self.knowledge_graph and enrichment.entities: - for entity in enrichment.entities: - existing_ent = self.knowledge_graph._get_or_create_entity( - entity.name, entity.entity_type, - ) - existing_ent.memory_ids.add(mem_id) - self.knowledge_graph.memory_entities[mem_id] = { - e.name for e in enrichment.entities - } - - # Post-store hooks: profiles - if self.profile_processor and enrichment.profile_updates: - for profile_update in enrichment.profile_updates: - try: - self.profile_processor.apply_update( - profile_update=profile_update, - memory_id=mem_id, - user_id=user_id, - ) - except Exception as e: - logger.warning("Profile update failed during enrichment for %s: %s", mem_id, e) - - # Generate fact decomposition vectors - if enrichment.facts: - valid_facts = [ - (i, f.strip()) for i, f in enumerate(enrichment.facts[:8]) - if f.strip() and len(f.strip()) >= 10 - ] - if valid_facts: - try: - fact_texts = [ft for _, ft in valid_facts] - fact_embeddings = self.embedder.embed_batch(fact_texts, memory_action="add") - fact_vectors, fact_payloads, fact_ids = [], [], [] - for (i, fact_text), fact_emb in zip(valid_facts, fact_embeddings): - fact_id = f"{mem_id}__fact_{i}" - fact_vectors.append(fact_emb) - fact_payloads.append({ - "memory_id": mem_id, - "is_fact": True, - "fact_index": i, - "fact_text": fact_text, - "user_id": user_id, - }) - fact_ids.append(fact_id) - if fact_vectors: - self.vector_store.insert( - vectors=fact_vectors, - payloads=fact_payloads, - ids=fact_ids, - ) - except Exception as e: - logger.warning("Fact embedding failed during enrichment for %s: %s", mem_id, e) - - mem_meta["enrichment_status"] = "complete" - db_updates.append({ - "id": mem_id, - "metadata": mem_meta, - "categories": mem_cats, - "enrichment_status": "complete", - }) - enriched_count += 1 - - # Batch DB update - self.db.update_enrichment_bulk(db_updates) - batches_processed += 1 - - # Check remaining - remaining_count = len(self.db.get_pending_enrichment(user_id=user_id, limit=1)) - - return { - "enriched_count": enriched_count, - "batches": batches_processed, - "remaining": remaining_count, - } + """Batch-enrich memories that were stored with deferred enrichment.""" + return self._write_pipeline.enrich_pending( + user_id=user_id, + batch_size=batch_size, + max_batches=max_batches, + ) _normalize_bitemporal_value = staticmethod(normalize_bitemporal_value) _parse_bitemporal_datetime = classmethod(lambda cls, v: parse_bitemporal_datetime(v)) @@ -3131,3 +2611,6 @@ def get_decay_log(self, limit: int = 20) -> List[Dict[str, Any]]: """Get recent decay history for dashboard sparkline.""" return self.db.get_decay_log_entries(limit=limit) + +# Historical alias kept for legacy ``engram.memory.main.Memory`` imports. +Memory = FullMemory diff --git a/dhee/memory/orchestration.py b/dhee/memory/orchestration.py index ee1873d..ff3295b 100644 --- a/dhee/memory/orchestration.py +++ b/dhee/memory/orchestration.py @@ -17,9 +17,7 @@ build_map_candidates, build_query_plan, deterministic_inconsistency_check, - extract_atomic_facts, is_low_confidence_answer, - reduce_atomic_facts, render_fact_context, ) @@ -47,6 +45,8 @@ def __init__( profile_processor_fn: Callable, evolution_layer_fn: Callable, llm_fn: Callable, + extract_atomic_facts_fn: Callable, + reduce_atomic_facts_fn: Callable, ): self._config = config self._db = db @@ -59,6 +59,8 @@ def __init__( self._profile_processor_fn = profile_processor_fn self._evolution_layer_fn = evolution_layer_fn self._llm_fn = llm_fn + self._extract_atomic_facts_fn = extract_atomic_facts_fn + self._reduce_atomic_facts_fn = reduce_atomic_facts_fn # Internal state self._reducer_cache: Dict[str, Dict[str, Any]] = {} self._guardrail_auto_disabled: bool = False @@ -676,7 +678,11 @@ def _execute_map_reduce( reason_codes: List[str] = [] active_orchestrator_llm = orchestrator_llm or self._llm_fn() orch_cfg = getattr(self._config, "orchestration", None) - max_query_llm_calls = int(getattr(orch_cfg, "max_query_llm_calls", 2) or 2) + raw_max_query_llm_calls = getattr(orch_cfg, "max_query_llm_calls", 2) + try: + max_query_llm_calls = int(raw_max_query_llm_calls if raw_max_query_llm_calls is not None else 2) + except (TypeError, ValueError): + max_query_llm_calls = 2 coverage_sufficient = bool((coverage or {}).get("sufficient")) if coverage_sufficient: @@ -697,12 +703,16 @@ def _execute_map_reduce( # NOTE: Event-first reduction (Phase 2) disabled — episodic events # alone lack sufficient coverage for accurate multi-session counting. # The LLM-based map-reduce path below is more reliable. + if mode == "strict": + mode_requires_map_reduce = True + else: + mode_requires_map_reduce = (not coverage_sufficient) or inconsistency_detected should_run_map_reduce = bool( query_plan.should_map_reduce and active_orchestrator_llm is not None and results - and (mode in ("strict", "hybrid") or not coverage_sufficient or inconsistency_detected) + and mode_requires_map_reduce ) if query_plan.should_map_reduce and active_orchestrator_llm is None: reason_codes.append("no_orchestrator_llm") @@ -733,14 +743,14 @@ def _execute_map_reduce( per_candidate_max_chars=map_max_chars_value, ) if llm_calls_used < float(max_query_llm_calls): - facts = extract_atomic_facts( + facts = self._extract_atomic_facts_fn( llm=active_orchestrator_llm, question=query, question_type=question_type, question_date=question_date, candidates=map_candidates, ) - reduced_answer, _ = reduce_atomic_facts( + reduced_answer, _ = self._reduce_atomic_facts_fn( question=query, intent=query_plan.intent, facts=facts, @@ -808,14 +818,14 @@ def _execute_map_reduce( per_candidate_max_chars=map_max_chars_value, ) if llm_calls_used < float(max_query_llm_calls): - facts = extract_atomic_facts( + facts = self._extract_atomic_facts_fn( llm=active_orchestrator_llm, question=query, question_type=question_type, question_date=question_date, candidates=map_candidates, ) - reduced_answer, _ = reduce_atomic_facts( + reduced_answer, _ = self._reduce_atomic_facts_fn( question=query, intent=query_plan.intent, facts=facts, diff --git a/dhee/memory/write_pipeline.py b/dhee/memory/write_pipeline.py index c78805e..602ce41 100644 --- a/dhee/memory/write_pipeline.py +++ b/dhee/memory/write_pipeline.py @@ -25,6 +25,7 @@ normalize_bitemporal_value, ) from dhee.memory.utils import ( + build_filters_and_metadata, normalize_categories, parse_messages, strip_code_fences, @@ -70,6 +71,7 @@ def __init__( assign_to_scene_fn: Optional[Callable] = None, update_profiles_fn: Optional[Callable] = None, store_prospective_scenes_fn: Optional[Callable] = None, + persist_categories_fn: Optional[Callable] = None, ): self._db = db self._embedder = embedder @@ -100,6 +102,7 @@ def __init__( self._assign_to_scene_fn = assign_to_scene_fn self._update_profiles_fn = update_profiles_fn self._store_prospective_scenes_fn = store_prospective_scenes_fn + self._persist_categories_fn = persist_categories_fn # ------------------------------------------------------------------ # Convenience accessors for lazy processors @@ -190,6 +193,118 @@ def _normalize_connector_id(self, connector_id): def _infer_scope(self, **kwargs): return self._scope_resolver.infer_scope(**kwargs) if self._scope_resolver else "agent" + def _persist_categories(self) -> None: + if self._persist_categories_fn: + self._persist_categories_fn() + + def _render_existing_categories(self, limit: int = 30) -> Optional[str]: + cat_proc = self._category_processor + if not cat_proc: + return None + cats = cat_proc.get_all_categories() + if not cats: + return None + return "\n".join( + f"- {c['id']}: {c['name']} — {c.get('description', '')}" + for c in cats[:limit] + ) + + def _insert_fact_vectors( + self, + *, + fact_entries: List[Dict[str, Any]], + warning_prefix: str, + ) -> float: + if not fact_entries or not self._vector_store: + return 0.0 + + embed_calls = 0.0 + fact_embeddings: List[List[float]] = [] + try: + fact_texts = [entry["fact_text"] for entry in fact_entries] + for start in range(0, len(fact_texts), 50): + sub = fact_texts[start:start + 50] + fact_embeddings.extend( + self._embedder.embed_batch(sub, memory_action="add") + ) + embed_calls += 1.0 + + fact_vectors: List[List[float]] = [] + fact_payloads: List[Dict[str, Any]] = [] + fact_ids: List[str] = [] + for entry, fact_embedding in zip(fact_entries, fact_embeddings): + fact_vectors.append(fact_embedding) + fact_payloads.append( + { + "memory_id": entry["memory_id"], + "is_fact": True, + "fact_index": entry["fact_index"], + "fact_text": entry["fact_text"], + "user_id": entry.get("user_id"), + "agent_id": entry.get("agent_id"), + } + ) + fact_ids.append( + f"{entry['memory_id']}__fact_{entry['fact_index']}" + ) + if fact_vectors: + self._vector_store.insert( + vectors=fact_vectors, + payloads=fact_payloads, + ids=fact_ids, + ) + except Exception as exc: + logger.warning("%s: %s", warning_prefix, exc) + return 0.0 + + return embed_calls + + def _apply_entity_updates( + self, + *, + memory_id: str, + entities: Optional[List[Any]], + warning_prefix: str, + auto_link: bool = True, + ) -> None: + knowledge_graph = self._graph + if not knowledge_graph or not entities: + return + try: + for entity in entities: + existing_ent = knowledge_graph._get_or_create_entity( + entity.name, entity.entity_type, + ) + existing_ent.memory_ids.add(memory_id) + knowledge_graph.memory_entities[memory_id] = { + entity.name for entity in entities + } + if auto_link and self._graph_config.auto_link_entities: + knowledge_graph.link_by_shared_entities(memory_id) + except Exception as exc: + logger.warning("%s for %s: %s", warning_prefix, memory_id, exc) + + def _apply_profile_updates_batch( + self, + *, + memory_id: str, + user_id: str, + profile_updates: Optional[List[Any]], + warning_prefix: str, + ) -> None: + profile_proc = self._profile_processor + if not profile_proc or not profile_updates: + return + try: + for profile_update in profile_updates: + profile_proc.apply_update( + profile_update=profile_update, + memory_id=memory_id, + user_id=user_id, + ) + except Exception as exc: + logger.warning("%s for %s: %s", warning_prefix, memory_id, exc) + # ------------------------------------------------------------------ # Extracted public methods # ------------------------------------------------------------------ @@ -1177,6 +1292,644 @@ def process_single_memory_lite( "enrichment_status": "pending", } + def process_memory_batch( + self, + items: List[Dict[str, Any]], + *, + user_id: Optional[str], + agent_id: Optional[str], + run_id: Optional[str], + app_id: Optional[str], + metadata: Optional[Dict[str, Any]], + filters: Optional[Dict[str, Any]], + initial_strength: float, + echo_depth: Optional[str], + batch_config, + **common_kwargs: Any, + ) -> List[Dict[str, Any]]: + """Process a batch of memory items with batched echo/embed/DB.""" + contents: List[str] = [] + item_metadata_list: List[Dict[str, Any]] = [] + for item in items: + content = item.get("content") or item.get("messages", "") + if isinstance(content, list): + content = " ".join( + m.get("content", "") for m in content if isinstance(m, dict) + ) + contents.append(str(content).strip()) + item_meta = dict(metadata or {}) + item_meta.update(item.get("metadata") or {}) + item_metadata_list.append(item_meta) + + batch_llm_calls_total = 0.0 + batch_embed_calls_total = 0.0 + batch_input_tokens_total = 0.0 + batch_output_tokens_total = 0.0 + + echo_results = [None] * len(contents) + category_results = [None] * len(contents) + enrichment_results = [None] * len(contents) + + enrichment_config = getattr(self._config, "enrichment", None) + echo_proc = self._echo_processor + cat_proc = self._category_processor + unified = self._unified_enrichment + use_unified = ( + unified is not None + and self._echo_config.enable_echo + and batch_config.batch_echo + ) + + if use_unified: + try: + depth_override = ( + EchoDepth(echo_depth) + if echo_depth + else EchoDepth(self._echo_config.default_depth) + ) + existing_cats = self._render_existing_categories() + enrich_batch_size = ( + enrichment_config.max_batch_size if enrichment_config else 10 + ) + for start in range(0, len(contents), enrich_batch_size): + end = min(start + enrich_batch_size, len(contents)) + sub_contents = contents[start:end] + sub_results = unified.enrich_batch( + sub_contents, + depth=depth_override, + existing_categories=existing_cats, + include_entities=( + enrichment_config.include_entities + if enrichment_config + else True + ), + include_profiles=( + enrichment_config.include_profiles + if enrichment_config + else True + ), + ) + sub_input_tokens = sum( + estimate_token_count(c) for c in sub_contents + ) + sub_input_tokens += estimate_token_count(existing_cats) + batch_llm_calls_total += 1.0 + batch_input_tokens_total += sub_input_tokens + batch_output_tokens_total += estimate_output_tokens( + sub_input_tokens + ) + for offset, enrichment in enumerate(sub_results): + idx = start + offset + if enrichment.echo_result: + echo_results[idx] = enrichment.echo_result + if enrichment.category_match: + category_results[idx] = enrichment.category_match + enrichment_results[idx] = enrichment + logger.info( + "Unified batch enrichment completed for %d memories", + len(contents), + ) + except Exception as exc: + logger.warning( + "Unified batch enrichment failed, falling back to separate: %s", + exc, + ) + echo_results = [None] * len(contents) + category_results = [None] * len(contents) + enrichment_results = [None] * len(contents) + use_unified = False + + if not use_unified: + if echo_proc and self._echo_config.enable_echo and batch_config.batch_echo: + depth_override = ( + EchoDepth(echo_depth) + if echo_depth + else EchoDepth(self._echo_config.default_depth) + ) + if depth_override != EchoDepth.SHALLOW: + echo_input_tokens = sum( + estimate_token_count(c) for c in contents if c + ) + non_empty_count = sum(1 for c in contents if c) + batch_llm_calls_total += float(non_empty_count) + batch_input_tokens_total += echo_input_tokens + batch_output_tokens_total += estimate_output_tokens( + echo_input_tokens + ) + try: + echo_results = echo_proc.process_batch( + contents, depth=depth_override + ) + except Exception as exc: + logger.warning( + "Batch echo failed, processing individually: %s", + exc, + ) + for idx, content in enumerate(contents): + if not content: + continue + try: + depth_override = ( + EchoDepth(echo_depth) if echo_depth else None + ) + echo_results[idx] = echo_proc.process( + content, depth=depth_override + ) + except Exception as fallback_exc: + logger.debug( + "Individual echo fallback failed for batch item %d: %s", + idx, + fallback_exc, + ) + + if ( + cat_proc + and self._category_config.auto_categorize + and batch_config.batch_category + ): + if self._category_config.use_llm_categorization: + cat_input_tokens = sum( + estimate_token_count(c) for c in contents if c + ) + non_empty_count = sum(1 for c in contents if c) + batch_llm_calls_total += float(non_empty_count) + batch_input_tokens_total += cat_input_tokens + batch_output_tokens_total += estimate_output_tokens( + cat_input_tokens + ) + try: + category_results = cat_proc.detect_categories_batch( + contents, + use_llm=self._category_config.use_llm_categorization, + ) + except Exception as exc: + logger.warning("Batch category failed: %s", exc) + + primary_texts: List[str] = [] + for idx, content in enumerate(contents): + primary_texts.append( + self.select_primary_text(content, echo_results[idx]) + ) + + if batch_config.batch_embed: + try: + embeddings: List[List[float]] = [] + for start in range(0, len(primary_texts), 50): + sub = primary_texts[start:start + 50] + embeddings.extend( + self._embedder.embed_batch(sub, memory_action="add") + ) + batch_embed_calls_total += 1.0 + except Exception as exc: + logger.warning( + "Batch embed failed, falling back to sequential: %s", + exc, + ) + embeddings = [ + self._embedder.embed(text, memory_action="add") + for text in primary_texts + ] + batch_embed_calls_total += float(len(primary_texts)) + else: + embeddings = [ + self._embedder.embed(text, memory_action="add") + for text in primary_texts + ] + batch_embed_calls_total += float(len(primary_texts)) + + echo_node_texts: List[str] = [] + for idx, content in enumerate(contents): + echo_result = echo_results[idx] + primary_text = primary_texts[idx] + if primary_text != content: + cleaned = content.strip() + if cleaned: + echo_node_texts.append(cleaned) + if echo_result: + for paraphrase in echo_result.paraphrases: + cleaned = str(paraphrase).strip() + if cleaned: + echo_node_texts.append(cleaned) + for question in echo_result.questions: + cleaned = str(question).strip() + if cleaned: + echo_node_texts.append(cleaned) + + embedding_cache: Dict[str, List[float]] = {} + if echo_node_texts: + unique_texts = list(dict.fromkeys(echo_node_texts)) + try: + all_echo_embeddings: List[List[float]] = [] + for start in range(0, len(unique_texts), 50): + sub = unique_texts[start:start + 50] + all_echo_embeddings.extend( + self._embedder.embed_batch(sub, memory_action="add") + ) + batch_embed_calls_total += 1.0 + for text, emb in zip(unique_texts, all_echo_embeddings): + embedding_cache[text] = emb + logger.info( + "Batch-embedded %d echo node texts in %d API calls", + len(unique_texts), + (len(unique_texts) + 49) // 50, + ) + except Exception as exc: + logger.warning( + "Batch echo node embedding failed, will embed individually: %s", + exc, + ) + + processed_metadata_base, effective_filters = build_filters_and_metadata( + user_id=user_id, + agent_id=agent_id, + run_id=run_id, + input_metadata=metadata, + input_filters=filters, + ) + if app_id: + processed_metadata_base["app_id"] = app_id + + now = datetime.now(timezone.utc).isoformat() + memory_records: List[Dict[str, Any]] = [] + record_entries: List[Dict[str, Any]] = [] + episodic_rows: List[Tuple[str, Optional[str], str, Dict[str, Any]]] = [] + vector_batch: List[Tuple[List[List[float]], List[Dict[str, Any]], List[str]]] = [] + results: List[Dict[str, Any]] = [] + + for idx, content in enumerate(contents): + if not content: + continue + + owner_user_id = items[idx].get("user_id") or user_id + memory_id = str(uuid.uuid4()) + mem_metadata = dict(processed_metadata_base) + mem_metadata.update(item_metadata_list[idx]) + mem_metadata = attach_bitemporal_metadata( + mem_metadata, observed_time=now + ) + + echo_result = echo_results[idx] + effective_strength = initial_strength + mem_categories = list(items[idx].get("categories") or []) + + if echo_result: + effective_strength = ( + initial_strength * echo_result.strength_multiplier + ) + mem_metadata.update(echo_result.to_metadata()) + if not mem_categories and echo_result.category: + mem_categories = [echo_result.category] + + cat_match = category_results[idx] + if cat_match and not mem_categories: + mem_categories = [cat_match.category_id] + mem_metadata["category_confidence"] = cat_match.confidence + mem_metadata["category_auto"] = True + + embedding = embeddings[idx] + namespace_value = str( + mem_metadata.get("namespace", "default") or "default" + ).strip() or "default" + memory_type = self.classify_memory_type( + mem_metadata, mem_metadata.get("role", "user") + ) + + s_fast_val = s_mid_val = s_slow_val = None + distillation_config = self._distillation_config + if distillation_config and distillation_config.enable_multi_trace: + s_fast_val, s_mid_val, s_slow_val = initialize_traces( + effective_strength, is_new=True + ) + + memory_data = { + "id": memory_id, + "memory": content, + "user_id": owner_user_id, + "agent_id": items[idx].get("agent_id") or agent_id, + "run_id": items[idx].get("run_id") or run_id, + "app_id": items[idx].get("app_id") or app_id, + "metadata": mem_metadata, + "categories": mem_categories, + "immutable": items[idx].get("immutable", False), + "expiration_date": items[idx].get("expiration_date"), + "created_at": now, + "updated_at": now, + "layer": "sml", + "strength": effective_strength, + "access_count": 0, + "last_accessed": now, + "embedding": embedding, + "confidentiality_scope": "work", + "source_type": "mcp", + "source_app": items[idx].get("source_app"), + "source_event_id": mem_metadata.get("source_event_id"), + "decay_lambda": self._fade_config.sml_decay_rate, + "status": "active", + "importance": mem_metadata.get("importance", 0.5), + "sensitivity": mem_metadata.get("sensitivity", "normal"), + "namespace": namespace_value, + "memory_type": memory_type, + "s_fast": s_fast_val, + "s_mid": s_mid_val, + "s_slow": s_slow_val, + } + memory_records.append(memory_data) + record_entries.append( + { + "source_index": idx, + "record": memory_data, + "user_id": owner_user_id, + "content": content, + } + ) + episodic_rows.append((memory_id, owner_user_id, content, mem_metadata)) + + vectors, payloads, vector_ids = build_index_vectors( + memory_id=memory_id, + content=content, + primary_text=primary_texts[idx], + embedding=embedding, + echo_result=echo_result, + metadata=mem_metadata, + categories=mem_categories, + user_id=owner_user_id, + agent_id=items[idx].get("agent_id") or agent_id, + run_id=items[idx].get("run_id") or run_id, + app_id=items[idx].get("app_id") or app_id, + embedder=self._embedder, + embedding_cache=embedding_cache if embedding_cache else None, + ) + if vectors: + vector_batch.append((vectors, payloads, vector_ids)) + + results.append( + { + "id": memory_id, + "memory": content, + "event": "ADD", + "layer": "sml", + "strength": effective_strength, + "echo_depth": ( + echo_result.echo_depth.value if echo_result else None + ), + "categories": mem_categories, + "namespace": namespace_value, + "memory_type": memory_type, + } + ) + + if memory_records: + try: + self._db.add_memories_batch(memory_records) + except Exception as exc: + logger.error( + "Batch DB insert failed, falling back to sequential: %s", + exc, + ) + for record in memory_records: + self._db.add_memory(record) + + for vectors, payloads, vector_ids in vector_batch: + try: + self._vector_store.insert( + vectors=vectors, + payloads=payloads, + ids=vector_ids, + ) + except Exception as exc: + logger.error("Vector insert failed in batch: %s", exc) + + for memory_id, owner_user_id, content, mem_metadata in episodic_rows: + _index_episodic( + db=self._db, + config=self._config, + memory_id=memory_id, + user_id=owner_user_id, + content=content, + metadata=mem_metadata, + ) + + if cat_proc: + for entry in record_entries: + record = entry["record"] + if record.get("categories"): + for cat_id in record["categories"]: + cat_proc.update_category_stats( + cat_id, + record["strength"], + is_addition=True, + ) + + fact_entries: List[Dict[str, Any]] = [] + for entry in record_entries: + idx = entry["source_index"] + enrichment = enrichment_results[idx] + if not enrichment or not enrichment.facts: + continue + for fact_index, fact_text in enumerate(enrichment.facts[:8]): + cleaned = fact_text.strip() + if cleaned and len(cleaned) >= 10: + fact_entries.append( + { + "memory_id": entry["record"]["id"], + "fact_index": fact_index, + "fact_text": cleaned, + "user_id": entry["user_id"], + "agent_id": entry["record"].get("agent_id"), + } + ) + batch_embed_calls_total += self._insert_fact_vectors( + fact_entries=fact_entries, + warning_prefix="Batch fact embedding/insert failed", + ) + + engram_extractor = self._engram_extractor + context_resolver = self._context_resolver + for entry in record_entries: + idx = entry["source_index"] + record = entry["record"] + enrichment = enrichment_results[idx] + if enrichment: + self._apply_entity_updates( + memory_id=record["id"], + entities=enrichment.entities, + warning_prefix="Entity linking failed", + ) + self._apply_profile_updates_batch( + memory_id=record["id"], + user_id=record.get("user_id") or user_id or "default", + profile_updates=enrichment.profile_updates, + warning_prefix="Profile update failed", + ) + + if engram_extractor: + try: + engram = engram_extractor.extract( + content=record.get("memory", ""), + session_context=None, + existing_metadata=record.get("metadata"), + user_id=record.get("user_id") or user_id or "default", + ) + if context_resolver and engram: + context_resolver.store_engram(engram, record["id"]) + except Exception as exc: + logger.warning( + "Engram extraction failed for %s: %s", + record["id"], + exc, + ) + + if episodic_rows: + sample_count = float(len(episodic_rows)) + llm_calls_per_memory = batch_llm_calls_total / sample_count + input_tokens_per_memory = batch_input_tokens_total / sample_count + output_tokens_per_memory = batch_output_tokens_total / sample_count + embed_calls_per_memory = batch_embed_calls_total / sample_count + for _, owner_user_id, _, _ in episodic_rows: + self._record_cost( + phase="write", + user_id=owner_user_id, + llm_calls=llm_calls_per_memory, + input_tokens=input_tokens_per_memory, + output_tokens=output_tokens_per_memory, + embed_calls=embed_calls_per_memory, + ) + + return results + + def enrich_pending( + self, + *, + user_id: str = "default", + batch_size: int = 10, + max_batches: int = 5, + ) -> Dict[str, Any]: + """Batch-enrich memories that were stored with deferred enrichment.""" + limit = batch_size * max_batches + pending = self._db.get_pending_enrichment(user_id=user_id, limit=limit) + if not pending: + return {"enriched_count": 0, "batches": 0, "remaining": 0} + + enriched_count = 0 + batches_processed = 0 + unified = self._unified_enrichment + + for start in range(0, len(pending), batch_size): + batch = pending[start:start + batch_size] + contents = [memory.get("memory", "") for memory in batch] + + enrichment_results = None + if unified is not None: + try: + enrichment_results = unified.enrich_batch( + contents, + depth=EchoDepth.MEDIUM, + existing_categories=self._render_existing_categories(), + include_entities=True, + include_profiles=True, + ) + except Exception as exc: + logger.warning( + "Unified batch enrichment failed in enrich_pending: %s", + exc, + ) + enrichment_results = None + + if enrichment_results is None: + enrichment_results = [] + for content in contents: + if unified is not None: + try: + enrichment_results.append( + unified.enrich(content, depth=EchoDepth.MEDIUM) + ) + except Exception as exc: + logger.debug( + "Single-memory enrichment fallback failed: %s", + exc, + ) + enrichment_results.append(None) + else: + enrichment_results.append(None) + + db_updates: List[Dict[str, Any]] = [] + fact_entries: List[Dict[str, Any]] = [] + + for memory, enrichment in zip(batch, enrichment_results): + mem_id = memory["id"] + mem_meta = memory.get("metadata", {}) or {} + mem_cats = memory.get("categories", []) or [] + + if enrichment: + if enrichment.echo_result: + mem_meta.update(enrichment.echo_result.to_metadata()) + if not mem_cats and enrichment.echo_result.category: + mem_cats = [enrichment.echo_result.category] + + if enrichment.category_match and not mem_cats: + mem_cats = [enrichment.category_match.category_id] + mem_meta["category_confidence"] = ( + enrichment.category_match.confidence + ) + mem_meta["category_auto"] = True + + if enrichment.facts: + mem_meta["enrichment_facts"] = enrichment.facts[:8] + + self._apply_entity_updates( + memory_id=mem_id, + entities=enrichment.entities, + warning_prefix="Entity linking failed during enrichment", + ) + self._apply_profile_updates_batch( + memory_id=mem_id, + user_id=user_id, + profile_updates=enrichment.profile_updates, + warning_prefix="Profile update failed during enrichment", + ) + + if enrichment.facts: + for fact_index, fact_text in enumerate(enrichment.facts[:8]): + cleaned = fact_text.strip() + if cleaned and len(cleaned) >= 10: + fact_entries.append( + { + "memory_id": mem_id, + "fact_index": fact_index, + "fact_text": cleaned, + "user_id": user_id, + "agent_id": memory.get("agent_id"), + } + ) + + mem_meta["enrichment_status"] = "complete" + db_updates.append( + { + "id": mem_id, + "metadata": mem_meta, + "categories": mem_cats, + "enrichment_status": "complete", + } + ) + enriched_count += 1 + + self._insert_fact_vectors( + fact_entries=fact_entries, + warning_prefix="Fact embedding failed during enrichment", + ) + self._db.update_enrichment_bulk(db_updates) + batches_processed += 1 + + remaining_count = len( + self._db.get_pending_enrichment(user_id=user_id, limit=1) + ) + + return { + "enriched_count": enriched_count, + "batches": batches_processed, + "remaining": remaining_count, + } + def extract_memories( self, messages: List[Dict[str, Any]], diff --git a/dhee/simple.py b/dhee/simple.py index 2142acc..376f9b0 100644 --- a/dhee/simple.py +++ b/dhee/simple.py @@ -18,11 +18,13 @@ from __future__ import annotations +import logging import os import tempfile from pathlib import Path from typing import Any, Dict, List, Optional, Union +from dhee.checkpoint_runtime import run_checkpoint_common from dhee.configs.base import ( CategoryMemConfig, EchoMemConfig, @@ -34,6 +36,8 @@ ) from dhee.memory.main import FullMemory +logger = logging.getLogger(__name__) + def _detect_provider() -> str: """Detect which LLM/embedder provider to use based on environment.""" @@ -366,11 +370,33 @@ def provider(self) -> str: """Current LLM/embedder provider.""" return self._provider + @property + def memory(self) -> FullMemory: + """Expose the configured runtime memory engine for advanced integrations.""" + return self._memory + @property def data_dir(self) -> Path: """Data storage directory.""" return self._data_dir + def enrich_pending( + self, + user_id: str = "default", + batch_size: int = 10, + max_batches: int = 5, + ) -> Dict[str, Any]: + """Run deferred enrichment using the configured runtime memory engine.""" + return self.memory.enrich_pending( + user_id=user_id, + batch_size=batch_size, + max_batches=max_batches, + ) + + def close(self) -> None: + """Release runtime resources held by the underlying memory engine.""" + self._memory.close() + class Dhee: """4-tool HyperAgent interface — the simplest way to make any agent intelligent. @@ -550,7 +576,7 @@ def context( hyper_ctx = self._buddhi.get_hyper_context( user_id=uid, task_description=task_description, - memory=self._engram._memory, + memory=self._engram.memory, ) if operational: return hyper_ctx.to_operational_dict() @@ -623,73 +649,29 @@ def checkpoint( if not what_worked: what_worked = outcome.get("what_worked") - result: Dict[str, Any] = {} - - # 1. Session digest - try: - from dhee.core.kernel import save_session_digest - digest = save_session_digest( - task_summary=summary, - agent_id=agent_id, - repo=repo, - status=status, - decisions_made=decisions, - files_touched=files_touched, - todos_remaining=todos, - ) - result["session_saved"] = True - if isinstance(digest, dict): - result["session_id"] = digest.get("session_id") - except Exception: - result["session_saved"] = False - - # 2. Batch enrichment of deferred memories - memory = self._engram._memory - if hasattr(memory, "enrich_pending"): - try: - enrich_result = memory.enrich_pending( - user_id=uid, batch_size=10, max_batches=5, - ) - enriched = enrich_result.get("enriched_count", 0) - if enriched > 0: - result["memories_enriched"] = enriched - except Exception: - pass - - # 3. Outcome recording - clamped_score = None - if outcome_score is not None: - clamped_score = max(0.0, min(1.0, float(outcome_score))) - if task_type and clamped_score is not None: - insight = self._buddhi.record_outcome( - user_id=uid, task_type=task_type, score=clamped_score, - ) - result["outcome_recorded"] = True - if insight: - result["auto_insight"] = insight.to_dict() - - # 4. Insight synthesis - if any([what_worked, what_failed, key_decision]): - insights = self._buddhi.reflect( - user_id=uid, - task_type=task_type or "general", - what_worked=what_worked, - what_failed=what_failed, - key_decision=key_decision, - outcome_score=clamped_score, - ) - result["insights_created"] = len(insights) - - # 5. Intention storage - if remember_to: - intention = self._buddhi.store_intention( - user_id=uid, - description=remember_to, - trigger_keywords=trigger_keywords, - ) - result["intention_stored"] = intention.to_dict() - - return result + return run_checkpoint_common( + logger=logger, + log_prefix="Checkpoint", + user_id=uid, + summary=summary, + status=status, + agent_id=agent_id, + repo=repo, + decisions=decisions, + files_touched=files_touched, + todos=todos, + task_type=task_type, + outcome_score=outcome_score, + what_worked=what_worked, + what_failed=what_failed, + key_decision=key_decision, + remember_to=remember_to, + trigger_keywords=trigger_keywords, + enrich_pending_fn=self._engram.enrich_pending, + record_outcome_fn=self._buddhi.record_outcome, + reflect_fn=self._buddhi.reflect, + store_intention_fn=self._buddhi.store_intention, + ) # ------------------------------------------------------------------ # Auto-lifecycle (driven by SessionTracker) @@ -704,14 +686,37 @@ def _handle_tracker_signals(self, signals: Dict[str, Any], user_id: str) -> None if signals.get("needs_auto_checkpoint"): args = signals.get("auto_checkpoint_args", {}) try: - self.checkpoint(user_id=user_id, **args) - except Exception: - pass + checkpoint_result = self.checkpoint(user_id=user_id, **args) + for warning in checkpoint_result.get("warnings", []): + logger.warning("Auto-checkpoint warning: %s", warning) + except Exception as exc: + logger.warning("Auto-checkpoint failed: %s", exc, exc_info=True) # Auto-context for new session if signals.get("needs_auto_context"): task = signals.get("inferred_task") try: self.context(task_description=task, user_id=user_id) - except Exception: - pass + except Exception as exc: + logger.warning("Auto-context failed: %s", exc, exc_info=True) + + def close(self) -> None: + """Flush cognition state and release runtime resources.""" + errors: List[str] = [] + + try: + self._buddhi.flush() + except Exception as exc: + logger.exception("Dhee close failed for buddhi.flush") + errors.append(f"buddhi.flush: {type(exc).__name__}: {exc}") + + try: + self._engram.close() + except Exception as exc: + logger.exception("Dhee close failed for engram.close") + errors.append(f"engram.close: {type(exc).__name__}: {exc}") + + if errors: + raise RuntimeError( + "Failed to close Dhee resources: " + "; ".join(errors) + ) diff --git a/dhee/utils/math.py b/dhee/utils/math.py index f67e52e..0d538ff 100644 --- a/dhee/utils/math.py +++ b/dhee/utils/math.py @@ -1,25 +1,47 @@ -"""Vector math — Rust-powered, no fallbacks.""" +"""Vector math — Rust-accelerated with pure-Python fallback.""" +import math from typing import List, Optional -from dhee_accel import ( - cosine_similarity as _rs_cosine, - cosine_similarity_batch as _rs_cosine_batch, -) -ACCEL_AVAILABLE = True +try: + from dhee_accel import ( + cosine_similarity as _rs_cosine, + cosine_similarity_batch as _rs_cosine_batch, + ) + ACCEL_AVAILABLE = True +except ImportError: + ACCEL_AVAILABLE = False + + +def _py_cosine(a: List[float], b: List[float]) -> float: + """Pure-Python cosine similarity.""" + dot = sum(x * y for x, y in zip(a, b)) + norm_a = math.sqrt(sum(x * x for x in a)) + norm_b = math.sqrt(sum(x * x for x in b)) + denom = norm_a * norm_b + if denom == 0.0: + return 0.0 + result = dot / denom + if math.isnan(result) or math.isinf(result): + return 0.0 + return result def cosine_similarity(a: Optional[List[float]], b: Optional[List[float]]) -> float: - """Compute cosine similarity between two vectors (Rust-accelerated).""" + """Compute cosine similarity between two vectors.""" if not a or not b or len(a) != len(b): return 0.0 - return _rs_cosine(list(a), list(b)) + if ACCEL_AVAILABLE: + return _rs_cosine(list(a), list(b)) + return _py_cosine(list(a), list(b)) def cosine_similarity_batch( query: List[float], store: List[List[float]] ) -> List[float]: - """Compute cosine similarity of *query* against every vector in *store* (SIMD).""" + """Compute cosine similarity of *query* against every vector in *store*.""" if not query or not store: return [0.0] * len(store) - return _rs_cosine_batch(list(query), [list(v) for v in store]) + if ACCEL_AVAILABLE: + return _rs_cosine_batch(list(query), [list(v) for v in store]) + return [_py_cosine(list(query), list(v)) if len(v) == len(query) else 0.0 for v in store] diff --git a/engram-bus/engram_bus/bus.py b/engram-bus/engram_bus/bus.py index a458c7a..937f9c7 100644 --- a/engram-bus/engram_bus/bus.py +++ b/engram-bus/engram_bus/bus.py @@ -307,15 +307,25 @@ def get_session( self, session_id: Optional[str] = None, agent_id: Optional[str] = None, + repo: Optional[str] = None, ) -> Optional[Dict]: - return self._ensure_store().get_session(session_id=session_id, agent_id=agent_id) + return self._ensure_store().get_session( + session_id=session_id, + agent_id=agent_id, + repo=repo, + ) def list_sessions( self, agent_id: Optional[str] = None, status: Optional[str] = None, + repo: Optional[str] = None, ) -> List[Dict]: - return self._ensure_store().list_sessions(agent_id=agent_id, status=status) + return self._ensure_store().list_sessions( + agent_id=agent_id, + status=status, + repo=repo, + ) def update_session(self, session_id: str, **kwargs: Any) -> None: self._ensure_store().update_session(session_id, **kwargs) diff --git a/engram-bus/engram_bus/server.py b/engram-bus/engram_bus/server.py index 3666786..5467b3d 100644 --- a/engram-bus/engram_bus/server.py +++ b/engram-bus/engram_bus/server.py @@ -156,6 +156,7 @@ def cb(_topic: str, data: Any, agent_id: Optional[str]) -> None: result = bus.get_session( session_id=req.get("session_id"), agent_id=req.get("agent_id"), + repo=req.get("repo"), ) return {"ok": True, "session": result} @@ -163,6 +164,7 @@ def cb(_topic: str, data: Any, agent_id: Optional[str]) -> None: result = bus.list_sessions( agent_id=req.get("agent_id"), status=req.get("status"), + repo=req.get("repo"), ) return {"ok": True, "sessions": result} @@ -405,16 +407,32 @@ def get_session( self, session_id: Optional[str] = None, agent_id: Optional[str] = None, + repo: Optional[str] = None, ) -> Optional[Dict]: - resp = self._send({"op": "get_session", "session_id": session_id, "agent_id": agent_id}) + resp = self._send( + { + "op": "get_session", + "session_id": session_id, + "agent_id": agent_id, + "repo": repo, + } + ) return resp.get("session") def list_sessions( self, agent_id: Optional[str] = None, status: Optional[str] = None, + repo: Optional[str] = None, ) -> List[Dict]: - resp = self._send({"op": "list_sessions", "agent_id": agent_id, "status": status}) + resp = self._send( + { + "op": "list_sessions", + "agent_id": agent_id, + "status": status, + "repo": repo, + } + ) return resp.get("sessions", []) def update_session(self, session_id: str, **kwargs: Any) -> None: diff --git a/engram-bus/engram_bus/store.py b/engram-bus/engram_bus/store.py index 5580b36..00d6376 100644 --- a/engram-bus/engram_bus/store.py +++ b/engram-bus/engram_bus/store.py @@ -110,6 +110,7 @@ def get_session( self, session_id: Optional[str] = None, agent_id: Optional[str] = None, + repo: Optional[str] = None, ) -> Optional[Dict]: with self._lock: if session_id: @@ -117,9 +118,32 @@ def get_session( "SELECT * FROM handoff_sessions WHERE id = ?", (session_id,) ).fetchone() elif agent_id: + if repo is not None: + row = self._conn.execute( + """ + SELECT * FROM handoff_sessions + WHERE agent_id = ? AND repo = ? + ORDER BY updated DESC LIMIT 1 + """, + (agent_id, repo), + ).fetchone() + else: + row = self._conn.execute( + """ + SELECT * FROM handoff_sessions + WHERE agent_id = ? + ORDER BY updated DESC LIMIT 1 + """, + (agent_id,), + ).fetchone() + elif repo is not None: row = self._conn.execute( - "SELECT * FROM handoff_sessions WHERE agent_id = ? ORDER BY updated DESC LIMIT 1", - (agent_id,), + """ + SELECT * FROM handoff_sessions + WHERE repo = ? + ORDER BY updated DESC LIMIT 1 + """, + (repo,), ).fetchone() else: return None @@ -131,6 +155,7 @@ def list_sessions( self, agent_id: Optional[str] = None, status: Optional[str] = None, + repo: Optional[str] = None, ) -> List[Dict]: clauses: List[str] = [] params: List[Any] = [] @@ -140,6 +165,9 @@ def list_sessions( if status: clauses.append("status = ?") params.append(status) + if repo is not None: + clauses.append("repo = ?") + params.append(repo) where = " WHERE " + " AND ".join(clauses) if clauses else "" with self._lock: rows = self._conn.execute( diff --git a/engram-bus/tests/test_bus.py b/engram-bus/tests/test_bus.py index fde225e..5dc2148 100644 --- a/engram-bus/tests/test_bus.py +++ b/engram-bus/tests/test_bus.py @@ -546,6 +546,18 @@ def test_get_session_by_agent_id(self): assert session["task_summary"] == "second" # most recent bus.close() + def test_get_session_scoped_by_repo(self): + bus = Bus() + bus.save_session("agent-1", repo="/tmp/repo-a", task_summary="repo a") + bus.save_session("agent-1", repo="/tmp/repo-b", task_summary="repo b") + + session = bus.get_session(agent_id="agent-1", repo="/tmp/repo-a") + + assert session is not None + assert session["task_summary"] == "repo a" + assert session["repo"] == "/tmp/repo-a" + bus.close() + def test_get_session_not_found(self): bus = Bus() assert bus.get_session(session_id="nonexistent") is None @@ -562,6 +574,18 @@ def test_list_sessions(self): assert len(a1_sessions) == 2 bus.close() + def test_list_sessions_by_repo(self): + bus = Bus() + bus.save_session("a1", repo="/tmp/repo-a", task_summary="t1") + bus.save_session("a2", repo="/tmp/repo-b", task_summary="t2") + bus.save_session("a1", repo="/tmp/repo-a", task_summary="t3") + + repo_a_sessions = bus.list_sessions(repo="/tmp/repo-a") + + assert len(repo_a_sessions) == 2 + assert all(session["repo"] == "/tmp/repo-a" for session in repo_a_sessions) + bus.close() + def test_list_sessions_by_status(self): bus = Bus() sid = bus.save_session("a1") diff --git a/engram/__init__.py b/engram/__init__.py new file mode 100644 index 0000000..9b596d9 --- /dev/null +++ b/engram/__init__.py @@ -0,0 +1,19 @@ +"""Compatibility shim for the historical ``engram`` package namespace. + +The project was renamed to ``dhee`` but several legacy modules, tests, and +sidecar packages still import ``engram.*``. Keep that import surface working +by pointing Python's package search path at the live ``dhee`` package. +""" + +from __future__ import annotations + +import dhee as _dhee + +from dhee import * # noqa: F401,F403 + +__all__ = getattr(_dhee, "__all__", []) +__path__ = list(getattr(_dhee, "__path__", [])) +__doc__ = _dhee.__doc__ +__file__ = __file__ +__package__ = "engram" +__version__ = getattr(_dhee, "__version__", "0.0.0") diff --git a/pyproject.toml b/pyproject.toml index 62b7e24..09b62b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dhee" -version = "3.0.0" +version = "3.0.1" description = "Cognition layer for AI agents — persistent memory, performance tracking, and insight synthesis" readme = "README.md" requires-python = ">=3.9" @@ -27,7 +27,6 @@ dependencies = [ "pydantic>=2.0", "requests>=2.28.0", "pyyaml>=6.0", - "dhee-accel>=0.1.0", ] [project.optional-dependencies] @@ -56,13 +55,14 @@ all = [ "mcp>=1.0.0", "fastapi>=0.100.0", "uvicorn>=0.20.0", - "dhee-accel>=0.1.0", "engram-bus>=0.1.0", "llama-cpp-python>=0.3", "sentence-transformers>=3.0", ] dev = [ "pytest>=7.0.0", + "pytest-asyncio>=0.23.0", + "openai>=1.0.0", "build>=1.0.0", "twine>=5.0.0", ] @@ -81,4 +81,4 @@ Changelog = "https://github.com/Sankhya-AI/Dhee/blob/main/CHANGELOG.md" [tool.setuptools.packages.find] where = ["."] -include = ["dhee*"] +include = ["dhee*", "engram*"] diff --git a/pytest.ini b/pytest.ini index b05c451..1d4fe0a 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,11 +1,11 @@ [pytest] testpaths = tests - dhee-bus/tests - dhee-enterprise/tests + engram-bus/tests + engram-enterprise/tests pythonpath = . - dhee-bus - dhee-enterprise + engram-bus + engram-enterprise markers = integration: tests that require external services or credentials diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..426da20 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test support package for shared suite helpers.""" diff --git a/tests/_live.py b/tests/_live.py new file mode 100644 index 0000000..dea423e --- /dev/null +++ b/tests/_live.py @@ -0,0 +1,80 @@ +"""Shared helpers for opt-in live integration tests.""" + +from __future__ import annotations + +import importlib +import os +from pathlib import Path + +import pytest + +LIVE_TESTS_ENV = "DHEE_RUN_LIVE_TESTS" +NVIDIA_KEYS = ( + "NVIDIA_API_KEY", + "NVIDIA_EMBEDDING_API_KEY", + "NVIDIA_QWEN_API_KEY", + "LLAMA_API_KEY", +) + + +def load_project_env() -> None: + """Populate environment variables from the repo-root .env if present.""" + env_path = Path(__file__).resolve().parents[1] / ".env" + if not env_path.exists(): + return + + for raw_line in env_path.read_text(encoding="utf-8").splitlines(): + line = raw_line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, _, value = line.partition("=") + os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'")) + + +def _missing_optional_packages(packages: tuple[str, ...]) -> list[str]: + missing = [] + for package in packages: + try: + importlib.import_module(package) + except ImportError: + missing.append(package) + return missing + + +def require_live_nvidia_tests(*packages: str) -> None: + """Skip the current pytest module unless live NVIDIA tests are explicitly enabled.""" + load_project_env() + + if os.getenv(LIVE_TESTS_ENV) != "1": + pytest.skip( + f"set {LIVE_TESTS_ENV}=1 to run live NVIDIA integration tests", + allow_module_level=True, + ) + + if not any(os.environ.get(key) for key in NVIDIA_KEYS): + pytest.skip("requires NVIDIA API credentials", allow_module_level=True) + + missing = _missing_optional_packages(tuple(packages)) + if missing: + pytest.skip( + f"requires optional dependencies: {', '.join(sorted(missing))}", + allow_module_level=True, + ) + + +def ensure_live_nvidia_runtime(*packages: str) -> None: + """Raise a runtime error when a live suite is executed without prerequisites.""" + load_project_env() + + problems = [] + if os.getenv(LIVE_TESTS_ENV) != "1": + problems.append(f"set {LIVE_TESTS_ENV}=1") + if not any(os.environ.get(key) for key in NVIDIA_KEYS): + problems.append("provide NVIDIA API credentials") + + missing = _missing_optional_packages(tuple(packages)) + if missing: + problems.append(f"install optional dependencies: {', '.join(sorted(missing))}") + + if problems: + raise RuntimeError("Cannot run live NVIDIA suite: " + "; ".join(problems)) diff --git a/tests/test_auto_lifecycle.py b/tests/test_auto_lifecycle.py index 8f888c0..7411e6f 100644 --- a/tests/test_auto_lifecycle.py +++ b/tests/test_auto_lifecycle.py @@ -7,7 +7,9 @@ import time import pytest -from dhee.simple import Dhee +from dhee.memory.main import FullMemory +from dhee.adapters.base import DheePlugin +from dhee.simple import Dhee, Engram @pytest.fixture @@ -109,6 +111,27 @@ def test_checkpoint_auto_fills_outcome(self, dhee): result = dhee.checkpoint("Finished the task") assert isinstance(result, dict) + def test_checkpoint_surfaces_session_and_enrichment_errors(self, tmp_path, monkeypatch): + """checkpoint() should report degraded lifecycle work instead of hiding it.""" + d = Dhee(in_memory=True, data_dir=str(tmp_path)) + + def fail_digest(**_kwargs): + raise RuntimeError("handoff store offline") + + def fail_enrichment(**_kwargs): + raise RuntimeError("batch enrichment unavailable") + + monkeypatch.setattr("dhee.core.kernel.save_session_digest", fail_digest) + monkeypatch.setattr(d._engram.memory, "enrich_pending", fail_enrichment) + + result = d.checkpoint("Finished task") + + assert result["session_saved"] is False + assert result["session_save_error"] == "handoff store offline" + assert result["enrichment_error"] == "batch enrichment unavailable" + assert any("handoff store offline" in warning for warning in result["warnings"]) + assert any("batch enrichment unavailable" in warning for warning in result["warnings"]) + class TestShruti: """Verify shruti-tier memories are auto-detected.""" @@ -143,3 +166,141 @@ def test_disable_auto_checkpoint(self, tmp_path): d._tracker._last_activity_time = time.time() - 2.0 signals = d._tracker.on_remember("session 2", "m2") assert signals.get("needs_auto_checkpoint") is None + + +class TestEngramSurface: + def test_public_memory_property_exposes_runtime_engine(self, tmp_path): + e = Engram(provider="mock", in_memory=True, data_dir=str(tmp_path)) + assert isinstance(e.memory, FullMemory) + + def test_plugin_public_memory_property_exposes_runtime_engine(self, tmp_path): + plugin = DheePlugin(in_memory=True, data_dir=str(tmp_path)) + assert isinstance(plugin.memory, FullMemory) + + def test_engram_close_delegates_to_runtime_memory(self, tmp_path, monkeypatch): + e = Engram(provider="mock", in_memory=True, data_dir=str(tmp_path)) + closed = [] + + def _close(): + closed.append(True) + + monkeypatch.setattr(e.memory, "close", _close) + + e.close() + + assert closed == [True] + + def test_full_memory_close_raises_aggregated_errors(self, tmp_path): + e = Engram(provider="mock", in_memory=True, data_dir=str(tmp_path)) + memory = e.memory + + class BrokenEvolution: + def flush(self): + raise RuntimeError("evolution flush failed") + + class BrokenExecutor: + def shutdown(self): + raise RuntimeError("executor shutdown failed") + + class BrokenVectorStore: + def close(self): + raise RuntimeError("vector close failed") + + class BrokenDb: + def close(self): + raise RuntimeError("db close failed") + + memory._evolution_layer = BrokenEvolution() + memory._executor = BrokenExecutor() + memory.vector_store = BrokenVectorStore() + memory.db = BrokenDb() + + with pytest.raises(RuntimeError) as exc_info: + memory.close() + + message = str(exc_info.value) + assert "evolution.flush" in message + assert "evolution flush failed" in message + assert "executor.shutdown" in message + assert "executor shutdown failed" in message + assert "vector_store.close" in message + assert "vector close failed" in message + assert "db.close" in message + assert "db close failed" in message + assert memory._executor is None + assert memory.vector_store is None + assert memory.db is None + + +class TestShutdownSurface: + def test_dhee_close_flushes_cognition_and_memory(self, tmp_path, monkeypatch): + d = Dhee(in_memory=True, data_dir=str(tmp_path)) + calls = [] + + monkeypatch.setattr(d._buddhi, "flush", lambda: calls.append("buddhi")) + monkeypatch.setattr(d._engram, "close", lambda: calls.append("engram")) + + d.close() + + assert calls == ["buddhi", "engram"] + + def test_dhee_close_raises_aggregated_errors(self, tmp_path, monkeypatch): + d = Dhee(in_memory=True, data_dir=str(tmp_path)) + + def _buddhi_boom(): + raise RuntimeError("buddhi flush failed") + + def _engram_boom(): + raise RuntimeError("engram close failed") + + monkeypatch.setattr(d._buddhi, "flush", _buddhi_boom) + monkeypatch.setattr(d._engram, "close", _engram_boom) + + with pytest.raises(RuntimeError) as exc_info: + d.close() + + message = str(exc_info.value) + assert "buddhi.flush" in message + assert "buddhi flush failed" in message + assert "engram.close" in message + assert "engram close failed" in message + + def test_plugin_close_flushes_cognition_and_memory(self, tmp_path, monkeypatch): + plugin = DheePlugin(in_memory=True, data_dir=str(tmp_path)) + calls = [] + + monkeypatch.setattr(plugin._buddhi, "flush", lambda: calls.append("buddhi")) + monkeypatch.setattr(plugin._engram, "close", lambda: calls.append("engram")) + + plugin.close() + + assert calls == ["buddhi", "engram"] + + def test_plugin_cognition_health_reports_derivation_failures( + self, tmp_path, monkeypatch + ): + plugin = DheePlugin(in_memory=True, data_dir=str(tmp_path)) + + def _policy_boom(*_args, **_kwargs): + raise RuntimeError("policy health unavailable") + + def _belief_boom(*_args, **_kwargs): + raise RuntimeError("belief health unavailable") + + monkeypatch.setattr(plugin._kernel.policies, "get_user_policies", _policy_boom) + monkeypatch.setattr(plugin._kernel.beliefs, "get_contradictions", _belief_boom) + + health = plugin.cognition_health() + + assert "kernel" in health + assert "buddhi" in health + assert any( + err["component"] == "policies.get_user_policies" + and "policy health unavailable" in err["error"] + for err in health.get("errors", []) + ) + assert any( + err["component"] == "beliefs.get_contradictions" + and "belief health unavailable" in err["error"] + for err in health.get("errors", []) + ) diff --git a/tests/test_backward_compat.py b/tests/test_backward_compat.py index 8391221..0a2eab2 100644 --- a/tests/test_backward_compat.py +++ b/tests/test_backward_compat.py @@ -21,6 +21,21 @@ def db(): class TestMemoryBackwardCompat: """Existing memory CRUD operations should work unchanged.""" + def test_top_level_engram_import(self): + from dhee import Engram + + memory = Engram(provider="mock", in_memory=True) + + assert memory is not None + + def test_legacy_engram_namespace_imports(self): + from engram.configs.base import MemoryConfig + from engram.memory.main import FullMemory, Memory + + assert MemoryConfig is not None + assert FullMemory is not None + assert Memory is FullMemory + def test_add_memory(self, db): mem_id = db.add_memory({ "memory": "The user prefers dark mode", diff --git a/tests/test_cognition_evals.py b/tests/test_cognition_evals.py index e9bf1f6..860d16c 100644 --- a/tests/test_cognition_evals.py +++ b/tests/test_cognition_evals.py @@ -266,6 +266,78 @@ def test_cross_agent_handoff(self): assert session_b is None or session_b.get("agent_id") != "agent-a", \ "Agent B should not find Agent A's session" + def test_handoff_lookup_respects_repo(self): + """Repo-scoped handoff lookup should not bleed across repositories.""" + if not self.has_bus: + pytest.skip("engram-bus not installed") + + from dhee.core.kernel import save_session_digest, get_last_session + + save_session_digest( + task_summary="Repo A work", + agent_id="agent-repo", + repo="/tmp/repo-a", + status="paused", + db_path=self.db_path, + ) + save_session_digest( + task_summary="Repo B work", + agent_id="agent-repo", + repo="/tmp/repo-b", + status="paused", + db_path=self.db_path, + ) + + session_a = get_last_session( + agent_id="agent-repo", + repo="/tmp/repo-a", + fallback_log_recovery=False, + db_path=self.db_path, + ) + session_b = get_last_session( + agent_id="agent-repo", + repo="/tmp/repo-b", + fallback_log_recovery=False, + db_path=self.db_path, + ) + + assert session_a is not None + assert session_a["repo"] == "/tmp/repo-a" + assert "Repo A" in session_a["task_summary"] + assert session_b is not None + assert session_b["repo"] == "/tmp/repo-b" + assert "Repo B" in session_b["task_summary"] + + def test_handoff_helpers_accept_continuity_kwargs(self): + """Continuity helper kwargs should be accepted even if unused internally.""" + if not self.has_bus: + pytest.skip("engram-bus not installed") + + from dhee.core.kernel import save_session_digest, get_last_session + + save_session_digest( + task_summary="Continuity compatibility session", + agent_id="agent-compat", + repo="/tmp/test-repo", + status="paused", + db_path=self.db_path, + user_id="default", + requester_agent_id="codex", + ) + + session = get_last_session( + agent_id="agent-compat", + repo="/tmp/test-repo", + fallback_log_recovery=False, + db_path=self.db_path, + user_id="default", + requester_agent_id="codex", + ) + + assert session is not None + assert session["repo"] == "/tmp/test-repo" + assert "Continuity compatibility" in session["task_summary"] + def test_cognitive_state_in_handoff(self): """Store beliefs + policies + intentions. Do checkpoint. Verify get_cognitive_state() returns these primitives intact.""" diff --git a/tests/test_cognition_kernel.py b/tests/test_cognition_kernel.py index 6002f85..4514bed 100644 --- a/tests/test_cognition_kernel.py +++ b/tests/test_cognition_kernel.py @@ -95,6 +95,8 @@ def test_get_cognitive_state_empty(self, kernel): assert "beliefs" in state assert "triggered_intentions" in state assert "belief_warnings" in state + assert "state_errors" in state + assert state["state_errors"] == [] def test_get_cognitive_state_with_data(self, kernel): kernel.beliefs.add_belief("u", "Python is great", "programming", 0.9) @@ -102,11 +104,43 @@ def test_get_cognitive_state_with_data(self, kernel): state = kernel.get_cognitive_state("u", "programming") assert len(state["beliefs"]) > 0 or len(state["episodes"]) > 0 + def test_get_cognitive_state_reports_component_failures(self, kernel, monkeypatch): + def _boom(*_args, **_kwargs): + raise RuntimeError("belief store unavailable") + + monkeypatch.setattr(kernel.beliefs, "get_relevant_beliefs", _boom) + + state = kernel.get_cognitive_state("u", "programming") + + assert state["beliefs"] == [] + assert len(state["state_errors"]) == 1 + assert state["state_errors"][0]["component"] == "beliefs" + assert "belief store unavailable" in state["state_errors"][0]["error"] + def test_record_checkpoint_event(self, kernel): kernel.episodes.begin_episode("u", "working on auth", "bug_fix") result = kernel.record_checkpoint_event("u", "fixed auth bug", "completed", 0.9) assert "episode_closed" in result + def test_record_checkpoint_event_reports_errors_without_hiding_progress( + self, kernel, monkeypatch + ): + kernel.episodes.begin_episode("u", "working on auth", "bug_fix") + + def _boom(*_args, **_kwargs): + raise RuntimeError("event write failed") + + monkeypatch.setattr(kernel.episodes, "record_event", _boom) + + result = kernel.record_checkpoint_event("u", "fixed auth bug", "completed", 0.9) + + assert "episode_closed" in result + assert any( + err["component"] == "episodes.record_event" + and "event write failed" in err["error"] + for err in result.get("errors", []) + ) + def test_update_task_on_checkpoint(self, kernel): result = kernel.update_task_on_checkpoint( user_id="u", @@ -119,15 +153,82 @@ def test_update_task_on_checkpoint(self, kernel): ) assert "task_created" in result or "task_completed" in result + def test_update_task_on_checkpoint_surfaces_step_outcome_errors( + self, kernel, monkeypatch + ): + def _step_result(*_args, **_kwargs): + return { + "policies_updated": 0, + "errors": [ + { + "operation": "record_step_outcome", + "component": "policies.record_outcome", + "error": "RuntimeError: step policy write failed", + } + ], + } + + monkeypatch.setattr(kernel, "record_step_outcome", _step_result) + + result = kernel.update_task_on_checkpoint( + user_id="u", + goal="Fix login crash", + plan=["reproduce", "debug", "fix"], + task_type="bug_fix", + status="completed", + outcome_score=0.8, + summary="Fixed the crash", + ) + + assert "task_created" in result or "task_completed" in result + assert any( + err["component"] == "policies.record_outcome" + and "step policy write failed" in err["error"] + for err in result.get("errors", []) + ) + def test_selective_forget(self, kernel): # Should not error on empty state result = kernel.selective_forget("u") assert isinstance(result, dict) + def test_selective_forget_reports_store_failures(self, kernel, monkeypatch): + def _boom(*_args, **_kwargs): + raise RuntimeError("belief pruning failed") + + monkeypatch.setattr(kernel.beliefs, "prune_retracted", _boom) + + result = kernel.selective_forget("u") + + assert any( + err["component"] == "beliefs.prune_retracted" + and "belief pruning failed" in err["error"] + for err in result.get("errors", []) + ) + def test_flush(self, kernel): kernel.intentions.store("u", "test intention") kernel.flush() # Should not error + def test_flush_raises_aggregated_store_failures(self, kernel, monkeypatch): + def _task_boom(): + raise RuntimeError("task flush failed") + + def _policy_boom(): + raise RuntimeError("policy flush failed") + + monkeypatch.setattr(kernel.tasks, "flush", _task_boom) + monkeypatch.setattr(kernel.policies, "flush", _policy_boom) + + with pytest.raises(RuntimeError) as exc_info: + kernel.flush() + + message = str(exc_info.value) + assert "tasks.flush" in message + assert "task flush failed" in message + assert "policies.flush" in message + assert "policy flush failed" in message + def test_get_stats(self, kernel): stats = kernel.get_stats() assert "episodes" in stats @@ -136,6 +237,41 @@ def test_get_stats(self, kernel): assert "policies" in stats assert "intentions" in stats + def test_get_stats_reports_store_failures(self, kernel, monkeypatch): + def _boom(*_args, **_kwargs): + raise RuntimeError("task stats unavailable") + + monkeypatch.setattr(kernel.tasks, "get_stats", _boom) + + stats = kernel.get_stats() + + assert "error" in stats["tasks"] + assert "task stats unavailable" in stats["tasks"]["error"] + assert any( + err["component"] == "tasks.get_stats" + and "task stats unavailable" in err["error"] + for err in stats.get("errors", []) + ) + + def test_record_learning_outcomes_reports_component_failures( + self, kernel, monkeypatch + ): + def _boom(*_args, **_kwargs): + raise RuntimeError("policy matching failed") + + monkeypatch.setattr(kernel.policies, "match_policies", _boom) + + result = kernel.record_learning_outcomes( + "u", "bug_fix", success=True, baseline_score=0.5, actual_score=0.8 + ) + + assert result["policies_updated"] == 0 + assert any( + err["component"] == "policies.match_policies" + and "policy matching failed" in err["error"] + for err in result.get("errors", []) + ) + def test_repr(self, kernel): r = repr(kernel) assert "CognitionKernel" in r @@ -207,6 +343,57 @@ def test_flush_propagates(self, buddhi_with_kernel): kernel.intentions.store("u", "test") buddhi.flush() # Should flush kernel too + def test_buddhi_flush_raises_aggregated_failures( + self, buddhi_with_kernel, monkeypatch + ): + buddhi, kernel = buddhi_with_kernel + + class BrokenContrastive: + def flush(self): + raise RuntimeError("contrastive flush failed") + + def _kernel_boom(): + raise RuntimeError("kernel flush failed") + + buddhi._contrastive = BrokenContrastive() + monkeypatch.setattr(kernel, "flush", _kernel_boom) + + with pytest.raises(RuntimeError) as exc_info: + buddhi.flush() + + message = str(exc_info.value) + assert "kernel.flush" in message + assert "kernel flush failed" in message + assert "contrastive.flush" in message + assert "contrastive flush failed" in message + + def test_buddhi_get_stats_reports_errors_once( + self, buddhi_with_kernel + ): + buddhi, _kernel = buddhi_with_kernel + + class CountingContrastive: + def __init__(self): + self.calls = 0 + + def get_stats(self): + self.calls += 1 + raise RuntimeError("contrastive stats unavailable") + + store = CountingContrastive() + buddhi._contrastive = store + + stats = buddhi.get_stats() + + assert store.calls == 1 + assert "error" in stats["contrastive"] + assert "contrastive stats unavailable" in stats["contrastive"]["error"] + assert any( + err["component"] == "contrastive.get_stats" + and "contrastive stats unavailable" in err["error"] + for err in stats.get("errors", []) + ) + # ── Dhee + Kernel integration ────────────────────────────────────── @@ -353,6 +540,38 @@ def test_record_step_outcome(self, tmp_path): assert step_policy.apply_count == 1 assert step_policy.success_count == 1 + def test_record_step_outcome_reports_policy_write_failures( + self, tmp_path, monkeypatch + ): + """record_step_outcome returns explicit errors for policy update failures.""" + from dhee.core.cognition_kernel import CognitionKernel + + kernel = CognitionKernel(data_dir=str(tmp_path / "kernel")) + kernel.policies.create_step_policy( + user_id="u", + name="check_imports_fix", + task_types=["bug_fix"], + step_patterns=["check", "imports"], + approach="trace call stack instead", + ) + + def _boom(*_args, **_kwargs): + raise RuntimeError("policy write failed") + + monkeypatch.setattr(kernel.policies, "record_outcome", _boom) + + result = kernel.record_step_outcome( + "u", "bug_fix", "check imports first", + success=True, actual_score=0.8, + ) + + assert result["policies_updated"] == 0 + assert any( + err["component"] == "policies.record_outcome" + and "policy write failed" in err["error"] + for err in result.get("errors", []) + ) + # ── Utility Tracking (Phase 3) ─────────────────────────────────── @@ -526,6 +745,65 @@ def test_context_operational_flag(self, tmp_path): assert "user_id" in full assert "user_id" not in op + def test_hyper_context_surfaces_state_errors(self, tmp_path, monkeypatch): + """Buddhi should expose degraded kernel state instead of hiding it.""" + from dhee.core.buddhi import Buddhi + from dhee.core.cognition_kernel import CognitionKernel + + kernel = CognitionKernel(data_dir=str(tmp_path / "buddhi")) + buddhi = Buddhi(data_dir=str(tmp_path / "buddhi"), kernel=kernel) + + def _boom(*_args, **_kwargs): + raise RuntimeError("intentions unavailable") + + monkeypatch.setattr(kernel.intentions, "check_triggers", _boom) + + ctx = buddhi.get_hyper_context(user_id="u", task_description="bug_fix") + payload = ctx.to_dict() + + assert len(payload["state_errors"]) >= 1 + assert any("intentions unavailable" in msg for msg in payload["state_errors"]) + assert any("Cognitive state degraded" in warning for warning in payload["warnings"]) + + def test_hyper_context_surfaces_context_assembly_errors( + self, tmp_path, monkeypatch + ): + """Buddhi should expose degraded non-kernel context assembly too.""" + from dhee.core import kernel as handoff_kernel + from dhee.core.buddhi import Buddhi + from dhee.core.cognition_kernel import CognitionKernel + + kernel = CognitionKernel(data_dir=str(tmp_path / "buddhi")) + buddhi = Buddhi(data_dir=str(tmp_path / "buddhi"), kernel=kernel) + + class BrokenSkillStore: + def search(self, *_args, **_kwargs): + raise RuntimeError("skill store unavailable") + + class BrokenMemory: + def __init__(self): + self.skill_store = BrokenSkillStore() + + def search(self, *_args, **_kwargs): + raise RuntimeError("memory search unavailable") + + def _handoff_boom(*_args, **_kwargs): + raise RuntimeError("handoff store offline") + + monkeypatch.setattr(handoff_kernel, "get_last_session", _handoff_boom) + + ctx = buddhi.get_hyper_context( + user_id="u", + task_description="bug_fix", + memory=BrokenMemory(), + ) + payload = ctx.to_dict() + + assert any("handoff store offline" in msg for msg in payload["state_errors"]) + assert any("skill store unavailable" in msg for msg in payload["state_errors"]) + assert any("memory search unavailable" in msg for msg in payload["state_errors"]) + assert any("Context assembly degraded" in warning for warning in payload["warnings"]) + def test_critical_blockers_surfaced(self, tmp_path): """Blockers from active task appear in critical_blockers.""" from dhee.core.buddhi import Buddhi diff --git a/tests/test_cognition_v3.py b/tests/test_cognition_v3.py index 22cfa40..1173400 100644 --- a/tests/test_cognition_v3.py +++ b/tests/test_cognition_v3.py @@ -49,6 +49,20 @@ def test_engram_add_and_search(self): stats = e.stats() assert isinstance(stats, dict) + def test_in_memory_engram_isolates_cognition_state(self, tmpdir): + """In-memory Engram should keep cognition sidecars under its temp root.""" + from dhee.simple import Engram + + e = Engram(provider="mock", in_memory=True, data_dir=tmpdir) + + buddhi = e._memory.buddhi_layer + evolution = e._memory.evolution_layer + + assert buddhi is not None + assert evolution is not None + assert buddhi._data_dir == os.path.join(tmpdir, "buddhi") + assert evolution._data_dir == tmpdir + # ═══════════════════════════════════════════════════════════════════════════ # 2. Contrastive Pairs — closed loop diff --git a/tests/test_e2e_all_features.py b/tests/test_e2e_all_features.py index 8d2a50d..f7efbe4 100644 --- a/tests/test_e2e_all_features.py +++ b/tests/test_e2e_all_features.py @@ -1,50 +1,42 @@ -"""End-to-end real-user test of ALL Engram AGI Memory Kernel features. +"""Opt-in live end-to-end coverage for the Dhee memory stack. -Exercises every feature as a real user would: creates a Memory instance with -a real NVIDIA LLM backend, stores real memories, and validates every subsystem. - -Run: - .venv/bin/python tests/test_e2e_all_features.py +This suite exercises the NVIDIA-backed runtime plus optional power packages. +It is intentionally skipped unless ``DHEE_RUN_LIVE_TESTS=1`` is set. """ +from __future__ import annotations + import json -import os import sys import time -import threading import traceback -from datetime import datetime, timezone +from pathlib import Path import pytest -# ── Setup ────────────────────────────────────────────────────── - -_ENV_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env") -if os.path.exists(_ENV_PATH): - with open(_ENV_PATH) as f: - for line in f: - line = line.strip() - if line and "=" in line and not line.startswith("#"): - key, _, value = line.partition("=") - os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'")) - -_NVIDIA_KEYS = ( - "NVIDIA_API_KEY", - "NVIDIA_EMBEDDING_API_KEY", - "NVIDIA_QWEN_API_KEY", - "LLAMA_API_KEY", -) -if not any(os.environ.get(key) for key in _NVIDIA_KEYS): - pytest.skip("requires NVIDIA API credentials", allow_module_level=True) +if __package__ in (None, ""): + sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from tests._live import ensure_live_nvidia_runtime, require_live_nvidia_tests + +pytestmark = pytest.mark.integration + +if __name__ != "__main__": + require_live_nvidia_tests("openai") from dhee.configs.base import ( - MemoryConfig, LLMConfig, EmbedderConfig, VectorStoreConfig, - EchoMemConfig, CategoryMemConfig, ProfileConfig, + CategoryMemConfig, + EchoMemConfig, + EmbedderConfig, + LLMConfig, + MemoryConfig, + ProfileConfig, + VectorStoreConfig, ) from dhee.memory.main import FullMemory as Memory -def create_memory(): +def create_memory() -> Memory: config = MemoryConfig( llm=LLMConfig( provider="nvidia", @@ -57,7 +49,6 @@ def create_memory(): vector_store=VectorStoreConfig(provider="memory", config={}), history_db_path=":memory:", embedding_model_dims=1024, - # Disable features that cause extra API calls or Llama-8B parsing failures: echo=EchoMemConfig(enable_echo=False), category=CategoryMemConfig(use_llm_categorization=False), profile=ProfileConfig(enable_profiles=False), @@ -65,849 +56,391 @@ def create_memory(): return Memory(config=config) -# ── Test runner ──────────────────────────────────────────────── - -passed = 0 -failed = 0 -errors = [] - - -TEST_TIMEOUT = 30 # seconds per test — prevents hanging on API failures - - -def _run_with_timeout(fn, timeout): - """Run fn in a daemon thread; raise TimeoutError if it doesn't finish.""" - result = [None] - error = [None] - - def target(): - try: - fn() - except Exception as e: - error[0] = e - - t = threading.Thread(target=target, daemon=True) - t.start() - t.join(timeout) - if t.is_alive(): - raise TimeoutError(f"Test timed out after {timeout}s (NVIDIA API may be down)") - if error[0] is not None: - raise error[0] - - -def test(name): - """Decorator to run a test and track results.""" - def decorator(fn): - global passed, failed - print(f"\n{'─'*60}") - print(f"TEST: {name}") - print(f"{'─'*60}") - try: - _run_with_timeout(fn, TEST_TIMEOUT) - passed += 1 - print(f" ✓ PASSED") - except Exception as e: - failed += 1 - errors.append((name, str(e))) - print(f" ✗ FAILED: {e}") - traceback.print_exc() - return fn - return decorator - - -# ── Create memory instance ───────────────────────────────────── - -print("=" * 60) -print("ENGRAM AGI MEMORY KERNEL — FULL E2E TEST") -print("=" * 60) -print("\nCreating Memory instance with NVIDIA backend...") -memory = create_memory() -print("Memory created successfully.\n") - -# Store IDs for cross-feature tests -stored_ids = {} - - -# ═══════════════════════════════════════════════════════════════ -# PHASE 1: Core Memory (existing) -# ═══════════════════════════════════════════════════════════════ - -@test("1.1 — Add a memory") -def _(): - result = memory.add( - "Vivek prefers Python over JavaScript for backend work", - user_id="vivek", - metadata={"source": "e2e_test"}, - infer=False, - ) - items = result.get("results", []) - assert len(items) > 0, "No results returned from add" - stored_ids["first_memory"] = items[0]["id"] - print(f" Stored memory ID: {stored_ids['first_memory']}") - - -@test("1.2 — Add multiple memories for search testing") -def _(): - memories_to_add = [ - "Production deploy uses GitHub Actions, not Jenkins", - "PostgreSQL to MongoDB migration failed due to schema issues", - "Team chose React for frontend because of components", - "Rate limiting uses Redis with sliding window", - "CI pipeline: pytest, eslint, then Docker build", - "Auth uses JWT with 15-min expiry and refresh tokens", - "Microservices use gRPC internally, REST externally", - ] - for content in memories_to_add: - result = memory.add(content, user_id="vivek", infer=False) - items = result.get("results", []) - assert len(items) > 0 - stored_ids["all_count"] = len(memories_to_add) + 1 # +1 for first - print(f" Added {len(memories_to_add)} more memories") - - -@test("1.3 — Search memories semantically") -def _(): - result = memory.search("what deployment tool do we use", user_id="vivek", limit=5) - items = result.get("results", []) - assert len(items) > 0, "No search results" - top = items[0] - print(f" Top result: {top.get('memory', '')[:80]}...") - print(f" Score: {top.get('composite_score', 0):.3f}") - # Semantic search — top result should be related to deployment/CI/infrastructure - top_text = top.get("memory", "").lower() - assert any(kw in top_text for kw in ("deploy", "github", "ci", "docker", "pipeline", "actions")), \ - f"Top result not deployment-related: {top_text[:80]}" - - -@test("1.4 — Get a specific memory by ID") -def _(): - mem = memory.get(stored_ids["first_memory"]) - assert mem is not None, "Memory not found" - assert "Python" in mem.get("memory", "") - print(f" Retrieved: {mem['memory'][:60]}...") - - -@test("1.5 — Update a memory") -def _(): - memory.update(stored_ids["first_memory"], { - "content": "Vivek prefers Python over JS for all work", - }) - updated = memory.get(stored_ids["first_memory"]) - assert "Python" in updated.get("memory", "") - print(f" Updated: {updated['memory'][:60]}...") - - -@test("1.6 — Get all memories") -def _(): - result = memory.get_all(user_id="vivek", limit=50) - items = result.get("results", []) - assert len(items) >= stored_ids["all_count"], f"Expected >= {stored_ids['all_count']}, got {len(items)}" - print(f" Total memories: {len(items)}") - - -@test("1.7 — Get memory stats") -def _(): - stats = memory.get_stats(user_id="vivek") - assert stats is not None - print(f" Stats: {json.dumps(stats, indent=2, default=str)[:200]}") - - -@test("1.8 — Apply memory decay") -def _(): - result = memory.apply_decay(scope={"user_id": "vivek"}) - print(f" Decay result: {result}") - - -# Allow embeddings to settle -time.sleep(1) - -# ═══════════════════════════════════════════════════════════════ -# PHASE 2: Procedural Memory -# ═══════════════════════════════════════════════════════════════ - -@test("2.1 — Create episode memories for procedure extraction") -def _(): - episodes = [ - "Debugged login: checked auth logs, traced JWT, found expired token, fixed refresh", - "Fixed signup: checked logs, traced auth flow, found token expiry, updated refresh", - "Session timeout fix: reviewed logs, traced token lifecycle, updated refresh process", - ] - ep_ids = [] - for ep in episodes: - result = memory.add(ep, user_id="vivek", metadata={"memory_type": "episodic", "explicit_remember": True}, infer=False) - items = result.get("results", []) - assert len(items) > 0, f"add() returned no results for episode: {ep[:40]}" - item = items[0] - assert "id" in item, f"add() result missing 'id' key — got event={item.get('event')}: {item}" - ep_ids.append(item["id"]) - stored_ids["episode_ids"] = ep_ids - print(f" Created {len(ep_ids)} episode memories") - - -@test("2.2 — Extract a procedure from episodes") -def _(): - from engram_procedural import Procedural - - ep_ids = stored_ids.get("episode_ids") - assert ep_ids, "Skipping — no episode IDs from previous test" - - proc = Procedural(memory, user_id="vivek") - result = proc.extract_procedure( - episode_ids=ep_ids, - name="debug_auth_issues", - domain="authentication", - ) - assert "error" not in result, f"Extraction failed: {result}" - stored_ids["procedure_id"] = result.get("id", "") - print(f" Procedure: {result.get('name')}") - steps = result.get("steps", []) - if isinstance(steps, str): - try: - steps = json.loads(steps) - except (json.JSONDecodeError, TypeError): - steps = [steps] - print(f" Steps: {len(steps)}") - for i, s in enumerate(steps[:5]): - print(f" {i+1}. {s}") - - -@test("2.3 — Get procedure by name") -def _(): - from engram_procedural import Procedural - - proc = Procedural(memory, user_id="vivek") - result = proc.get_procedure("debug_auth_issues") - assert result is not None, "Procedure not found" - print(f" Found: {result['name']} (use_count={result.get('use_count', 0)})") - - -@test("2.4 — Log procedure execution (success)") -def _(): - from engram_procedural import Procedural - - proc = Procedural(memory, user_id="vivek") - proc_id = stored_ids.get("procedure_id", "") - if not proc_id: - print(" Skipping — no procedure ID") - return - - for i in range(3): - result = proc.log_execution(proc_id, success=True, context=f"Run {i+1}") - print(f" Execution {i+1}: use_count={result.get('use_count')}, " - f"success_rate={result.get('success_rate')}, " - f"automaticity={result.get('automaticity')}") - - -@test("2.5 — Log procedure execution (failure)") -def _(): - from engram_procedural import Procedural - - proc = Procedural(memory, user_id="vivek") - proc_id = stored_ids.get("procedure_id", "") - if not proc_id: - return - result = proc.log_execution(proc_id, success=False, context="Token was not expired this time") - print(f" After failure: success_rate={result.get('success_rate')}, " - f"automaticity={result.get('automaticity')}") - - -@test("2.6 — Refine a procedure") -def _(): - from engram_procedural import Procedural - - proc = Procedural(memory, user_id="vivek") - proc_id = stored_ids.get("procedure_id", "") - if not proc_id: - print(" Skipping — no procedure ID") - return - result = proc.refine_procedure( - proc_id, - correction="Also check session cookie before tracing JWT", - ) - assert result.get("refined") is True, f"Refinement failed: {result}" - new_steps = result.get("new_steps", []) - print(f" Refined to {len(new_steps)} steps") - - -@test("2.7 — Search procedures semantically") -def _(): - from engram_procedural import Procedural - time.sleep(0.5) - - proc = Procedural(memory, user_id="vivek") - results = proc.search_procedures("how to debug authentication problems") - assert len(results) > 0, "No procedures found" - print(f" Found {len(results)} procedures") - for r in results[:3]: - print(f" - {r.get('name')} (automaticity={r.get('automaticity', 0):.2f})") - - -@test("2.8 — List all active procedures") -def _(): - from engram_procedural import Procedural - - proc = Procedural(memory, user_id="vivek") - results = proc.list_procedures(status="active") - print(f" Active procedures: {len(results)}") - for r in results: - print(f" - {r.get('name')} (uses={r.get('use_count', 0)}, success={r.get('success_rate', 0):.0%})") - - -time.sleep(2) # avoid NVIDIA API rate limits - -# ═══════════════════════════════════════════════════════════════ -# PHASE 2: Reconsolidation -# ═══════════════════════════════════════════════════════════════ - -@test("3.1 — Propose a memory update") -def _(): - from engram_reconsolidation import Reconsolidation - - rc = Reconsolidation(memory, user_id="vivek") - # Find a memory about the deploy pipeline - search = memory.search("deploy pipeline", user_id="vivek", limit=1) - items = search.get("results", []) - assert len(items) > 0 - target_id = items[0]["id"] - stored_ids["rc_target"] = target_id - print(f" Target memory: {items[0].get('memory', '')[:60]}...") - - proposal = rc.propose_update( - memory_id=target_id, - new_context="We now also run Snyk security scanning before Docker build", - ) - print(f" Proposal: {json.dumps(proposal, indent=2, default=str)[:300]}") - if proposal.get("id"): - stored_ids["proposal_id"] = proposal["id"] - - -@test("3.2 — List pending proposals") -def _(): - from engram_reconsolidation import Reconsolidation - - rc = Reconsolidation(memory, user_id="vivek") - pending = rc.list_pending_proposals() - print(f" Pending proposals: {len(pending)}") - for p in pending: - print(f" - target={p.get('target_memory_id', '')[:12]}... " - f"confidence={p.get('confidence', 0):.2f} " - f"type={p.get('change_type')}") - - -@test("3.3 — Apply a reconsolidation proposal") -def _(): - from engram_reconsolidation import Reconsolidation - - rc = Reconsolidation(memory, user_id="vivek") - proposal_id = stored_ids.get("proposal_id") - if not proposal_id: - print(" Skipping — no pending proposal") - return - - result = rc.apply_update(proposal_id) - print(f" Apply result: status={result.get('status')}, version={result.get('version')}") - - # Verify the target memory was updated - target = memory.get(stored_ids.get("rc_target", "")) - if target: - print(f" Updated memory: {target.get('memory', '')[:80]}...") - - -@test("3.4 — Propose and reject an update") -def _(): - from engram_reconsolidation import Reconsolidation - - rc = Reconsolidation(memory, user_id="vivek") - # Create a test memory - result = memory.add("The team standup is every Monday at 10am", user_id="vivek", infer=False) - items = result.get("results", []) - assert items - item = items[0] - assert "id" in item, f"add() returned non-stored item: {item.get('event')}" - mid = item["id"] - - proposal = rc.propose_update(mid, new_context="Standups moved to Tuesday at 2pm") - if proposal.get("id"): - reject = rc.reject_update(proposal["id"], reason="Not confirmed yet") - print(f" Rejected: {reject.get('status')}") - else: - print(f" Proposal was: {proposal.get('status', 'n/a')} (no_change/skipped is ok)") - - -@test("3.5 — Get reconsolidation stats") -def _(): - from engram_reconsolidation import Reconsolidation - - rc = Reconsolidation(memory, user_id="vivek") - stats = rc.get_stats() - print(f" Stats: {json.dumps(stats, indent=2)}") - - -@test("3.6 — Get version history") -def _(): - from engram_reconsolidation import Reconsolidation - - rc = Reconsolidation(memory, user_id="vivek") - target_id = stored_ids.get("rc_target", "") - if not target_id: - return - history = rc.get_version_history(target_id) - print(f" Version history entries: {len(history)}") - - -time.sleep(2) # avoid NVIDIA API rate limits - -# ═══════════════════════════════════════════════════════════════ -# PHASE 3: Failure Learning -# ═══════════════════════════════════════════════════════════════ - -@test("4.1 — Log a failure") -def _(): - from engram_failure import FailureLearning - - fl = FailureLearning(memory, user_id="vivek") - result = fl.log_failure( - action="deploy_to_production", - error="Connection timeout to AWS ECS", - context="Deploy with reduced capacity", - severity="high", - agent_id="claude-code", - ) - print(f" Result: {result}") - assert result.get("action") == "deploy_to_production" or result.get("status") == "logged", \ - f"Unexpected log_failure result: {result}" - stored_ids["failure_1"] = result.get("id", "") - print(f" Logged: {result.get('action')} — {result.get('error', '')[:50]}") - - -@test("4.2 — Log multiple related failures") -def _(): - from engram_failure import FailureLearning - - fl = FailureLearning(memory, user_id="vivek") - failures = [ - ("deploy_staging", "ECS cluster timeout", "Staging, Friday evening"), - ("deploy_canary", "Load balancer refused", "Canary during peak"), - ("deploy_hotfix", "ECS task start timeout", "Emergency hotfix midnight"), - ] - ids = [stored_ids.get("failure_1", "")] - for action, error, context in failures: - result = fl.log_failure(action=action, error=error, context=context, severity="high") - if result.get("id"): - ids.append(result["id"]) - stored_ids["failure_ids"] = [i for i in ids if i] - print(f" Logged {len(failures)} more failures (total IDs: {len(stored_ids['failure_ids'])})") - - -@test("4.3 — Search past failures") -def _(): - from engram_failure import FailureLearning - time.sleep(0.5) - - fl = FailureLearning(memory, user_id="vivek") - results = fl.search_failures("deployment timeout ECS") - assert len(results) > 0, "No failures found" - print(f" Found {len(results)} matching failures") - for r in results[:3]: - print(f" - {r.get('action')}: {r.get('error', '')[:50]}") - - -@test("4.4 — Extract an anti-pattern from failures") -def _(): - from engram_failure import FailureLearning - - fl = FailureLearning(memory, user_id="vivek") - failure_ids = stored_ids.get("failure_ids", []) - if len(failure_ids) < 3: - print(f" Skipping — only {len(failure_ids)} failures, need 3") - return - - result = fl.extract_antipattern( - failure_ids=failure_ids[:3], - name="deploy_during_off_hours", - ) - print(f" Anti-pattern: {result.get('name', 'n/a')}") - print(f" Description: {result.get('description', 'n/a')[:100]}") - warning_signs = result.get("warning_signs", []) - if warning_signs: - print(f" Warning signs: {warning_signs[:3]}") - - -@test("4.5 — List anti-patterns") -def _(): - from engram_failure import FailureLearning - - fl = FailureLearning(memory, user_id="vivek") - result = fl.list_antipatterns() - print(f" Anti-patterns: {len(result)}") - for ap in result: - print(f" - {ap.get('name')}: {ap.get('description', '')[:60]}") - - -@test("4.6 — Get failure stats") -def _(): - from engram_failure import FailureLearning - - fl = FailureLearning(memory, user_id="vivek") - stats = fl.get_failure_stats() - print(f" Stats: {json.dumps(stats, indent=2)}") - assert stats["total_failures"] >= 4 - - -@test("4.7 — Search recovery strategies") -def _(): - from engram_failure import FailureLearning - - fl = FailureLearning(memory, user_id="vivek") - results = fl.search_recovery_strategies("timeout during deploy") - print(f" Recovery strategies found: {len(results)}") - - -# ═══════════════════════════════════════════════════════════════ -# PHASE 3: Working Memory -# ═══════════════════════════════════════════════════════════════ - -@test("5.1 — Push items to working memory") -def _(): - from engram_working import WorkingMemory - - wm = WorkingMemory(memory, user_id="vivek", capacity=5) - r1 = wm.push("Current task: fix the auth token refresh bug", tag="task") - r2 = wm.push("The JWT secret is rotated every 24h", tag="context") - r3 = wm.push("Related PR: #1234 by Alice", tag="reference") - stored_ids["wm_key_1"] = r1["key"] - stored_ids["wm_key_2"] = r2["key"] - print(f" Pushed 3 items, buffer size: {r3.get('buffer_size')}") - - -@test("5.2 — List working memory (sorted by activation)") -def _(): - from engram_working import WorkingMemory - - wm = WorkingMemory(memory, user_id="vivek_wm_list") - wm.push("Item A", tag="a") - wm.push("Item B", tag="b") - items = wm.list() - assert len(items) == 2 - print(f" Items in WM: {len(items)}") - for item in items: - print(f" [{item['tag']}] {item['content'][:40]} (activation={item['activation']:.2f})") - - -@test("5.3 — Peek at item (refreshes activation)") -def _(): - from engram_working import WorkingMemory - - wm = WorkingMemory(memory, user_id="vivek_wm_peek") - pushed = wm.push("Important context") - key = pushed["key"] - - peeked = wm.peek(key) - assert peeked is not None - assert peeked["access_count"] == 1 - print(f" Peeked: activation={peeked['activation']:.2f}, accesses={peeked['access_count']}") - - -@test("5.4 — Capacity eviction (Miller's Law)") -def _(): - from engram_working import WorkingMemory - - wm = WorkingMemory(memory, user_id="vivek_wm_evict", capacity=3) - wm.push("Item 1", tag="1") - wm.push("Item 2", tag="2") - wm.push("Item 3", tag="3") - r = wm.push("Item 4 — should evict item 1", tag="4") - - assert r.get("evicted") is not None, "No eviction happened" - assert len(wm.list()) == 3 - print(f" Evicted: {r['evicted']['content'][:40]}") - print(f" Buffer size after eviction: {wm.size}") - - -@test("5.5 — Pop item from working memory") -def _(): - from engram_working import WorkingMemory - - wm = WorkingMemory(memory, user_id="vivek_wm_pop") - pushed = wm.push("Temporary note") - key = pushed["key"] - - popped = wm.pop(key) - assert popped is not None - assert wm.size == 0 - print(f" Popped: {popped['content']}") - - -@test("5.6 — Flush working memory to long-term") -def _(): - from engram_working import WorkingMemory - - wm = WorkingMemory(memory, user_id="vivek_wm_flush") - wm.push("Important insight 1", tag="insight") - wm.push("Important insight 2", tag="insight") - - result = wm.flush_to_longterm() - assert result["flushed"] == 2 - assert wm.size == 0 - print(f" Flushed {result['flushed']} items to long-term memory") - - # Verify they ended up in long-term memory - time.sleep(0.5) - search = memory.search("Important insight", user_id="vivek_wm_flush", limit=5) - found = search.get("results", []) - print(f" Found {len(found)} flushed items in long-term memory") - - -@test("5.7 — Get relevant working memory items") -def _(): - from engram_working import WorkingMemory - - wm = WorkingMemory(memory, user_id="vivek_wm_relevant") - wm.push("Fix authentication token refresh bug", tag="task") - wm.push("Database migration plan for Q2 2026", tag="plan") - wm.push("Team standup at 10am tomorrow", tag="reminder") - - relevant = wm.get_relevant("authentication token") - assert len(relevant) >= 1 - print(f" Query: 'authentication token'") - print(f" Relevant items: {len(relevant)}") - for r in relevant: - print(f" [{r['tag']}] {r['content'][:50]}") - - -time.sleep(2) # avoid NVIDIA API rate limits - -# ═══════════════════════════════════════════════════════════════ -# PHASE 4: Salience -# ═══════════════════════════════════════════════════════════════ - -@test("6.1 — Compute salience (heuristic) for different content") -def _(): - from dhee.core.salience import compute_salience - - neutral = compute_salience("The meeting is at 3pm in room 204") - positive = compute_salience("We just achieved an amazing breakthrough on the project!") - urgent = compute_salience("CRITICAL production crash! Emergency deployment needed immediately!") - - print(f" Neutral: valence={neutral['sal_valence']:+.2f}, arousal={neutral['sal_arousal']:.2f}, salience={neutral['sal_salience_score']:.2f}") - print(f" Positive: valence={positive['sal_valence']:+.2f}, arousal={positive['sal_arousal']:.2f}, salience={positive['sal_salience_score']:.2f}") - print(f" Urgent: valence={urgent['sal_valence']:+.2f}, arousal={urgent['sal_arousal']:.2f}, salience={urgent['sal_salience_score']:.2f}") - - assert urgent["sal_salience_score"] > neutral["sal_salience_score"] - - -@test("6.2 — Tag a memory with salience") -def _(): - from dhee.core.salience import compute_salience - - # Add a high-salience memory - result = memory.add( - "CRITICAL: Production database corruption! Emergency fix!", - user_id="vivek", - infer=False, - ) - items = result.get("results", []) - assert items, "No results from add()" - assert "id" in items[0], f"add() returned non-stored item: {items[0].get('event')}" - mid = items[0]["id"] - - salience = compute_salience(items[0].get("memory", "")) - md = items[0].get("metadata", {}) or {} - md.update(salience) - memory.update(mid, {"metadata": md}) - - updated = memory.get(mid) - updated_md = updated.get("metadata", {}) or {} - print(f" Tagged memory with salience: {updated_md.get('sal_salience_score', 'N/A')}") - assert updated_md.get("sal_salience_score", 0) > 0 - - -@test("6.3 — Salience decay modifier") -def _(): - from dhee.core.salience import salience_decay_modifier - - # High salience → decay slower - high = salience_decay_modifier(1.0) - mid = salience_decay_modifier(0.5) - none = salience_decay_modifier(0.0) - - print(f" High salience (1.0) → decay multiplier: {high:.2f}") - print(f" Mid salience (0.5) → decay multiplier: {mid:.2f}") - print(f" No salience (0.0) → decay multiplier: {none:.2f}") - - assert high < mid < none - assert none == 1.0 # No salience = normal decay - - -# ═══════════════════════════════════════════════════════════════ -# PHASE 4: Causal Reasoning -# ═══════════════════════════════════════════════════════════════ - -@test("7.1 — Add causal links between memories") -def _(): - from dhee.core.graph import KnowledgeGraph, RelationType - - graph = KnowledgeGraph() - - # Create a causal chain: bug_found → investigation → root_cause → fix - graph.add_relationship("mem_bug", "mem_investigation", RelationType.LED_TO) - graph.add_relationship("mem_investigation", "mem_root_cause", RelationType.LED_TO) - graph.add_relationship("mem_root_cause", "mem_fix", RelationType.LED_TO) - graph.add_relationship("mem_fix", "mem_bug", RelationType.PREVENTS) - - stats = graph.stats() - print(f" Graph stats: {stats['total_relationships']} relationships") - print(f" Causal types: LED_TO={stats['relationship_types'].get('led_to', 0)}, " - f"PREVENTS={stats['relationship_types'].get('prevents', 0)}") - stored_ids["causal_graph"] = graph - - -@test("7.2 — Traverse causal chain (backward)") -def _(): - graph = stored_ids.get("causal_graph") - if not graph: - return - - chain = graph.get_causal_chain("mem_fix", direction="backward", depth=5) - print(f" Backward from fix: {len(chain)} nodes") - for mid, depth, path in chain: - print(f" depth={depth}: {mid}") - - -@test("7.3 — Traverse causal chain (forward)") -def _(): - graph = stored_ids.get("causal_graph") - if not graph: - return - - chain = graph.get_causal_chain("mem_bug", direction="forward", depth=5) - print(f" Forward from bug: {len(chain)} nodes") - for mid, depth, path in chain: - print(f" depth={depth}: {mid}") - - -@test("7.4 — Detect causal language in text") -def _(): - from dhee.core.graph import detect_causal_language, RelationType - - texts = [ - ("The outage was caused by a misconfigured load balancer", [RelationType.CAUSED_BY]), - ("Upgrading the library led to a performance regression", [RelationType.LED_TO]), - ("Input validation prevents SQL injection attacks", [RelationType.PREVENTS]), - ("The caching layer enables sub-millisecond response times", [RelationType.ENABLES]), - ("The payment service requires authentication", [RelationType.REQUIRES]), - ("Nothing special in this text at all", []), - ] - - for text, expected in texts: - found = detect_causal_language(text) - status = "✓" if set(found) == set(expected) else "✗" - print(f" {status} '{text[:50]}...' → {[r.value for r in found]}") - - -# ═══════════════════════════════════════════════════════════════ -# PHASE 5: AGI Loop -# ═══════════════════════════════════════════════════════════════ - -@test("8.1 — Get system health") -def _(): - from dhee.core.agi_loop import get_system_health - - health = get_system_health(memory, user_id="vivek") - print(f" Health: {health['health_pct']:.0f}% ({health['available']}/{health['total']} systems)") - for name, status in health["systems"].items(): - avail = "✓" if status.get("available") else "✗" - print(f" {avail} {name}") - - -@test("8.2 — Run full AGI cognitive cycle") -def _(): - from dhee.core.agi_loop import run_agi_cycle - - result = run_agi_cycle(memory, user_id="vivek") - summary = result.get("summary", {}) - print(f" Cycle complete: {summary.get('ok', 0)} ok, " - f"{summary.get('errors', 0)} errors, " - f"{summary.get('skipped', 0)} skipped " - f"(out of {summary.get('total_subsystems', 0)} subsystems)") - - for key, val in result.items(): - if isinstance(val, dict) and "status" in val: - status_icon = "✓" if val["status"] == "ok" else ("⊘" if val["status"] == "skipped" else "✗") - print(f" {status_icon} {key}: {val['status']}") - - -# ═══════════════════════════════════════════════════════════════ -# PHASE 1 (existing): Heartbeat Behaviors -# ═══════════════════════════════════════════════════════════════ - -@test("9.1 — Run all new heartbeat behaviors") -def _(): - from engram_heartbeat.behaviors import run_behavior, BUILTIN_BEHAVIORS - - print(f" All behaviors: {list(BUILTIN_BEHAVIORS.keys())}") - - for action in ["extract_procedures", "process_reconsolidation", - "extract_antipatterns", "wm_decay", "agi_loop"]: - result = run_behavior(action, memory, {"user_id": "vivek"}, agent_id="test") - status_icon = "✓" if result["status"] == "ok" else ("⊘" if result["status"] == "skipped" else "✗") - print(f" {status_icon} {action}: {result['status']}") - - -# ═══════════════════════════════════════════════════════════════ -# CROSS-FEATURE: Search with all boosts active -# ═══════════════════════════════════════════════════════════════ - -@test("10.1 — Search with procedural + salience boosts in results") -def _(): - results = memory.search("debug authentication", user_id="vivek", limit=5) - items = results.get("results", []) - print(f" Found {len(items)} results") - for item in items[:3]: - print(f" score={item.get('composite_score', 0):.3f} | " - f"echo={item.get('echo_boost', 0):.3f} | " - f"cat={item.get('category_boost', 0):.3f} | " - f"proc={item.get('proc_boost', 0):.3f} | " - f"sal={item.get('salience_boost', 0):.3f} | " - f"{item.get('memory', '')[:50]}...") - - -# ═══════════════════════════════════════════════════════════════ -# INLINE CONFIG VERIFICATION -# ═══════════════════════════════════════════════════════════════ - -@test("11.1 — Verify all inline configs on MemoryConfig") -def _(): - config = MemoryConfig() - configs = { - "procedural": config.procedural.model_dump(), - "reconsolidation": config.reconsolidation.model_dump(), - "failure": config.failure.model_dump(), - "working_memory": config.working_memory.model_dump(), - "salience": config.salience.model_dump(), - "causal": config.causal.model_dump(), +def run_live_e2e_suite() -> dict[str, object]: + passed = 0 + failed = 0 + errors: list[tuple[str, str]] = [] + stored_ids: dict[str, object] = {} + + memory = create_memory() + + def step(name: str): + def decorator(fn): + nonlocal passed, failed + print(f"\n{'-' * 60}") + print(f"STEP: {name}") + print(f"{'-' * 60}") + try: + fn() + passed += 1 + print(" PASS") + except Exception as exc: # pragma: no cover - exercised only in live mode + failed += 1 + errors.append((name, f"{exc}\n{traceback.format_exc()}")) + print(f" FAIL: {exc}") + return fn + + return decorator + + try: + @step("Core memory CRUD, search, stats, and decay") + def _(): + first = memory.add( + "Vivek prefers Python over JavaScript for backend work", + user_id="vivek", + metadata={"source": "e2e_test"}, + infer=False, + ) + items = first.get("results", []) + assert items, "add() returned no results" + stored_ids["first_memory"] = items[0]["id"] + + for content in [ + "Production deploy uses GitHub Actions, not Jenkins", + "PostgreSQL to MongoDB migration failed due to schema issues", + "Team chose React for frontend because of components", + "Rate limiting uses Redis with sliding window", + "CI pipeline: pytest, eslint, then Docker build", + "Auth uses JWT with 15-min expiry and refresh tokens", + "Microservices use gRPC internally, REST externally", + ]: + result = memory.add(content, user_id="vivek", infer=False) + assert result.get("results"), f"failed to add memory: {content}" + + search = memory.search("what deployment tool do we use", user_id="vivek", limit=5) + hits = search.get("results", []) + assert hits, "search returned no results" + top_text = hits[0].get("memory", "").lower() + assert any(word in top_text for word in ("deploy", "github", "ci", "docker", "pipeline")) + + stored = memory.get(stored_ids["first_memory"]) + assert stored and "Python" in stored.get("memory", "") + + updated = memory.update( + stored_ids["first_memory"], + {"content": "Vivek prefers Python over JS for all work"}, + ) + assert updated is not None + + all_memories = memory.get_all(user_id="vivek", limit=50).get("results", []) + assert len(all_memories) >= 8 + + stats = memory.get_stats(user_id="vivek") + assert stats is not None + + decay_result = memory.apply_decay(scope={"user_id": "vivek"}) + assert isinstance(decay_result, dict) + + time.sleep(1) + + @step("Procedural memory extraction and refinement") + def _(): + from engram_procedural import Procedural + + episodes = [ + "Debugged login: checked auth logs, traced JWT, found expired token, fixed refresh", + "Fixed signup: checked logs, traced auth flow, found token expiry, updated refresh", + "Session timeout fix: reviewed logs, traced token lifecycle, updated refresh process", + ] + episode_ids = [] + for episode in episodes: + result = memory.add( + episode, + user_id="vivek", + metadata={"memory_type": "episodic", "explicit_remember": True}, + infer=False, + ) + items = result.get("results", []) + assert items and "id" in items[0], f"invalid add() result: {items}" + episode_ids.append(items[0]["id"]) + + proc = Procedural(memory, user_id="vivek") + extracted = proc.extract_procedure( + episode_ids=episode_ids, + name="debug_auth_issues", + domain="authentication", + ) + assert "error" not in extracted, f"procedure extraction failed: {extracted}" + proc_id = extracted.get("id") + assert proc_id, "procedure id missing" + stored_ids["procedure_id"] = proc_id + + looked_up = proc.get_procedure("debug_auth_issues") + assert looked_up is not None + + proc.log_execution(proc_id, success=True, context="Run 1") + proc.log_execution(proc_id, success=True, context="Run 2") + proc.log_execution(proc_id, success=False, context="Token was not expired this time") + + refined = proc.refine_procedure( + proc_id, + correction="Also check session cookie before tracing JWT", + ) + assert refined.get("refined") is True + + results = proc.search_procedures("how to debug authentication problems") + assert results, "procedure search returned no results" + + active = proc.list_procedures(status="active") + assert active, "expected at least one active procedure" + + time.sleep(1) + + @step("Reconsolidation proposal, apply, reject, and history") + def _(): + from engram_reconsolidation import Reconsolidation + + rc = Reconsolidation(memory, user_id="vivek") + search = memory.search("deploy pipeline", user_id="vivek", limit=1) + items = search.get("results", []) + assert items, "could not find deploy pipeline memory" + target_id = items[0]["id"] + + proposal = rc.propose_update( + memory_id=target_id, + new_context="We now also run Snyk security scanning before Docker build", + ) + if proposal.get("id"): + stored_ids["proposal_id"] = proposal["id"] + + pending = rc.list_pending_proposals() + assert pending is not None + + proposal_id = stored_ids.get("proposal_id") + if proposal_id: + applied = rc.apply_update(proposal_id) + assert applied.get("status") in {"applied", "accepted", "updated"} + + created = memory.add("The team standup is every Monday at 10am", user_id="vivek", infer=False) + created_items = created.get("results", []) + assert created_items and "id" in created_items[0] + extra_id = created_items[0]["id"] + + rejection_candidate = rc.propose_update(extra_id, new_context="Standups moved to Tuesday at 2pm") + if rejection_candidate.get("id"): + rejected = rc.reject_update(rejection_candidate["id"], reason="Not confirmed yet") + assert rejected.get("status") in {"rejected", "declined"} + + stats = rc.get_stats() + assert isinstance(stats, dict) + + history = rc.get_version_history(target_id) + assert isinstance(history, list) + + time.sleep(1) + + @step("Failure learning and anti-pattern extraction") + def _(): + from engram_failure import FailureLearning + + fl = FailureLearning(memory, user_id="vivek") + first = fl.log_failure( + action="deploy_to_production", + error="Connection timeout to AWS ECS", + context="Deploy with reduced capacity", + severity="high", + agent_id="claude-code", + ) + assert first.get("action") == "deploy_to_production" or first.get("status") == "logged" + + failure_ids = [first.get("id")] + for action, error, context in [ + ("deploy_staging", "ECS cluster timeout", "Staging, Friday evening"), + ("deploy_canary", "Load balancer refused", "Canary during peak"), + ("deploy_hotfix", "ECS task start timeout", "Emergency hotfix midnight"), + ]: + result = fl.log_failure(action=action, error=error, context=context, severity="high") + if result.get("id"): + failure_ids.append(result["id"]) + failure_ids = [failure_id for failure_id in failure_ids if failure_id] + assert len(failure_ids) >= 3 + + search_results = fl.search_failures("deployment timeout ECS") + assert search_results, "expected failure search results" + + antipattern = fl.extract_antipattern( + failure_ids=failure_ids[:3], + name="deploy_during_off_hours", + ) + assert antipattern is not None + + listed = fl.list_antipatterns() + assert isinstance(listed, list) + + stats = fl.get_failure_stats() + assert stats["total_failures"] >= 4 + + recovery = fl.search_recovery_strategies("timeout during deploy") + assert isinstance(recovery, list) + + time.sleep(1) + + @step("Working memory operations and long-term flush") + def _(): + from engram_working import WorkingMemory + + wm = WorkingMemory(memory, user_id="vivek", capacity=5) + first = wm.push("Current task: fix the auth token refresh bug", tag="task") + second = wm.push("The JWT secret is rotated every 24h", tag="context") + wm.push("Related PR: #1234 by Alice", tag="reference") + assert first.get("key") and second.get("key") + + listed = wm.list() + assert len(listed) == 3 + + peeked = wm.peek(first["key"]) + assert peeked is not None and peeked["access_count"] >= 1 + + evictor = WorkingMemory(memory, user_id="vivek_wm_evict", capacity=3) + evictor.push("Item 1", tag="1") + evictor.push("Item 2", tag="2") + evictor.push("Item 3", tag="3") + eviction = evictor.push("Item 4", tag="4") + assert eviction.get("evicted") is not None + + popper = WorkingMemory(memory, user_id="vivek_wm_pop") + popped_key = popper.push("Temporary note")["key"] + popped = popper.pop(popped_key) + assert popped is not None and popper.size == 0 + + flusher = WorkingMemory(memory, user_id="vivek_wm_flush") + flusher.push("Important insight 1", tag="insight") + flusher.push("Important insight 2", tag="insight") + flushed = flusher.flush_to_longterm() + assert flushed["flushed"] == 2 + time.sleep(0.5) + + found = memory.search("Important insight", user_id="vivek_wm_flush", limit=5).get("results", []) + assert found, "flushed items not found in long-term memory" + + relevance = WorkingMemory(memory, user_id="vivek_wm_relevant") + relevance.push("Fix authentication token refresh bug", tag="task") + relevance.push("Database migration plan for Q2 2026", tag="plan") + relevance.push("Team standup at 10am tomorrow", tag="reminder") + relevant_items = relevance.get_relevant("authentication token") + assert relevant_items, "expected relevant working-memory items" + + @step("Salience scoring and decay modifiers") + def _(): + from dhee.core.salience import compute_salience, salience_decay_modifier + + neutral = compute_salience("The meeting is at 3pm in room 204") + positive = compute_salience("We just achieved an amazing breakthrough on the project!") + urgent = compute_salience("CRITICAL production crash! Emergency deployment needed immediately!") + assert urgent["sal_salience_score"] > neutral["sal_salience_score"] + assert positive["sal_salience_score"] >= neutral["sal_salience_score"] + + created = memory.add( + "CRITICAL: Production database corruption! Emergency fix!", + user_id="vivek", + infer=False, + ) + items = created.get("results", []) + assert items and "id" in items[0] + memory_id = items[0]["id"] + + salience = compute_salience(items[0].get("memory", "")) + metadata = items[0].get("metadata", {}) or {} + metadata.update(salience) + memory.update(memory_id, {"metadata": metadata}) + + updated = memory.get(memory_id) + updated_md = updated.get("metadata", {}) or {} + assert updated_md.get("sal_salience_score", 0) > 0 + + high = salience_decay_modifier(1.0) + mid = salience_decay_modifier(0.5) + none = salience_decay_modifier(0.0) + assert high < mid < none + assert none == 1.0 + + @step("Causal reasoning graph utilities") + def _(): + from dhee.core.graph import KnowledgeGraph, RelationType, detect_causal_language + + graph = KnowledgeGraph() + graph.add_relationship("mem_bug", "mem_investigation", RelationType.LED_TO) + graph.add_relationship("mem_investigation", "mem_root_cause", RelationType.LED_TO) + graph.add_relationship("mem_root_cause", "mem_fix", RelationType.LED_TO) + graph.add_relationship("mem_fix", "mem_bug", RelationType.PREVENTS) + + stats = graph.stats() + assert stats["total_relationships"] >= 4 + + backward = graph.get_causal_chain("mem_fix", direction="backward", depth=5) + forward = graph.get_causal_chain("mem_bug", direction="forward", depth=5) + assert backward and forward + + found = detect_causal_language("The outage was caused by a misconfigured load balancer") + assert RelationType.CAUSED_BY in found + + @step("AGI loop health and cycle execution") + def _(): + from dhee.core.agi_loop import get_system_health, run_agi_cycle + + health = get_system_health(memory, user_id="vivek") + assert health["total"] >= health["available"] + + cycle = run_agi_cycle(memory, user_id="vivek") + summary = cycle.get("summary", {}) + assert summary.get("total_subsystems", 0) >= summary.get("ok", 0) + + @step("Heartbeat behaviors") + def _(): + from engram_heartbeat.behaviors import BUILTIN_BEHAVIORS, run_behavior + + assert BUILTIN_BEHAVIORS + for action in [ + "extract_procedures", + "process_reconsolidation", + "extract_antipatterns", + "wm_decay", + "agi_loop", + ]: + result = run_behavior(action, memory, {"user_id": "vivek"}, agent_id="test") + assert result["status"] in {"ok", "skipped"} + + @step("Cross-feature search and inline configuration validation") + def _(): + results = memory.search("debug authentication", user_id="vivek", limit=5).get("results", []) + assert isinstance(results, list) + + config = MemoryConfig() + configs = { + "procedural": config.procedural.model_dump(), + "reconsolidation": config.reconsolidation.model_dump(), + "failure": config.failure.model_dump(), + "working_memory": config.working_memory.model_dump(), + "salience": config.salience.model_dump(), + "causal": config.causal.model_dump(), + } + assert all(isinstance(values, dict) for values in configs.values()) + print(json.dumps(configs, indent=2)[:500]) + + finally: + memory.close() + + return { + "passed": passed, + "failed": failed, + "errors": errors, } - for name, values in configs.items(): - print(f" {name}: {values}") - - -# ═══════════════════════════════════════════════════════════════ -# FINAL REPORT -# ═══════════════════════════════════════════════════════════════ -print("\n" + "=" * 60) -print("FINAL REPORT") -print("=" * 60) -print(f"\n PASSED: {passed}") -print(f" FAILED: {failed}") -print(f" TOTAL: {passed + failed}") -if errors: - print(f"\n FAILURES:") - for name, err in errors: - print(f" ✗ {name}: {err}") +def test_live_e2e_all_features() -> None: + results = run_live_e2e_suite() + failures = results["errors"] + if failures: + summary = "\n".join(f"- {name}: {details.splitlines()[0]}" for name, details in failures) + pytest.fail( + f"live E2E suite reported {results['failed']} failed step(s) out of " + f"{results['passed'] + results['failed']}:\n{summary}" + ) -print() -memory.close() -sys.exit(1 if failed > 0 else 0) +if __name__ == "__main__": + ensure_live_nvidia_runtime("openai") + results = run_live_e2e_suite() + total = results["passed"] + results["failed"] + print(f"\nLive E2E summary: {results['passed']} passed, {results['failed']} failed, {total} total") + raise SystemExit(1 if results["failed"] else 0) diff --git a/tests/test_power_packages.py b/tests/test_power_packages.py index 859802e..0f8fc9d 100644 --- a/tests/test_power_packages.py +++ b/tests/test_power_packages.py @@ -10,26 +10,11 @@ from dhee.configs.base import MemoryConfig, LLMConfig, EmbedderConfig, VectorStoreConfig from dhee.memory.main import FullMemory as Memory +from tests._live import require_live_nvidia_tests -# Load keys from .env file in project root -_ENV_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env") -if os.path.exists(_ENV_PATH): - with open(_ENV_PATH) as f: - for line in f: - line = line.strip() - if line and "=" in line and not line.startswith("#"): - key, _, value = line.partition("=") - value = value.strip().strip('"').strip("'") - os.environ.setdefault(key.strip(), value) - -_NVIDIA_KEYS = ( - "NVIDIA_API_KEY", - "NVIDIA_EMBEDDING_API_KEY", - "NVIDIA_QWEN_API_KEY", - "LLAMA_API_KEY", -) -if not any(os.environ.get(key) for key in _NVIDIA_KEYS): - pytest.skip("requires NVIDIA API credentials", allow_module_level=True) +pytestmark = pytest.mark.integration + +require_live_nvidia_tests("openai") # Power packages are separate installable packages (engram_router, engram_heartbeat, etc.) try: diff --git a/tests/test_presets.py b/tests/test_presets.py index 3feac01..5dbde10 100644 --- a/tests/test_presets.py +++ b/tests/test_presets.py @@ -20,8 +20,8 @@ def test_minimal_disables_features(self): def test_smart_detects_provider(self): c = MemoryConfig.smart() # Smart should use the best available provider - assert c.embedder.provider in {"gemini", "openai", "ollama", "simple"} - assert c.llm.provider in {"gemini", "openai", "ollama", "mock"} + assert c.embedder.provider in {"gemini", "openai", "ollama", "nvidia", "qwen", "simple"} + assert c.llm.provider in {"dhee", "gemini", "openai", "ollama", "nvidia", "mock"} def test_smart_no_scenes(self): c = MemoryConfig.smart() diff --git a/tests/test_scene.py b/tests/test_scene.py index 6a969ff..4f1c1c0 100644 --- a/tests/test_scene.py +++ b/tests/test_scene.py @@ -3,7 +3,7 @@ import os import tempfile import uuid -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import pytest @@ -69,12 +69,12 @@ def test_no_location(self): class TestBoundaryDetection: def test_no_current_scene(self, processor): - result = processor.detect_boundary("hello", datetime.utcnow().isoformat(), None) + result = processor.detect_boundary("hello", datetime.now(timezone.utc).isoformat(), None) assert result.is_new_scene is True assert result.reason == "no_scene" def test_time_gap(self, processor): - now = datetime.utcnow() + now = datetime.now(timezone.utc) old_time = (now - timedelta(minutes=60)).isoformat() scene = {"start_time": old_time, "end_time": old_time, "memory_ids": ["a"]} result = processor.detect_boundary("hi", now.isoformat(), scene) @@ -82,7 +82,7 @@ def test_time_gap(self, processor): assert result.reason == "time_gap" def test_no_gap(self, processor): - now = datetime.utcnow() + now = datetime.now(timezone.utc) recent = (now - timedelta(minutes=5)).isoformat() scene = { "start_time": recent, @@ -95,7 +95,7 @@ def test_no_gap(self, processor): assert result.is_new_scene is False def test_max_memories(self, processor): - now = datetime.utcnow() + now = datetime.now(timezone.utc) recent = (now - timedelta(minutes=1)).isoformat() scene = { "start_time": recent, @@ -109,7 +109,7 @@ def test_max_memories(self, processor): assert result.reason == "max_memories" def test_topic_shift(self, processor): - now = datetime.utcnow() + now = datetime.now(timezone.utc) recent = (now - timedelta(minutes=1)).isoformat() # Orthogonal embeddings = similarity 0 scene_emb = [1.0, 0.0, 0.0] @@ -126,7 +126,7 @@ def test_topic_shift(self, processor): assert result.reason == "topic_shift" def test_location_change(self, processor): - now = datetime.utcnow() + now = datetime.now(timezone.utc) recent = (now - timedelta(minutes=1)).isoformat() scene = { "start_time": recent, @@ -143,7 +143,7 @@ def test_location_change(self, processor): class TestSceneLifecycle: def test_create_scene(self, processor, db): mem_id = str(uuid.uuid4()) - now = datetime.utcnow().isoformat() + now = datetime.now(timezone.utc).isoformat() # Add a memory first db.add_memory({"id": mem_id, "memory": "test", "user_id": "u1"}) @@ -166,7 +166,7 @@ def test_create_scene(self, processor, db): def test_add_memory_to_scene(self, processor, db): mem1 = str(uuid.uuid4()) mem2 = str(uuid.uuid4()) - now = datetime.utcnow().isoformat() + now = datetime.now(timezone.utc).isoformat() db.add_memory({"id": mem1, "memory": "first", "user_id": "u1"}) db.add_memory({"id": mem2, "memory": "second", "user_id": "u1"}) @@ -178,7 +178,7 @@ def test_add_memory_to_scene(self, processor, db): def test_close_scene(self, processor, db): mem_id = str(uuid.uuid4()) - now = datetime.utcnow().isoformat() + now = datetime.now(timezone.utc).isoformat() db.add_memory({"id": mem_id, "memory": "test", "user_id": "u1"}) scene = processor.create_scene(mem_id, "u1", now, topic="topic") @@ -190,7 +190,7 @@ def test_close_scene(self, processor, db): def test_get_open_scene(self, processor, db): mem_id = str(uuid.uuid4()) - now = datetime.utcnow().isoformat() + now = datetime.now(timezone.utc).isoformat() db.add_memory({"id": mem_id, "memory": "test", "user_id": "u1"}) processor.create_scene(mem_id, "u1", now, topic="t1") @@ -201,7 +201,7 @@ def test_get_open_scene(self, processor, db): class TestSceneSearch: def test_keyword_search(self, processor, db): mem_id = str(uuid.uuid4()) - now = datetime.utcnow().isoformat() + now = datetime.now(timezone.utc).isoformat() db.add_memory({"id": mem_id, "memory": "test", "user_id": "u1"}) processor.create_scene(mem_id, "u1", now, topic="python debugging session")