From f34fc56aae0c9a0b9b2bef25e7c4c25c1402f2b4 Mon Sep 17 00:00:00 2001 From: David Alberto Adler Date: Thu, 21 May 2026 20:11:49 +0100 Subject: [PATCH 1/3] feat(risk): dedup category classifier via per-connection temp table Follow-up to #2976. Four duplicated CASE expressions in server/internal/risk/queries.sql shared a 40-line categorisation block that the Go classifier in internal/risk/categories already owned. Replace them with a single subquery that joins risk_results against a session-scoped TEMP TABLE risk_category_lookup populated from the canonical Go classifier on every pool connection. How it works: - internal/risk/categories.BootstrapConnection runs on AfterConnect for every pgxpool connection (cmd/gram/deps.go and the testenv pool helper for risk_results tests). It creates risk_category_lookup if needed, truncates it, and inserts rows derived from Definitions. - All four risk queries now resolve the category via: (SELECT rcl.category FROM risk_category_lookup rcl WHERE source-match OR rule_id-match OR LIKE prefix ORDER BY priority LIMIT 1) -- with COALESCE to 'custom' - internal/risk/categories/sqlc_schema.sql declares the temp table for sqlc's static analysis (not migrated; runtime owns table creation). Result: one source of truth. Adding or moving a rule means editing Definitions and restarting the server; SQL never has to be touched. Perf: ListRiskRulesByCategory goes from ~3.2ms to ~8.7ms on the local dev DB (2,410 findings) because of the SubPlan that scans the 14-row lookup table per result row. Well within budget; see PR description for the EXPLAIN ANALYZE comparison. --- server/cmd/gram/deps.go | 7 + server/database/sqlc.yaml | 8 +- server/internal/risk/categories/bootstrap.go | 108 +++++++++ .../internal/risk/categories/sqlc_schema.sql | 14 ++ server/internal/risk/queries.sql | 228 ++++-------------- server/internal/risk/repo/queries.sql.go | 228 ++++-------------- server/internal/testenv/postgresql.go | 10 +- 7 files changed, 245 insertions(+), 358 deletions(-) create mode 100644 server/internal/risk/categories/bootstrap.go create mode 100644 server/internal/risk/categories/sqlc_schema.sql diff --git a/server/cmd/gram/deps.go b/server/cmd/gram/deps.go index fefebe4c0e..44648351b1 100644 --- a/server/cmd/gram/deps.go +++ b/server/cmd/gram/deps.go @@ -57,6 +57,7 @@ import ( "github.com/speakeasy-api/gram/server/internal/must" "github.com/speakeasy-api/gram/server/internal/o11y" "github.com/speakeasy-api/gram/server/internal/productfeatures" + "github.com/speakeasy-api/gram/server/internal/risk/categories" "github.com/speakeasy-api/gram/server/internal/telemetry" "github.com/speakeasy-api/gram/server/internal/temporal" "github.com/speakeasy-api/gram/server/internal/thirdparty/polar" @@ -217,6 +218,12 @@ func newDBClient(ctx context.Context, logger *slog.Logger, meterProvider metric. o11y.NewPGXLogger(logger.With(attr.SlogComponent("pgx")), consoleLogLevel), ) + // Populate the risk category classifier temp table on every new connection. + // SQL queries that need to bucket findings into categories join against + // risk_category_lookup instead of carrying their own CASE expression; + // internal/risk/categories owns the canonical mapping. + poolcfg.AfterConnect = categories.BootstrapConnection + pool, err := pgxpool.NewWithConfig(ctx, poolcfg) if err != nil { return nil, fmt.Errorf("failed to create pgx pool: %w", err) diff --git a/server/database/sqlc.yaml b/server/database/sqlc.yaml index d5561593ba..c3da7b511e 100644 --- a/server/database/sqlc.yaml +++ b/server/database/sqlc.yaml @@ -536,7 +536,13 @@ sql: sql_package: "pgx/v5" omit_unused_structs: true - - schema: schema.sql + - schema: + - schema.sql + # Declares the risk_category_lookup TEMP TABLE so sqlc's static analysis + # can resolve column references in queries.sql. The actual table is + # created at runtime by internal/risk/categories.BootstrapConnection on + # every pool connection. + - ../internal/risk/categories/sqlc_schema.sql queries: ../internal/risk/queries.sql engine: postgresql gen: diff --git a/server/internal/risk/categories/bootstrap.go b/server/internal/risk/categories/bootstrap.go new file mode 100644 index 0000000000..0f66177176 --- /dev/null +++ b/server/internal/risk/categories/bootstrap.go @@ -0,0 +1,108 @@ +package categories + +import ( + "context" + "fmt" + "strings" + + "github.com/jackc/pgx/v5" +) + +// BootstrapConnection creates a session-scoped TEMP TABLE +// `risk_category_lookup` on the connection and populates it from the +// canonical Definitions slice. +// +// Wire this into pgxpool.Config.AfterConnect so every connection in the pool +// owns one classifier table for its lifetime. Subsequent SQL queries that +// need to classify risk_results join against the temp table by name; the Go +// classifier (Definitions) is the single source of truth, the SQL CASE +// expressions that used to live in queries.sql are gone. +// +// Schema: +// +// priority INT — evaluation order; first match in ORDER BY priority ASC wins +// category TEXT — bucket the finding rolls up to (e.g. 'secrets') +// source TEXT — exact source match (e.g. 'shadow_mcp', 'gitleaks') +// rule_id TEXT — exact rule_id match (e.g. 'pii.credit_card') +// rule_prefix TEXT — LIKE-prefix match (e.g. 'secret.') +// +// Each row populates one of source / rule_id / rule_prefix; the others are +// NULL. Findings that match none of the rows are treated as `custom` via a +// COALESCE in the consuming queries. +func BootstrapConnection(ctx context.Context, conn *pgx.Conn) error { + if _, err := conn.Exec(ctx, ` + CREATE TEMP TABLE IF NOT EXISTS risk_category_lookup ( + priority INTEGER NOT NULL, + category TEXT NOT NULL, + source TEXT, + rule_id TEXT, + rule_prefix TEXT + ) ON COMMIT PRESERVE ROWS; + TRUNCATE risk_category_lookup; + `); err != nil { + return fmt.Errorf("create risk_category_lookup temp table: %w", err) + } + + // Build a VALUES insert from the canonical classifier so SQL never has + // the mapping baked in. + var ( + args []any + rows []string + idx = 1 + prio = 0 + emit = func(category, source, ruleID, prefix string) { + placeholders := fmt.Sprintf("($%d::int, $%d::text, NULLIF($%d::text, ''), NULLIF($%d::text, ''), NULLIF($%d::text, ''))", + idx, idx+1, idx+2, idx+3, idx+4) + rows = append(rows, placeholders) + args = append(args, prio, category, source, ruleID, prefix) + idx += 5 + prio++ + } + ) + + for _, def := range Definitions { + switch { + case def.Source != "": + emit(string(def.Category), def.Source, "", "") + case len(def.RuleIDs) > 0: + for _, ruleID := range def.RuleIDs { + emit(string(def.Category), "", ruleID, "") + } + case def.RulePrefix != "": + emit(string(def.Category), "", "", def.RulePrefix) + } + } + + if len(rows) == 0 { + return nil + } + + insertSQL := fmt.Sprintf( + "INSERT INTO risk_category_lookup (priority, category, source, rule_id, rule_prefix) VALUES %s", + strings.Join(rows, ", "), + ) + if _, err := conn.Exec(ctx, insertSQL, args...); err != nil { + return fmt.Errorf("populate risk_category_lookup: %w", err) + } + return nil +} + +// ClassifySubquery returns the SQL fragment that resolves a risk_results row +// to its category by joining against risk_category_lookup. Embed inside a +// SELECT or use as a LATERAL join source. +// +// Returns 'custom' for unmatched rows so callers don't have to repeat that. +const ClassifySubquery = `( + SELECT COALESCE( + ( + SELECT rcl.category + FROM risk_category_lookup rcl + WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) + OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + ORDER BY rcl.priority ASC + LIMIT 1 + ), + 'custom' + ) +)` diff --git a/server/internal/risk/categories/sqlc_schema.sql b/server/internal/risk/categories/sqlc_schema.sql new file mode 100644 index 0000000000..a30a54b031 --- /dev/null +++ b/server/internal/risk/categories/sqlc_schema.sql @@ -0,0 +1,14 @@ +-- Schema fragment consumed ONLY by sqlc's static analysis so that queries in +-- internal/risk/queries.sql can reference the risk_category_lookup TEMP TABLE +-- without a migration. +-- +-- The actual table is created at runtime by BootstrapConnection (bootstrap.go) +-- on every pool connection, populated from the canonical Go classifier. +-- DO NOT migrate this file; it is intentionally not in server/migrations. +CREATE TABLE IF NOT EXISTS risk_category_lookup ( + priority INTEGER NOT NULL, + category TEXT NOT NULL, + source TEXT, + rule_id TEXT, + rule_prefix TEXT +); diff --git a/server/internal/risk/queries.sql b/server/internal/risk/queries.sql index 10277c9ee1..759b44eb11 100644 --- a/server/internal/risk/queries.sql +++ b/server/internal/risk/queries.sql @@ -163,50 +163,20 @@ LIMIT @row_limit; -- queries. WITH user_findings AS ( SELECT - CASE - WHEN rr.source IN ('shadow_mcp', 'destructive_tool', 'cli_destructive', 'prompt_injection') THEN rr.source - WHEN rr.rule_id LIKE 'secret.%' THEN 'secrets' - WHEN rr.rule_id IN ('pii.credit_card', 'pii.iban_code', 'pii.us_bank_number', 'pii.crypto') THEN 'financial' - WHEN rr.rule_id IN ( - 'pii.us_ssn' - , 'pii.us_passport' - , 'pii.us_driver_license' - , 'pii.us_itin' - , 'pii.uk_nhs' - , 'pii.uk_nino' - , 'pii.uk_passport' - , 'pii.es_nif' - , 'pii.it_fiscal_code' - , 'pii.au_tfn' - , 'pii.in_pan' - , 'pii.in_aadhaar' - , 'pii.sg_nric_fin' - ) THEN 'government_ids' - WHEN rr.rule_id IN ( - 'pii.medical_license' - , 'pii.us_mbi' - , 'pii.us_npi' - , 'pii.medical_disease_disorder' - , 'pii.medical_medication' - , 'pii.medical_therapeutic_procedure' - , 'pii.medical_clinical_event' - , 'pii.medical_biological_attribute' - , 'pii.medical_family_history' - ) THEN 'healthcare' - WHEN rr.rule_id IN ( - 'pii.harmful_content_request' - , 'pii.policy_violation' - , 'pii.unauthorized_action' - , 'pii.topic_boundary_violation' - ) THEN 'off_policy' - WHEN rr.rule_id LIKE 'pii.%' THEN 'pii' - -- Scanner-source fallbacks: keep these LAST so any prefixed - -- rule_id wins. Stay in sync with the Go classifier in - -- internal/risk/categories. - WHEN rr.source = 'gitleaks' THEN 'secrets' - WHEN rr.source = 'presidio' THEN 'pii' - ELSE 'custom' - END AS category + -- Classify via the risk_category_lookup TEMP TABLE populated from the + -- Go classifier in internal/risk/categories. See bootstrap.go and the + -- pgxpool AfterConnect hook in cmd/gram/deps.go. + COALESCE( + ( + SELECT rcl.category FROM risk_category_lookup rcl + WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) + OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + ORDER BY rcl.priority ASC + LIMIT 1 + ), + 'custom' + )::text AS category FROM risk_results rr JOIN chat_messages cm ON cm.id = rr.chat_message_id LEFT JOIN chats c ON c.id = cm.chat_id AND c.deleted IS FALSE @@ -246,50 +216,20 @@ WITH categorized AS ( SELECT COALESCE(rr.rule_id, '')::TEXT AS rule_id, rr.source, - CASE - WHEN rr.source IN ('shadow_mcp', 'destructive_tool', 'cli_destructive', 'prompt_injection') THEN rr.source - WHEN rr.rule_id LIKE 'secret.%' THEN 'secrets' - WHEN rr.rule_id IN ('pii.credit_card', 'pii.iban_code', 'pii.us_bank_number', 'pii.crypto') THEN 'financial' - WHEN rr.rule_id IN ( - 'pii.us_ssn' - , 'pii.us_passport' - , 'pii.us_driver_license' - , 'pii.us_itin' - , 'pii.uk_nhs' - , 'pii.uk_nino' - , 'pii.uk_passport' - , 'pii.es_nif' - , 'pii.it_fiscal_code' - , 'pii.au_tfn' - , 'pii.in_pan' - , 'pii.in_aadhaar' - , 'pii.sg_nric_fin' - ) THEN 'government_ids' - WHEN rr.rule_id IN ( - 'pii.medical_license' - , 'pii.us_mbi' - , 'pii.us_npi' - , 'pii.medical_disease_disorder' - , 'pii.medical_medication' - , 'pii.medical_therapeutic_procedure' - , 'pii.medical_clinical_event' - , 'pii.medical_biological_attribute' - , 'pii.medical_family_history' - ) THEN 'healthcare' - WHEN rr.rule_id IN ( - 'pii.harmful_content_request' - , 'pii.policy_violation' - , 'pii.unauthorized_action' - , 'pii.topic_boundary_violation' - ) THEN 'off_policy' - WHEN rr.rule_id LIKE 'pii.%' THEN 'pii' - -- Scanner-source fallbacks: keep these LAST so any prefixed - -- rule_id wins. Stay in sync with the Go classifier in - -- internal/risk/categories. - WHEN rr.source = 'gitleaks' THEN 'secrets' - WHEN rr.source = 'presidio' THEN 'pii' - ELSE 'custom' - END AS category + -- Classify via the risk_category_lookup TEMP TABLE populated from the + -- Go classifier in internal/risk/categories. See bootstrap.go and the + -- pgxpool AfterConnect hook in cmd/gram/deps.go. + COALESCE( + ( + SELECT rcl.category FROM risk_category_lookup rcl + WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) + OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + ORDER BY rcl.priority ASC + LIMIT 1 + ), + 'custom' + )::text AS category FROM risk_results rr WHERE rr.project_id = @project_id AND rr.found IS TRUE @@ -338,50 +278,17 @@ WITH buckets AS ( categorized AS ( SELECT date_trunc('hour', rr.created_at)::timestamptz AS bucket_start - , CASE - WHEN rr.source IN ('shadow_mcp', 'destructive_tool', 'cli_destructive', 'prompt_injection') THEN rr.source - WHEN rr.rule_id LIKE 'secret.%' THEN 'secrets' - WHEN rr.rule_id IN ('pii.credit_card', 'pii.iban_code', 'pii.us_bank_number', 'pii.crypto') THEN 'financial' - WHEN rr.rule_id IN ( - 'pii.us_ssn' - , 'pii.us_passport' - , 'pii.us_driver_license' - , 'pii.us_itin' - , 'pii.uk_nhs' - , 'pii.uk_nino' - , 'pii.uk_passport' - , 'pii.es_nif' - , 'pii.it_fiscal_code' - , 'pii.au_tfn' - , 'pii.in_pan' - , 'pii.in_aadhaar' - , 'pii.sg_nric_fin' - ) THEN 'government_ids' - WHEN rr.rule_id IN ( - 'pii.medical_license' - , 'pii.us_mbi' - , 'pii.us_npi' - , 'pii.medical_disease_disorder' - , 'pii.medical_medication' - , 'pii.medical_therapeutic_procedure' - , 'pii.medical_clinical_event' - , 'pii.medical_biological_attribute' - , 'pii.medical_family_history' - ) THEN 'healthcare' - WHEN rr.rule_id IN ( - 'pii.harmful_content_request' - , 'pii.policy_violation' - , 'pii.unauthorized_action' - , 'pii.topic_boundary_violation' - ) THEN 'off_policy' - WHEN rr.rule_id LIKE 'pii.%' THEN 'pii' - -- Scanner-source fallbacks: keep these LAST so any prefixed - -- rule_id wins. Stay in sync with the Go classifier in - -- internal/risk/categories. - WHEN rr.source = 'gitleaks' THEN 'secrets' - WHEN rr.source = 'presidio' THEN 'pii' - ELSE 'custom' - END AS category + , COALESCE( + ( + SELECT rcl.category FROM risk_category_lookup rcl + WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) + OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + ORDER BY rcl.priority ASC + LIMIT 1 + ), + 'custom' + )::text AS category FROM risk_results rr WHERE rr.project_id = sqlc.arg(project_id)::uuid AND rr.found IS TRUE @@ -527,52 +434,17 @@ FROM ( AND (sqlc.narg(from_time)::timestamptz IS NULL OR cm.created_at >= sqlc.narg(from_time)::timestamptz) AND (sqlc.narg(to_time)::timestamptz IS NULL OR cm.created_at < sqlc.narg(to_time)::timestamptz) AND (@rule_id::text = '' OR rr.rule_id ILIKE '%' || @rule_id::text || '%') - AND (@category::text = '' OR ( - CASE - WHEN rr.source IN ('shadow_mcp', 'destructive_tool', 'cli_destructive', 'prompt_injection') THEN rr.source - WHEN rr.rule_id LIKE 'secret.%' THEN 'secrets' - WHEN rr.rule_id IN ('pii.credit_card', 'pii.iban_code', 'pii.us_bank_number', 'pii.crypto') THEN 'financial' - WHEN rr.rule_id IN ( - 'pii.us_ssn' - , 'pii.us_passport' - , 'pii.us_driver_license' - , 'pii.us_itin' - , 'pii.uk_nhs' - , 'pii.uk_nino' - , 'pii.uk_passport' - , 'pii.es_nif' - , 'pii.it_fiscal_code' - , 'pii.au_tfn' - , 'pii.in_pan' - , 'pii.in_aadhaar' - , 'pii.sg_nric_fin' - ) THEN 'government_ids' - WHEN rr.rule_id IN ( - 'pii.medical_license' - , 'pii.us_mbi' - , 'pii.us_npi' - , 'pii.medical_disease_disorder' - , 'pii.medical_medication' - , 'pii.medical_therapeutic_procedure' - , 'pii.medical_clinical_event' - , 'pii.medical_biological_attribute' - , 'pii.medical_family_history' - ) THEN 'healthcare' - WHEN rr.rule_id IN ( - 'pii.harmful_content_request' - , 'pii.policy_violation' - , 'pii.unauthorized_action' - , 'pii.topic_boundary_violation' - ) THEN 'off_policy' - WHEN rr.rule_id LIKE 'pii.%' THEN 'pii' - -- Scanner-source fallbacks: keep these LAST so any prefixed - -- rule_id wins. Stay in sync with the Go classifier in - -- internal/risk/categories. - WHEN rr.source = 'gitleaks' THEN 'secrets' - WHEN rr.source = 'presidio' THEN 'pii' - ELSE 'custom' - END - ) = @category::text) + AND (@category::text = '' OR COALESCE( + ( + SELECT rcl.category FROM risk_category_lookup rcl + WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) + OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + ORDER BY rcl.priority ASC + LIMIT 1 + ), + 'custom' + ) = @category::text) ) sub WHERE sub.dedup_rank = 1 AND ( diff --git a/server/internal/risk/repo/queries.sql.go b/server/internal/risk/repo/queries.sql.go index 3c57816a13..1b656f1467 100644 --- a/server/internal/risk/repo/queries.sql.go +++ b/server/internal/risk/repo/queries.sql.go @@ -651,50 +651,17 @@ WITH buckets AS ( categorized AS ( SELECT date_trunc('hour', rr.created_at)::timestamptz AS bucket_start - , CASE - WHEN rr.source IN ('shadow_mcp', 'destructive_tool', 'cli_destructive', 'prompt_injection') THEN rr.source - WHEN rr.rule_id LIKE 'secret.%' THEN 'secrets' - WHEN rr.rule_id IN ('pii.credit_card', 'pii.iban_code', 'pii.us_bank_number', 'pii.crypto') THEN 'financial' - WHEN rr.rule_id IN ( - 'pii.us_ssn' - , 'pii.us_passport' - , 'pii.us_driver_license' - , 'pii.us_itin' - , 'pii.uk_nhs' - , 'pii.uk_nino' - , 'pii.uk_passport' - , 'pii.es_nif' - , 'pii.it_fiscal_code' - , 'pii.au_tfn' - , 'pii.in_pan' - , 'pii.in_aadhaar' - , 'pii.sg_nric_fin' - ) THEN 'government_ids' - WHEN rr.rule_id IN ( - 'pii.medical_license' - , 'pii.us_mbi' - , 'pii.us_npi' - , 'pii.medical_disease_disorder' - , 'pii.medical_medication' - , 'pii.medical_therapeutic_procedure' - , 'pii.medical_clinical_event' - , 'pii.medical_biological_attribute' - , 'pii.medical_family_history' - ) THEN 'healthcare' - WHEN rr.rule_id IN ( - 'pii.harmful_content_request' - , 'pii.policy_violation' - , 'pii.unauthorized_action' - , 'pii.topic_boundary_violation' - ) THEN 'off_policy' - WHEN rr.rule_id LIKE 'pii.%' THEN 'pii' - -- Scanner-source fallbacks: keep these LAST so any prefixed - -- rule_id wins. Stay in sync with the Go classifier in - -- internal/risk/categories. - WHEN rr.source = 'gitleaks' THEN 'secrets' - WHEN rr.source = 'presidio' THEN 'pii' - ELSE 'custom' - END AS category + , COALESCE( + ( + SELECT rcl.category FROM risk_category_lookup rcl + WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) + OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + ORDER BY rcl.priority ASC + LIMIT 1 + ), + 'custom' + )::text AS category FROM risk_results rr WHERE rr.project_id = $3::uuid AND rr.found IS TRUE @@ -1143,52 +1110,17 @@ FROM ( AND ($3::timestamptz IS NULL OR cm.created_at >= $3::timestamptz) AND ($4::timestamptz IS NULL OR cm.created_at < $4::timestamptz) AND ($5::text = '' OR rr.rule_id ILIKE '%' || $5::text || '%') - AND ($6::text = '' OR ( - CASE - WHEN rr.source IN ('shadow_mcp', 'destructive_tool', 'cli_destructive', 'prompt_injection') THEN rr.source - WHEN rr.rule_id LIKE 'secret.%' THEN 'secrets' - WHEN rr.rule_id IN ('pii.credit_card', 'pii.iban_code', 'pii.us_bank_number', 'pii.crypto') THEN 'financial' - WHEN rr.rule_id IN ( - 'pii.us_ssn' - , 'pii.us_passport' - , 'pii.us_driver_license' - , 'pii.us_itin' - , 'pii.uk_nhs' - , 'pii.uk_nino' - , 'pii.uk_passport' - , 'pii.es_nif' - , 'pii.it_fiscal_code' - , 'pii.au_tfn' - , 'pii.in_pan' - , 'pii.in_aadhaar' - , 'pii.sg_nric_fin' - ) THEN 'government_ids' - WHEN rr.rule_id IN ( - 'pii.medical_license' - , 'pii.us_mbi' - , 'pii.us_npi' - , 'pii.medical_disease_disorder' - , 'pii.medical_medication' - , 'pii.medical_therapeutic_procedure' - , 'pii.medical_clinical_event' - , 'pii.medical_biological_attribute' - , 'pii.medical_family_history' - ) THEN 'healthcare' - WHEN rr.rule_id IN ( - 'pii.harmful_content_request' - , 'pii.policy_violation' - , 'pii.unauthorized_action' - , 'pii.topic_boundary_violation' - ) THEN 'off_policy' - WHEN rr.rule_id LIKE 'pii.%' THEN 'pii' - -- Scanner-source fallbacks: keep these LAST so any prefixed - -- rule_id wins. Stay in sync with the Go classifier in - -- internal/risk/categories. - WHEN rr.source = 'gitleaks' THEN 'secrets' - WHEN rr.source = 'presidio' THEN 'pii' - ELSE 'custom' - END - ) = $6::text) + AND ($6::text = '' OR COALESCE( + ( + SELECT rcl.category FROM risk_category_lookup rcl + WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) + OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + ORDER BY rcl.priority ASC + LIMIT 1 + ), + 'custom' + ) = $6::text) ) sub WHERE sub.dedup_rank = 1 AND ( @@ -1365,50 +1297,20 @@ WITH categorized AS ( SELECT COALESCE(rr.rule_id, '')::TEXT AS rule_id, rr.source, - CASE - WHEN rr.source IN ('shadow_mcp', 'destructive_tool', 'cli_destructive', 'prompt_injection') THEN rr.source - WHEN rr.rule_id LIKE 'secret.%' THEN 'secrets' - WHEN rr.rule_id IN ('pii.credit_card', 'pii.iban_code', 'pii.us_bank_number', 'pii.crypto') THEN 'financial' - WHEN rr.rule_id IN ( - 'pii.us_ssn' - , 'pii.us_passport' - , 'pii.us_driver_license' - , 'pii.us_itin' - , 'pii.uk_nhs' - , 'pii.uk_nino' - , 'pii.uk_passport' - , 'pii.es_nif' - , 'pii.it_fiscal_code' - , 'pii.au_tfn' - , 'pii.in_pan' - , 'pii.in_aadhaar' - , 'pii.sg_nric_fin' - ) THEN 'government_ids' - WHEN rr.rule_id IN ( - 'pii.medical_license' - , 'pii.us_mbi' - , 'pii.us_npi' - , 'pii.medical_disease_disorder' - , 'pii.medical_medication' - , 'pii.medical_therapeutic_procedure' - , 'pii.medical_clinical_event' - , 'pii.medical_biological_attribute' - , 'pii.medical_family_history' - ) THEN 'healthcare' - WHEN rr.rule_id IN ( - 'pii.harmful_content_request' - , 'pii.policy_violation' - , 'pii.unauthorized_action' - , 'pii.topic_boundary_violation' - ) THEN 'off_policy' - WHEN rr.rule_id LIKE 'pii.%' THEN 'pii' - -- Scanner-source fallbacks: keep these LAST so any prefixed - -- rule_id wins. Stay in sync with the Go classifier in - -- internal/risk/categories. - WHEN rr.source = 'gitleaks' THEN 'secrets' - WHEN rr.source = 'presidio' THEN 'pii' - ELSE 'custom' - END AS category + -- Classify via the risk_category_lookup TEMP TABLE populated from the + -- Go classifier in internal/risk/categories. See bootstrap.go and the + -- pgxpool AfterConnect hook in cmd/gram/deps.go. + COALESCE( + ( + SELECT rcl.category FROM risk_category_lookup rcl + WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) + OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + ORDER BY rcl.priority ASC + LIMIT 1 + ), + 'custom' + )::text AS category FROM risk_results rr WHERE rr.project_id = $2 AND rr.found IS TRUE @@ -1466,50 +1368,20 @@ func (q *Queries) ListRiskRulesByCategory(ctx context.Context, arg ListRiskRules const listRiskUserCategoryBreakdown = `-- name: ListRiskUserCategoryBreakdown :many WITH user_findings AS ( SELECT - CASE - WHEN rr.source IN ('shadow_mcp', 'destructive_tool', 'cli_destructive', 'prompt_injection') THEN rr.source - WHEN rr.rule_id LIKE 'secret.%' THEN 'secrets' - WHEN rr.rule_id IN ('pii.credit_card', 'pii.iban_code', 'pii.us_bank_number', 'pii.crypto') THEN 'financial' - WHEN rr.rule_id IN ( - 'pii.us_ssn' - , 'pii.us_passport' - , 'pii.us_driver_license' - , 'pii.us_itin' - , 'pii.uk_nhs' - , 'pii.uk_nino' - , 'pii.uk_passport' - , 'pii.es_nif' - , 'pii.it_fiscal_code' - , 'pii.au_tfn' - , 'pii.in_pan' - , 'pii.in_aadhaar' - , 'pii.sg_nric_fin' - ) THEN 'government_ids' - WHEN rr.rule_id IN ( - 'pii.medical_license' - , 'pii.us_mbi' - , 'pii.us_npi' - , 'pii.medical_disease_disorder' - , 'pii.medical_medication' - , 'pii.medical_therapeutic_procedure' - , 'pii.medical_clinical_event' - , 'pii.medical_biological_attribute' - , 'pii.medical_family_history' - ) THEN 'healthcare' - WHEN rr.rule_id IN ( - 'pii.harmful_content_request' - , 'pii.policy_violation' - , 'pii.unauthorized_action' - , 'pii.topic_boundary_violation' - ) THEN 'off_policy' - WHEN rr.rule_id LIKE 'pii.%' THEN 'pii' - -- Scanner-source fallbacks: keep these LAST so any prefixed - -- rule_id wins. Stay in sync with the Go classifier in - -- internal/risk/categories. - WHEN rr.source = 'gitleaks' THEN 'secrets' - WHEN rr.source = 'presidio' THEN 'pii' - ELSE 'custom' - END AS category + -- Classify via the risk_category_lookup TEMP TABLE populated from the + -- Go classifier in internal/risk/categories. See bootstrap.go and the + -- pgxpool AfterConnect hook in cmd/gram/deps.go. + COALESCE( + ( + SELECT rcl.category FROM risk_category_lookup rcl + WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) + OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + ORDER BY rcl.priority ASC + LIMIT 1 + ), + 'custom' + )::text AS category FROM risk_results rr JOIN chat_messages cm ON cm.id = rr.chat_message_id LEFT JOIN chats c ON c.id = cm.chat_id AND c.deleted IS FALSE diff --git a/server/internal/testenv/postgresql.go b/server/internal/testenv/postgresql.go index d1debf129a..91311b48f0 100644 --- a/server/internal/testenv/postgresql.go +++ b/server/internal/testenv/postgresql.go @@ -12,6 +12,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/speakeasy-api/gram/server/internal/o11y" + "github.com/speakeasy-api/gram/server/internal/risk/categories" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/postgres" ) @@ -94,7 +95,14 @@ func newPostgresCloneFunc(container *postgres.PostgresContainer) PostgresDBClone } cloneuri := strings.Replace(uri, "gotestdb", clonename, 1) - pool, err := pgxpool.New(ctx, cloneuri) + poolcfg, err := pgxpool.ParseConfig(cloneuri) + if err != nil { + return nil, fmt.Errorf("parse clone pool config: %w", err) + } + // Mirror the production AfterConnect hook so queries that rely on + // the risk_category_lookup TEMP TABLE work the same way under tests. + poolcfg.AfterConnect = categories.BootstrapConnection + pool, err := pgxpool.NewWithConfig(ctx, poolcfg) if err != nil { return nil, fmt.Errorf("create pgx pool: %w", err) } From 3310d89183398a4413d496180043186d8c167fbf Mon Sep 17 00:00:00 2001 From: David Alberto Adler Date: Thu, 21 May 2026 22:46:32 +0100 Subject: [PATCH 2/3] refactor(risk): pass categories as unnest arrays instead of a temp table The previous commit introduced a per-connection TEMP TABLE (risk_category_lookup) populated by a pgxpool AfterConnect hook so that the four risk-overview queries could share one category classifier without each carrying its own CASE expression. That worked but bound a runtime concern (connection bootstrap) to a build concern (sqlc static analysis), and it required a sqlc_schema.sql stub plus matching bootstrap code in cmd/gram/deps.go and internal/testenv. Switch to passing the classifier as five parallel arrays (priority, category, source, rule_id, rule_prefix) per query call. Each query builds the same risk_category_lookup CTE inline via parallel unnest() calls. No connection hook, no schema stub, no AfterConnect hook in tests; the Go classifier in internal/risk/categories remains the single source of truth. Drops the unused Filter type and FilterFor() helper that were leftovers from the previous per-query CASE design. --- server/.golangci.yaml | 1 - server/cmd/gram/deps.go | 7 - server/database/sqlc.yaml | 8 +- server/internal/risk/categories/bootstrap.go | 108 ----------- server/internal/risk/categories/categories.go | 54 +----- .../risk/categories/categories_test.go | 43 ----- server/internal/risk/categories/sql.go | 45 +++++ .../internal/risk/categories/sqlc_schema.sql | 14 -- server/internal/risk/impl.go | 38 +++- server/internal/risk/queries.sql | 83 +++++---- server/internal/risk/repo/queries.sql.go | 167 ++++++++++++------ server/internal/testenv/postgresql.go | 10 +- 12 files changed, 255 insertions(+), 323 deletions(-) delete mode 100644 server/internal/risk/categories/bootstrap.go create mode 100644 server/internal/risk/categories/sql.go delete mode 100644 server/internal/risk/categories/sqlc_schema.sql diff --git a/server/.golangci.yaml b/server/.golangci.yaml index 2b52ab7a48..cd8a143100 100644 --- a/server/.golangci.yaml +++ b/server/.golangci.yaml @@ -269,7 +269,6 @@ linters: # always-present metadata. Filter is the runtime output struct, same # pattern. - '^github\.com/speakeasy-api/gram/server/internal/risk/categories\.Definition$' - - '^github\.com/speakeasy-api/gram/server/internal/risk/categories\.Filter$' exhaustive: check: - switch diff --git a/server/cmd/gram/deps.go b/server/cmd/gram/deps.go index 44648351b1..fefebe4c0e 100644 --- a/server/cmd/gram/deps.go +++ b/server/cmd/gram/deps.go @@ -57,7 +57,6 @@ import ( "github.com/speakeasy-api/gram/server/internal/must" "github.com/speakeasy-api/gram/server/internal/o11y" "github.com/speakeasy-api/gram/server/internal/productfeatures" - "github.com/speakeasy-api/gram/server/internal/risk/categories" "github.com/speakeasy-api/gram/server/internal/telemetry" "github.com/speakeasy-api/gram/server/internal/temporal" "github.com/speakeasy-api/gram/server/internal/thirdparty/polar" @@ -218,12 +217,6 @@ func newDBClient(ctx context.Context, logger *slog.Logger, meterProvider metric. o11y.NewPGXLogger(logger.With(attr.SlogComponent("pgx")), consoleLogLevel), ) - // Populate the risk category classifier temp table on every new connection. - // SQL queries that need to bucket findings into categories join against - // risk_category_lookup instead of carrying their own CASE expression; - // internal/risk/categories owns the canonical mapping. - poolcfg.AfterConnect = categories.BootstrapConnection - pool, err := pgxpool.NewWithConfig(ctx, poolcfg) if err != nil { return nil, fmt.Errorf("failed to create pgx pool: %w", err) diff --git a/server/database/sqlc.yaml b/server/database/sqlc.yaml index c3da7b511e..d5561593ba 100644 --- a/server/database/sqlc.yaml +++ b/server/database/sqlc.yaml @@ -536,13 +536,7 @@ sql: sql_package: "pgx/v5" omit_unused_structs: true - - schema: - - schema.sql - # Declares the risk_category_lookup TEMP TABLE so sqlc's static analysis - # can resolve column references in queries.sql. The actual table is - # created at runtime by internal/risk/categories.BootstrapConnection on - # every pool connection. - - ../internal/risk/categories/sqlc_schema.sql + - schema: schema.sql queries: ../internal/risk/queries.sql engine: postgresql gen: diff --git a/server/internal/risk/categories/bootstrap.go b/server/internal/risk/categories/bootstrap.go deleted file mode 100644 index 0f66177176..0000000000 --- a/server/internal/risk/categories/bootstrap.go +++ /dev/null @@ -1,108 +0,0 @@ -package categories - -import ( - "context" - "fmt" - "strings" - - "github.com/jackc/pgx/v5" -) - -// BootstrapConnection creates a session-scoped TEMP TABLE -// `risk_category_lookup` on the connection and populates it from the -// canonical Definitions slice. -// -// Wire this into pgxpool.Config.AfterConnect so every connection in the pool -// owns one classifier table for its lifetime. Subsequent SQL queries that -// need to classify risk_results join against the temp table by name; the Go -// classifier (Definitions) is the single source of truth, the SQL CASE -// expressions that used to live in queries.sql are gone. -// -// Schema: -// -// priority INT — evaluation order; first match in ORDER BY priority ASC wins -// category TEXT — bucket the finding rolls up to (e.g. 'secrets') -// source TEXT — exact source match (e.g. 'shadow_mcp', 'gitleaks') -// rule_id TEXT — exact rule_id match (e.g. 'pii.credit_card') -// rule_prefix TEXT — LIKE-prefix match (e.g. 'secret.') -// -// Each row populates one of source / rule_id / rule_prefix; the others are -// NULL. Findings that match none of the rows are treated as `custom` via a -// COALESCE in the consuming queries. -func BootstrapConnection(ctx context.Context, conn *pgx.Conn) error { - if _, err := conn.Exec(ctx, ` - CREATE TEMP TABLE IF NOT EXISTS risk_category_lookup ( - priority INTEGER NOT NULL, - category TEXT NOT NULL, - source TEXT, - rule_id TEXT, - rule_prefix TEXT - ) ON COMMIT PRESERVE ROWS; - TRUNCATE risk_category_lookup; - `); err != nil { - return fmt.Errorf("create risk_category_lookup temp table: %w", err) - } - - // Build a VALUES insert from the canonical classifier so SQL never has - // the mapping baked in. - var ( - args []any - rows []string - idx = 1 - prio = 0 - emit = func(category, source, ruleID, prefix string) { - placeholders := fmt.Sprintf("($%d::int, $%d::text, NULLIF($%d::text, ''), NULLIF($%d::text, ''), NULLIF($%d::text, ''))", - idx, idx+1, idx+2, idx+3, idx+4) - rows = append(rows, placeholders) - args = append(args, prio, category, source, ruleID, prefix) - idx += 5 - prio++ - } - ) - - for _, def := range Definitions { - switch { - case def.Source != "": - emit(string(def.Category), def.Source, "", "") - case len(def.RuleIDs) > 0: - for _, ruleID := range def.RuleIDs { - emit(string(def.Category), "", ruleID, "") - } - case def.RulePrefix != "": - emit(string(def.Category), "", "", def.RulePrefix) - } - } - - if len(rows) == 0 { - return nil - } - - insertSQL := fmt.Sprintf( - "INSERT INTO risk_category_lookup (priority, category, source, rule_id, rule_prefix) VALUES %s", - strings.Join(rows, ", "), - ) - if _, err := conn.Exec(ctx, insertSQL, args...); err != nil { - return fmt.Errorf("populate risk_category_lookup: %w", err) - } - return nil -} - -// ClassifySubquery returns the SQL fragment that resolves a risk_results row -// to its category by joining against risk_category_lookup. Embed inside a -// SELECT or use as a LATERAL join source. -// -// Returns 'custom' for unmatched rows so callers don't have to repeat that. -const ClassifySubquery = `( - SELECT COALESCE( - ( - SELECT rcl.category - FROM risk_category_lookup rcl - WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) - OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) - OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') - ORDER BY rcl.priority ASC - LIMIT 1 - ), - 'custom' - ) -)` diff --git a/server/internal/risk/categories/categories.go b/server/internal/risk/categories/categories.go index a2b6f2393c..daf25c9522 100644 --- a/server/internal/risk/categories/categories.go +++ b/server/internal/risk/categories/categories.go @@ -1,16 +1,17 @@ // Package categories is the single source of truth for the -// (source, rule_id) → risk category mapping shown across the dashboard. +// (source, rule_id) -> risk category mapping shown across the dashboard. // // Previously the mapping lived in four duplicated SQL CASE expressions // (queries.sql) and a separate TypeScript classifier (risk-utils.ts), // which silently drifted whenever a new rule was added. Now: // // - Definitions below is the canonical list. -// - Classify(source, ruleID) is the canonical lookup. -// - SQLFilter(category) produces the parameter set the queries use to -// filter without an in-query CASE. -// - JSONResult() is what /rpc/risk.categories returns so the dashboard -// can consume the same data instead of maintaining its own copy. +// - Classify(source, ruleID) is the canonical lookup used by Go callers. +// - SQLRows() projects Definitions into five parallel arrays that every +// SQL query in internal/risk inlines via an unnest CTE. Same source of +// truth, no in-query CASE. +// - /rpc/risk.categories serves the JSON form so the dashboard reads the +// same definitions instead of maintaining its own copy. // // When adding a new rule or category, edit Definitions and the matching // frontend RULE_CATEGORY_META label, and everything else follows. @@ -224,47 +225,6 @@ func matchesRule(def Definition, ruleID string) bool { return false } -// Filter is the parameter set a SQL query uses to express "rows belonging -// to this category" without an in-query CASE expression. Pass these -// directly to query params; a query is filtered by category iff at least -// one of Sources / RuleIDs / RulePrefix is non-empty. -type Filter struct { - Sources []string - RuleIDs []string - RulePrefix string -} - -// FilterFor returns the SQL filter for one category. Empty Filter means -// "match nothing" (unknown category); callers should distinguish that -// from "no filter applied" upstream. -func FilterFor(cat Category) Filter { - if cat == "" { - return Filter{} - } - if cat == CategoryCustom { - // Custom is the fallback: anything not matched by the explicit - // definitions. Caller composes it as NOT (any other category). - // In practice the dashboard never filters by "custom" since it's - // the "everything else" bucket; emit an empty filter and let it - // be a no-op rather than implementing the negation here. - return Filter{} - } - for _, def := range Definitions { - if def.Category != cat { - continue - } - out := Filter{RulePrefix: def.RulePrefix} - if def.Source != "" { - out.Sources = []string{def.Source} - } - if len(def.RuleIDs) > 0 { - out.RuleIDs = append(out.RuleIDs, def.RuleIDs...) - } - return out - } - return Filter{} -} - // All returns every Definition, with CustomDefinition appended last. // Used by the JSON endpoint and tests. func All() []Definition { diff --git a/server/internal/risk/categories/categories_test.go b/server/internal/risk/categories/categories_test.go index 4a433555d2..45e60af340 100644 --- a/server/internal/risk/categories/categories_test.go +++ b/server/internal/risk/categories/categories_test.go @@ -66,49 +66,6 @@ func TestClassify_PinsPriorCASEBehavior(t *testing.T) { } } -func TestFilterFor(t *testing.T) { - t.Parallel() - - t.Run("source category", func(t *testing.T) { - t.Parallel() - f := FilterFor(CategoryShadowMCP) - require.Equal(t, []string{"shadow_mcp"}, f.Sources) - require.Empty(t, f.RuleIDs) - require.Empty(t, f.RulePrefix) - }) - - t.Run("prefix category", func(t *testing.T) { - t.Parallel() - f := FilterFor(CategorySecrets) - require.Equal(t, "secret.", f.RulePrefix) - require.Empty(t, f.Sources) - require.Empty(t, f.RuleIDs) - }) - - t.Run("explicit-list category", func(t *testing.T) { - t.Parallel() - f := FilterFor(CategoryFinancial) - require.Contains(t, f.RuleIDs, "pii.credit_card") - require.Empty(t, f.RulePrefix) - }) - - t.Run("custom is no-op", func(t *testing.T) { - t.Parallel() - f := FilterFor(CategoryCustom) - require.Empty(t, f.Sources) - require.Empty(t, f.RuleIDs) - require.Empty(t, f.RulePrefix) - }) - - t.Run("empty is no-op", func(t *testing.T) { - t.Parallel() - f := FilterFor("") - require.Empty(t, f.Sources) - require.Empty(t, f.RuleIDs) - require.Empty(t, f.RulePrefix) - }) -} - func TestAll_IncludesCustom(t *testing.T) { t.Parallel() all := All() diff --git a/server/internal/risk/categories/sql.go b/server/internal/risk/categories/sql.go new file mode 100644 index 0000000000..0661391ff7 --- /dev/null +++ b/server/internal/risk/categories/sql.go @@ -0,0 +1,45 @@ +package categories + +// SQLRows returns the canonical classifier as five parallel arrays sized +// 1-to-1 by row. Pass them straight into any query that classifies findings +// via the standard CTE: +// +// WITH risk_category_lookup(priority, category, source, rule_id, rule_prefix) AS ( +// SELECT * FROM unnest( +// @cat_priority::int[], +// @cat_category::text[], +// @cat_source::text[], +// @cat_rule_id::text[], +// @cat_rule_prefix::text[] +// ) AS t(priority, category, source, rule_id, rule_prefix) +// ) +// +// The CTE is then joined per row with a small subquery that picks the +// first matching category by priority. See queries.sql. +// +// Returning parallel arrays (rather than a Definition struct slice and +// having every caller flatten it) keeps the boundary with sqlc clean: +// sqlc understands `unnest(text[], …)` natively; it does not understand +// per-row CTE composition. +func SQLRows() (priority []int32, category, source, ruleID, rulePrefix []string) { + for prio, def := range Definitions { + emit := func(src, id, prefix string) { + priority = append(priority, int32(prio)) //nolint:gosec // bounded by Definitions length, far below int32 max + category = append(category, string(def.Category)) + source = append(source, src) + ruleID = append(ruleID, id) + rulePrefix = append(rulePrefix, prefix) + } + switch { + case def.Source != "": + emit(def.Source, "", "") + case len(def.RuleIDs) > 0: + for _, id := range def.RuleIDs { + emit("", id, "") + } + case def.RulePrefix != "": + emit("", "", def.RulePrefix) + } + } + return priority, category, source, ruleID, rulePrefix +} diff --git a/server/internal/risk/categories/sqlc_schema.sql b/server/internal/risk/categories/sqlc_schema.sql deleted file mode 100644 index a30a54b031..0000000000 --- a/server/internal/risk/categories/sqlc_schema.sql +++ /dev/null @@ -1,14 +0,0 @@ --- Schema fragment consumed ONLY by sqlc's static analysis so that queries in --- internal/risk/queries.sql can reference the risk_category_lookup TEMP TABLE --- without a migration. --- --- The actual table is created at runtime by BootstrapConnection (bootstrap.go) --- on every pool connection, populated from the canonical Go classifier. --- DO NOT migrate this file; it is intentionally not in server/migrations. -CREATE TABLE IF NOT EXISTS risk_category_lookup ( - priority INTEGER NOT NULL, - category TEXT NOT NULL, - source TEXT, - rule_id TEXT, - rule_prefix TEXT -); diff --git a/server/internal/risk/impl.go b/server/internal/risk/impl.go index 4291bfd9b0..902e912faa 100644 --- a/server/internal/risk/impl.go +++ b/server/internal/risk/impl.go @@ -795,10 +795,16 @@ func (s *Service) GetRiskOverview(ctx context.Context, payload *gen.GetRiskOverv return nil, oops.E(oops.CodeUnexpected, err, "list risk overview top users").Log(ctx, s.logger) } + catPriority, catCategory, catSource, catRuleID, catRulePrefix := categories.SQLRows() timeSeriesRows, err := s.repo.ListRiskOverviewTimeSeriesFindings(ctx, repo.ListRiskOverviewTimeSeriesFindingsParams{ - ProjectID: *authCtx.ProjectID, - FromTime: window.from, - ToTime: window.to, + CatPriority: catPriority, + CatCategory: catCategory, + CatSource: catSource, + CatRuleID: catRuleID, + CatRulePrefix: catRulePrefix, + ProjectID: *authCtx.ProjectID, + FromTime: window.from, + ToTime: window.to, }) if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "list risk overview time series findings").Log(ctx, s.logger) @@ -1011,7 +1017,13 @@ func (s *Service) listResultsByPolicy(ctx context.Context, projectID uuid.UUID, func (s *Service) listResultsByProject(ctx context.Context, projectID uuid.UUID, cursor *riskResultsCursor, pageSize int, totalCount int64, category string, ruleID string, uniqueMatch bool, fromTime, toTime pgtype.Timestamptz) (*gen.ListRiskResultsResult, error) { cursorCreatedAt, cursorID := cursorToParams(cursor) + catPriority, catCategory, catSource, catRuleID, catRulePrefix := categories.SQLRows() rows, err := s.repo.ListRiskResultsByProjectFound(ctx, repo.ListRiskResultsByProjectFoundParams{ + CatPriority: catPriority, + CatCategory: catCategory, + CatSource: catSource, + CatRuleID: catRuleID, + CatRulePrefix: catRulePrefix, ProjectID: projectID, FromTime: fromTime, ToTime: toTime, @@ -1079,7 +1091,13 @@ func (s *Service) GetRiskUserBreakdown(ctx context.Context, payload *gen.GetRisk } window := riskOverviewWindowParams(from, to) + catPriority, catCategory, catSource, catRuleID, catRulePrefix := categories.SQLRows() categoryRows, err := s.repo.ListRiskUserCategoryBreakdown(ctx, repo.ListRiskUserCategoryBreakdownParams{ + CatPriority: catPriority, + CatCategory: catCategory, + CatSource: catSource, + CatRuleID: catRuleID, + CatRulePrefix: catRulePrefix, ProjectID: *authCtx.ProjectID, FromTime: window.from, ToTime: window.to, @@ -1144,11 +1162,17 @@ func (s *Service) GetRiskRuleBreakdown(ctx context.Context, payload *gen.GetRisk } window := riskOverviewWindowParams(from, to) + catPriority, catCategory, catSource, catRuleID, catRulePrefix := categories.SQLRows() rows, err := s.repo.ListRiskRulesByCategory(ctx, repo.ListRiskRulesByCategoryParams{ - ProjectID: *authCtx.ProjectID, - FromTime: window.from, - ToTime: window.to, - Category: payload.Category, + CatPriority: catPriority, + CatCategory: catCategory, + CatSource: catSource, + CatRuleID: catRuleID, + CatRulePrefix: catRulePrefix, + ProjectID: *authCtx.ProjectID, + FromTime: window.from, + ToTime: window.to, + Category: payload.Category, }) if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "list rule breakdown").Log(ctx, s.logger) diff --git a/server/internal/risk/queries.sql b/server/internal/risk/queries.sql index 759b44eb11..b4abe96838 100644 --- a/server/internal/risk/queries.sql +++ b/server/internal/risk/queries.sql @@ -159,19 +159,23 @@ LIMIT @row_limit; -- name: ListRiskUserCategoryBreakdown :many -- Per-category finding counts for a single external_user_id in a window. --- The category CASE expression must stay in sync with the other ListRisk* --- queries. -WITH user_findings AS ( +-- Categories are resolved against the canonical Go classifier passed in as +-- the @cat_* parallel arrays (see internal/risk/categories.SQLRows). +WITH risk_category_lookup AS ( + SELECT unnest(@cat_priority::int[]) AS priority, + unnest(@cat_category::text[]) AS category, + unnest(@cat_source::text[]) AS source, + unnest(@cat_rule_id::text[]) AS rule_id, + unnest(@cat_rule_prefix::text[]) AS rule_prefix +), +user_findings AS ( SELECT - -- Classify via the risk_category_lookup TEMP TABLE populated from the - -- Go classifier in internal/risk/categories. See bootstrap.go and the - -- pgxpool AfterConnect hook in cmd/gram/deps.go. COALESCE( ( SELECT rcl.category FROM risk_category_lookup rcl - WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) - OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) - OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + WHERE (rcl.source != '' AND rcl.source = rr.source) + OR (rcl.rule_id != '' AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix != '' AND rr.rule_id LIKE rcl.rule_prefix || '%') ORDER BY rcl.priority ASC LIMIT 1 ), @@ -209,22 +213,26 @@ GROUP BY rr.rule_id, rr.source ORDER BY findings DESC, rule_id ASC; -- name: ListRiskRulesByCategory :many --- Returns per-rule_id finding counts for a category within a window. --- The CASE expression must stay in sync with ListRiskOverviewTimeSeriesFindings --- and ListRiskResultsByProjectFound; all three classify rr.rule_id the same way. -WITH categorized AS ( +-- Per-rule_id finding counts for a category within a window. Categories are +-- resolved against the canonical Go classifier passed in as the @cat_* +-- parallel arrays (see internal/risk/categories.SQLRows). +WITH risk_category_lookup AS ( + SELECT unnest(@cat_priority::int[]) AS priority, + unnest(@cat_category::text[]) AS category, + unnest(@cat_source::text[]) AS source, + unnest(@cat_rule_id::text[]) AS rule_id, + unnest(@cat_rule_prefix::text[]) AS rule_prefix +), +categorized AS ( SELECT COALESCE(rr.rule_id, '')::TEXT AS rule_id, rr.source, - -- Classify via the risk_category_lookup TEMP TABLE populated from the - -- Go classifier in internal/risk/categories. See bootstrap.go and the - -- pgxpool AfterConnect hook in cmd/gram/deps.go. COALESCE( ( SELECT rcl.category FROM risk_category_lookup rcl - WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) - OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) - OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + WHERE (rcl.source != '' AND rcl.source = rr.source) + OR (rcl.rule_id != '' AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix != '' AND rr.rule_id LIKE rcl.rule_prefix || '%') ORDER BY rcl.priority ASC LIMIT 1 ), @@ -268,7 +276,16 @@ ORDER BY findings DESC, email ASC LIMIT @row_limit; -- name: ListRiskOverviewTimeSeriesFindings :many -WITH buckets AS ( +-- Categories are resolved against the canonical Go classifier passed in as +-- the @cat_* parallel arrays (see internal/risk/categories.SQLRows). +WITH risk_category_lookup AS ( + SELECT unnest(@cat_priority::int[]) AS priority, + unnest(@cat_category::text[]) AS category, + unnest(@cat_source::text[]) AS source, + unnest(@cat_rule_id::text[]) AS rule_id, + unnest(@cat_rule_prefix::text[]) AS rule_prefix +), +buckets AS ( SELECT generate_series( date_trunc('hour', sqlc.arg(from_time)::timestamptz) , date_trunc('hour', (sqlc.arg(to_time)::timestamptz - INTERVAL '1 microsecond')) @@ -281,9 +298,9 @@ categorized AS ( , COALESCE( ( SELECT rcl.category FROM risk_category_lookup rcl - WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) - OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) - OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + WHERE (rcl.source != '' AND rcl.source = rr.source) + OR (rcl.rule_id != '' AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix != '' AND rr.rule_id LIKE rcl.rule_prefix || '%') ORDER BY rcl.priority ASC LIMIT 1 ), @@ -393,17 +410,23 @@ WHERE risk_policy_id = @risk_policy_id -- Sort by the underlying chat message's created_at (the event time), NOT -- rr.created_at (the scan time). The background drain workflow analyzes -- historical messages in arbitrary order, so rr.created_at can put a --- finding for an old message ahead of one for a recent message — which is +-- finding for an old message ahead of one for a recent message, which is -- exactly the "random-seeming" order users see in Recent Findings. -- Cursor is (cm.created_at, rr.id) for stable pagination. --- The category CASE expression here must stay in sync with the one in --- ListRiskOverviewTimeSeriesFindings; both derive the user-facing category --- key from rr.source and rr.rule_id. +-- Categories are resolved against the canonical Go classifier passed in as +-- the @cat_* parallel arrays (see internal/risk/categories.SQLRows). -- -- When @unique_match is TRUE, dedup at the SQL layer: keep only one row per -- (risk_policy_id, rule_id, match), choosing the most recent occurrence. Done -- inside a subquery so pagination over the deduped stream stays correct -- (client-side dedup over paged data broke "Load more"). +WITH risk_category_lookup AS ( + SELECT unnest(@cat_priority::int[]) AS priority, + unnest(@cat_category::text[]) AS category, + unnest(@cat_source::text[]) AS source, + unnest(@cat_rule_id::text[]) AS rule_id, + unnest(@cat_rule_prefix::text[]) AS rule_prefix +) SELECT sub.id, sub.project_id, sub.organization_id, sub.risk_policy_id, sub.risk_policy_version, sub.chat_message_id, sub.source, sub.found, @@ -437,9 +460,9 @@ FROM ( AND (@category::text = '' OR COALESCE( ( SELECT rcl.category FROM risk_category_lookup rcl - WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) - OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) - OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + WHERE (rcl.source != '' AND rcl.source = rr.source) + OR (rcl.rule_id != '' AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix != '' AND rr.rule_id LIKE rcl.rule_prefix || '%') ORDER BY rcl.priority ASC LIMIT 1 ), diff --git a/server/internal/risk/repo/queries.sql.go b/server/internal/risk/repo/queries.sql.go index 1b656f1467..2259598e09 100644 --- a/server/internal/risk/repo/queries.sql.go +++ b/server/internal/risk/repo/queries.sql.go @@ -641,10 +641,17 @@ func (q *Queries) ListEnabledToolIdentityPoliciesByProject(ctx context.Context, } const listRiskOverviewTimeSeriesFindings = `-- name: ListRiskOverviewTimeSeriesFindings :many -WITH buckets AS ( +WITH risk_category_lookup AS ( + SELECT unnest($1::int[]) AS priority, + unnest($2::text[]) AS category, + unnest($3::text[]) AS source, + unnest($4::text[]) AS rule_id, + unnest($5::text[]) AS rule_prefix +), +buckets AS ( SELECT generate_series( - date_trunc('hour', $1::timestamptz) - , date_trunc('hour', ($2::timestamptz - INTERVAL '1 microsecond')) + date_trunc('hour', $6::timestamptz) + , date_trunc('hour', ($7::timestamptz - INTERVAL '1 microsecond')) , INTERVAL '1 hour' )::timestamptz AS bucket_start ), @@ -654,19 +661,19 @@ categorized AS ( , COALESCE( ( SELECT rcl.category FROM risk_category_lookup rcl - WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) - OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) - OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + WHERE (rcl.source != '' AND rcl.source = rr.source) + OR (rcl.rule_id != '' AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix != '' AND rr.rule_id LIKE rcl.rule_prefix || '%') ORDER BY rcl.priority ASC LIMIT 1 ), 'custom' )::text AS category FROM risk_results rr - WHERE rr.project_id = $3::uuid + WHERE rr.project_id = $8::uuid AND rr.found IS TRUE - AND rr.created_at >= $1 - AND rr.created_at < $2 + AND rr.created_at >= $6 + AND rr.created_at < $7 ), categories AS ( SELECT DISTINCT category @@ -691,9 +698,14 @@ ORDER BY buckets.bucket_start ASC, categories.category ASC ` type ListRiskOverviewTimeSeriesFindingsParams struct { - FromTime pgtype.Timestamptz - ToTime pgtype.Timestamptz - ProjectID uuid.UUID + CatPriority []int32 + CatCategory []string + CatSource []string + CatRuleID []string + CatRulePrefix []string + FromTime pgtype.Timestamptz + ToTime pgtype.Timestamptz + ProjectID uuid.UUID } type ListRiskOverviewTimeSeriesFindingsRow struct { @@ -702,8 +714,19 @@ type ListRiskOverviewTimeSeriesFindingsRow struct { Findings int64 } +// Categories are resolved against the canonical Go classifier passed in as +// the @cat_* parallel arrays (see internal/risk/categories.SQLRows). func (q *Queries) ListRiskOverviewTimeSeriesFindings(ctx context.Context, arg ListRiskOverviewTimeSeriesFindingsParams) ([]ListRiskOverviewTimeSeriesFindingsRow, error) { - rows, err := q.db.Query(ctx, listRiskOverviewTimeSeriesFindings, arg.FromTime, arg.ToTime, arg.ProjectID) + rows, err := q.db.Query(ctx, listRiskOverviewTimeSeriesFindings, + arg.CatPriority, + arg.CatCategory, + arg.CatSource, + arg.CatRuleID, + arg.CatRulePrefix, + arg.FromTime, + arg.ToTime, + arg.ProjectID, + ) if err != nil { return nil, err } @@ -1080,6 +1103,13 @@ func (q *Queries) ListRiskResultsByProjectAndPolicy(ctx context.Context, arg Lis } const listRiskResultsByProjectFound = `-- name: ListRiskResultsByProjectFound :many +WITH risk_category_lookup AS ( + SELECT unnest($10::int[]) AS priority, + unnest($11::text[]) AS category, + unnest($12::text[]) AS source, + unnest($13::text[]) AS rule_id, + unnest($14::text[]) AS rule_prefix +) SELECT sub.id, sub.project_id, sub.organization_id, sub.risk_policy_id, sub.risk_policy_version, sub.chat_message_id, sub.source, sub.found, @@ -1113,9 +1143,9 @@ FROM ( AND ($6::text = '' OR COALESCE( ( SELECT rcl.category FROM risk_category_lookup rcl - WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) - OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) - OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + WHERE (rcl.source != '' AND rcl.source = rr.source) + OR (rcl.rule_id != '' AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix != '' AND rr.rule_id LIKE rcl.rule_prefix || '%') ORDER BY rcl.priority ASC LIMIT 1 ), @@ -1141,6 +1171,11 @@ type ListRiskResultsByProjectFoundParams struct { CursorMessageCreatedAt pgtype.Timestamptz CursorID uuid.NullUUID PageLimit int32 + CatPriority []int32 + CatCategory []string + CatSource []string + CatRuleID []string + CatRulePrefix []string } type ListRiskResultsByProjectFoundRow struct { @@ -1170,12 +1205,11 @@ type ListRiskResultsByProjectFoundRow struct { // Sort by the underlying chat message's created_at (the event time), NOT // rr.created_at (the scan time). The background drain workflow analyzes // historical messages in arbitrary order, so rr.created_at can put a -// finding for an old message ahead of one for a recent message — which is +// finding for an old message ahead of one for a recent message, which is // exactly the "random-seeming" order users see in Recent Findings. // Cursor is (cm.created_at, rr.id) for stable pagination. -// The category CASE expression here must stay in sync with the one in -// ListRiskOverviewTimeSeriesFindings; both derive the user-facing category -// key from rr.source and rr.rule_id. +// Categories are resolved against the canonical Go classifier passed in as +// the @cat_* parallel arrays (see internal/risk/categories.SQLRows). // // When @unique_match is TRUE, dedup at the SQL layer: keep only one row per // (risk_policy_id, rule_id, match), choosing the most recent occurrence. Done @@ -1192,6 +1226,11 @@ func (q *Queries) ListRiskResultsByProjectFound(ctx context.Context, arg ListRis arg.CursorMessageCreatedAt, arg.CursorID, arg.PageLimit, + arg.CatPriority, + arg.CatCategory, + arg.CatSource, + arg.CatRuleID, + arg.CatRulePrefix, ) if err != nil { return nil, err @@ -1293,29 +1332,33 @@ func (q *Queries) ListRiskResultsGroupedByChat(ctx context.Context, arg ListRisk } const listRiskRulesByCategory = `-- name: ListRiskRulesByCategory :many -WITH categorized AS ( +WITH risk_category_lookup AS ( + SELECT unnest($2::int[]) AS priority, + unnest($3::text[]) AS category, + unnest($4::text[]) AS source, + unnest($5::text[]) AS rule_id, + unnest($6::text[]) AS rule_prefix +), +categorized AS ( SELECT COALESCE(rr.rule_id, '')::TEXT AS rule_id, rr.source, - -- Classify via the risk_category_lookup TEMP TABLE populated from the - -- Go classifier in internal/risk/categories. See bootstrap.go and the - -- pgxpool AfterConnect hook in cmd/gram/deps.go. COALESCE( ( SELECT rcl.category FROM risk_category_lookup rcl - WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) - OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) - OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + WHERE (rcl.source != '' AND rcl.source = rr.source) + OR (rcl.rule_id != '' AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix != '' AND rr.rule_id LIKE rcl.rule_prefix || '%') ORDER BY rcl.priority ASC LIMIT 1 ), 'custom' )::text AS category FROM risk_results rr - WHERE rr.project_id = $2 + WHERE rr.project_id = $7 AND rr.found IS TRUE - AND rr.created_at >= $3 - AND rr.created_at < $4 + AND rr.created_at >= $8 + AND rr.created_at < $9 ) SELECT rule_id, source, COUNT(*)::BIGINT AS findings FROM categorized @@ -1325,10 +1368,15 @@ ORDER BY findings DESC, rule_id ASC ` type ListRiskRulesByCategoryParams struct { - Category string - ProjectID uuid.UUID - FromTime pgtype.Timestamptz - ToTime pgtype.Timestamptz + Category string + CatPriority []int32 + CatCategory []string + CatSource []string + CatRuleID []string + CatRulePrefix []string + ProjectID uuid.UUID + FromTime pgtype.Timestamptz + ToTime pgtype.Timestamptz } type ListRiskRulesByCategoryRow struct { @@ -1337,12 +1385,17 @@ type ListRiskRulesByCategoryRow struct { Findings int64 } -// Returns per-rule_id finding counts for a category within a window. -// The CASE expression must stay in sync with ListRiskOverviewTimeSeriesFindings -// and ListRiskResultsByProjectFound; all three classify rr.rule_id the same way. +// Per-rule_id finding counts for a category within a window. Categories are +// resolved against the canonical Go classifier passed in as the @cat_* +// parallel arrays (see internal/risk/categories.SQLRows). func (q *Queries) ListRiskRulesByCategory(ctx context.Context, arg ListRiskRulesByCategoryParams) ([]ListRiskRulesByCategoryRow, error) { rows, err := q.db.Query(ctx, listRiskRulesByCategory, arg.Category, + arg.CatPriority, + arg.CatCategory, + arg.CatSource, + arg.CatRuleID, + arg.CatRulePrefix, arg.ProjectID, arg.FromTime, arg.ToTime, @@ -1366,17 +1419,21 @@ func (q *Queries) ListRiskRulesByCategory(ctx context.Context, arg ListRiskRules } const listRiskUserCategoryBreakdown = `-- name: ListRiskUserCategoryBreakdown :many -WITH user_findings AS ( +WITH risk_category_lookup AS ( + SELECT unnest($1::int[]) AS priority, + unnest($2::text[]) AS category, + unnest($3::text[]) AS source, + unnest($4::text[]) AS rule_id, + unnest($5::text[]) AS rule_prefix +), +user_findings AS ( SELECT - -- Classify via the risk_category_lookup TEMP TABLE populated from the - -- Go classifier in internal/risk/categories. See bootstrap.go and the - -- pgxpool AfterConnect hook in cmd/gram/deps.go. COALESCE( ( SELECT rcl.category FROM risk_category_lookup rcl - WHERE (rcl.source IS NOT NULL AND rcl.source = rr.source) - OR (rcl.rule_id IS NOT NULL AND rcl.rule_id = rr.rule_id) - OR (rcl.rule_prefix IS NOT NULL AND rr.rule_id LIKE rcl.rule_prefix || '%') + WHERE (rcl.source != '' AND rcl.source = rr.source) + OR (rcl.rule_id != '' AND rcl.rule_id = rr.rule_id) + OR (rcl.rule_prefix != '' AND rr.rule_id LIKE rcl.rule_prefix || '%') ORDER BY rcl.priority ASC LIMIT 1 ), @@ -1385,11 +1442,11 @@ WITH user_findings AS ( FROM risk_results rr JOIN chat_messages cm ON cm.id = rr.chat_message_id LEFT JOIN chats c ON c.id = cm.chat_id AND c.deleted IS FALSE - WHERE rr.project_id = $1 + WHERE rr.project_id = $6 AND rr.found IS TRUE - AND rr.created_at >= $2 - AND rr.created_at < $3 - AND COALESCE(NULLIF(cm.external_user_id, ''), NULLIF(c.external_user_id, ''), '') = $4::text + AND rr.created_at >= $7 + AND rr.created_at < $8 + AND COALESCE(NULLIF(cm.external_user_id, ''), NULLIF(c.external_user_id, ''), '') = $9::text ) SELECT category, COUNT(*)::BIGINT AS findings FROM user_findings @@ -1398,6 +1455,11 @@ ORDER BY findings DESC, category ASC ` type ListRiskUserCategoryBreakdownParams struct { + CatPriority []int32 + CatCategory []string + CatSource []string + CatRuleID []string + CatRulePrefix []string ProjectID uuid.UUID FromTime pgtype.Timestamptz ToTime pgtype.Timestamptz @@ -1410,10 +1472,15 @@ type ListRiskUserCategoryBreakdownRow struct { } // Per-category finding counts for a single external_user_id in a window. -// The category CASE expression must stay in sync with the other ListRisk* -// queries. +// Categories are resolved against the canonical Go classifier passed in as +// the @cat_* parallel arrays (see internal/risk/categories.SQLRows). func (q *Queries) ListRiskUserCategoryBreakdown(ctx context.Context, arg ListRiskUserCategoryBreakdownParams) ([]ListRiskUserCategoryBreakdownRow, error) { rows, err := q.db.Query(ctx, listRiskUserCategoryBreakdown, + arg.CatPriority, + arg.CatCategory, + arg.CatSource, + arg.CatRuleID, + arg.CatRulePrefix, arg.ProjectID, arg.FromTime, arg.ToTime, diff --git a/server/internal/testenv/postgresql.go b/server/internal/testenv/postgresql.go index 91311b48f0..d1debf129a 100644 --- a/server/internal/testenv/postgresql.go +++ b/server/internal/testenv/postgresql.go @@ -12,7 +12,6 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/speakeasy-api/gram/server/internal/o11y" - "github.com/speakeasy-api/gram/server/internal/risk/categories" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/postgres" ) @@ -95,14 +94,7 @@ func newPostgresCloneFunc(container *postgres.PostgresContainer) PostgresDBClone } cloneuri := strings.Replace(uri, "gotestdb", clonename, 1) - poolcfg, err := pgxpool.ParseConfig(cloneuri) - if err != nil { - return nil, fmt.Errorf("parse clone pool config: %w", err) - } - // Mirror the production AfterConnect hook so queries that rely on - // the risk_category_lookup TEMP TABLE work the same way under tests. - poolcfg.AfterConnect = categories.BootstrapConnection - pool, err := pgxpool.NewWithConfig(ctx, poolcfg) + pool, err := pgxpool.New(ctx, cloneuri) if err != nil { return nil, fmt.Errorf("create pgx pool: %w", err) } From cc4fd531693a9b51921b5ad608cdec5c275a2926 Mon Sep 17 00:00:00 2001 From: David Alberto Adler Date: Fri, 22 May 2026 08:42:56 +0100 Subject: [PATCH 3/3] refactor(dashboard): classify findings via /rpc/risk.categories The dashboard's risk-utils held its own (source, rule_id) -> category classifier built from SOURCE_TO_CATEGORY and DETECTION_RULES. That was a parallel copy of the Go classifier in internal/risk/categories, guaranteed to drift whenever a rule was added on one side and not the other. Replace it with useFindingClassifier(), a hook that reads the canonical classifier through the existing useRiskCategories() React Query binding (/rpc/risk.categories, shipped in #2976). CategoryLabel renders nothing during the first fetch and falls back to "custom" for unmapped rules once data is loaded. Per-rule human-readable titles still come from DETECTION_RULES since the API does not expose them. --- .../dashboard/src/pages/security/risk-ui.tsx | 6 +- .../src/pages/security/risk-utils.ts | 70 ++++++++++--------- 2 files changed, 40 insertions(+), 36 deletions(-) diff --git a/client/dashboard/src/pages/security/risk-ui.tsx b/client/dashboard/src/pages/security/risk-ui.tsx index 1f7fd42069..0f51193e74 100644 --- a/client/dashboard/src/pages/security/risk-ui.tsx +++ b/client/dashboard/src/pages/security/risk-ui.tsx @@ -8,7 +8,7 @@ import { type ReactNode, } from "react"; import { RULE_CATEGORY_META } from "./policy-data"; -import { getCategoryForFinding, getRuleTitleFallback } from "./risk-utils"; +import { getRuleTitleFallback, useFindingClassifier } from "./risk-utils"; import { Badge } from "@speakeasy-api/moonshine"; import { SimpleTooltip } from "@/components/ui/tooltip"; import { @@ -24,7 +24,9 @@ export function CategoryLabel({ source?: string; ruleId?: string; }) { - const category = getCategoryForFinding(source, ruleId); + const classify = useFindingClassifier(); + if (!classify) return null; + const category = classify(source, ruleId); const meta = category ? RULE_CATEGORY_META[category] : RULE_CATEGORY_META.custom; diff --git a/client/dashboard/src/pages/security/risk-utils.ts b/client/dashboard/src/pages/security/risk-utils.ts index efc0555821..ff76e48399 100644 --- a/client/dashboard/src/pages/security/risk-utils.ts +++ b/client/dashboard/src/pages/security/risk-utils.ts @@ -1,52 +1,54 @@ +import { useMemo } from "react"; +import { useRiskCategories } from "@gram/client/react-query/index.js"; import { DETECTION_RULES, type RuleCategory } from "./policy-data"; import { humanizeRuleId } from "./rule-ids"; -const SOURCE_TO_CATEGORY: ReadonlyMap = new Map< - string, - RuleCategory ->([ - ["destructive_tool", "destructive_tool"], - ["shadow_mcp", "shadow_mcp"], - ["prompt_injection", "prompt_injection"], - ["cli_destructive", "cli_destructive"], - // Scanner-source fallbacks: when a rule_id doesn't carry its category - // prefix (e.g. gitleaks' bare "generic-api-key"), classify by source so - // we never leak the scanner name to the UI. Keep in sync with the Go - // classifier in server/internal/risk/categories. - ["gitleaks", "secrets"], - ["presidio", "pii"], -]); - -const ruleIdToCategory = new Map(); const ruleIdToTitle = new Map(); - -for (const [category, rules] of Object.entries(DETECTION_RULES)) { +for (const rules of Object.values(DETECTION_RULES)) { for (const rule of rules) { - ruleIdToCategory.set(rule.id, category as RuleCategory); ruleIdToTitle.set(rule.id, rule.title); } } - -// DETECTION_RULES.id is the canonical rule_id the backend writes to -// risk_results, so lookup maps key by it directly. -const RULE_ID_TO_CATEGORY: ReadonlyMap = ruleIdToCategory; const RULE_ID_TO_TITLE: ReadonlyMap = ruleIdToTitle; +// Per-rule human-readable titles aren't returned by /rpc/risk.categories +// (the API exposes only the canonical classification: source / rule_ids / +// rule_id_prefix). Keep the static title map for label display. export function getRuleTitleFallback(ruleId: string | undefined): string { if (!ruleId) return "-"; return RULE_ID_TO_TITLE.get(ruleId) ?? humanizeRuleId(ruleId); } -export function getCategoryForFinding( +export type FindingClassifier = ( source?: string, ruleId?: string, -): RuleCategory | null { - if (ruleId) { - const byRule = RULE_ID_TO_CATEGORY.get(ruleId); - if (byRule) return byRule; - } - if (source) { - return SOURCE_TO_CATEGORY.get(source) ?? null; - } - return null; +) => RuleCategory | null; + +// useFindingClassifier returns a (source, rule_id) -> category lookup +// backed by the canonical Go classifier served at /rpc/risk.categories. +// React Query dedupes across the page so calling this per table row is +// cheap. Returns null while the first fetch is in flight; consumers +// should treat that as "category unknown yet" and render nothing. +export function useFindingClassifier(): FindingClassifier | null { + const { data } = useRiskCategories(undefined, undefined, { + staleTime: Number.POSITIVE_INFINITY, + }); + return useMemo(() => { + const defs = data?.categories; + if (!defs) return null; + return (source, ruleId) => { + for (const def of defs) { + if (def.source && def.source === source) { + return def.key as RuleCategory; + } + if (def.ruleIds.length > 0 && ruleId && def.ruleIds.includes(ruleId)) { + return def.key as RuleCategory; + } + if (def.ruleIdPrefix && ruleId && ruleId.startsWith(def.ruleIdPrefix)) { + return def.key as RuleCategory; + } + } + return null; + }; + }, [data]); }