diff --git a/akd/guardrails/providers/granite_guardian.py b/akd/guardrails/providers/granite_guardian.py index 5c87904..ffe8cad 100644 --- a/akd/guardrails/providers/granite_guardian.py +++ b/akd/guardrails/providers/granite_guardian.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import math import os import re from enum import StrEnum @@ -321,6 +322,17 @@ class MultiRiskGraniteGuardianToolConfig(GraniteGuardianBaseConfig): ], description="Harm categories to report (filters model output, defaults to all harmful).", ) + score_threshold: float = Field( + default=0.0, + description=( + "Minimum per-category score required to include a detected risk. " + "Range 0.0-1.0. Default 0.0 = no filtering (all detected categories pass). " + "Set (e.g., 0.5) to drop low-confidence detections and reduce false positives. " + "Has no effect if Ollama does not return logprobs (scores are None)." + ), + ge=0.0, + le=1.0, + ) @model_validator(mode="after") def validate_multi_risk_model(self) -> Self: @@ -384,12 +396,13 @@ async def _arun(self, params: GuardrailInput, **kwargs) -> GuardrailOutput: # Step 1: Detect if harmful (Yes/No + confidence) step1_result = await self._call_harm_detection(step1_prompt) - is_harmful = step1_result.get("label", "").lower() == "yes" + label = step1_result.get("label", "").lower() confidence = step1_result.get("confidence", "") + is_harmful = label == "yes" if self.debug: logger.debug( - f"[{self.__class__.__name__}] Step 1 - is_harmful={is_harmful}, confidence={confidence}", + f"[{self.__class__.__name__}] Step 1 - is_harmful={is_harmful}, label={label}, confidence={confidence}", ) if not is_harmful: @@ -399,7 +412,7 @@ async def _arun(self, params: GuardrailInput, **kwargs) -> GuardrailOutput: risk_results={}, provider=self.__class__.__name__, extra={ - "risk_label": "no", + "risk_label": label or "no", "confidence": confidence, "step1_raw": step1_result.get("raw_response"), }, @@ -410,22 +423,30 @@ async def _arun(self, params: GuardrailInput, **kwargs) -> GuardrailOutput: step2_prompt = step1_prompt + step1_result["raw_response"] + _END_OF_TEXT + "\n" step2_result = await self._call_category_detection(step2_prompt) - # Filter to configured categories and exclude non-harmful markers + category_scores: dict[GraniteHarmCategory, float | None] = step2_result.get("categories", {}) + threshold = self.config.score_threshold + _non_harmful = (GraniteHarmCategory.NOT_HARMFUL_PROMPT, GraniteHarmCategory.NOT_HARMFUL_RESPONSE) + + # Filter: configured categories only, exclude non-harmful markers, drop categories + # whose per-category score is below the threshold. Categories without a score + # (e.g., when Ollama doesn't return logprobs) pass through unfiltered. detected = [ cat - for cat in step2_result.get("categories", []) - if cat in categories_to_check - and cat not in (GraniteHarmCategory.NOT_HARMFUL_PROMPT, GraniteHarmCategory.NOT_HARMFUL_RESPONSE) + for cat, score in category_scores.items() + if cat in categories_to_check and cat not in _non_harmful and (score is None or score >= threshold) ] if self.debug: logger.debug( f"[{self.__class__.__name__}] Step 2 - Detected {len(detected)} risks " - f"(filtered from {len(step2_result.get('categories', []))}): {[r.value for r in detected]}", + f"(filtered from {len(category_scores)}, threshold={threshold}): " + f"{[(r.value, category_scores.get(r)) for r in detected]}", ) - # Build per-risk results - risk_results: dict[RiskCategory, dict[str, Any]] = {cat: {"is_risky": True} for cat in detected} + # Build per-risk results with per-category score. + risk_results: dict[RiskCategory, dict[str, Any]] = { + cat: {"is_risky": True, "score": category_scores.get(cat)} for cat in detected + } return GuardrailOutput( detected_risks=detected, @@ -436,7 +457,7 @@ async def _arun(self, params: GuardrailInput, **kwargs) -> GuardrailOutput: "confidence": confidence, "step1_raw": step1_result.get("raw_response"), "step2_raw": step2_result.get("raw_response"), - "unfiltered_categories": step2_result.get("categories", []), + "unfiltered_categories": category_scores, }, ) @@ -511,6 +532,10 @@ async def _call_category_detection(self, prompt: str) -> dict[str, Any]: "model": self.config.model.value, "prompt": prompt, "stream": False, + # logprobs let us derive per-category model confidence from the + # first token emitted for each category name. + "logprobs": True, + "top_logprobs": 5, "options": { "num_ctx": 8192, "temperature": 0, @@ -521,7 +546,9 @@ async def _call_category_detection(self, prompt: str) -> dict[str, Any]: ) response.raise_for_status() - content = response.json().get("response", "") + data = response.json() + content = data.get("response", "") + token_logprobs = data.get("logprobs") or [] if self.debug: logger.debug( @@ -529,11 +556,12 @@ async def _call_category_detection(self, prompt: str) -> dict[str, Any]: f"--- RESPONSE START ---\n{content}\n--- RESPONSE END ---", ) - categories = self._parse_categories(content) + categories = self._parse_categories(content, token_logprobs) if self.debug: logger.debug( - f"[{self.__class__.__name__}] Step 2 - Parsed categories: {[c.value for c in categories]}", + f"[{self.__class__.__name__}] Step 2 - Parsed categories with scores: " + f"{[(c.value, s) for c, s in categories.items()]}", ) return { @@ -542,21 +570,69 @@ async def _call_category_detection(self, prompt: str) -> dict[str, Any]: } except Exception as e: logger.error(f"[{self.__class__.__name__}] Step 2 error: {e}") - return {"error": str(e), "categories": []} + return {"error": str(e), "categories": {}} - def _parse_categories(self, content: str) -> list[GraniteHarmCategory]: - """Parse comma-separated categories from model output.""" - content = content.replace("", "").strip() - if not content: - return [] + def _parse_categories( + self, + content: str, + token_logprobs: list[dict[str, Any]] | None = None, + ) -> dict[GraniteHarmCategory, float | None]: + """Parse categories from model output into a ``{category: score}`` mapping. + + The score is ``exp(first_token_logprob)`` (model confidence in emitting that + category's first token), or ``None`` when ``token_logprobs`` is not provided. + Python dicts preserve insertion order, so the model's emission order is kept. + """ + cleaned = content.replace("", "").strip() + if not cleaned: + return {} + + raw_categories = [c.strip() for c in cleaned.split(",") if c.strip()] - raw_categories = [cat.strip() for cat in content.split(",") if cat.strip()] + # Walk the token stream and find, for each emitted category, the logprob of + # the first non-whitespace, non-comma token that starts it. + first_token_logprobs = self._first_token_logprob_per_category(token_logprobs or []) - categories = [] - for raw_cat in raw_categories: + results: dict[GraniteHarmCategory, float | None] = {} + for i, raw_cat in enumerate(raw_categories): try: - categories.append(GraniteHarmCategory(raw_cat)) + cat = GraniteHarmCategory(raw_cat) except ValueError: - logger.warning(f"[MultiHarmGraniteGuardianTool] Unknown category: {raw_cat}") + logger.warning(f"[{self.__class__.__name__}] Unknown category: {raw_cat}") + continue + lp = first_token_logprobs[i] if i < len(first_token_logprobs) else None + results[cat] = math.exp(lp) if lp is not None else None + + return results + + @staticmethod + def _first_token_logprob_per_category( + token_logprobs: list[dict[str, Any]], + ) -> list[float | None]: + """Extract the logprob of the first token of each comma-separated category. + + The Step 2 output looks like ``Violence, Unethical Behavior``. We skip leading + whitespace/commas and grab the logprob of each category's first token. Returns a + list of logprobs aligned with the comma-separated category order. + """ + if not token_logprobs: + return [] - return categories + first_lps: list[float | None] = [] + expecting_start = True # at position 0 and after each comma + for entry in token_logprobs: + tok = entry.get("token", "") + stripped = tok.strip() + # A comma ends the current category; next non-whitespace token starts the next one. + if "," in tok: + expecting_start = True + continue + # Skip pure whitespace/newline tokens between categories. + if not stripped: + continue + if expecting_start: + lp = entry.get("logprob") + first_lps.append(lp if isinstance(lp, (int, float)) else None) + expecting_start = False + + return first_lps