diff --git a/.changeset/add-provider-router.md b/.changeset/add-provider-router.md new file mode 100644 index 000000000..d88d6af5d --- /dev/null +++ b/.changeset/add-provider-router.md @@ -0,0 +1,5 @@ +--- +"ai-gateway-provider": minor +--- + +Add `createProviderRouter()` for routing model IDs to native SDK providers by prefix, preserving provider-specific features like Anthropic prompt caching that are lost through the unified OpenAI-compatible path. diff --git a/packages/ai-gateway-provider/package.json b/packages/ai-gateway-provider/package.json index 391324d9f..d33d17db0 100644 --- a/packages/ai-gateway-provider/package.json +++ b/packages/ai-gateway-provider/package.json @@ -119,6 +119,11 @@ "import": "./dist/providers/openrouter.mjs", "require": "./dist/providers/openrouter.js" }, + "./providers/router": { + "types": "./dist/providers/router.d.ts", + "import": "./dist/providers/router.mjs", + "require": "./dist/providers/router.js" + }, "./providers/unified": { "types": "./dist/providers/unified.d.ts", "import": "./dist/providers/unified.mjs", diff --git a/packages/ai-gateway-provider/src/providers/index.ts b/packages/ai-gateway-provider/src/providers/index.ts index a31b24b69..084ff48fa 100644 --- a/packages/ai-gateway-provider/src/providers/index.ts +++ b/packages/ai-gateway-provider/src/providers/index.ts @@ -13,4 +13,5 @@ export { createGroq } from "./groq"; export { createMistral } from "./mistral"; export { createOpenAI } from "./openai"; export { createPerplexity } from "./perplexity"; +export { createProviderRouter, type ProviderRouterConfig } from "./router"; export { createXai } from "./xai"; diff --git a/packages/ai-gateway-provider/src/providers/router.ts b/packages/ai-gateway-provider/src/providers/router.ts new file mode 100644 index 000000000..2baf40503 --- /dev/null +++ b/packages/ai-gateway-provider/src/providers/router.ts @@ -0,0 +1,41 @@ +import type { LanguageModelV3 } from "@ai-sdk/provider"; +import { createUnified } from "./unified"; + +type ProviderFactory = (modelId: string) => LanguageModelV3; + +export type ProviderRouterConfig = { + /** + * Map of model ID prefix to a provider instance. + * When a model ID matches "prefix/model-name", the corresponding provider + * is used to create the model with the bare name (prefix stripped). + */ + providers: Record; + /** + * Optional fallback for model IDs that don't match any provider prefix. + * Defaults to createUnified() (OpenAI-compatible format). + */ + fallback?: ProviderFactory; +}; + +/** + * Creates a model router that selects native provider SDKs based on + * the model ID prefix. This preserves provider-specific features like + * Anthropic's cache_control that are lost through the unified OpenAI-compatible path. + */ +export function createProviderRouter(config: ProviderRouterConfig): ProviderFactory { + const fallback = config.fallback ?? createUnified(); + + return (modelId: string): LanguageModelV3 => { + const slashIndex = modelId.indexOf("/"); + if (slashIndex > 0) { + const prefix = modelId.slice(0, slashIndex); + const bareId = modelId.slice(slashIndex + 1); + const provider = config.providers[prefix]; + if (provider) { + return provider(bareId); + } + } + + return fallback(modelId); + }; +} diff --git a/packages/ai-gateway-provider/test/router-integration.test.ts b/packages/ai-gateway-provider/test/router-integration.test.ts new file mode 100644 index 000000000..3000f6f26 --- /dev/null +++ b/packages/ai-gateway-provider/test/router-integration.test.ts @@ -0,0 +1,160 @@ +import { createAnthropic } from "@ai-sdk/anthropic"; +import { generateText } from "ai"; +import { HttpResponse, http } from "msw"; +import { setupServer } from "msw/node"; +import { afterAll, afterEach, beforeAll, describe, expect, it } from "vitest"; +import { createAiGateway } from "../src"; +import { createProviderRouter } from "../src/providers/router"; + +const TEST_ACCOUNT_ID = "test-account-id"; +const TEST_API_KEY = "test-api-key"; +const TEST_GATEWAY = "my-gateway"; + +let capturedBody: any = null; + +const gatewayHandler = http.post( + `https://gateway.ai.cloudflare.com/v1/${TEST_ACCOUNT_ID}/${TEST_GATEWAY}`, + async ({ request }) => { + capturedBody = await request.json(); + + return HttpResponse.json( + { + id: "msg_test123", + type: "message", + role: "assistant", + model: "claude-sonnet-4-5-20250514", + content: [{ type: "text", text: "Hello" }], + stop_reason: "end_turn", + stop_sequence: null, + usage: { input_tokens: 25, output_tokens: 1 }, + }, + { + headers: { + "cf-aig-step": "0", + }, + }, + ); + }, +); + +const server = setupServer(gatewayHandler); + +describe("Provider Router Integration", () => { + beforeAll(() => server.listen()); + afterEach(() => { + server.resetHandlers(); + capturedBody = null; + }); + afterAll(() => server.close()); + + it("preserves cache_control on user messages through router → native Anthropic → gateway", async () => { + const anthropic = createAnthropic({ apiKey: TEST_API_KEY }); + const router = createProviderRouter({ + providers: { + anthropic: (modelId) => anthropic.languageModel(modelId), + }, + }); + + const gateway = createAiGateway({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + gateway: TEST_GATEWAY, + }); + + await generateText({ + model: gateway(router("anthropic/claude-sonnet-4-5-20250514")), + messages: [ + { + role: "user", + content: [ + { + type: "text", + text: "What is 2+2?", + providerOptions: { + anthropic: { cacheControl: { type: "ephemeral" } }, + }, + }, + ], + }, + ], + }); + + expect(capturedBody).toBeDefined(); + expect(capturedBody).toHaveLength(1); + + const query = capturedBody[0].query; + expect(query.model).toBe("claude-sonnet-4-5-20250514"); + expect(query.messages[0].role).toBe("user"); + expect(query.messages[0].content[0].cache_control).toEqual({ + type: "ephemeral", + }); + }); + + it("preserves cache_control on system messages through router → native Anthropic → gateway", async () => { + const anthropic = createAnthropic({ apiKey: TEST_API_KEY }); + const router = createProviderRouter({ + providers: { + anthropic: (modelId) => anthropic.languageModel(modelId), + }, + }); + + const gateway = createAiGateway({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + gateway: TEST_GATEWAY, + }); + + await generateText({ + model: gateway(router("anthropic/claude-sonnet-4-5-20250514")), + messages: [ + { + role: "system", + content: "You are a helpful assistant", + providerOptions: { + anthropic: { cacheControl: { type: "ephemeral" } }, + }, + }, + { + role: "user", + content: "Hello", + }, + ], + }); + + expect(capturedBody).toBeDefined(); + expect(capturedBody).toHaveLength(1); + + const query = capturedBody[0].query; + expect(query.system).toBeDefined(); + const systemBlocks = Array.isArray(query.system) ? query.system : [query.system]; + const lastBlock = systemBlocks[systemBlocks.length - 1]; + expect(lastBlock.cache_control).toEqual({ type: "ephemeral" }); + }); + + it("routes through native Anthropic SDK (not unified) based on model prefix", async () => { + const anthropic = createAnthropic({ apiKey: TEST_API_KEY }); + const router = createProviderRouter({ + providers: { + anthropic: (modelId) => anthropic.languageModel(modelId), + }, + }); + + const gateway = createAiGateway({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + gateway: TEST_GATEWAY, + }); + + await generateText({ + model: gateway(router("anthropic/claude-sonnet-4-5-20250514")), + prompt: "Hello", + }); + + expect(capturedBody).toBeDefined(); + expect(capturedBody).toHaveLength(1); + + expect(capturedBody[0].provider).toBe("anthropic"); + expect(capturedBody[0].endpoint).toBe("v1/messages"); + expect(capturedBody[0].query.model).toBe("claude-sonnet-4-5-20250514"); + }); +}); diff --git a/packages/ai-gateway-provider/test/router.test.ts b/packages/ai-gateway-provider/test/router.test.ts new file mode 100644 index 000000000..35fa99d59 --- /dev/null +++ b/packages/ai-gateway-provider/test/router.test.ts @@ -0,0 +1,116 @@ +import type { LanguageModelV3 } from "@ai-sdk/provider"; +import { beforeEach, describe, expect, it, vi } from "vitest"; + +vi.mock("../src/providers/unified", () => ({ + createUnified: vi.fn( + () => + vi.fn((modelId: string) => + ({ source: "unified", modelId }) as unknown as LanguageModelV3, + ), + ), +})); + +import { createUnified } from "../src/providers/unified"; +import { createProviderRouter } from "../src/providers/router"; + +type MockModel = { + source: string; + modelId: string; +}; + +const asMockModel = (model: LanguageModelV3) => model as unknown as MockModel; + +const makeProvider = (source: string) => + vi.fn((modelId: string) => + ({ source, modelId }) as unknown as LanguageModelV3, + ); + +describe("createProviderRouter", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("selects native provider when model ID has matching prefix", () => { + const anthropic = makeProvider("anthropic"); + const router = createProviderRouter({ providers: { anthropic } }); + + const model = asMockModel(router("anthropic/claude-sonnet-4-5")); + + expect(model).toEqual({ + source: "anthropic", + modelId: "claude-sonnet-4-5", + }); + expect(anthropic).toHaveBeenCalledWith("claude-sonnet-4-5"); + }); + + it("strips prefix and passes bare model name to native provider", () => { + const anthropic = makeProvider("anthropic"); + const router = createProviderRouter({ providers: { anthropic } }); + + router("anthropic/claude-3-5-haiku-latest"); + + expect(anthropic).toHaveBeenCalledWith("claude-3-5-haiku-latest"); + }); + + it("falls back to unified for unknown prefixes", () => { + const anthropic = makeProvider("anthropic"); + const router = createProviderRouter({ providers: { anthropic } }); + + const model = asMockModel(router("unknown/model-id")); + + expect(model).toEqual({ source: "unified", modelId: "unknown/model-id" }); + expect(anthropic).not.toHaveBeenCalled(); + }); + + it("falls back to unified for model IDs without a slash", () => { + const anthropic = makeProvider("anthropic"); + const router = createProviderRouter({ providers: { anthropic } }); + + const model = asMockModel(router("claude-sonnet-4-5")); + + expect(model).toEqual({ source: "unified", modelId: "claude-sonnet-4-5" }); + expect(anthropic).not.toHaveBeenCalled(); + }); + + it("uses custom fallback when provided", () => { + const anthropic = makeProvider("anthropic"); + const customFallback = makeProvider("custom-fallback"); + const router = createProviderRouter({ + providers: { anthropic }, + fallback: customFallback, + }); + + const model = asMockModel(router("unknown/model-id")); + + expect(model).toEqual({ source: "custom-fallback", modelId: "unknown/model-id" }); + expect(customFallback).toHaveBeenCalledWith("unknown/model-id"); + expect(createUnified).not.toHaveBeenCalled(); + }); + + it("routes correctly when multiple providers are registered", () => { + const anthropic = makeProvider("anthropic"); + const openai = makeProvider("openai"); + const router = createProviderRouter({ providers: { anthropic, openai } }); + + const anthropicModel = asMockModel(router("anthropic/claude-3-opus")); + const openaiModel = asMockModel(router("openai/gpt-4o-mini")); + + expect(anthropicModel).toEqual({ source: "anthropic", modelId: "claude-3-opus" }); + expect(openaiModel).toEqual({ source: "openai", modelId: "gpt-4o-mini" }); + expect(anthropic).toHaveBeenCalledWith("claude-3-opus"); + expect(openai).toHaveBeenCalledWith("gpt-4o-mini"); + }); + + it("falls back to unified for every model when providers map is empty", () => { + const router = createProviderRouter({ providers: {} }); + + const modelA = asMockModel(router("anthropic/claude-sonnet-4-5")); + const modelB = asMockModel(router("gpt-4o-mini")); + + expect(modelA).toEqual({ + source: "unified", + modelId: "anthropic/claude-sonnet-4-5", + }); + expect(modelB).toEqual({ source: "unified", modelId: "gpt-4o-mini" }); + }); +});