Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 102 additions & 26 deletions akd/guardrails/providers/granite_guardian.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import asyncio
import math
import os
import re
from enum import StrEnum
Expand Down Expand Up @@ -321,6 +322,17 @@ class MultiRiskGraniteGuardianToolConfig(GraniteGuardianBaseConfig):
],
description="Harm categories to report (filters model output, defaults to all harmful).",
)
score_threshold: float = Field(
default=0.0,
description=(
"Minimum per-category score required to include a detected risk. "
"Range 0.0-1.0. Default 0.0 = no filtering (all detected categories pass). "
"Set (e.g., 0.5) to drop low-confidence detections and reduce false positives. "
"Has no effect if Ollama does not return logprobs (scores are None)."
),
ge=0.0,
le=1.0,
)
Comment thread
muthukumaranR marked this conversation as resolved.

@model_validator(mode="after")
def validate_multi_risk_model(self) -> Self:
Expand Down Expand Up @@ -384,12 +396,13 @@ async def _arun(self, params: GuardrailInput, **kwargs) -> GuardrailOutput:
# Step 1: Detect if harmful (Yes/No + confidence)
step1_result = await self._call_harm_detection(step1_prompt)

is_harmful = step1_result.get("label", "").lower() == "yes"
label = step1_result.get("label", "").lower()
confidence = step1_result.get("confidence", "")
is_harmful = label == "yes"

if self.debug:
logger.debug(
f"[{self.__class__.__name__}] Step 1 - is_harmful={is_harmful}, confidence={confidence}",
f"[{self.__class__.__name__}] Step 1 - is_harmful={is_harmful}, label={label}, confidence={confidence}",
)

if not is_harmful:
Expand All @@ -399,7 +412,7 @@ async def _arun(self, params: GuardrailInput, **kwargs) -> GuardrailOutput:
risk_results={},
provider=self.__class__.__name__,
extra={
"risk_label": "no",
"risk_label": label or "no",
"confidence": confidence,
"step1_raw": step1_result.get("raw_response"),
},
Expand All @@ -410,22 +423,30 @@ async def _arun(self, params: GuardrailInput, **kwargs) -> GuardrailOutput:
step2_prompt = step1_prompt + step1_result["raw_response"] + _END_OF_TEXT + "\n<categories>"
step2_result = await self._call_category_detection(step2_prompt)

# Filter to configured categories and exclude non-harmful markers
category_scores: dict[GraniteHarmCategory, float | None] = step2_result.get("categories", {})
threshold = self.config.score_threshold
_non_harmful = (GraniteHarmCategory.NOT_HARMFUL_PROMPT, GraniteHarmCategory.NOT_HARMFUL_RESPONSE)

# Filter: configured categories only, exclude non-harmful markers, drop categories
# whose per-category score is below the threshold. Categories without a score
# (e.g., when Ollama doesn't return logprobs) pass through unfiltered.
detected = [
cat
for cat in step2_result.get("categories", [])
if cat in categories_to_check
and cat not in (GraniteHarmCategory.NOT_HARMFUL_PROMPT, GraniteHarmCategory.NOT_HARMFUL_RESPONSE)
for cat, score in category_scores.items()
if cat in categories_to_check and cat not in _non_harmful and (score is None or score >= threshold)
]

if self.debug:
logger.debug(
f"[{self.__class__.__name__}] Step 2 - Detected {len(detected)} risks "
f"(filtered from {len(step2_result.get('categories', []))}): {[r.value for r in detected]}",
f"(filtered from {len(category_scores)}, threshold={threshold}): "
f"{[(r.value, category_scores.get(r)) for r in detected]}",
)

# Build per-risk results
risk_results: dict[RiskCategory, dict[str, Any]] = {cat: {"is_risky": True} for cat in detected}
# Build per-risk results with per-category score.
risk_results: dict[RiskCategory, dict[str, Any]] = {
cat: {"is_risky": True, "score": category_scores.get(cat)} for cat in detected
}

return GuardrailOutput(
detected_risks=detected,
Expand All @@ -436,7 +457,7 @@ async def _arun(self, params: GuardrailInput, **kwargs) -> GuardrailOutput:
"confidence": confidence,
"step1_raw": step1_result.get("raw_response"),
"step2_raw": step2_result.get("raw_response"),
"unfiltered_categories": step2_result.get("categories", []),
"unfiltered_categories": category_scores,
},
)

