From 0ea08ee8269c1c89844ff770a5c47e9278e74979 Mon Sep 17 00:00:00 2001 From: NISH1001 Date: Mon, 13 Apr 2026 16:45:45 -0500 Subject: [PATCH 1/2] Add per-category score to MultiRiskGraniteGuardianTool via Step 2 logprobs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes I have done - Added `score_threshold` config field to `MultiRiskGraniteGuardianToolConfig` to optionally drop low-confidence detections and reduce false positives. Defaults to `0.0` (no filter) to preserve current behavior. - Each entry in `risk_results` now includes a per-category `score` in [0, 1], derived from the logprob of the category's first emitted token in Step 2. - Categories whose per-category score is below `score_threshold` are dropped from `detected_risks` and `risk_results`. Categories without a score (e.g., when Ollama does not return logprobs) pass through unfiltered as a graceful fallback. - Step 1 decision remains label-based (`Yes`/`No`) as per the model card; no extra Ollama calls are added. ### Why The multi-harm model's self-reported text confidence (e.g., `"High"`, `"Not Harmful"`) is effectively binary in practice and useless for thresholding. False positives on some categories (e.g., `Harmful` flagged on sarcasm) could not be filtered without manual category exclusion. Per-category logprob-derived scores give callers a real numeric signal to threshold on (e.g., `Violence=0.97` vs `Harmful=0.35` for the same input). ### How I made the changes - `akd/guardrails/providers/granite_guardian.py`: - `MultiRiskGraniteGuardianToolConfig`: added `score_threshold: float = 0.0` with `ge=0.0, le=1.0` validation. - `_call_category_detection`: added top-level `logprobs: True` and `top_logprobs: 5` to the Ollama `/api/generate` request body (Ollama accepts these as top-level params, not inside `options`). - `_parse_categories_with_scores`: new helper that parses comma-separated categories and computes per-category scores as `exp(first_token_logprob)`. - `_first_token_logprob_per_category`: new static helper that walks the token stream, skipping whitespace/commas, and returns the logprob of the first token of each emitted category. - `_parse_categories`: kept as a thin wrapper over `_parse_categories_with_scores` for backward compatibility. - `_arun`: applies `score_threshold` as a filter in the Step 2 detected-categories list comprehension; builds `risk_results` with `{"is_risky": True, "score": }` per category. ### How to test - `uv run pytest tests/guardrails/` — existing tests still pass (8 failures in `test_granite_think.py` are pre-existing and require a live `granite3.3-guardian:8b` model). - `uv run python scripts/test_multi_harm.py` with live Ollama + `granite-guardian-3.2-5b-multi-harm-GGUF` — verifies per-category scores appear in `risk_results` (e.g., Violence=0.97, Unethical Behavior=0.76 for the same violent input; Harmful=0.35 on sarcasm, correctly flagged as low-confidence). --- akd/guardrails/providers/granite_guardian.py | 128 ++++++++++++++++--- 1 file changed, 107 insertions(+), 21 deletions(-) diff --git a/akd/guardrails/providers/granite_guardian.py b/akd/guardrails/providers/granite_guardian.py index 5c87904..b7a39eb 100644 --- a/akd/guardrails/providers/granite_guardian.py +++ b/akd/guardrails/providers/granite_guardian.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import math import os import re from enum import StrEnum @@ -321,6 +322,17 @@ class MultiRiskGraniteGuardianToolConfig(GraniteGuardianBaseConfig): ], description="Harm categories to report (filters model output, defaults to all harmful).", ) + score_threshold: float = Field( + default=0.0, + description=( + "Minimum per-category score required to include a detected risk. " + "Range 0.0-1.0. Default 0.0 = no filtering (all detected categories pass). " + "Set (e.g., 0.5) to drop low-confidence detections and reduce false positives. " + "Has no effect if Ollama does not return logprobs (scores are None)." + ), + ge=0.0, + le=1.0, + ) @model_validator(mode="after") def validate_multi_risk_model(self) -> Self: @@ -384,12 +396,13 @@ async def _arun(self, params: GuardrailInput, **kwargs) -> GuardrailOutput: # Step 1: Detect if harmful (Yes/No + confidence) step1_result = await self._call_harm_detection(step1_prompt) - is_harmful = step1_result.get("label", "").lower() == "yes" + label = step1_result.get("label", "").lower() confidence = step1_result.get("confidence", "") + is_harmful = label == "yes" if self.debug: logger.debug( - f"[{self.__class__.__name__}] Step 1 - is_harmful={is_harmful}, confidence={confidence}", + f"[{self.__class__.__name__}] Step 1 - is_harmful={is_harmful}, label={label}, confidence={confidence}", ) if not is_harmful: @@ -399,7 +412,7 @@ async def _arun(self, params: GuardrailInput, **kwargs) -> GuardrailOutput: risk_results={}, provider=self.__class__.__name__, extra={ - "risk_label": "no", + "risk_label": label or "no", "confidence": confidence, "step1_raw": step1_result.get("raw_response"), }, @@ -410,22 +423,32 @@ async def _arun(self, params: GuardrailInput, **kwargs) -> GuardrailOutput: step2_prompt = step1_prompt + step1_result["raw_response"] + _END_OF_TEXT + "\n" step2_result = await self._call_category_detection(step2_prompt) - # Filter to configured categories and exclude non-harmful markers + per_category_scores: dict[GraniteHarmCategory, float | None] = step2_result.get("scores", {}) + threshold = self.config.score_threshold + _non_harmful = (GraniteHarmCategory.NOT_HARMFUL_PROMPT, GraniteHarmCategory.NOT_HARMFUL_RESPONSE) + + # Filter: configured categories only, exclude non-harmful markers, drop categories + # whose per-category score is below the threshold. Categories without a score + # (e.g., when Ollama doesn't return logprobs) pass through unfiltered. detected = [ cat for cat in step2_result.get("categories", []) if cat in categories_to_check - and cat not in (GraniteHarmCategory.NOT_HARMFUL_PROMPT, GraniteHarmCategory.NOT_HARMFUL_RESPONSE) + and cat not in _non_harmful + and (per_category_scores.get(cat) is None or per_category_scores[cat] >= threshold) ] if self.debug: logger.debug( f"[{self.__class__.__name__}] Step 2 - Detected {len(detected)} risks " - f"(filtered from {len(step2_result.get('categories', []))}): {[r.value for r in detected]}", + f"(filtered from {len(step2_result.get('categories', []))}, threshold={threshold}): " + f"{[(r.value, per_category_scores.get(r)) for r in detected]}", ) - # Build per-risk results - risk_results: dict[RiskCategory, dict[str, Any]] = {cat: {"is_risky": True} for cat in detected} + # Build per-risk results with per-category score. + risk_results: dict[RiskCategory, dict[str, Any]] = { + cat: {"is_risky": True, "score": per_category_scores.get(cat)} for cat in detected + } return GuardrailOutput( detected_risks=detected, @@ -511,6 +534,10 @@ async def _call_category_detection(self, prompt: str) -> dict[str, Any]: "model": self.config.model.value, "prompt": prompt, "stream": False, + # logprobs let us derive per-category model confidence from the + # first token emitted for each category name. + "logprobs": True, + "top_logprobs": 5, "options": { "num_ctx": 8192, "temperature": 0, @@ -521,7 +548,9 @@ async def _call_category_detection(self, prompt: str) -> dict[str, Any]: ) response.raise_for_status() - content = response.json().get("response", "") + data = response.json() + content = data.get("response", "") + token_logprobs = data.get("logprobs") or [] if self.debug: logger.debug( @@ -529,34 +558,91 @@ async def _call_category_detection(self, prompt: str) -> dict[str, Any]: f"--- RESPONSE START ---\n{content}\n--- RESPONSE END ---", ) - categories = self._parse_categories(content) + categories_with_scores = self._parse_categories_with_scores(content, token_logprobs) + categories = [cat for cat, _ in categories_with_scores] + scores = {cat: s for cat, s in categories_with_scores} if self.debug: logger.debug( - f"[{self.__class__.__name__}] Step 2 - Parsed categories: {[c.value for c in categories]}", + f"[{self.__class__.__name__}] Step 2 - Parsed categories with scores: " + f"{[(c.value, s) for c, s in categories_with_scores]}", ) return { "categories": categories, + "scores": scores, "raw_response": content, } except Exception as e: logger.error(f"[{self.__class__.__name__}] Step 2 error: {e}") - return {"error": str(e), "categories": []} + return {"error": str(e), "categories": [], "scores": {}} def _parse_categories(self, content: str) -> list[GraniteHarmCategory]: - """Parse comma-separated categories from model output.""" - content = content.replace("", "").strip() - if not content: + """Parse comma-separated categories from model output (no scores).""" + return [cat for cat, _ in self._parse_categories_with_scores(content, [])] + + def _parse_categories_with_scores( + self, + content: str, + token_logprobs: list[dict[str, Any]], + ) -> list[tuple[GraniteHarmCategory, float | None]]: + """Parse categories from model output and compute per-category scores from logprobs. + + For each emitted category, the score is the probability of its first token + (``exp(logprob)`` of the token that starts the category name). None if logprobs + are unavailable. Categories are returned in the order they appear in ``content``. + """ + cleaned = content.replace("", "").strip() + if not cleaned: return [] - raw_categories = [cat.strip() for cat in content.split(",") if cat.strip()] + raw_categories = [c.strip() for c in cleaned.split(",") if c.strip()] + + # Walk the token stream and find, for each emitted category, the logprob of + # the first non-whitespace, non-comma token that starts it. + first_token_logprobs = self._first_token_logprob_per_category(token_logprobs) - categories = [] - for raw_cat in raw_categories: + results: list[tuple[GraniteHarmCategory, float | None]] = [] + for i, raw_cat in enumerate(raw_categories): try: - categories.append(GraniteHarmCategory(raw_cat)) + cat = GraniteHarmCategory(raw_cat) except ValueError: - logger.warning(f"[MultiHarmGraniteGuardianTool] Unknown category: {raw_cat}") + logger.warning(f"[{self.__class__.__name__}] Unknown category: {raw_cat}") + continue + lp = first_token_logprobs[i] if i < len(first_token_logprobs) else None + score = math.exp(lp) if lp is not None else None + results.append((cat, score)) + + return results + + @staticmethod + def _first_token_logprob_per_category( + token_logprobs: list[dict[str, Any]], + ) -> list[float | None]: + """Extract the logprob of the first token of each comma-separated category. + + The Step 2 output looks like ``Violence, Unethical Behavior``. We skip leading + whitespace/commas and grab the logprob of each category's first token. Returns a + list of logprobs aligned with the comma-separated category order. + """ + if not token_logprobs: + return [] - return categories + first_lps: list[float | None] = [] + expecting_start = True # at position 0 and after each comma + for entry in token_logprobs: + tok = entry.get("token", "") + stripped = tok.strip() + # A comma ends the current category; next non-whitespace token starts the next one. + if "," in tok: + expecting_start = True + continue + # Skip pure whitespace/newline tokens between categories. + if not stripped: + continue + if expecting_start: + lp = entry.get("logprob") + first_lps.append(lp if isinstance(lp, (int, float)) else None) + expecting_start = False + + return first_lps From abd39dd5b7477ce80b4db988513d7323609b86e5 Mon Sep 17 00:00:00 2001 From: NISH1001 Date: Mon, 13 Apr 2026 19:27:02 -0500 Subject: [PATCH 2/2] Consolidate _parse_categories into single dict-returning function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What - Merged `_parse_categories` and `_parse_categories_with_scores` into a single `_parse_categories` method that returns `dict[GraniteHarmCategory, float | None]` (category -> per-category score). - Removed the redundant `"scores"` key from the Step 2 return dict; `"categories"` now holds both the categories and their scores as a dict mapping. - `unfiltered_categories` in `extra` now preserves per-category scores alongside the category list (previously was a bare list). ### Why The previous split had `_parse_categories` as a thin list-returning wrapper over `_parse_categories_with_scores` purely for backward compatibility, but `_parse_categories` was only called internally and had no external consumers — dead weight. A single dict return is also a more natural fit: `risk_results` is already a dict of category -> metadata, and downstream consumers need both iteration and score lookup. Dicts preserve insertion order in Python 3.7+, so the model's emission order is kept. ### How - `akd/guardrails/providers/granite_guardian.py`: - `_parse_categories`: now takes optional `token_logprobs`, returns `dict[GraniteHarmCategory, float | None]`. When `token_logprobs` is None/empty, scores are `None` (same behavior as the pre-logprobs version, just wrapped in dict keys instead of a list). - `_call_category_detection`: returns `{"categories": , "raw_response": ...}` (removed the separate `"scores"` key). - `_arun`: renamed local from `per_category_scores` to `category_scores`; iterates `category_scores.items()` directly in the filter comprehension; passes the whole `category_scores` dict as `extra["unfiltered_categories"]` to preserve scores in observability output. ### How to test - `uv run pytest tests/guardrails/ --ignore=tests/guardrails/test_granite_think.py` — all 21 tests pass. - `uv run python scripts/test_multi_harm.py` with live Ollama — confirms `risk_results` still has `score` per category and `extra["unfiltered_categories"]` now includes scores. --- akd/guardrails/providers/granite_guardian.py | 52 ++++++++------------ 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/akd/guardrails/providers/granite_guardian.py b/akd/guardrails/providers/granite_guardian.py index b7a39eb..ffe8cad 100644 --- a/akd/guardrails/providers/granite_guardian.py +++ b/akd/guardrails/providers/granite_guardian.py @@ -423,7 +423,7 @@ async def _arun(self, params: GuardrailInput, **kwargs) -> GuardrailOutput: step2_prompt = step1_prompt + step1_result["raw_response"] + _END_OF_TEXT + "\n" step2_result = await self._call_category_detection(step2_prompt) - per_category_scores: dict[GraniteHarmCategory, float | None] = step2_result.get("scores", {}) + category_scores: dict[GraniteHarmCategory, float | None] = step2_result.get("categories", {}) threshold = self.config.score_threshold _non_harmful = (GraniteHarmCategory.NOT_HARMFUL_PROMPT, GraniteHarmCategory.NOT_HARMFUL_RESPONSE) @@ -432,22 +432,20 @@ async def _arun(self, params: GuardrailInput, **kwargs) -> GuardrailOutput: # (e.g., when Ollama doesn't return logprobs) pass through unfiltered. detected = [ cat - for cat in step2_result.get("categories", []) - if cat in categories_to_check - and cat not in _non_harmful - and (per_category_scores.get(cat) is None or per_category_scores[cat] >= threshold) + for cat, score in category_scores.items() + if cat in categories_to_check and cat not in _non_harmful and (score is None or score >= threshold) ] if self.debug: logger.debug( f"[{self.__class__.__name__}] Step 2 - Detected {len(detected)} risks " - f"(filtered from {len(step2_result.get('categories', []))}, threshold={threshold}): " - f"{[(r.value, per_category_scores.get(r)) for r in detected]}", + f"(filtered from {len(category_scores)}, threshold={threshold}): " + f"{[(r.value, category_scores.get(r)) for r in detected]}", ) # Build per-risk results with per-category score. risk_results: dict[RiskCategory, dict[str, Any]] = { - cat: {"is_risky": True, "score": per_category_scores.get(cat)} for cat in detected + cat: {"is_risky": True, "score": category_scores.get(cat)} for cat in detected } return GuardrailOutput( @@ -459,7 +457,7 @@ async def _arun(self, params: GuardrailInput, **kwargs) -> GuardrailOutput: "confidence": confidence, "step1_raw": step1_result.get("raw_response"), "step2_raw": step2_result.get("raw_response"), - "unfiltered_categories": step2_result.get("categories", []), + "unfiltered_categories": category_scores, }, ) @@ -558,51 +556,44 @@ async def _call_category_detection(self, prompt: str) -> dict[str, Any]: f"--- RESPONSE START ---\n{content}\n--- RESPONSE END ---", ) - categories_with_scores = self._parse_categories_with_scores(content, token_logprobs) - categories = [cat for cat, _ in categories_with_scores] - scores = {cat: s for cat, s in categories_with_scores} + categories = self._parse_categories(content, token_logprobs) if self.debug: logger.debug( f"[{self.__class__.__name__}] Step 2 - Parsed categories with scores: " - f"{[(c.value, s) for c, s in categories_with_scores]}", + f"{[(c.value, s) for c, s in categories.items()]}", ) return { "categories": categories, - "scores": scores, "raw_response": content, } except Exception as e: logger.error(f"[{self.__class__.__name__}] Step 2 error: {e}") - return {"error": str(e), "categories": [], "scores": {}} + return {"error": str(e), "categories": {}} - def _parse_categories(self, content: str) -> list[GraniteHarmCategory]: - """Parse comma-separated categories from model output (no scores).""" - return [cat for cat, _ in self._parse_categories_with_scores(content, [])] - - def _parse_categories_with_scores( + def _parse_categories( self, content: str, - token_logprobs: list[dict[str, Any]], - ) -> list[tuple[GraniteHarmCategory, float | None]]: - """Parse categories from model output and compute per-category scores from logprobs. + token_logprobs: list[dict[str, Any]] | None = None, + ) -> dict[GraniteHarmCategory, float | None]: + """Parse categories from model output into a ``{category: score}`` mapping. - For each emitted category, the score is the probability of its first token - (``exp(logprob)`` of the token that starts the category name). None if logprobs - are unavailable. Categories are returned in the order they appear in ``content``. + The score is ``exp(first_token_logprob)`` (model confidence in emitting that + category's first token), or ``None`` when ``token_logprobs`` is not provided. + Python dicts preserve insertion order, so the model's emission order is kept. """ cleaned = content.replace("", "").strip() if not cleaned: - return [] + return {} raw_categories = [c.strip() for c in cleaned.split(",") if c.strip()] # Walk the token stream and find, for each emitted category, the logprob of # the first non-whitespace, non-comma token that starts it. - first_token_logprobs = self._first_token_logprob_per_category(token_logprobs) + first_token_logprobs = self._first_token_logprob_per_category(token_logprobs or []) - results: list[tuple[GraniteHarmCategory, float | None]] = [] + results: dict[GraniteHarmCategory, float | None] = {} for i, raw_cat in enumerate(raw_categories): try: cat = GraniteHarmCategory(raw_cat) @@ -610,8 +601,7 @@ def _parse_categories_with_scores( logger.warning(f"[{self.__class__.__name__}] Unknown category: {raw_cat}") continue lp = first_token_logprobs[i] if i < len(first_token_logprobs) else None - score = math.exp(lp) if lp is not None else None - results.append((cat, score)) + results[cat] = math.exp(lp) if lp is not None else None return results