From 262f9b5bdfa16eb76cc621a6b97839f01ee67ffb Mon Sep 17 00:00:00 2001 From: Vijay Yadav Date: Sat, 4 Apr 2026 14:01:54 -0400 Subject: [PATCH] feat: add Databricks AI Gateway as LLM provider Implements Databricks serving endpoints support with PAT auth, workspace URL resolution, and OpenAI-compatible request handling. - Add databricks ProviderID to schema - Create auth plugin with PAT parsing and host validation - Add custom loader with env var fallback (DATABRICKS_HOST + DATABRICKS_TOKEN) - Register 11 foundation models (Llama, Claude, GPT, Gemini, DBRX, Mixtral) - Add 24 unit tests for host validation, PAT parsing, body transforms - E2E tests included (skipped without credentials) Closes #602 Co-Authored-By: Vijay Yadav --- .../src/altimate/plugin/databricks.ts | 125 ++++++++++ packages/opencode/src/plugin/index.ts | 7 +- packages/opencode/src/provider/provider.ts | 93 +++++++ packages/opencode/src/provider/schema.ts | 3 + .../test/altimate/databricks-provider.test.ts | 227 ++++++++++++++++++ 5 files changed, 453 insertions(+), 2 deletions(-) create mode 100644 packages/opencode/src/altimate/plugin/databricks.ts create mode 100644 packages/opencode/test/altimate/databricks-provider.test.ts diff --git a/packages/opencode/src/altimate/plugin/databricks.ts b/packages/opencode/src/altimate/plugin/databricks.ts new file mode 100644 index 0000000000..c76cbb978d --- /dev/null +++ b/packages/opencode/src/altimate/plugin/databricks.ts @@ -0,0 +1,125 @@ +import type { Hooks, PluginInput } from "@opencode-ai/plugin" +import { Auth, OAUTH_DUMMY_KEY } from "@/auth" + +/** + * Databricks workspace host regex. + * Matches patterns like: myworkspace.cloud.databricks.com, adb-1234567890.12.azuredatabricks.net + */ +export const VALID_HOST_RE = /^[a-zA-Z0-9._-]+\.(cloud\.databricks\.com|azuredatabricks\.net|gcp\.databricks\.com)$/ + +/** Parse a `host::token` credential string for Databricks PAT auth. */ +export function parseDatabricksPAT(code: string): { host: string; token: string } | null { + const sep = code.indexOf("::") + if (sep === -1) return null + const host = code.substring(0, sep).trim() + const token = code.substring(sep + 2).trim() + if (!host || !token) return null + if (!VALID_HOST_RE.test(host)) return null + return { host, token } +} + +/** + * Transform a Databricks request body string. + * Databricks Foundation Model APIs use max_tokens (OpenAI-compatible), + * but some endpoints may prefer max_completion_tokens. + */ +export function transformDatabricksBody(bodyText: string): { body: string } { + const parsed = JSON.parse(bodyText) + + // Databricks uses max_tokens for most endpoints, but some newer ones + // expect max_completion_tokens. Normalize to max_tokens for compatibility. + if ("max_completion_tokens" in parsed && !("max_tokens" in parsed)) { + parsed.max_tokens = parsed.max_completion_tokens + delete parsed.max_completion_tokens + } + + return { body: JSON.stringify(parsed) } +} + +export async function DatabricksAuthPlugin(_input: PluginInput): Promise { + return { + auth: { + provider: "databricks", + async loader(getAuth, provider) { + const auth = await getAuth() + if (auth.type !== "oauth") return {} + + for (const model of Object.values(provider.models)) { + model.cost = { input: 0, output: 0, cache: { read: 0, write: 0 } } + } + + return { + apiKey: OAUTH_DUMMY_KEY, + async fetch(requestInput: RequestInfo | URL, init?: RequestInit) { + const currentAuth = await getAuth() + if (currentAuth.type !== "oauth") return fetch(requestInput, init) + + const headers = new Headers() + if (init?.headers) { + if (init.headers instanceof Headers) { + init.headers.forEach((value, key) => headers.set(key, value)) + } else if (Array.isArray(init.headers)) { + for (const [key, value] of init.headers) { + if (value !== undefined) headers.set(key, String(value)) + } + } else { + for (const [key, value] of Object.entries(init.headers)) { + if (value !== undefined) headers.set(key, String(value)) + } + } + } + + headers.set("authorization", `Bearer ${currentAuth.access}`) + + let body = init?.body + if (body) { + try { + let text: string + if (typeof body === "string") { + text = body + } else if (body instanceof Uint8Array || body instanceof ArrayBuffer) { + text = new TextDecoder().decode(body) + } else { + text = "" + } + if (text) { + const result = transformDatabricksBody(text) + body = result.body + headers.delete("content-length") + } + } catch { + // JSON parse error — pass original body through untransformed + } + } + + return fetch(requestInput, { ...init, headers, body }) + }, + } + }, + methods: [ + { + label: "Databricks PAT", + type: "oauth", + authorize: async () => ({ + url: "https://accounts.cloud.databricks.com", + instructions: + "Enter your credentials as: ::\n e.g. myworkspace.cloud.databricks.com::dapi1234567890abcdef\n Create a PAT in Databricks: Settings → Developer → Access Tokens → Generate New Token", + method: "code" as const, + callback: async (code: string) => { + const parsed = parseDatabricksPAT(code) + if (!parsed) return { type: "failed" as const } + return { + type: "success" as const, + access: parsed.token, + refresh: "", + // Databricks PATs can be configured with custom TTLs; use 90-day default + expires: Date.now() + 90 * 24 * 60 * 60 * 1000, + accountId: parsed.host, + } + }, + }), + }, + ], + }, + } +} diff --git a/packages/opencode/src/plugin/index.ts b/packages/opencode/src/plugin/index.ts index 7c8f0dbf18..9bc1976c95 100644 --- a/packages/opencode/src/plugin/index.ts +++ b/packages/opencode/src/plugin/index.ts @@ -15,6 +15,9 @@ import { gitlabAuthPlugin as GitlabAuthPlugin } from "@gitlab/opencode-gitlab-au // altimate_change start — snowflake cortex plugin import import { SnowflakeCortexAuthPlugin } from "../altimate/plugin/snowflake" // altimate_change end +// altimate_change start — databricks plugin import +import { DatabricksAuthPlugin } from "../altimate/plugin/databricks" +// altimate_change end // altimate_change start — altimate backend auth plugin import { AltimateAuthPlugin } from "../altimate/plugin/altimate" // altimate_change end @@ -28,8 +31,8 @@ export namespace Plugin { // GitlabAuthPlugin uses a different version of @opencode-ai/plugin (from npm) // vs the workspace version, causing a type mismatch on internal HeyApiClient. // The types are structurally compatible at runtime. - // altimate_change start — snowflake cortex and altimate backend internal plugins - const INTERNAL_PLUGINS: PluginInstance[] = [CodexAuthPlugin, CopilotAuthPlugin, GitlabAuthPlugin as unknown as PluginInstance, SnowflakeCortexAuthPlugin, AltimateAuthPlugin] + // altimate_change start — snowflake cortex, databricks, and altimate backend internal plugins + const INTERNAL_PLUGINS: PluginInstance[] = [CodexAuthPlugin, CopilotAuthPlugin, GitlabAuthPlugin as unknown as PluginInstance, SnowflakeCortexAuthPlugin, DatabricksAuthPlugin, AltimateAuthPlugin] // altimate_change end const state = Instance.state(async () => { diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index 9e81a4ff38..381e59e5a0 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -50,6 +50,9 @@ import { ModelID, ProviderID } from "./schema" // altimate_change start — snowflake cortex account validation import { VALID_ACCOUNT_RE } from "../altimate/plugin/snowflake" // altimate_change end +// altimate_change start — databricks host validation +import { VALID_HOST_RE } from "../altimate/plugin/databricks" +// altimate_change end const DEFAULT_CHUNK_TIMEOUT = 120_000 @@ -733,6 +736,32 @@ export namespace Provider { } }, // altimate_change end + // altimate_change start — databricks provider loader + databricks: async () => { + const auth = await Auth.get("databricks") + if (auth?.type !== "oauth") { + // Fall back to env-based config + const host = Env.get("DATABRICKS_HOST") + const token = Env.get("DATABRICKS_TOKEN") + if (!host || !token) return { autoload: false } + return { + autoload: true, + options: { + baseURL: `https://${host}/serving-endpoints`, + apiKey: token, + }, + } + } + const host = auth.accountId ?? Env.get("DATABRICKS_HOST") + if (!host || !VALID_HOST_RE.test(host)) return { autoload: false } + return { + autoload: true, + options: { + baseURL: `https://${host}/serving-endpoints`, + }, + } + }, + // altimate_change end } export const Model = z @@ -1019,6 +1048,70 @@ export namespace Provider { } // altimate_change end + // altimate_change start — databricks provider models + function makeDatabricksModel( + id: string, + name: string, + limits: { context: number; output: number }, + caps?: { reasoning?: boolean; attachment?: boolean; toolcall?: boolean; image?: boolean }, + ): Model { + const m: Model = { + id: ModelID.make(id), + providerID: ProviderID.databricks, + api: { + id, + url: "", + npm: "@ai-sdk/openai-compatible", + }, + name, + capabilities: { + temperature: true, + reasoning: caps?.reasoning ?? false, + attachment: caps?.attachment ?? false, + toolcall: caps?.toolcall ?? true, + input: { text: true, audio: false, image: caps?.image ?? false, video: false, pdf: false }, + output: { text: true, audio: false, image: false, video: false, pdf: false }, + interleaved: false, + }, + cost: { input: 0, output: 0, cache: { read: 0, write: 0 } }, + limit: { context: limits.context, output: limits.output }, + status: "active" as const, + options: {}, + headers: {}, + release_date: "2024-01-01", + variants: {}, + } + m.variants = mapValues(ProviderTransform.variants(m), (v) => v) + return m + } + + database["databricks"] = { + id: ProviderID.databricks, + source: "custom", + name: "Databricks", + env: ["DATABRICKS_TOKEN"], + options: {}, + models: { + // Meta Llama models — tool calling supported + "databricks-meta-llama-3-1-405b-instruct": makeDatabricksModel("databricks-meta-llama-3-1-405b-instruct", "Meta Llama 3.1 405B Instruct", { context: 128000, output: 4096 }), + "databricks-meta-llama-3-1-70b-instruct": makeDatabricksModel("databricks-meta-llama-3-1-70b-instruct", "Meta Llama 3.1 70B Instruct", { context: 128000, output: 4096 }), + "databricks-meta-llama-3-1-8b-instruct": makeDatabricksModel("databricks-meta-llama-3-1-8b-instruct", "Meta Llama 3.1 8B Instruct", { context: 128000, output: 4096 }), + // Claude models via Databricks AI Gateway + "databricks-claude-sonnet-4-6": makeDatabricksModel("databricks-claude-sonnet-4-6", "Claude Sonnet 4.6", { context: 200000, output: 64000 }), + "databricks-claude-opus-4-6": makeDatabricksModel("databricks-claude-opus-4-6", "Claude Opus 4.6", { context: 200000, output: 32000 }), + // GPT models via Databricks AI Gateway + "databricks-gpt-5-4": makeDatabricksModel("databricks-gpt-5-4", "GPT-5-4", { context: 128000, output: 16384 }), + "databricks-gpt-5-mini": makeDatabricksModel("databricks-gpt-5-mini", "GPT-5 Mini", { context: 128000, output: 16384 }), + // Gemini models via Databricks AI Gateway + "databricks-gemini-3-1-pro": makeDatabricksModel("databricks-gemini-3-1-pro", "Gemini 3.1 Pro", { context: 1000000, output: 8192 }), + // DBRX — Databricks native model + "databricks-dbrx-instruct": makeDatabricksModel("databricks-dbrx-instruct", "DBRX Instruct", { context: 32768, output: 4096 }), + // Mixtral via Databricks + "databricks-mixtral-8x7b-instruct": makeDatabricksModel("databricks-mixtral-8x7b-instruct", "Mixtral 8x7B Instruct", { context: 32768, output: 4096 }, { toolcall: false }), + }, + } + // altimate_change end + // altimate_change start — register altimate-backend as an OpenAI-compatible provider if (!database["altimate-backend"]) { const backendModels: Record = { diff --git a/packages/opencode/src/provider/schema.ts b/packages/opencode/src/provider/schema.ts index 4e53acd6a6..d4f17c33a6 100644 --- a/packages/opencode/src/provider/schema.ts +++ b/packages/opencode/src/provider/schema.ts @@ -26,6 +26,9 @@ export const ProviderID = providerIdSchema.pipe( // altimate_change start — snowflake cortex provider ID snowflakeCortex: schema.makeUnsafe("snowflake-cortex"), // altimate_change end + // altimate_change start — databricks provider ID + databricks: schema.makeUnsafe("databricks"), + // altimate_change end })), ) diff --git a/packages/opencode/test/altimate/databricks-provider.test.ts b/packages/opencode/test/altimate/databricks-provider.test.ts new file mode 100644 index 0000000000..33cce3a4f7 --- /dev/null +++ b/packages/opencode/test/altimate/databricks-provider.test.ts @@ -0,0 +1,227 @@ +/** + * Databricks AI Gateway Provider Tests + * + * Unit tests for PAT parsing, host validation, and request body transforms. + * E2E tests for the serving endpoints API (skipped without credentials). + * + * For E2E tests, set: + * export DATABRICKS_HOST="myworkspace.cloud.databricks.com" + * export DATABRICKS_TOKEN="dapi1234567890abcdef" + * + * Run: + * bun test test/altimate/databricks-provider.test.ts + */ + +import { describe, expect, test } from "bun:test" +import { + parseDatabricksPAT, + transformDatabricksBody, + VALID_HOST_RE, +} from "../../src/altimate/plugin/databricks" + +// --------------------------------------------------------------------------- +// Host validation regex +// --------------------------------------------------------------------------- + +describe("VALID_HOST_RE", () => { + test("accepts standard AWS workspace host", () => { + expect(VALID_HOST_RE.test("myworkspace.cloud.databricks.com")).toBe(true) + }) + + test("accepts Azure workspace host", () => { + expect(VALID_HOST_RE.test("adb-1234567890.12.azuredatabricks.net")).toBe(true) + }) + + test("accepts GCP workspace host", () => { + expect(VALID_HOST_RE.test("myworkspace.gcp.databricks.com")).toBe(true) + }) + + test("accepts hyphenated workspace names", () => { + expect(VALID_HOST_RE.test("my-workspace-123.cloud.databricks.com")).toBe(true) + }) + + test("rejects bare hostname without domain", () => { + expect(VALID_HOST_RE.test("myworkspace")).toBe(false) + }) + + test("rejects non-databricks domain", () => { + expect(VALID_HOST_RE.test("myworkspace.cloud.example.com")).toBe(false) + }) + + test("rejects empty string", () => { + expect(VALID_HOST_RE.test("")).toBe(false) + }) + + test("rejects URL with protocol", () => { + expect(VALID_HOST_RE.test("https://myworkspace.cloud.databricks.com")).toBe(false) + }) +}) + +// --------------------------------------------------------------------------- +// PAT parsing +// --------------------------------------------------------------------------- + +describe("parseDatabricksPAT", () => { + test("parses valid AWS host::token", () => { + const result = parseDatabricksPAT("myworkspace.cloud.databricks.com::dapi1234567890abcdef") + expect(result).toEqual({ + host: "myworkspace.cloud.databricks.com", + token: "dapi1234567890abcdef", + }) + }) + + test("parses valid Azure host::token", () => { + const result = parseDatabricksPAT("adb-123.45.azuredatabricks.net::dapi-token-here") + expect(result).toEqual({ + host: "adb-123.45.azuredatabricks.net", + token: "dapi-token-here", + }) + }) + + test("parses valid GCP host::token", () => { + const result = parseDatabricksPAT("my-ws.gcp.databricks.com::dapiABCDEF123") + expect(result).toEqual({ + host: "my-ws.gcp.databricks.com", + token: "dapiABCDEF123", + }) + }) + + test("trims whitespace from host and token", () => { + const result = parseDatabricksPAT(" myworkspace.cloud.databricks.com :: dapi123 ") + expect(result).toEqual({ + host: "myworkspace.cloud.databricks.com", + token: "dapi123", + }) + }) + + test("returns null for missing separator", () => { + expect(parseDatabricksPAT("myworkspace.cloud.databricks.com:dapi123")).toBeNull() + }) + + test("returns null for empty host", () => { + expect(parseDatabricksPAT("::dapi123")).toBeNull() + }) + + test("returns null for empty token", () => { + expect(parseDatabricksPAT("myworkspace.cloud.databricks.com::")).toBeNull() + }) + + test("returns null for invalid host domain", () => { + expect(parseDatabricksPAT("example.com::dapi123")).toBeNull() + }) + + test("returns null for empty string", () => { + expect(parseDatabricksPAT("")).toBeNull() + }) + + test("returns null for single colon separator", () => { + expect(parseDatabricksPAT("host.cloud.databricks.com:token")).toBeNull() + }) +}) + +// --------------------------------------------------------------------------- +// Request body transforms +// --------------------------------------------------------------------------- + +describe("transformDatabricksBody", () => { + test("converts max_completion_tokens to max_tokens", () => { + const input = JSON.stringify({ + model: "databricks-meta-llama-3-1-70b-instruct", + messages: [{ role: "user", content: "hello" }], + max_completion_tokens: 4096, + }) + const result = JSON.parse(transformDatabricksBody(input).body) + expect(result.max_tokens).toBe(4096) + expect(result.max_completion_tokens).toBeUndefined() + }) + + test("preserves max_tokens if already present", () => { + const input = JSON.stringify({ + model: "databricks-meta-llama-3-1-70b-instruct", + messages: [{ role: "user", content: "hello" }], + max_tokens: 2048, + }) + const result = JSON.parse(transformDatabricksBody(input).body) + expect(result.max_tokens).toBe(2048) + }) + + test("does not convert when both max_tokens and max_completion_tokens exist", () => { + const input = JSON.stringify({ + model: "databricks-meta-llama-3-1-70b-instruct", + messages: [{ role: "user", content: "hello" }], + max_tokens: 2048, + max_completion_tokens: 4096, + }) + const result = JSON.parse(transformDatabricksBody(input).body) + expect(result.max_tokens).toBe(2048) + expect(result.max_completion_tokens).toBe(4096) + }) + + test("passes through body without max token fields unchanged", () => { + const input = JSON.stringify({ + model: "databricks-dbrx-instruct", + messages: [{ role: "user", content: "hello" }], + stream: true, + }) + const result = JSON.parse(transformDatabricksBody(input).body) + expect(result.model).toBe("databricks-dbrx-instruct") + expect(result.stream).toBe(true) + expect(result.max_tokens).toBeUndefined() + }) +}) + +// --------------------------------------------------------------------------- +// E2E tests (skipped without credentials) +// --------------------------------------------------------------------------- + +const DATABRICKS_HOST = process.env.DATABRICKS_HOST +const DATABRICKS_TOKEN = process.env.DATABRICKS_TOKEN +const HAS_DATABRICKS = !!(DATABRICKS_HOST && DATABRICKS_TOKEN) + +describe("Databricks Serving Endpoints E2E", () => { + const skipReason = HAS_DATABRICKS ? undefined : "DATABRICKS_HOST and DATABRICKS_TOKEN not set" + + test.skipIf(!HAS_DATABRICKS)("chat completion with foundation model", async () => { + const baseURL = `https://${DATABRICKS_HOST}/serving-endpoints` + const res = await fetch(`${baseURL}/databricks-meta-llama-3-1-8b-instruct/invocations`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${DATABRICKS_TOKEN}`, + }, + body: JSON.stringify({ + messages: [{ role: "user", content: "Say hello in one word." }], + max_tokens: 32, + }), + }) + + expect(res.ok).toBe(true) + const data = await res.json() + expect(data.choices).toBeDefined() + expect(data.choices.length).toBeGreaterThan(0) + expect(data.choices[0].message.content).toBeTruthy() + }) + + test.skipIf(!HAS_DATABRICKS)("streaming chat completion", async () => { + const baseURL = `https://${DATABRICKS_HOST}/serving-endpoints` + const res = await fetch(`${baseURL}/databricks-meta-llama-3-1-8b-instruct/invocations`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${DATABRICKS_TOKEN}`, + }, + body: JSON.stringify({ + messages: [{ role: "user", content: "Say hello." }], + max_tokens: 32, + stream: true, + }), + }) + + expect(res.ok).toBe(true) + expect(res.headers.get("content-type")).toContain("text/event-stream") + + const text = await res.text() + expect(text).toContain("data:") + expect(text).toContain("[DONE]") + }) +})