Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/add-provider-router.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"ai-gateway-provider": minor
---

Add `createProviderRouter()` for routing model IDs to native SDK providers by prefix, preserving provider-specific features like Anthropic prompt caching that are lost through the unified OpenAI-compatible path.
5 changes: 5 additions & 0 deletions packages/ai-gateway-provider/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@
"import": "./dist/providers/openrouter.mjs",
"require": "./dist/providers/openrouter.js"
},
"./providers/router": {
"types": "./dist/providers/router.d.ts",
"import": "./dist/providers/router.mjs",
"require": "./dist/providers/router.js"
},
"./providers/unified": {
"types": "./dist/providers/unified.d.ts",
"import": "./dist/providers/unified.mjs",
Expand Down
1 change: 1 addition & 0 deletions packages/ai-gateway-provider/src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ export { createGroq } from "./groq";
export { createMistral } from "./mistral";
export { createOpenAI } from "./openai";
export { createPerplexity } from "./perplexity";
export { createProviderRouter, type ProviderRouterConfig } from "./router";
export { createXai } from "./xai";
41 changes: 41 additions & 0 deletions packages/ai-gateway-provider/src/providers/router.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import type { LanguageModelV3 } from "@ai-sdk/provider";
import { createUnified } from "./unified";

type ProviderFactory = (modelId: string) => LanguageModelV3;

export type ProviderRouterConfig = {
/**
* Map of model ID prefix to a provider instance.
* When a model ID matches "prefix/model-name", the corresponding provider
* is used to create the model with the bare name (prefix stripped).
*/
providers: Record<string, ProviderFactory>;
/**
* Optional fallback for model IDs that don't match any provider prefix.
* Defaults to createUnified() (OpenAI-compatible format).
*/
fallback?: ProviderFactory;
};

/**
* Creates a model router that selects native provider SDKs based on
* the model ID prefix. This preserves provider-specific features like
* Anthropic's cache_control that are lost through the unified OpenAI-compatible path.
*/
export function createProviderRouter(config: ProviderRouterConfig): ProviderFactory {
const fallback = config.fallback ?? createUnified();

return (modelId: string): LanguageModelV3 => {
const slashIndex = modelId.indexOf("/");
if (slashIndex > 0) {
const prefix = modelId.slice(0, slashIndex);
const bareId = modelId.slice(slashIndex + 1);
const provider = config.providers[prefix];
if (provider) {
return provider(bareId);
}
}

return fallback(modelId);
};
}
160 changes: 160 additions & 0 deletions packages/ai-gateway-provider/test/router-integration.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import { createAnthropic } from "@ai-sdk/anthropic";
import { generateText } from "ai";
import { HttpResponse, http } from "msw";
import { setupServer } from "msw/node";
import { afterAll, afterEach, beforeAll, describe, expect, it } from "vitest";
import { createAiGateway } from "../src";
import { createProviderRouter } from "../src/providers/router";

const TEST_ACCOUNT_ID = "test-account-id";
const TEST_API_KEY = "test-api-key";
const TEST_GATEWAY = "my-gateway";

let capturedBody: any = null;

const gatewayHandler = http.post(
`https://gateway.ai.cloudflare.com/v1/${TEST_ACCOUNT_ID}/${TEST_GATEWAY}`,
async ({ request }) => {
capturedBody = await request.json();

return HttpResponse.json(
{
id: "msg_test123",
type: "message",
role: "assistant",
model: "claude-sonnet-4-5-20250514",
content: [{ type: "text", text: "Hello" }],
stop_reason: "end_turn",
stop_sequence: null,
usage: { input_tokens: 25, output_tokens: 1 },
},
{
headers: {
"cf-aig-step": "0",
},
},
);
},
);

const server = setupServer(gatewayHandler);

describe("Provider Router Integration", () => {
beforeAll(() => server.listen());
afterEach(() => {
server.resetHandlers();
capturedBody = null;
});
afterAll(() => server.close());

it("preserves cache_control on user messages through router → native Anthropic → gateway", async () => {
const anthropic = createAnthropic({ apiKey: TEST_API_KEY });
const router = createProviderRouter({
providers: {
anthropic: (modelId) => anthropic.languageModel(modelId),
},
});

const gateway = createAiGateway({
accountId: TEST_ACCOUNT_ID,
apiKey: TEST_API_KEY,
gateway: TEST_GATEWAY,
});

await generateText({
model: gateway(router("anthropic/claude-sonnet-4-5-20250514")),
messages: [
{
role: "user",
content: [
{
type: "text",
text: "What is 2+2?",
providerOptions: {
anthropic: { cacheControl: { type: "ephemeral" } },
},
},
],
},
],
});

expect(capturedBody).toBeDefined();
expect(capturedBody).toHaveLength(1);

const query = capturedBody[0].query;
expect(query.model).toBe("claude-sonnet-4-5-20250514");
expect(query.messages[0].role).toBe("user");
expect(query.messages[0].content[0].cache_control).toEqual({
type: "ephemeral",
});
});

it("preserves cache_control on system messages through router → native Anthropic → gateway", async () => {
const anthropic = createAnthropic({ apiKey: TEST_API_KEY });
const router = createProviderRouter({
providers: {
anthropic: (modelId) => anthropic.languageModel(modelId),
},
});

const gateway = createAiGateway({
accountId: TEST_ACCOUNT_ID,
apiKey: TEST_API_KEY,
gateway: TEST_GATEWAY,
});

await generateText({
model: gateway(router("anthropic/claude-sonnet-4-5-20250514")),
messages: [
{
role: "system",
content: "You are a helpful assistant",
providerOptions: {
anthropic: { cacheControl: { type: "ephemeral" } },
},
},
{
role: "user",
content: "Hello",
},
],
});

expect(capturedBody).toBeDefined();
expect(capturedBody).toHaveLength(1);

const query = capturedBody[0].query;
expect(query.system).toBeDefined();
const systemBlocks = Array.isArray(query.system) ? query.system : [query.system];
const lastBlock = systemBlocks[systemBlocks.length - 1];
expect(lastBlock.cache_control).toEqual({ type: "ephemeral" });
});

it("routes through native Anthropic SDK (not unified) based on model prefix", async () => {
const anthropic = createAnthropic({ apiKey: TEST_API_KEY });
const router = createProviderRouter({
providers: {
anthropic: (modelId) => anthropic.languageModel(modelId),
},
});

const gateway = createAiGateway({
accountId: TEST_ACCOUNT_ID,
apiKey: TEST_API_KEY,
gateway: TEST_GATEWAY,
});

await generateText({
model: gateway(router("anthropic/claude-sonnet-4-5-20250514")),
prompt: "Hello",
});

expect(capturedBody).toBeDefined();
expect(capturedBody).toHaveLength(1);

expect(capturedBody[0].provider).toBe("anthropic");
expect(capturedBody[0].endpoint).toBe("v1/messages");
expect(capturedBody[0].query.model).toBe("claude-sonnet-4-5-20250514");
});
});
116 changes: 116 additions & 0 deletions packages/ai-gateway-provider/test/router.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import type { LanguageModelV3 } from "@ai-sdk/provider";
import { beforeEach, describe, expect, it, vi } from "vitest";

vi.mock("../src/providers/unified", () => ({
createUnified: vi.fn(
() =>
vi.fn((modelId: string) =>
({ source: "unified", modelId }) as unknown as LanguageModelV3,
),
),
}));

import { createUnified } from "../src/providers/unified";
import { createProviderRouter } from "../src/providers/router";

type MockModel = {
source: string;
modelId: string;
};

const asMockModel = (model: LanguageModelV3) => model as unknown as MockModel;

const makeProvider = (source: string) =>
vi.fn((modelId: string) =>
({ source, modelId }) as unknown as LanguageModelV3,
);

describe("createProviderRouter", () => {
beforeEach(() => {
vi.clearAllMocks();
});

it("selects native provider when model ID has matching prefix", () => {
const anthropic = makeProvider("anthropic");
const router = createProviderRouter({ providers: { anthropic } });

const model = asMockModel(router("anthropic/claude-sonnet-4-5"));

expect(model).toEqual({
source: "anthropic",
modelId: "claude-sonnet-4-5",
});
expect(anthropic).toHaveBeenCalledWith("claude-sonnet-4-5");
});

it("strips prefix and passes bare model name to native provider", () => {
const anthropic = makeProvider("anthropic");
const router = createProviderRouter({ providers: { anthropic } });

router("anthropic/claude-3-5-haiku-latest");

expect(anthropic).toHaveBeenCalledWith("claude-3-5-haiku-latest");
});

it("falls back to unified for unknown prefixes", () => {
const anthropic = makeProvider("anthropic");
const router = createProviderRouter({ providers: { anthropic } });

const model = asMockModel(router("unknown/model-id"));

expect(model).toEqual({ source: "unified", modelId: "unknown/model-id" });
expect(anthropic).not.toHaveBeenCalled();
});

it("falls back to unified for model IDs without a slash", () => {
const anthropic = makeProvider("anthropic");
const router = createProviderRouter({ providers: { anthropic } });

const model = asMockModel(router("claude-sonnet-4-5"));

expect(model).toEqual({ source: "unified", modelId: "claude-sonnet-4-5" });
expect(anthropic).not.toHaveBeenCalled();
});

it("uses custom fallback when provided", () => {
const anthropic = makeProvider("anthropic");
const customFallback = makeProvider("custom-fallback");
const router = createProviderRouter({
providers: { anthropic },
fallback: customFallback,
});

const model = asMockModel(router("unknown/model-id"));

expect(model).toEqual({ source: "custom-fallback", modelId: "unknown/model-id" });
expect(customFallback).toHaveBeenCalledWith("unknown/model-id");
expect(createUnified).not.toHaveBeenCalled();
});

it("routes correctly when multiple providers are registered", () => {
const anthropic = makeProvider("anthropic");
const openai = makeProvider("openai");
const router = createProviderRouter({ providers: { anthropic, openai } });

const anthropicModel = asMockModel(router("anthropic/claude-3-opus"));
const openaiModel = asMockModel(router("openai/gpt-4o-mini"));

expect(anthropicModel).toEqual({ source: "anthropic", modelId: "claude-3-opus" });
expect(openaiModel).toEqual({ source: "openai", modelId: "gpt-4o-mini" });
expect(anthropic).toHaveBeenCalledWith("claude-3-opus");
expect(openai).toHaveBeenCalledWith("gpt-4o-mini");
});

it("falls back to unified for every model when providers map is empty", () => {
const router = createProviderRouter({ providers: {} });

const modelA = asMockModel(router("anthropic/claude-sonnet-4-5"));
const modelB = asMockModel(router("gpt-4o-mini"));

expect(modelA).toEqual({
source: "unified",
modelId: "anthropic/claude-sonnet-4-5",
});
expect(modelB).toEqual({ source: "unified", modelId: "gpt-4o-mini" });
});
});
Loading