From 3aca672f9d80c6a2ef6fb0a31bd4b9cfad7f975a Mon Sep 17 00:00:00 2001 From: Ryan Skidmore Date: Fri, 13 Mar 2026 14:12:11 -0500 Subject: [PATCH] feat(ai-gateway-provider): add createProviderRouter for native SDK routing Add a provider router that maps model ID prefixes to native AI SDK provider instances, enabling consumers to route requests through the correct native SDK (e.g. Anthropic, OpenAI) based on model name. This preserves provider-specific features like prompt caching that are lost when using the unified provider, while maintaining the convenience of a single provider interface. - createProviderRouter() accepts a map of model prefix -> provider - Falls back to a configurable default provider - Unit tests for routing logic (7 cases) - Integration tests proving cache_control survives the full pipeline: router -> native Anthropic SDK -> createAiGateway -> gateway request --- .changeset/add-provider-router.md | 5 + packages/ai-gateway-provider/package.json | 5 + .../src/providers/index.ts | 1 + .../src/providers/router.ts | 41 +++++ .../test/router-integration.test.ts | 160 ++++++++++++++++++ .../ai-gateway-provider/test/router.test.ts | 116 +++++++++++++ 6 files changed, 328 insertions(+) create mode 100644 .changeset/add-provider-router.md create mode 100644 packages/ai-gateway-provider/src/providers/router.ts create mode 100644 packages/ai-gateway-provider/test/router-integration.test.ts create mode 100644 packages/ai-gateway-provider/test/router.test.ts diff --git a/.changeset/add-provider-router.md b/.changeset/add-provider-router.md new file mode 100644 index 000000000..d88d6af5d --- /dev/null +++ b/.changeset/add-provider-router.md @@ -0,0 +1,5 @@ +--- +"ai-gateway-provider": minor +--- + +Add `createProviderRouter()` for routing model IDs to native SDK providers by prefix, preserving provider-specific features like Anthropic prompt caching that are lost through the unified OpenAI-compatible path. diff --git a/packages/ai-gateway-provider/package.json b/packages/ai-gateway-provider/package.json index 391324d9f..d33d17db0 100644 --- a/packages/ai-gateway-provider/package.json +++ b/packages/ai-gateway-provider/package.json @@ -119,6 +119,11 @@ "import": "./dist/providers/openrouter.mjs", "require": "./dist/providers/openrouter.js" }, + "./providers/router": { + "types": "./dist/providers/router.d.ts", + "import": "./dist/providers/router.mjs", + "require": "./dist/providers/router.js" + }, "./providers/unified": { "types": "./dist/providers/unified.d.ts", "import": "./dist/providers/unified.mjs", diff --git a/packages/ai-gateway-provider/src/providers/index.ts b/packages/ai-gateway-provider/src/providers/index.ts index a31b24b69..084ff48fa 100644 --- a/packages/ai-gateway-provider/src/providers/index.ts +++ b/packages/ai-gateway-provider/src/providers/index.ts @@ -13,4 +13,5 @@ export { createGroq } from "./groq"; export { createMistral } from "./mistral"; export { createOpenAI } from "./openai"; export { createPerplexity } from "./perplexity"; +export { createProviderRouter, type ProviderRouterConfig } from "./router"; export { createXai } from "./xai"; diff --git a/packages/ai-gateway-provider/src/providers/router.ts b/packages/ai-gateway-provider/src/providers/router.ts new file mode 100644 index 000000000..2baf40503 --- /dev/null +++ b/packages/ai-gateway-provider/src/providers/router.ts @@ -0,0 +1,41 @@ +import type { LanguageModelV3 } from "@ai-sdk/provider"; +import { createUnified } from "./unified"; + +type ProviderFactory = (modelId: string) => LanguageModelV3; + +export type ProviderRouterConfig = { + /** + * Map of model ID prefix to a provider instance. + * When a model ID matches "prefix/model-name", the corresponding provider + * is used to create the model with the bare name (prefix stripped). + */ + providers: Record; + /** + * Optional fallback for model IDs that don't match any provider prefix. + * Defaults to createUnified() (OpenAI-compatible format). + */ + fallback?: ProviderFactory; +}; + +/** + * Creates a model router that selects native provider SDKs based on + * the model ID prefix. This preserves provider-specific features like + * Anthropic's cache_control that are lost through the unified OpenAI-compatible path. + */ +export function createProviderRouter(config: ProviderRouterConfig): ProviderFactory { + const fallback = config.fallback ?? createUnified(); + + return (modelId: string): LanguageModelV3 => { + const slashIndex = modelId.indexOf("/"); + if (slashIndex > 0) { + const prefix = modelId.slice(0, slashIndex); + const bareId = modelId.slice(slashIndex + 1); + const provider = config.providers[prefix]; + if (provider) { + return provider(bareId); + } + } + + return fallback(modelId); + }; +} diff --git a/packages/ai-gateway-provider/test/router-integration.test.ts b/packages/ai-gateway-provider/test/router-integration.test.ts new file mode 100644 index 000000000..3000f6f26 --- /dev/null +++ b/packages/ai-gateway-provider/test/router-integration.test.ts @@ -0,0 +1,160 @@ +import { createAnthropic } from "@ai-sdk/anthropic"; +import { generateText } from "ai"; +import { HttpResponse, http } from "msw"; +import { setupServer } from "msw/node"; +import { afterAll, afterEach, beforeAll, describe, expect, it } from "vitest"; +import { createAiGateway } from "../src"; +import { createProviderRouter } from "../src/providers/router"; + +const TEST_ACCOUNT_ID = "test-account-id"; +const TEST_API_KEY = "test-api-key"; +const TEST_GATEWAY = "my-gateway"; + +let capturedBody: any = null; + +const gatewayHandler = http.post( + `https://gateway.ai.cloudflare.com/v1/${TEST_ACCOUNT_ID}/${TEST_GATEWAY}`, + async ({ request }) => { + capturedBody = await request.json(); + + return HttpResponse.json( + { + id: "msg_test123", + type: "message", + role: "assistant", + model: "claude-sonnet-4-5-20250514", + content: [{ type: "text", text: "Hello" }], + stop_reason: "end_turn", + stop_sequence: null, + usage: { input_tokens: 25, output_tokens: 1 }, + }, + { + headers: { + "cf-aig-step": "0", + }, + }, + ); + }, +); + +const server = setupServer(gatewayHandler); + +describe("Provider Router Integration", () => { + beforeAll(() => server.listen()); + afterEach(() => { + server.resetHandlers(); + capturedBody = null; + }); + afterAll(() => server.close()); + + it("preserves cache_control on user messages through router → native Anthropic → gateway", async () => { + const anthropic = createAnthropic({ apiKey: TEST_API_KEY }); + const router = createProviderRouter({ + providers: { + anthropic: (modelId) => anthropic.languageModel(modelId), + }, + }); + + const gateway = createAiGateway({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + gateway: TEST_GATEWAY, + }); + + await generateText({ + model: gateway(router("anthropic/claude-sonnet-4-5-20250514")), + messages: [ + { + role: "user", + content: [ + { + type: "text", + text: "What is 2+2?", + providerOptions: { + anthropic: { cacheControl: { type: "ephemeral" } }, + }, + }, + ], + }, + ], + }); + + expect(capturedBody).toBeDefined(); + expect(capturedBody).toHaveLength(1); + + const query = capturedBody[0].query; + expect(query.model).toBe("claude-sonnet-4-5-20250514"); + expect(query.messages[0].role).toBe("user"); + expect(query.messages[0].content[0].cache_control).toEqual({ + type: "ephemeral", + }); + }); + + it("preserves cache_control on system messages through router → native Anthropic → gateway", async () => { + const anthropic = createAnthropic({ apiKey: TEST_API_KEY }); + const router = createProviderRouter({ + providers: { + anthropic: (modelId) => anthropic.languageModel(modelId), + }, + }); + + const gateway = createAiGateway({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + gateway: TEST_GATEWAY, + }); + + await generateText({ + model: gateway(router("anthropic/claude-sonnet-4-5-20250514")), + messages: [ + { + role: "system", + content: "You are a helpful assistant", + providerOptions: { + anthropic: { cacheControl: { type: "ephemeral" } }, + }, + }, + { + role: "user", + content: "Hello", + }, + ], + }); + + expect(capturedBody).toBeDefined(); + expect(capturedBody).toHaveLength(1); + + const query = capturedBody[0].query; + expect(query.system).toBeDefined(); + const systemBlocks = Array.isArray(query.system) ? query.system : [query.system]; + const lastBlock = systemBlocks[systemBlocks.length - 1]; + expect(lastBlock.cache_control).toEqual({ type: "ephemeral" }); + }); + + it("routes through native Anthropic SDK (not unified) based on model prefix", async () => { + const anthropic = createAnthropic({ apiKey: TEST_API_KEY }); + const router = createProviderRouter({ + providers: { + anthropic: (modelId) => anthropic.languageModel(modelId), + }, + }); + + const gateway = createAiGateway({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + gateway: TEST_GATEWAY, + }); + + await generateText({ + model: gateway(router("anthropic/claude-sonnet-4-5-20250514")), + prompt: "Hello", + }); + + expect(capturedBody).toBeDefined(); + expect(capturedBody).toHaveLength(1); + + expect(capturedBody[0].provider).toBe("anthropic"); + expect(capturedBody[0].endpoint).toBe("v1/messages"); + expect(capturedBody[0].query.model).toBe("claude-sonnet-4-5-20250514"); + }); +}); diff --git a/packages/ai-gateway-provider/test/router.test.ts b/packages/ai-gateway-provider/test/router.test.ts new file mode 100644 index 000000000..35fa99d59 --- /dev/null +++ b/packages/ai-gateway-provider/test/router.test.ts @@ -0,0 +1,116 @@ +import type { LanguageModelV3 } from "@ai-sdk/provider"; +import { beforeEach, describe, expect, it, vi } from "vitest"; + +vi.mock("../src/providers/unified", () => ({ + createUnified: vi.fn( + () => + vi.fn((modelId: string) => + ({ source: "unified", modelId }) as unknown as LanguageModelV3, + ), + ), +})); + +import { createUnified } from "../src/providers/unified"; +import { createProviderRouter } from "../src/providers/router"; + +type MockModel = { + source: string; + modelId: string; +}; + +const asMockModel = (model: LanguageModelV3) => model as unknown as MockModel; + +const makeProvider = (source: string) => + vi.fn((modelId: string) => + ({ source, modelId }) as unknown as LanguageModelV3, + ); + +describe("createProviderRouter", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("selects native provider when model ID has matching prefix", () => { + const anthropic = makeProvider("anthropic"); + const router = createProviderRouter({ providers: { anthropic } }); + + const model = asMockModel(router("anthropic/claude-sonnet-4-5")); + + expect(model).toEqual({ + source: "anthropic", + modelId: "claude-sonnet-4-5", + }); + expect(anthropic).toHaveBeenCalledWith("claude-sonnet-4-5"); + }); + + it("strips prefix and passes bare model name to native provider", () => { + const anthropic = makeProvider("anthropic"); + const router = createProviderRouter({ providers: { anthropic } }); + + router("anthropic/claude-3-5-haiku-latest"); + + expect(anthropic).toHaveBeenCalledWith("claude-3-5-haiku-latest"); + }); + + it("falls back to unified for unknown prefixes", () => { + const anthropic = makeProvider("anthropic"); + const router = createProviderRouter({ providers: { anthropic } }); + + const model = asMockModel(router("unknown/model-id")); + + expect(model).toEqual({ source: "unified", modelId: "unknown/model-id" }); + expect(anthropic).not.toHaveBeenCalled(); + }); + + it("falls back to unified for model IDs without a slash", () => { + const anthropic = makeProvider("anthropic"); + const router = createProviderRouter({ providers: { anthropic } }); + + const model = asMockModel(router("claude-sonnet-4-5")); + + expect(model).toEqual({ source: "unified", modelId: "claude-sonnet-4-5" }); + expect(anthropic).not.toHaveBeenCalled(); + }); + + it("uses custom fallback when provided", () => { + const anthropic = makeProvider("anthropic"); + const customFallback = makeProvider("custom-fallback"); + const router = createProviderRouter({ + providers: { anthropic }, + fallback: customFallback, + }); + + const model = asMockModel(router("unknown/model-id")); + + expect(model).toEqual({ source: "custom-fallback", modelId: "unknown/model-id" }); + expect(customFallback).toHaveBeenCalledWith("unknown/model-id"); + expect(createUnified).not.toHaveBeenCalled(); + }); + + it("routes correctly when multiple providers are registered", () => { + const anthropic = makeProvider("anthropic"); + const openai = makeProvider("openai"); + const router = createProviderRouter({ providers: { anthropic, openai } }); + + const anthropicModel = asMockModel(router("anthropic/claude-3-opus")); + const openaiModel = asMockModel(router("openai/gpt-4o-mini")); + + expect(anthropicModel).toEqual({ source: "anthropic", modelId: "claude-3-opus" }); + expect(openaiModel).toEqual({ source: "openai", modelId: "gpt-4o-mini" }); + expect(anthropic).toHaveBeenCalledWith("claude-3-opus"); + expect(openai).toHaveBeenCalledWith("gpt-4o-mini"); + }); + + it("falls back to unified for every model when providers map is empty", () => { + const router = createProviderRouter({ providers: {} }); + + const modelA = asMockModel(router("anthropic/claude-sonnet-4-5")); + const modelB = asMockModel(router("gpt-4o-mini")); + + expect(modelA).toEqual({ + source: "unified", + modelId: "anthropic/claude-sonnet-4-5", + }); + expect(modelB).toEqual({ source: "unified", modelId: "gpt-4o-mini" }); + }); +});