From fac351de48d3a3b232384c6a0ee6ba223d968919 Mon Sep 17 00:00:00 2001 From: Ashish-dwi99 Date: Tue, 31 Mar 2026 13:39:53 +0530 Subject: [PATCH] V2.2.2 --- dhee/adapters/base.py | 1 + dhee/core/buddhi.py | 34 +++- dhee/core/policy.py | 315 ++++++++++++++++++++++++++++++++++--- dhee/simple.py | 1 + tests/test_cognition_v3.py | 196 +++++++++++++++++++++++ 5 files changed, 520 insertions(+), 27 deletions(-) diff --git a/dhee/adapters/base.py b/dhee/adapters/base.py index b525cdf..60c797c 100644 --- a/dhee/adapters/base.py +++ b/dhee/adapters/base.py @@ -264,6 +264,7 @@ def checkpoint( user_id=uid, task_type=task_type or "general", what_worked=what_worked, what_failed=what_failed, key_decision=key_decision, + outcome_score=score if score is not None else None, ) result["insights_created"] = len(insights) diff --git a/dhee/core/buddhi.py b/dhee/core/buddhi.py index f2b8a60..0a8e225 100644 --- a/dhee/core/buddhi.py +++ b/dhee/core/buddhi.py @@ -1002,6 +1002,7 @@ def reflect( what_worked: Optional[str] = None, what_failed: Optional[str] = None, key_decision: Optional[str] = None, + outcome_score: Optional[float] = None, ) -> List[Insight]: """Agent-triggered reflection. Synthesizes insights from experience. @@ -1009,6 +1010,9 @@ def reflect( This is the explicit version of DGM-H's persistent memory — the agent tells Dhee what it learned, and Dhee stores it as transferable insight. + + If outcome_score is provided, policy utility is updated using the + performance delta between the moving-average baseline and actual score. """ new_insights = [] @@ -1076,16 +1080,31 @@ def reflect( except Exception: pass - # Phase 3: Extract policy from task outcomes + # Phase 3: Extract policy from task outcomes, with utility deltas + # Compute baseline from moving average for utility scoring (D2Skill) + baseline_score = None + if outcome_score is not None: + try: + key = f"{user_id}:{task_type}" + records = self._performance.get(key, []) + if len(records) >= 2: + recent = records[-min(10, len(records)):] + baseline_score = sum(r["score"] for r in recent) / len(recent) + except Exception: + pass + if what_worked: try: p_store = self._get_policy_store() - # Record success for any matching active policies matched = p_store.match_policies(user_id, task_type, f"{task_type} task") for policy in matched: - p_store.record_outcome(policy.id, success=True) + p_store.record_outcome( + policy.id, + success=True, + baseline_score=baseline_score, + actual_score=outcome_score, + ) - # If we have enough task history, try to extract a new policy ts_store = self._get_task_state_store() completed = ts_store.get_tasks_by_type(user_id, task_type, limit=10) if len(completed) >= 3: @@ -1099,7 +1118,12 @@ def reflect( p_store = self._get_policy_store() matched = p_store.match_policies(user_id, task_type, f"{task_type} task") for policy in matched: - p_store.record_outcome(policy.id, success=False) + p_store.record_outcome( + policy.id, + success=False, + baseline_score=baseline_score, + actual_score=outcome_score, + ) except Exception: pass diff --git a/dhee/core/policy.py b/dhee/core/policy.py index 0887763..d52c8d4 100644 --- a/dhee/core/policy.py +++ b/dhee/core/policy.py @@ -1,4 +1,4 @@ -"""PolicyCase — outcome-linked condition->action rules. +"""PolicyCase — outcome-linked condition->action rules with measured utility. A PolicyCase is NOT a text reflection like "I learned that X works better." It is a structured, executable rule: @@ -6,15 +6,23 @@ condition: When task_type matches AND context contains pattern action: Use approach X with parameters Y evidence: Won 7/10 times when applied (outcome-tracked) + utility: Measured +0.23 performance delta when applied vs baseline + +Policies exist at two granularities (from D2Skill, arXiv:2603.28716): + - TASK: high-level strategy for a task type ("for bug_fix tasks, start with git blame") + - STEP: local correction for a specific step-state ("when tests fail after a fix, check imports first") Policies are: - Extracted from TaskState outcomes (what plan succeeded for what task type) - - Validated by tracking win-rate across applications + - Validated by tracking win-rate AND measured utility across applications + - Retrieved using similarity + utility + exploration (not just match score) + - Pruned by utility to keep the store bounded and high-signal - Promoted/demoted based on performance (not just age) - Surfaced in HyperContext as actionable guidance The key difference from insights: insights are descriptive ("X works"), -policies are prescriptive ("when you see A, do B, because it won C% of the time"). +policies are prescriptive ("when you see A, do B, because it won C% of the time +and improved outcomes by +0.23"). Policy lifecycle: proposed -> active -> validated -> deprecated """ @@ -41,6 +49,11 @@ class PolicyStatus(str, Enum): DEPRECATED = "deprecated" # Win rate dropped below threshold +class PolicyGranularity(str, Enum): + TASK = "task" # High-level strategy guidance for a task type + STEP = "step" # Local correction/decision support for a specific step-state + + @dataclass class PolicyCondition: """When this policy should fire.""" @@ -48,8 +61,9 @@ class PolicyCondition: context_patterns: List[str] = field(default_factory=list) # keywords in task description min_confidence: float = 0.0 # only fire if policy confidence >= this exclude_patterns: List[str] = field(default_factory=list) # don't fire if these present + step_patterns: List[str] = field(default_factory=list) # for STEP policies: step-state keywords - def matches(self, task_type: str, task_description: str) -> float: + def matches(self, task_type: str, task_description: str, step_context: str = "") -> float: """Score how well this condition matches. Returns 0.0-1.0.""" if not self.task_types: return 0.0 @@ -81,7 +95,16 @@ def matches(self, task_type: str, task_description: str) -> float: else: context_score = 1.0 # No pattern constraint = always matches - return type_match * context_score + # Step pattern match (for STEP-granularity policies) + step_score = 1.0 + if self.step_patterns and step_context: + step_lower = step_context.lower() + step_matched = sum(1 for p in self.step_patterns if p.lower() in step_lower) + step_score = step_matched / len(self.step_patterns) + elif self.step_patterns and not step_context: + step_score = 0.3 # Weak match if step context not provided + + return type_match * context_score * step_score def to_dict(self) -> Dict[str, Any]: return { @@ -89,6 +112,7 @@ def to_dict(self) -> Dict[str, Any]: "context_patterns": self.context_patterns, "min_confidence": self.min_confidence, "exclude_patterns": self.exclude_patterns, + "step_patterns": self.step_patterns, } @classmethod @@ -98,6 +122,7 @@ def from_dict(cls, d: Dict[str, Any]) -> PolicyCondition: context_patterns=d.get("context_patterns", []), min_confidence=d.get("min_confidence", 0.0), exclude_patterns=d.get("exclude_patterns", []), + step_patterns=d.get("step_patterns", []), ) @@ -129,7 +154,7 @@ def from_dict(cls, d: Dict[str, Any]) -> PolicyAction: @dataclass class PolicyCase: - """A condition->action rule with outcome tracking.""" + """A condition->action rule with outcome tracking and measured utility.""" id: str user_id: str @@ -141,17 +166,28 @@ class PolicyCase: created_at: float updated_at: float + # Granularity (D2Skill dual-granularity) + granularity: PolicyGranularity = PolicyGranularity.TASK + # Outcome tracking apply_count: int = 0 # times this policy was applied success_count: int = 0 # times application led to success failure_count: int = 0 # times application led to failure + # Utility tracking (D2Skill measured performance delta) + utility: float = 0.0 # EMA of performance deltas + last_delta: float = 0.0 # most recent performance delta + cumulative_delta: float = 0.0 # sum of all deltas for lifetime tracking + # Source tracking source_task_ids: List[str] = field(default_factory=list) source_episode_ids: List[str] = field(default_factory=list) tags: List[str] = field(default_factory=list) + # Utility EMA smoothing factor + _UTILITY_ALPHA: float = 0.3 + @property def win_rate(self) -> float: """Win rate with Laplace smoothing (add-1).""" @@ -170,8 +206,47 @@ def confidence(self) -> float: spread = z * math.sqrt((p * (1 - p) + z * z / (4 * n)) / n) return max(0.0, (center - spread) / denominator) + @property + def exploration_bonus(self) -> float: + """UCB-style exploration bonus for under-tested policies.""" + return 1.0 / math.sqrt(self.apply_count + 1) + + @property + def retrieval_score_components(self) -> Dict[str, float]: + """Return components for debugging retrieval ranking.""" + return { + "utility": self.utility, + "confidence": self.confidence, + "win_rate": self.win_rate, + "exploration_bonus": self.exploration_bonus, + "apply_count": float(self.apply_count), + } + def record_application(self, success: bool) -> None: - """Record an application of this policy and its outcome.""" + """Record an application of this policy and its outcome (no delta).""" + self.apply_count += 1 + if success: + self.success_count += 1 + else: + self.failure_count += 1 + self.updated_at = time.time() + self._update_status() + + def record_outcome( + self, + success: bool, + baseline_score: Optional[float] = None, + actual_score: Optional[float] = None, + ) -> float: + """Record an application with optional measured performance delta. + + If baseline_score and actual_score are provided, computes the performance + delta and updates utility via EMA. This is the core D2Skill insight: + skills/policies are not just stored, they are re-scored based on measured + contribution. + + Returns the computed delta (0.0 if no scores provided). + """ self.apply_count += 1 if success: self.success_count += 1 @@ -179,8 +254,19 @@ def record_application(self, success: bool) -> None: self.failure_count += 1 self.updated_at = time.time() - # Auto-promote/demote based on evidence + delta = 0.0 + if baseline_score is not None and actual_score is not None: + delta = actual_score - baseline_score + self.last_delta = delta + self.cumulative_delta += delta + # EMA update: utility tracks the running average performance lift + self.utility = ( + self._UTILITY_ALPHA * delta + + (1 - self._UTILITY_ALPHA) * self.utility + ) + self._update_status() + return delta def _update_status(self) -> None: """Update status based on accumulated evidence.""" @@ -201,11 +287,15 @@ def to_dict(self) -> Dict[str, Any]: "condition": self.condition.to_dict(), "action": self.action.to_dict(), "status": self.status.value, + "granularity": self.granularity.value, "created_at": self.created_at, "updated_at": self.updated_at, "apply_count": self.apply_count, "success_count": self.success_count, "failure_count": self.failure_count, + "utility": self.utility, + "last_delta": self.last_delta, + "cumulative_delta": self.cumulative_delta, "source_task_ids": self.source_task_ids, "source_episode_ids": self.source_episode_ids, "tags": self.tags, @@ -215,9 +305,11 @@ def to_compact(self) -> Dict[str, Any]: """Compact format for HyperContext.""" result = { "name": self.name, + "level": self.granularity.value, "when": ", ".join(self.condition.task_types), "do": self.action.approach[:200], "win_rate": round(self.win_rate, 2), + "utility": round(self.utility, 3), "confidence": round(self.confidence, 2), "applied": self.apply_count, } @@ -236,11 +328,15 @@ def from_dict(cls, d: Dict[str, Any]) -> PolicyCase: condition=PolicyCondition.from_dict(d.get("condition", {})), action=PolicyAction.from_dict(d.get("action", {})), status=PolicyStatus(d.get("status", "proposed")), + granularity=PolicyGranularity(d.get("granularity", "task")), created_at=d.get("created_at", time.time()), updated_at=d.get("updated_at", time.time()), apply_count=d.get("apply_count", 0), success_count=d.get("success_count", 0), failure_count=d.get("failure_count", 0), + utility=d.get("utility", 0.0), + last_delta=d.get("last_delta", 0.0), + cumulative_delta=d.get("cumulative_delta", 0.0), source_task_ids=d.get("source_task_ids", []), source_episode_ids=d.get("source_episode_ids", []), tags=d.get("tags", []), @@ -250,11 +346,17 @@ def from_dict(cls, d: Dict[str, Any]) -> PolicyCase: class PolicyStore: """Manages policy lifecycle, matching, and learning from task outcomes. + Retrieval uses a three-signal ranking (from D2Skill, arXiv:2603.28716): + 1. Condition match (semantic similarity to the current task context) + 2. Utility score (measured performance delta when this policy was applied) + 3. Exploration bonus (UCB-style bonus for under-tested policies) + Policy extraction pipeline: - 1. TaskState completes with success → analyze plan steps - 2. Find similar completed tasks → extract common successful patterns + 1. TaskState completes with success -> analyze plan steps + 2. Find similar completed tasks -> extract common successful patterns 3. Generate PolicyCase with condition (task_type match) and action (plan pattern) - 4. Track applications and outcomes → promote/demote + 4. Track applications and outcomes -> promote/demote + 5. Prune low-utility policies to keep the store bounded This is NOT LLM-dependent. Policy extraction uses structural analysis of task plans and outcomes. LLM can optionally refine policy names/descriptions. @@ -262,6 +364,12 @@ class PolicyStore: MIN_TASKS_FOR_POLICY = 3 # Need at least 3 similar completed tasks SIMILARITY_THRESHOLD = 0.3 # Minimum overlap for "similar" tasks + MAX_POLICIES_PER_USER = 200 # Utility-based pruning threshold + + # Retrieval weights + MATCH_WEIGHT = 0.4 + UTILITY_WEIGHT = 0.35 + EXPLORATION_WEIGHT = 0.25 def __init__(self, data_dir: Optional[str] = None): self._dir = data_dir or os.path.join( @@ -277,9 +385,11 @@ def create_policy( name: str, task_types: List[str], approach: str, + granularity: PolicyGranularity = PolicyGranularity.TASK, steps: Optional[List[str]] = None, avoid: Optional[List[str]] = None, context_patterns: Optional[List[str]] = None, + step_patterns: Optional[List[str]] = None, source_task_ids: Optional[List[str]] = None, source_episode_ids: Optional[List[str]] = None, ) -> PolicyCase: @@ -292,6 +402,7 @@ def create_policy( condition=PolicyCondition( task_types=task_types, context_patterns=context_patterns or [], + step_patterns=step_patterns or [], ), action=PolicyAction( approach=approach, @@ -299,6 +410,7 @@ def create_policy( avoid=avoid or [], ), status=PolicyStatus.PROPOSED, + granularity=granularity, created_at=now, updated_at=now, source_task_ids=source_task_ids or [], @@ -309,6 +421,43 @@ def create_policy( self._save_policy(policy) return policy + def create_step_policy( + self, + user_id: str, + name: str, + task_types: List[str], + step_patterns: List[str], + approach: str, + avoid: Optional[List[str]] = None, + source_task_ids: Optional[List[str]] = None, + ) -> PolicyCase: + """Convenience: create a STEP-granularity policy for local correction. + + Step policies fire when: + - task_type matches AND + - step_patterns match the current step context + + Example: + store.create_step_policy( + user_id="u1", + name="check_imports_on_test_fail", + task_types=["bug_fix"], + step_patterns=["test", "fail", "import"], + approach="Check for missing or circular imports before debugging logic", + avoid=["Don't rewrite the test to make it pass"], + ) + """ + return self.create_policy( + user_id=user_id, + name=name, + task_types=task_types, + granularity=PolicyGranularity.STEP, + approach=approach, + step_patterns=step_patterns, + avoid=avoid, + source_task_ids=source_task_ids, + ) + def extract_from_tasks( self, user_id: str, @@ -330,7 +479,6 @@ def extract_from_tasks( # Find common steps across successful plans step_freq: Dict[str, int] = {} - avoid_freq: Dict[str, int] = {} for task in successful: for step in task.get("plan", []): if step.get("status") == "completed": @@ -342,6 +490,7 @@ def extract_from_tasks( t for t in completed_tasks if t.get("outcome_score", 0) < 0.4 and t.get("plan") ] + avoid_freq: Dict[str, int] = {} for task in failed: for step in task.get("plan", []): if step.get("status") == "failed": @@ -391,12 +540,18 @@ def match_policies( user_id: str, task_type: str, task_description: str, + step_context: str = "", + granularity: Optional[PolicyGranularity] = None, limit: int = 3, ) -> List[PolicyCase]: """Find policies that match the current task context. - Returns policies sorted by (match_score * confidence). - Only returns non-deprecated policies. + Uses three-signal ranking (D2Skill): + score = match_weight * condition_match + + utility_weight * normalized_utility + + exploration_weight * exploration_bonus + + Only returns non-deprecated policies. Optionally filter by granularity. """ scored: List[tuple] = [] for policy in self._policies.values(): @@ -404,29 +559,129 @@ def match_policies( continue if policy.status == PolicyStatus.DEPRECATED: continue + if granularity is not None and policy.granularity != granularity: + continue + + match_score = policy.condition.matches(task_type, task_description, step_context) + if match_score <= 0: + continue + if policy.confidence < policy.condition.min_confidence: + continue - match_score = policy.condition.matches(task_type, task_description) - if match_score > 0 and policy.confidence >= policy.condition.min_confidence: - combined = match_score * (0.5 + 0.5 * policy.confidence) - scored.append((policy, combined)) + # Normalize utility to [0, 1] range using sigmoid + norm_utility = 1.0 / (1.0 + math.exp(-3.0 * policy.utility)) + + combined = ( + self.MATCH_WEIGHT * match_score + + self.UTILITY_WEIGHT * norm_utility + + self.EXPLORATION_WEIGHT * policy.exploration_bonus + ) + scored.append((policy, combined)) scored.sort(key=lambda x: x[1], reverse=True) return [p for p, _ in scored[:limit]] + def match_task_policies( + self, + user_id: str, + task_type: str, + task_description: str, + limit: int = 3, + ) -> List[PolicyCase]: + """Convenience: match only TASK-granularity policies.""" + return self.match_policies( + user_id=user_id, + task_type=task_type, + task_description=task_description, + granularity=PolicyGranularity.TASK, + limit=limit, + ) + + def match_step_policies( + self, + user_id: str, + task_type: str, + task_description: str, + step_context: str, + limit: int = 3, + ) -> List[PolicyCase]: + """Convenience: match only STEP-granularity policies for local correction.""" + return self.match_policies( + user_id=user_id, + task_type=task_type, + task_description=task_description, + step_context=step_context, + granularity=PolicyGranularity.STEP, + limit=limit, + ) + def record_outcome( self, policy_id: str, success: bool, task_id: Optional[str] = None, - ) -> None: - """Record the outcome of applying a policy.""" + baseline_score: Optional[float] = None, + actual_score: Optional[float] = None, + ) -> Optional[float]: + """Record the outcome of applying a policy. + + If baseline_score and actual_score are provided, computes the performance + delta and updates the policy's utility score. Returns the delta. + """ policy = self._policies.get(policy_id) if not policy: - return - policy.record_application(success) + return None + delta = policy.record_outcome( + success=success, + baseline_score=baseline_score, + actual_score=actual_score, + ) if task_id and task_id not in policy.source_task_ids: policy.source_task_ids.append(task_id) self._save_policy(policy) + return delta + + def prune(self, user_id: str, max_policies: Optional[int] = None) -> Dict[str, Any]: + """Prune low-utility policies to keep the store bounded. + + Removes deprecated policies first, then lowest-utility policies + until count is within budget. Policies with status VALIDATED are + protected from pruning. + + Returns stats about what was pruned. + """ + budget = max_policies or self.MAX_POLICIES_PER_USER + user_policies = [ + p for p in self._policies.values() + if p.user_id == user_id + ] + + if len(user_policies) <= budget: + return {"pruned": 0, "total": len(user_policies)} + + # Sort by pruning priority: deprecated first, then by utility ascending + def prune_priority(p: PolicyCase) -> tuple: + # Protected: validated policies sort last + protected = 0 if p.status == PolicyStatus.VALIDATED else 1 + # Deprecated sort first (highest prune priority) + deprecated = 1 if p.status == PolicyStatus.DEPRECATED else 0 + return (protected, deprecated, -p.utility, -p.apply_count) + + candidates = sorted(user_policies, key=prune_priority, reverse=True) + + pruned = 0 + while len(user_policies) - pruned > budget and candidates: + victim = candidates.pop(0) + if victim.status == PolicyStatus.VALIDATED: + break # Don't prune validated policies + self._delete_policy(victim.id) + pruned += 1 + + return { + "pruned": pruned, + "total": len(user_policies) - pruned, + "budget": budget, + } def get_stats(self, user_id: Optional[str] = None) -> Dict[str, Any]: policies = list(self._policies.values()) @@ -434,18 +689,26 @@ def get_stats(self, user_id: Optional[str] = None) -> Dict[str, Any]: policies = [p for p in policies if p.user_id == user_id] by_status = {} + by_granularity = {} for p in policies: by_status[p.status.value] = by_status.get(p.status.value, 0) + 1 + by_granularity[p.granularity.value] = by_granularity.get(p.granularity.value, 0) + 1 validated = [p for p in policies if p.status == PolicyStatus.VALIDATED] + with_utility = [p for p in policies if p.apply_count > 0] return { "total": len(policies), "by_status": by_status, + "by_granularity": by_granularity, "validated_count": len(validated), "avg_win_rate": ( sum(p.win_rate for p in validated) / len(validated) if validated else 0.0 ), + "avg_utility": ( + sum(p.utility for p in with_utility) / len(with_utility) + if with_utility else 0.0 + ), } # ------------------------------------------------------------------ @@ -478,6 +741,14 @@ def _find_similar_policy( return None + def _delete_policy(self, policy_id: str) -> None: + self._policies.pop(policy_id, None) + path = os.path.join(self._dir, f"{policy_id}.json") + try: + os.remove(path) + except OSError: + pass + # ------------------------------------------------------------------ # Persistence # ------------------------------------------------------------------ diff --git a/dhee/simple.py b/dhee/simple.py index 21c07d6..442340c 100644 --- a/dhee/simple.py +++ b/dhee/simple.py @@ -613,6 +613,7 @@ def checkpoint( what_worked=what_worked, what_failed=what_failed, key_decision=key_decision, + outcome_score=score if outcome_score is not None else None, ) result["insights_created"] = len(insights) diff --git a/tests/test_cognition_v3.py b/tests/test_cognition_v3.py index 6f9a4ba..22cfa40 100644 --- a/tests/test_cognition_v3.py +++ b/tests/test_cognition_v3.py @@ -664,6 +664,202 @@ def test_persistence(self, tmpdir): assert len(s2._policies) == 1 +# ═══════════════════════════════════════════════════════════════════════════ +# 8b. PolicyCase — dual-granularity + utility scoring (D2Skill) +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestPolicyUtility: + """Tests for D2Skill-inspired dual-granularity and utility scoring.""" + + def test_utility_update_from_delta(self, tmpdir): + from dhee.core.policy import PolicyStore + store = PolicyStore(data_dir=os.path.join(tmpdir, "policies")) + + policy = store.create_policy("u", "p1", ["bug_fix"], "git blame first") + + # Apply with a positive performance delta + delta = store.record_outcome( + policy.id, success=True, + baseline_score=0.6, actual_score=0.9, + ) + assert delta == pytest.approx(0.3) # 0.9 - 0.6 + p = store._policies[policy.id] + assert p.utility > 0 # EMA should be positive + assert p.last_delta == pytest.approx(0.3) + assert p.cumulative_delta == pytest.approx(0.3) + + def test_utility_ema_smoothing(self, tmpdir): + from dhee.core.policy import PolicyStore + store = PolicyStore(data_dir=os.path.join(tmpdir, "policies")) + + policy = store.create_policy("u", "p1", ["testing"], "write tests first") + + # Series of outcomes with varying deltas + store.record_outcome(policy.id, True, baseline_score=0.5, actual_score=0.8) # +0.3 + u1 = store._policies[policy.id].utility + store.record_outcome(policy.id, True, baseline_score=0.5, actual_score=0.9) # +0.4 + u2 = store._policies[policy.id].utility + store.record_outcome(policy.id, False, baseline_score=0.5, actual_score=0.3) # -0.2 + u3 = store._policies[policy.id].utility + + # Utility should have risen then fallen + assert u2 > u1 + assert u3 < u2 + + def test_negative_utility_from_bad_outcomes(self, tmpdir): + from dhee.core.policy import PolicyStore + store = PolicyStore(data_dir=os.path.join(tmpdir, "policies")) + + policy = store.create_policy("u", "bad", ["debug"], "random changes") + + # Consistently worse outcomes when applied + for _ in range(5): + store.record_outcome(policy.id, False, baseline_score=0.7, actual_score=0.3) + + p = store._policies[policy.id] + assert p.utility < 0 # Negative utility = policy makes things worse + assert p.cumulative_delta < 0 + + def test_dual_granularity_creation(self, tmpdir): + from dhee.core.policy import PolicyStore, PolicyGranularity + store = PolicyStore(data_dir=os.path.join(tmpdir, "policies")) + + task_policy = store.create_policy( + "u", "task_strategy", ["bug_fix"], + approach="Start with git blame", + granularity=PolicyGranularity.TASK, + ) + step_policy = store.create_step_policy( + "u", "import_check", ["bug_fix"], + step_patterns=["test", "fail", "import"], + approach="Check for missing imports before debugging logic", + ) + + assert task_policy.granularity == PolicyGranularity.TASK + assert step_policy.granularity == PolicyGranularity.STEP + + def test_match_by_granularity(self, tmpdir): + from dhee.core.policy import PolicyStore, PolicyGranularity + store = PolicyStore(data_dir=os.path.join(tmpdir, "policies")) + + store.create_policy( + "u", "task_strat", ["bug_fix"], + approach="Start with git blame", + granularity=PolicyGranularity.TASK, + ) + store.create_step_policy( + "u", "step_fix", ["bug_fix"], + step_patterns=["test", "fail"], + approach="Check test imports", + ) + + task_only = store.match_task_policies("u", "bug_fix", "fix auth bug") + step_only = store.match_step_policies( + "u", "bug_fix", "fix auth bug", step_context="test fail import error", + ) + all_policies = store.match_policies("u", "bug_fix", "fix auth bug") + + assert len(task_only) == 1 + assert task_only[0].granularity == PolicyGranularity.TASK + assert len(step_only) == 1 + assert step_only[0].granularity == PolicyGranularity.STEP + assert len(all_policies) == 2 + + def test_retrieval_ranking_utility_beats_no_utility(self, tmpdir): + from dhee.core.policy import PolicyStore + store = PolicyStore(data_dir=os.path.join(tmpdir, "policies")) + + # Create two policies with same match score + p_low = store.create_policy("u", "low_util", ["debug"], "approach A") + p_high = store.create_policy("u", "high_util", ["debug"], "approach B") + + # Give p_high strong positive utility (5 successes with big delta) + for _ in range(5): + store.record_outcome(p_high.id, True, baseline_score=0.5, actual_score=0.9) + # Give p_low weak utility (5 successes with small delta) — stays ACTIVE, not DEPRECATED + for _ in range(5): + store.record_outcome(p_low.id, True, baseline_score=0.5, actual_score=0.55) + + matched = store.match_policies("u", "debug", "debug the thing") + assert len(matched) >= 2 + # High utility should rank first + assert matched[0].id == p_high.id + + def test_exploration_bonus_for_new_policies(self, tmpdir): + from dhee.core.policy import PolicyStore + store = PolicyStore(data_dir=os.path.join(tmpdir, "policies")) + + p_new = store.create_policy("u", "new_pol", ["t"], "approach new") + p_old = store.create_policy("u", "old_pol", ["t"], "approach old") + + # Apply old policy many times but with zero utility + for _ in range(50): + store.record_outcome(p_old.id, True) + + # New policy should have higher exploration bonus + assert p_new.exploration_bonus > p_old.exploration_bonus + + def test_pruning(self, tmpdir): + from dhee.core.policy import PolicyStore, PolicyStatus + store = PolicyStore(data_dir=os.path.join(tmpdir, "policies")) + + # Create many policies + ids = [] + for i in range(10): + p = store.create_policy("u", f"p{i}", ["t"], f"approach {i}") + ids.append(p.id) + + # Deprecate first 5 + for pid in ids[:5]: + for _ in range(5): + store.record_outcome(pid, success=False) + + # Validate last 2 + for pid in ids[-2:]: + for _ in range(20): + store.record_outcome(pid, success=True) + + stats = store.prune("u", max_policies=5) + assert stats["pruned"] == 5 + assert stats["total"] == 5 + + # Validated policies should survive + for pid in ids[-2:]: + assert pid in store._policies + + def test_step_condition_matching(self): + from dhee.core.policy import PolicyCondition + cond = PolicyCondition( + task_types=["bug_fix"], + step_patterns=["test", "fail"], + ) + + # Good match with step context + score = cond.matches("bug_fix", "fix auth", step_context="test fail import") + assert score > 0.5 + + # No step context — weak match + score_weak = cond.matches("bug_fix", "fix auth", step_context="") + assert score_weak < score + + def test_compact_includes_utility_and_level(self, tmpdir): + from dhee.core.policy import PolicyStore, PolicyGranularity + store = PolicyStore(data_dir=os.path.join(tmpdir, "policies")) + + p = store.create_policy( + "u", "p1", ["t"], "approach", + granularity=PolicyGranularity.STEP, + ) + store.record_outcome(p.id, True, baseline_score=0.5, actual_score=0.8) + + compact = store._policies[p.id].to_compact() + assert "utility" in compact + assert "level" in compact + assert compact["level"] == "step" + assert compact["utility"] > 0 + + # ═══════════════════════════════════════════════════════════════════════════ # 9. BeliefNode — confidence + contradiction # ═══════════════════════════════════════════════════════════════════════════