Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions client/dashboard/src/pages/security/risk-ui.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
type ReactNode,
} from "react";
import { RULE_CATEGORY_META } from "./policy-data";
import { getCategoryForFinding, getRuleTitleFallback } from "./risk-utils";
import { getRuleTitleFallback, useFindingClassifier } from "./risk-utils";
import { Badge } from "@speakeasy-api/moonshine";
import { SimpleTooltip } from "@/components/ui/tooltip";
import {
Expand All @@ -24,7 +24,9 @@ export function CategoryLabel({
source?: string;
ruleId?: string;
}) {
const category = getCategoryForFinding(source, ruleId);
const classify = useFindingClassifier();
if (!classify) return null;
const category = classify(source, ruleId);
const meta = category
? RULE_CATEGORY_META[category]
: RULE_CATEGORY_META.custom;
Expand Down
70 changes: 36 additions & 34 deletions client/dashboard/src/pages/security/risk-utils.ts
Original file line number Diff line number Diff line change
@@ -1,52 +1,54 @@
import { useMemo } from "react";
import { useRiskCategories } from "@gram/client/react-query/index.js";
import { DETECTION_RULES, type RuleCategory } from "./policy-data";
import { humanizeRuleId } from "./rule-ids";

const SOURCE_TO_CATEGORY: ReadonlyMap<string, RuleCategory> = new Map<
string,
RuleCategory
>([
["destructive_tool", "destructive_tool"],
["shadow_mcp", "shadow_mcp"],
["prompt_injection", "prompt_injection"],
["cli_destructive", "cli_destructive"],
// Scanner-source fallbacks: when a rule_id doesn't carry its category
// prefix (e.g. gitleaks' bare "generic-api-key"), classify by source so
// we never leak the scanner name to the UI. Keep in sync with the Go
// classifier in server/internal/risk/categories.
["gitleaks", "secrets"],
["presidio", "pii"],
]);

const ruleIdToCategory = new Map<string, RuleCategory>();
const ruleIdToTitle = new Map<string, string>();

for (const [category, rules] of Object.entries(DETECTION_RULES)) {
for (const rules of Object.values(DETECTION_RULES)) {
for (const rule of rules) {
ruleIdToCategory.set(rule.id, category as RuleCategory);
ruleIdToTitle.set(rule.id, rule.title);
}
}

// DETECTION_RULES.id is the canonical rule_id the backend writes to
// risk_results, so lookup maps key by it directly.
const RULE_ID_TO_CATEGORY: ReadonlyMap<string, RuleCategory> = ruleIdToCategory;
const RULE_ID_TO_TITLE: ReadonlyMap<string, string> = ruleIdToTitle;

// Per-rule human-readable titles aren't returned by /rpc/risk.categories
// (the API exposes only the canonical classification: source / rule_ids /
// rule_id_prefix). Keep the static title map for label display.
export function getRuleTitleFallback(ruleId: string | undefined): string {
if (!ruleId) return "-";
return RULE_ID_TO_TITLE.get(ruleId) ?? humanizeRuleId(ruleId);
}

export function getCategoryForFinding(
export type FindingClassifier = (
source?: string,
ruleId?: string,
): RuleCategory | null {
if (ruleId) {
const byRule = RULE_ID_TO_CATEGORY.get(ruleId);
if (byRule) return byRule;
}
if (source) {
return SOURCE_TO_CATEGORY.get(source) ?? null;
}
return null;
) => RuleCategory | null;

// useFindingClassifier returns a (source, rule_id) -> category lookup
// backed by the canonical Go classifier served at /rpc/risk.categories.
// React Query dedupes across the page so calling this per table row is
// cheap. Returns null while the first fetch is in flight; consumers
// should treat that as "category unknown yet" and render nothing.
export function useFindingClassifier(): FindingClassifier | null {
const { data } = useRiskCategories(undefined, undefined, {
staleTime: Number.POSITIVE_INFINITY,
});
return useMemo<FindingClassifier | null>(() => {
const defs = data?.categories;
if (!defs) return null;
return (source, ruleId) => {
for (const def of defs) {
if (def.source && def.source === source) {
return def.key as RuleCategory;
}
if (def.ruleIds.length > 0 && ruleId && def.ruleIds.includes(ruleId)) {
return def.key as RuleCategory;
}
if (def.ruleIdPrefix && ruleId && ruleId.startsWith(def.ruleIdPrefix)) {
return def.key as RuleCategory;
}
}
return null;
};
}, [data]);
}
1 change: 0 additions & 1 deletion server/.golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ linters:
# always-present metadata. Filter is the runtime output struct, same
# pattern.
- '^github\.com/speakeasy-api/gram/server/internal/risk/categories\.Definition$'
- '^github\.com/speakeasy-api/gram/server/internal/risk/categories\.Filter$'
exhaustive:
check:
- switch
Expand Down
54 changes: 7 additions & 47 deletions server/internal/risk/categories/categories.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
// Package categories is the single source of truth for the
// (source, rule_id) risk category mapping shown across the dashboard.
// (source, rule_id) -> risk category mapping shown across the dashboard.
//
// Previously the mapping lived in four duplicated SQL CASE expressions
// (queries.sql) and a separate TypeScript classifier (risk-utils.ts),
// which silently drifted whenever a new rule was added. Now:
//
// - Definitions below is the canonical list.
// - Classify(source, ruleID) is the canonical lookup.
// - SQLFilter(category) produces the parameter set the queries use to
// filter without an in-query CASE.
// - JSONResult() is what /rpc/risk.categories returns so the dashboard
// can consume the same data instead of maintaining its own copy.
// - Classify(source, ruleID) is the canonical lookup used by Go callers.
// - SQLRows() projects Definitions into five parallel arrays that every
// SQL query in internal/risk inlines via an unnest CTE. Same source of
// truth, no in-query CASE.
// - /rpc/risk.categories serves the JSON form so the dashboard reads the
// same definitions instead of maintaining its own copy.
//
// When adding a new rule or category, edit Definitions and the matching
// frontend RULE_CATEGORY_META label, and everything else follows.
Expand Down Expand Up @@ -224,47 +225,6 @@ func matchesRule(def Definition, ruleID string) bool {
return false
}

// Filter is the parameter set a SQL query uses to express "rows belonging
// to this category" without an in-query CASE expression. Pass these
// directly to query params; a query is filtered by category iff at least
// one of Sources / RuleIDs / RulePrefix is non-empty.
type Filter struct {
Sources []string
RuleIDs []string
RulePrefix string
}

// FilterFor returns the SQL filter for one category. Empty Filter means
// "match nothing" (unknown category); callers should distinguish that
// from "no filter applied" upstream.
func FilterFor(cat Category) Filter {
if cat == "" {
return Filter{}
}
if cat == CategoryCustom {
// Custom is the fallback: anything not matched by the explicit
// definitions. Caller composes it as NOT (any other category).
// In practice the dashboard never filters by "custom" since it's
// the "everything else" bucket; emit an empty filter and let it
// be a no-op rather than implementing the negation here.
return Filter{}
}
for _, def := range Definitions {
if def.Category != cat {
continue
}
out := Filter{RulePrefix: def.RulePrefix}
if def.Source != "" {
out.Sources = []string{def.Source}
}
if len(def.RuleIDs) > 0 {
out.RuleIDs = append(out.RuleIDs, def.RuleIDs...)
}
return out
}
return Filter{}
}

// All returns every Definition, with CustomDefinition appended last.
// Used by the JSON endpoint and tests.
func All() []Definition {
Expand Down
43 changes: 0 additions & 43 deletions server/internal/risk/categories/categories_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,49 +66,6 @@ func TestClassify_PinsPriorCASEBehavior(t *testing.T) {
}
}

func TestFilterFor(t *testing.T) {
t.Parallel()

t.Run("source category", func(t *testing.T) {
t.Parallel()
f := FilterFor(CategoryShadowMCP)
require.Equal(t, []string{"shadow_mcp"}, f.Sources)
require.Empty(t, f.RuleIDs)
require.Empty(t, f.RulePrefix)
})

t.Run("prefix category", func(t *testing.T) {
t.Parallel()
f := FilterFor(CategorySecrets)
require.Equal(t, "secret.", f.RulePrefix)
require.Empty(t, f.Sources)
require.Empty(t, f.RuleIDs)
})

t.Run("explicit-list category", func(t *testing.T) {
t.Parallel()
f := FilterFor(CategoryFinancial)
require.Contains(t, f.RuleIDs, "pii.credit_card")
require.Empty(t, f.RulePrefix)
})

t.Run("custom is no-op", func(t *testing.T) {
t.Parallel()
f := FilterFor(CategoryCustom)
require.Empty(t, f.Sources)
require.Empty(t, f.RuleIDs)
require.Empty(t, f.RulePrefix)
})

t.Run("empty is no-op", func(t *testing.T) {
t.Parallel()
f := FilterFor("")
require.Empty(t, f.Sources)
require.Empty(t, f.RuleIDs)
require.Empty(t, f.RulePrefix)
})
}

func TestAll_IncludesCustom(t *testing.T) {
t.Parallel()
all := All()
Expand Down
45 changes: 45 additions & 0 deletions server/internal/risk/categories/sql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package categories

// SQLRows returns the canonical classifier as five parallel arrays sized
// 1-to-1 by row. Pass them straight into any query that classifies findings
// via the standard CTE:
//
// WITH risk_category_lookup(priority, category, source, rule_id, rule_prefix) AS (
// SELECT * FROM unnest(
// @cat_priority::int[],
// @cat_category::text[],
// @cat_source::text[],
// @cat_rule_id::text[],
// @cat_rule_prefix::text[]
// ) AS t(priority, category, source, rule_id, rule_prefix)
// )
//
// The CTE is then joined per row with a small subquery that picks the
// first matching category by priority. See queries.sql.
//
// Returning parallel arrays (rather than a Definition struct slice and
// having every caller flatten it) keeps the boundary with sqlc clean:
// sqlc understands `unnest(text[], …)` natively; it does not understand
// per-row CTE composition.
func SQLRows() (priority []int32, category, source, ruleID, rulePrefix []string) {
for prio, def := range Definitions {
emit := func(src, id, prefix string) {
priority = append(priority, int32(prio)) //nolint:gosec // bounded by Definitions length, far below int32 max
category = append(category, string(def.Category))
source = append(source, src)
ruleID = append(ruleID, id)
rulePrefix = append(rulePrefix, prefix)
}
switch {
case def.Source != "":
emit(def.Source, "", "")
case len(def.RuleIDs) > 0:
for _, id := range def.RuleIDs {
emit("", id, "")
}
case def.RulePrefix != "":
emit("", "", def.RulePrefix)
}
}
return priority, category, source, ruleID, rulePrefix
}
38 changes: 31 additions & 7 deletions server/internal/risk/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -795,10 +795,16 @@ func (s *Service) GetRiskOverview(ctx context.Context, payload *gen.GetRiskOverv
return nil, oops.E(oops.CodeUnexpected, err, "list risk overview top users").Log(ctx, s.logger)
}

catPriority, catCategory, catSource, catRuleID, catRulePrefix := categories.SQLRows()
timeSeriesRows, err := s.repo.ListRiskOverviewTimeSeriesFindings(ctx, repo.ListRiskOverviewTimeSeriesFindingsParams{
ProjectID: *authCtx.ProjectID,
FromTime: window.from,
ToTime: window.to,
CatPriority: catPriority,
CatCategory: catCategory,
CatSource: catSource,
CatRuleID: catRuleID,
CatRulePrefix: catRulePrefix,
ProjectID: *authCtx.ProjectID,
FromTime: window.from,
ToTime: window.to,
})
if err != nil {
return nil, oops.E(oops.CodeUnexpected, err, "list risk overview time series findings").Log(ctx, s.logger)
Expand Down Expand Up @@ -1011,7 +1017,13 @@ func (s *Service) listResultsByPolicy(ctx context.Context, projectID uuid.UUID,

func (s *Service) listResultsByProject(ctx context.Context, projectID uuid.UUID, cursor *riskResultsCursor, pageSize int, totalCount int64, category string, ruleID string, uniqueMatch bool, fromTime, toTime pgtype.Timestamptz) (*gen.ListRiskResultsResult, error) {
cursorCreatedAt, cursorID := cursorToParams(cursor)
catPriority, catCategory, catSource, catRuleID, catRulePrefix := categories.SQLRows()
rows, err := s.repo.ListRiskResultsByProjectFound(ctx, repo.ListRiskResultsByProjectFoundParams{
CatPriority: catPriority,
CatCategory: catCategory,
CatSource: catSource,
CatRuleID: catRuleID,
CatRulePrefix: catRulePrefix,
ProjectID: projectID,
FromTime: fromTime,
ToTime: toTime,
Expand Down Expand Up @@ -1079,7 +1091,13 @@ func (s *Service) GetRiskUserBreakdown(ctx context.Context, payload *gen.GetRisk
}
window := riskOverviewWindowParams(from, to)

catPriority, catCategory, catSource, catRuleID, catRulePrefix := categories.SQLRows()
categoryRows, err := s.repo.ListRiskUserCategoryBreakdown(ctx, repo.ListRiskUserCategoryBreakdownParams{
CatPriority: catPriority,
CatCategory: catCategory,
CatSource: catSource,
CatRuleID: catRuleID,
CatRulePrefix: catRulePrefix,
ProjectID: *authCtx.ProjectID,
FromTime: window.from,
ToTime: window.to,
Expand Down Expand Up @@ -1144,11 +1162,17 @@ func (s *Service) GetRiskRuleBreakdown(ctx context.Context, payload *gen.GetRisk
}

window := riskOverviewWindowParams(from, to)
catPriority, catCategory, catSource, catRuleID, catRulePrefix := categories.SQLRows()
rows, err := s.repo.ListRiskRulesByCategory(ctx, repo.ListRiskRulesByCategoryParams{
ProjectID: *authCtx.ProjectID,
FromTime: window.from,
ToTime: window.to,
Category: payload.Category,
CatPriority: catPriority,
CatCategory: catCategory,
CatSource: catSource,
CatRuleID: catRuleID,
CatRulePrefix: catRulePrefix,
ProjectID: *authCtx.ProjectID,
FromTime: window.from,
ToTime: window.to,
Category: payload.Category,
})
if err != nil {
return nil, oops.E(oops.CodeUnexpected, err, "list rule breakdown").Log(ctx, s.logger)
Expand Down
Loading
Loading