Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion packages/proxy/src/providers/databricks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export async function getDatabricksOAuthAccessToken({
digest,
cacheGet,
cachePut,
fetch = globalThis.fetch,
}: {
secret: z.infer<typeof DatabricksOAuthSecretSchema>;
apiBase: string;
Expand All @@ -29,6 +30,7 @@ export async function getDatabricksOAuthAccessToken({
value: string,
ttl_seconds?: number,
) => Promise<void>;
fetch?: typeof globalThis.fetch;
}): Promise<string> {
const { client_id, client_secret } = secret;
const tokenUrl = `${apiBase}/oidc/v1/token`;
Expand Down Expand Up @@ -59,7 +61,9 @@ export async function getDatabricksOAuthAccessToken({
});
if (!res.ok) {
throw new Error(
`Databricks OAuth error (${res.status}): ${res.statusText} ${await res.text()}`,
`Databricks OAuth error (${res.status}): ${
res.statusText
} ${await res.text()}`,
);
}

Expand Down
123 changes: 123 additions & 0 deletions packages/proxy/src/providers/openai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
OpenAIChatCompletionChunk,
OpenAIChatCompletionCreateParams,
} from "@types";
import { type APISecret } from "@schema";
import { bypass, http, HttpResponse, JsonBodyType } from "msw";
import { setupServer } from "msw/node";
import { ChatCompletionContentPart } from "openai/resources";
Expand All @@ -18,6 +19,7 @@ import {
} from "vitest";
import { callProxyV1, createCapturingFetch } from "../../utils/tests";
import * as proxyUtil from "../util";
import { type FetchFn } from "../proxy";
import { normalizeOpenAIContent } from "./openai";
import * as util from "./util";
import {
Expand All @@ -30,6 +32,26 @@ import {
CSV_DATA_URL,
} from "../../tests/fixtures/base64";

function fetchInputUrl(input: Parameters<FetchFn>[0]): string {
if (typeof input === "string") {
return input;
}
if (input instanceof URL) {
return input.toString();
}
return input.url;
}

function fetchHeaderValue(
headers: HeadersInit | undefined,
name: string,
): string | null {
if (!headers) {
return null;
}
return new Headers(headers).get(name);
}

it("should deny reasoning_effort for unsupported models non-streaming", async () => {
const { json } = await callProxyV1<
OpenAIChatCompletionCreateParams,
Expand Down Expand Up @@ -379,6 +401,107 @@ it("handles /responses as endpoint_path", async () => {
});
});

it("uses injected fetch when supportsStreaming is false", async () => {
const { fetch, requests } = createCapturingFetch({ captureOnly: true });

await callProxyV1<OpenAIChatCompletionCreateParams, OpenAIChatCompletion>({
body: {
model: "gpt-4o-mini",
messages: [{ role: "user", content: "hello" }],
stream: true,
},
fetch,
getApiSecrets: async () => [
{
type: "openai",
name: "openai",
secret: "provider-secret",
metadata: {
api_base: "http://test.com/v1",
supportsStreaming: false,
},
},
],
});

expect(requests.length).toBe(1);
expect(requests[0].url).toBe("http://test.com/v1/chat/completions");
});

it("uses injected fetch for Databricks OAuth token exchange", async () => {
vi.stubGlobal("fetch", async () => {
throw new Error("global fetch was called");
});

try {
const calls: Array<{
url: string;
body: BodyInit | null | undefined;
headers: HeadersInit | undefined;
}> = [];
const fetch: FetchFn = async (input, init) => {
const url = fetchInputUrl(input);
calls.push({ url, body: init?.body, headers: init?.headers });

if (url === "https://dbc.example/oidc/v1/token") {
return new Response(
JSON.stringify({
access_token: "databricks-token",
token_type: "Bearer",
expires_in: 3600,
}),
{ headers: { "content-type": "application/json" } },
);
}

return new Response(JSON.stringify({ choices: [] }), {
headers: { "content-type": "application/json" },
});
};
const getApiSecrets = async (): Promise<APISecret[]> => [
{
type: "databricks",
name: "databricks",
secret: JSON.stringify({
client_id: "client-id",
client_secret: "client-secret",
}),
metadata: {
api_base: "https://dbc.example",
auth_type: "service_principal_oauth",
supportsStreaming: true,
},
},
];

await callProxyV1<OpenAIChatCompletionCreateParams, OpenAIChatCompletion>({
body: {
model: "databricks-model",
messages: [{ role: "user", content: "hello" }],
},
fetch,
...{ getApiSecrets },
});

expect(calls.map((call) => call.url)).toEqual([
"https://dbc.example/oidc/v1/token",
"https://dbc.example/serving-endpoints/databricks-model/invocations",
]);
expect(calls[0]?.body?.toString()).toBe(
"grant_type=client_credentials&scope=all-apis",
);
expect(fetchHeaderValue(calls[0]?.headers, "authorization")).toBe(
// Base64 for the fake "client-id:client-secret" fixture above.
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=", // gitleaks:allow
);
expect(fetchHeaderValue(calls[1]?.headers, "authorization")).toBe(
"Bearer databricks-token",
);
} finally {
vi.unstubAllGlobals();
}
});

it("uses model path for azure when metadata.deployment is non-string", async () => {
const { fetch, requests } = createCapturingFetch({ captureOnly: true });

Expand Down
24 changes: 21 additions & 3 deletions packages/proxy/src/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2135,7 +2135,7 @@ async function fetchOpenAI(
bearerToken = secret.secret;
} else {
// authType === "service_account_key"
bearerToken = await getGoogleAccessToken(secret.secret);
bearerToken = await getGoogleAccessToken(secret.secret, fetch);
}
} else {
const metadataApiBase =
Expand Down Expand Up @@ -2215,6 +2215,7 @@ async function fetchOpenAI(
digest,
cacheGet,
cachePut,
fetch,
});
} else {
bearerToken = secret.secret;
Expand Down Expand Up @@ -2336,6 +2337,7 @@ async function fetchOpenAI(
bodyData,
setHeader,
signal,
fetch,
});
}

Expand Down Expand Up @@ -2494,13 +2496,15 @@ async function fetchOpenAIFakeStream({
bodyData,
setHeader,
signal,
fetch,
}: {
method: "GET" | "POST";
fullURL: URL;
headers: Record<string, string>;
bodyData: null | any;
setHeader: (name: string, value: string) => void;
signal?: AbortSignal;
fetch: FetchFn;
}): Promise<ModelResponse> {
let isStream = false;
if (bodyData) {
Expand Down Expand Up @@ -2581,10 +2585,12 @@ async function vertexEndpointInfo({
secret: { secret, metadata },
modelSpec,
defaultLocation,
fetch = globalThis.fetch,
}: {
secret: APISecret;
modelSpec: ModelSpec | null;
defaultLocation: string;
fetch?: FetchFn;
}): Promise<VertexEndpointInfo> {
const { project, location, authType, api_base } =
VertexMetadataSchema.parse(metadata);
Expand All @@ -2595,7 +2601,9 @@ async function vertexEndpointInfo({
});
const apiBase = getVertexBaseUrl(api_base, resolvedLocation);
const accessToken =
authType === "access_token" ? secret : await getGoogleAccessToken(secret);
authType === "access_token"
? secret
: await getGoogleAccessToken(secret, fetch);
if (!accessToken) {
throw new Error("Failed to get Google access token");
}
Expand All @@ -2610,16 +2618,19 @@ async function fetchVertexAnthropicMessages({
modelSpec,
body,
signal,
fetch,
}: {
secret: APISecret;
modelSpec: ModelSpec | null;
body: unknown;
signal?: AbortSignal;
fetch: FetchFn;
}): Promise<ModelResponse> {
const { baseUrl, accessToken } = await vertexEndpointInfo({
secret,
modelSpec,
defaultLocation: "us-east5",
fetch,
});
const { model, ...rest } = z
.object({
Expand Down Expand Up @@ -2689,6 +2700,7 @@ async function fetchAnthropicMessages({
modelSpec,
body,
signal,
fetch: customFetch,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sure if upstream we need to override the fetch that we're passing that custom fetch everywhere

Comment thread
ibolmo marked this conversation as resolved.
});
default:
throw new ProxyBadRequestError(
Expand Down Expand Up @@ -2962,6 +2974,7 @@ async function fetchAnthropicChatCompletions({
secret,
modelSpec,
defaultLocation: "us-east5",
fetch: customFetch,
});
fullURL = new URL(
`${baseUrl}/${params.model}:${
Expand Down Expand Up @@ -3150,7 +3163,10 @@ async function openAIToolsToGoogleTools(params: {
return out;
}

async function getGoogleAccessToken(secret: string): Promise<string> {
async function getGoogleAccessToken(
secret: string,
fetch: FetchFn = globalThis.fetch,
): Promise<string> {
const {
private_key_id: kid,
private_key: pk,
Expand Down Expand Up @@ -3234,6 +3250,7 @@ async function fetchGoogleGenerateContent({
secret,
modelSpec,
defaultLocation: "us-central1",
fetch,
});
const url = new URL(`${baseUrl}/${model}:${method}`);
if (method === "streamGenerateContent") {
Expand Down Expand Up @@ -3370,6 +3387,7 @@ async function fetchGoogleChatCompletions({
secret,
modelSpec,
defaultLocation: "us-central1",
fetch,
});
fullURL = new URL(
`${baseUrl}/${model}:${
Expand Down
Loading