Expand Down Expand Up @@ -511,6 +532,10 @@ async def _call_category_detection(self, prompt: str) -> dict[str, Any]:
"model": self.config.model.value,
"prompt": prompt,
"stream": False,
# logprobs let us derive per-category model confidence from the
# first token emitted for each category name.
"logprobs": True,
"top_logprobs": 5,
"options": {
"num_ctx": 8192,
"temperature": 0,
Expand All @@ -521,19 +546,22 @@ async def _call_category_detection(self, prompt: str) -> dict[str, Any]:
)
response.raise_for_status()

content = response.json().get("response", "")
data = response.json()
content = data.get("response", "")
token_logprobs = data.get("logprobs") or []

if self.debug:
logger.debug(
f"[{self.__class__.__name__}] Step 2 - Ollama response:\n"
f"--- RESPONSE START ---\n{content}\n--- RESPONSE END ---",
)

categories = self._parse_categories(content)
categories = self._parse_categories(content, token_logprobs)

if self.debug:
logger.debug(
f"[{self.__class__.__name__}] Step 2 - Parsed categories: {[c.value for c in categories]}",
f"[{self.__class__.__name__}] Step 2 - Parsed categories with scores: "
f"{[(c.value, s) for c, s in categories.items()]}",
)

return {
Expand All @@ -542,21 +570,69 @@ async def _call_category_detection(self, prompt: str) -> dict[str, Any]:
}
except Exception as e:
logger.error(f"[{self.__class__.__name__}] Step 2 error: {e}")
return {"error": str(e), "categories": []}
return {"error": str(e), "categories": {}}

def _parse_categories(self, content: str) -> list[GraniteHarmCategory]:
"""Parse comma-separated categories from model output."""
content = content.replace("</categories>", "").strip()
if not content:
return []
def _parse_categories(
self,
content: str,
token_logprobs: list[dict[str, Any]] | None = None,
) -> dict[GraniteHarmCategory, float | None]:
"""Parse categories from model output into a ``{category: score}`` mapping.

The score is ``exp(first_token_logprob)`` (model confidence in emitting that
category's first token), or ``None`` when ``token_logprobs`` is not provided.
Python dicts preserve insertion order, so the model's emission order is kept.
"""
cleaned = content.replace("</categories>", "").strip()
if not cleaned:
return {}

raw_categories = [c.strip() for c in cleaned.split(",") if c.strip()]

raw_categories = [cat.strip() for cat in content.split(",") if cat.strip()]
# Walk the token stream and find, for each emitted category, the logprob of
# the first non-whitespace, non-comma token that starts it.
first_token_logprobs = self._first_token_logprob_per_category(token_logprobs or [])

categories = []
for raw_cat in raw_categories:
results: dict[GraniteHarmCategory, float | None] = {}
for i, raw_cat in enumerate(raw_categories):
try:
categories.append(GraniteHarmCategory(raw_cat))
cat = GraniteHarmCategory(raw_cat)
except ValueError:
logger.warning(f"[MultiHarmGraniteGuardianTool] Unknown category: {raw_cat}")
logger.warning(f"[{self.__class__.__name__}] Unknown category: {raw_cat}")
continue
lp = first_token_logprobs[i] if i < len(first_token_logprobs) else None
results[cat] = math.exp(lp) if lp is not None else None

return results

@staticmethod
def _first_token_logprob_per_category(
token_logprobs: list[dict[str, Any]],
) -> list[float | None]:
"""Extract the logprob of the first token of each comma-separated category.

The Step 2 output looks like ``Violence, Unethical Behavior``. We skip leading
whitespace/commas and grab the logprob of each category's first token. Returns a
list of logprobs aligned with the comma-separated category order.
"""
if not token_logprobs:
return []

return categories
first_lps: list[float | None] = []
expecting_start = True # at position 0 and after each comma
for entry in token_logprobs:
tok = entry.get("token", "")
stripped = tok.strip()
# A comma ends the current category; next non-whitespace token starts the next one.
if "," in tok:
expecting_start = True
continue
# Skip pure whitespace/newline tokens between categories.
if not stripped:
continue
if expecting_start:
lp = entry.get("logprob")
first_lps.append(lp if isinstance(lp, (int, float)) else None)
expecting_start = False

return first_lps
Loading