diff --git a/AGENTS.md b/AGENTS.md index 1bb7712..b5d3421 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -18,7 +18,7 @@ Before calling any change done, `cargo fmt --all`, the `clippy` line above, and ## Workspace -37 crates organized into six layers, with `goat-protocol` at the bottom of the dependency DAG: +56 crates organized into six layers, with `goat-protocol` at the bottom of the dependency DAG: **Infrastructure** - `goat-protocol` — shared wire contract (`Op`, `Event`, `TaskId`); serde only; leaf. @@ -37,9 +37,19 @@ Before calling any change done, `cargo fmt --all`, the `clippy` line above, and - `goat-provider` — the `Provider` trait; leaf. Key types: `Provider`, `Request` (incl. `ToolChoice`), `StreamEvent`, `StreamError`, `Message`, `Capabilities`, `Model`, `ProviderId`, `ContentBlock`. Providers classify their own wire errors into `StreamError` structurally (`error.rs` per provider); the engine never inspects error strings. - `goat-provider-anthropic` — Anthropic Claude API provider; per-model context windows, prompt-caching `cache_control` breakpoints (tools + system + last two messages), `stop_reason` overflow detection. - `goat-provider-gemini` — Google Gemini provider; API key (Generative Language API) or OAuth (Code Assist free tier, gemini-cli compatible); four modules: `lib` (provider orchestration), `wire` (Gemini request/response format), `oauth` (Google OAuth PKCE flow), `codeassist` (Code Assist envelope + project onboarding). -- `goat-provider-openai-compat` — OpenAI-family HTTP clients; three modules: `chat` (Chat Completions API, used by local providers), `responses` (Responses API, used by OpenAI and Codex), `common` (shared client/validate/discovery helpers). +- `goat-provider-openai-compat` — OpenAI-family HTTP clients; modules: `chat` (Chat Completions API), `responses` (Responses API), `hosted` (API-key builder + HTTPS host pinning), `common`, `vision`. - `goat-provider-openai` — OpenAI provider (wraps `responses` module). - `goat-provider-openai-codex` — OpenAI Codex provider (wraps `responses` module). +- `goat-provider-openrouter` — OpenRouter API-key provider; Chat Completions via `hosted::api_key`. +- `goat-provider-groq` — Groq API-key provider. +- `goat-provider-deepseek` — DeepSeek API-key provider. +- `goat-provider-mistral` — Mistral API-key provider. +- `goat-provider-zai` — Z.AI API-key provider; catalog-only validation/discovery. +- `goat-provider-zai-coding` — Z.AI Coding Plan API-key provider (distinct credential from `zai`). +- `goat-provider-kimi` — Moonshot Kimi API-key provider. +- `goat-provider-kimi-code` — Kimi Code OAuth device-code provider; owns `oauth` module and `KimiCodeProvider`. +- `goat-provider-qwen` — Qwen DashScope API-key provider; optional `--endpoint` for non-US workspaces. +- `goat-provider-xai` — xAI Grok provider; API key (Chat Completions) or SuperGrok/X Premium+ OAuth (Responses API); owns `oauth` module. - `goat-provider-local` — table-driven local-inference provider (Ollama, LM Studio, llama.cpp); wraps `chat` module. - `goat-providers` — provider registry; wires all provider crates. `Registry::new(store)` for default account, `Registry::load(store, account)` for explicit. `Registry::login(provider, status)` dispatches OAuth login through the `Provider::login` trait method. diff --git a/Cargo.lock b/Cargo.lock index cf1c4c8..f2a66da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2150,6 +2150,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "goat-provider-deepseek" +version = "0.1.18" +dependencies = [ + "goat-auth", + "goat-provider", + "goat-provider-openai-compat", +] + [[package]] name = "goat-provider-gemini" version = "0.1.18" @@ -2168,7 +2177,25 @@ dependencies = [ ] [[package]] -name = "goat-provider-hosted" +name = "goat-provider-groq" +version = "0.1.18" +dependencies = [ + "goat-auth", + "goat-provider", + "goat-provider-openai-compat", +] + +[[package]] +name = "goat-provider-kimi" +version = "0.1.18" +dependencies = [ + "goat-auth", + "goat-provider", + "goat-provider-openai-compat", +] + +[[package]] +name = "goat-provider-kimi-code" version = "0.1.18" dependencies = [ "goat-auth", @@ -2176,7 +2203,6 @@ dependencies = [ "goat-provider-openai-compat", "reqwest 0.12.28", "serde", - "serde_json", "thiserror 2.0.18", "tokio", ] @@ -2189,6 +2215,15 @@ dependencies = [ "goat-provider-openai-compat", ] +[[package]] +name = "goat-provider-mistral" +version = "0.1.18" +dependencies = [ + "goat-auth", + "goat-provider", + "goat-provider-openai-compat", +] + [[package]] name = "goat-provider-openai" version = "0.1.18" @@ -2221,6 +2256,7 @@ version = "0.1.18" dependencies = [ "eventsource-stream", "futures", + "goat-auth", "goat-provider", "reqwest 0.12.28", "serde", @@ -2228,6 +2264,57 @@ dependencies = [ "tokio", ] +[[package]] +name = "goat-provider-openrouter" +version = "0.1.18" +dependencies = [ + "goat-auth", + "goat-provider", + "goat-provider-openai-compat", +] + +[[package]] +name = "goat-provider-qwen" +version = "0.1.18" +dependencies = [ + "goat-auth", + "goat-provider", + "goat-provider-openai-compat", + "reqwest 0.12.28", +] + +[[package]] +name = "goat-provider-xai" +version = "0.1.18" +dependencies = [ + "goat-auth", + "goat-provider", + "goat-provider-openai-compat", + "open", + "reqwest 0.12.28", + "serde", + "thiserror 2.0.18", + "tokio", +] + +[[package]] +name = "goat-provider-zai" +version = "0.1.18" +dependencies = [ + "goat-auth", + "goat-provider", + "goat-provider-openai-compat", +] + +[[package]] +name = "goat-provider-zai-coding" +version = "0.1.18" +dependencies = [ + "goat-auth", + "goat-provider", + "goat-provider-openai-compat", +] + [[package]] name = "goat-providers" version = "0.1.18" @@ -2235,11 +2322,20 @@ dependencies = [ "goat-auth", "goat-provider", "goat-provider-anthropic", + "goat-provider-deepseek", "goat-provider-gemini", - "goat-provider-hosted", + "goat-provider-groq", + "goat-provider-kimi", + "goat-provider-kimi-code", "goat-provider-local", + "goat-provider-mistral", "goat-provider-openai", "goat-provider-openai-codex", + "goat-provider-openrouter", + "goat-provider-qwen", + "goat-provider-xai", + "goat-provider-zai", + "goat-provider-zai-coding", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 90baad8..1eeed3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,16 @@ goat-provider = { path = "crates/goat-provider" } goat-auth = { path = "crates/goat-auth" } goat-store = { path = "crates/goat-store" } goat-provider-openai-compat = { path = "crates/goat-provider-openai-compat" } -goat-provider-hosted = { path = "crates/goat-provider-hosted" } +goat-provider-openrouter = { path = "crates/goat-provider-openrouter" } +goat-provider-groq = { path = "crates/goat-provider-groq" } +goat-provider-deepseek = { path = "crates/goat-provider-deepseek" } +goat-provider-mistral = { path = "crates/goat-provider-mistral" } +goat-provider-zai = { path = "crates/goat-provider-zai" } +goat-provider-zai-coding = { path = "crates/goat-provider-zai-coding" } +goat-provider-kimi = { path = "crates/goat-provider-kimi" } +goat-provider-qwen = { path = "crates/goat-provider-qwen" } +goat-provider-kimi-code = { path = "crates/goat-provider-kimi-code" } +goat-provider-xai = { path = "crates/goat-provider-xai" } goat-provider-openai = { path = "crates/goat-provider-openai" } goat-provider-openai-codex = { path = "crates/goat-provider-openai-codex" } goat-provider-anthropic = { path = "crates/goat-provider-anthropic" } diff --git a/crates/goat-provider-deepseek/Cargo.toml b/crates/goat-provider-deepseek/Cargo.toml new file mode 100644 index 0000000..616978f --- /dev/null +++ b/crates/goat-provider-deepseek/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "goat-provider-deepseek" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true +publish = false + +[dependencies] +goat-provider = { workspace = true } +goat-provider-openai-compat = { workspace = true } +goat-auth = { workspace = true } + +[lints] +workspace = true \ No newline at end of file diff --git a/crates/goat-provider-deepseek/src/lib.rs b/crates/goat-provider-deepseek/src/lib.rs new file mode 100644 index 0000000..d126cd3 --- /dev/null +++ b/crates/goat-provider-deepseek/src/lib.rs @@ -0,0 +1,20 @@ +use goat_auth::CredentialStore; +use goat_provider_openai_compat::{OpenAiCompatProvider, api_key}; + +pub const PROVIDER_ID: &str = "deepseek"; +const BASE_URL: &str = "https://api.deepseek.com"; +const HOST: &str = "api.deepseek.com"; +const ENV_VAR: &str = "DEEPSEEK_API_KEY"; + +const CATALOG: &[&str] = &["deepseek-chat", "deepseek-reasoner"]; + +const CONTEXT_WINDOWS: &[(&str, u32)] = + &[("deepseek-chat", 128_000), ("deepseek-reasoner", 128_000)]; + +pub fn build(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { + api_key(store, account, PROVIDER_ID, BASE_URL, HOST, ENV_VAR) + .with_catalog(CATALOG) + .with_context_windows(CONTEXT_WINDOWS) + .with_images(false) + .with_reasoning_effort(false) +} diff --git a/crates/goat-provider-groq/Cargo.toml b/crates/goat-provider-groq/Cargo.toml new file mode 100644 index 0000000..5d49288 --- /dev/null +++ b/crates/goat-provider-groq/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "goat-provider-groq" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true +publish = false + +[dependencies] +goat-provider = { workspace = true } +goat-provider-openai-compat = { workspace = true } +goat-auth = { workspace = true } + +[lints] +workspace = true \ No newline at end of file diff --git a/crates/goat-provider-groq/src/lib.rs b/crates/goat-provider-groq/src/lib.rs new file mode 100644 index 0000000..970a1f8 --- /dev/null +++ b/crates/goat-provider-groq/src/lib.rs @@ -0,0 +1,39 @@ +use goat_auth::CredentialStore; +use goat_provider_openai_compat::{OpenAiCompatProvider, api_key}; + +pub const PROVIDER_ID: &str = "groq"; +const BASE_URL: &str = "https://api.groq.com/openai/v1"; +const HOST: &str = "api.groq.com"; +const ENV_VAR: &str = "GROQ_API_KEY"; + +const CATALOG: &[&str] = &[ + "openai/gpt-oss-120b", + "openai/gpt-oss-20b", + "moonshotai/kimi-k2-instruct-0905", + "qwen/qwen3-32b", + "deepseek-r1-distill-llama-70b", + "llama-3.3-70b-versatile", +]; + +const CONTEXT_WINDOWS: &[(&str, u32)] = &[ + ("openai/gpt-oss-120b", 131_072), + ("openai/gpt-oss-20b", 131_072), + ("moonshotai/kimi", 131_072), + ("qwen/qwen3", 131_072), + ("llama-3.3", 131_072), +]; + +fn is_chat_model(id: &str) -> bool { + let id = id.to_ascii_lowercase(); + !(id.contains("whisper") || id.contains("tts") || id.contains("embedding")) +} + +pub fn build(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { + api_key(store, account, PROVIDER_ID, BASE_URL, HOST, ENV_VAR) + .with_catalog(CATALOG) + .with_context_windows(CONTEXT_WINDOWS) + .with_model_filter(is_chat_model) + .with_images(false) + .with_stream_options(false) + .with_reasoning_effort(false) +} diff --git a/crates/goat-provider-hosted/src/lib.rs b/crates/goat-provider-hosted/src/lib.rs deleted file mode 100644 index b592db6..0000000 --- a/crates/goat-provider-hosted/src/lib.rs +++ /dev/null @@ -1,1248 +0,0 @@ -use std::path::PathBuf; - -use goat_auth::{Credential, CredentialKey, CredentialStore, TokenSet, ensure_valid}; -use goat_provider::{ - AuthMethod, Capabilities, Effort, LoginEndpointMetadata, Model, Provider, ProviderId, - ProviderMetadata, Request, StreamError, StreamEvent, WebSearchOutput, -}; -use goat_provider_openai_compat::{ChatDiscovery, ChatValidation, OpenAiCompatProvider}; -use reqwest::header::{ACCEPT, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, USER_AGENT}; -use serde::Deserialize; -use tokio::{sync::mpsc, task::JoinHandle}; - -const KIMI_SETUP: &[&str] = &[ - "Kimi Platform API key provider.", - "For Kimi Code OAuth, use `goat provider login kimi-code`.", - "API-key setup: `goat provider login kimi --key sk-...`.", -]; -const KIMI_CODE_SETUP: &[&str] = &[ - "Kimi Code OAuth device-code login.", - "Run `goat provider login kimi-code`, open the URL, and enter the code.", -]; -const QWEN_DEFAULT_ENDPOINT: &str = "https://dashscope-us.aliyuncs.com/compatible-mode/v1"; -const QWEN_SETUP: &[&str] = &[ - "Qwen DashScope API-key provider.", - "Default endpoint: https://dashscope-us.aliyuncs.com/compatible-mode/v1", - "Non-US workspaces: `goat provider login qwen --endpoint --key sk-...`.", - "Qwen OAuth enrollment is discontinued upstream.", -]; -const ZAI_CODING_SETUP: &[&str] = &[ - "Z.AI Coding Plan API-key provider.", - "Use `ZAI_CODING_API_KEY` or `goat provider login zai-coding --key sk-...`.", - "This is not OAuth and does not reuse the standard `zai` credential.", -]; - -pub const OPENROUTER: &str = "openrouter"; -pub const GROQ: &str = "groq"; -pub const DEEPSEEK: &str = "deepseek"; -pub const XAI: &str = "xai"; -pub const MISTRAL: &str = "mistral"; -pub const ZAI: &str = "zai"; -pub const ZAI_CODING: &str = "zai-coding"; -pub const KIMI: &str = "kimi"; -pub const KIMI_CODE: &str = "kimi-code"; -pub const QWEN: &str = "qwen"; - -const KIMI_CODE_BASE_URL: &str = "https://api.kimi.com/coding/v1"; -const KIMI_CODE_OAUTH_HOST: &str = "https://auth.kimi.com"; -const KIMI_CODE_CLIENT_ID: &str = "17e5f671-d194-4dfb-9706-5516cb48c098"; - -const OPENROUTER_CATALOG: &[&str] = &[ - "anthropic/claude-sonnet-4.5", - "openai/gpt-5.1", - "google/gemini-2.5-pro", - "deepseek/deepseek-chat-v3.1", - "qwen/qwen3-coder", - "moonshotai/kimi-k2", -]; - -const GROQ_CATALOG: &[&str] = &[ - "openai/gpt-oss-120b", - "openai/gpt-oss-20b", - "moonshotai/kimi-k2-instruct-0905", - "qwen/qwen3-32b", - "deepseek-r1-distill-llama-70b", - "llama-3.3-70b-versatile", -]; - -const DEEPSEEK_CATALOG: &[&str] = &["deepseek-chat", "deepseek-reasoner"]; - -const XAI_CATALOG: &[&str] = &[ - "grok-4", - "grok-4-fast-reasoning", - "grok-4-fast-non-reasoning", - "grok-3", - "grok-3-fast", - "grok-3-mini", -]; - -const MISTRAL_CATALOG: &[&str] = &[ - "mistral-large-latest", - "mistral-medium-latest", - "mistral-small-latest", - "devstral-medium-latest", - "codestral-latest", - "ministral-8b-latest", -]; - -const ZAI_CATALOG: &[&str] = &[ - "glm-5.2", - "glm-5.1", - "glm-5-turbo", - "glm-5", - "glm-4.7", - "glm-4.6", - "glm-4.5", - "glm-4-32b-0414-128k", -]; - -const ZAI_CODING_CATALOG: &[&str] = &["glm-5.2", "glm-5-turbo", "glm-4.7"]; - -const KIMI_CATALOG: &[&str] = &[ - "kimi-k2.7-code", - "kimi-k2.7-code-highspeed", - "kimi-k2.6", - "kimi-k2.5", - "moonshot-v1-128k", - "moonshot-v1-32k", - "moonshot-v1-8k", -]; - -const KIMI_CODE_CATALOG: &[&str] = &[ - "kimi-k2.7-code", - "kimi-k2.7-code-highspeed", - "kimi-k2.6", - "kimi-k2.5", -]; - -const QWEN_CATALOG: &[&str] = &[ - "qwen-plus", - "qwen-max", - "qwen-turbo", - "qwen3-coder-plus", - "qwen3-coder-flash", - "qwen-vl-plus", -]; - -const OPENROUTER_CONTEXT: &[(&str, u32)] = &[ - ("anthropic/claude-sonnet-4.5", 200_000), - ("openai/gpt-5", 400_000), - ("google/gemini-2.5", 1_000_000), - ("deepseek/deepseek", 128_000), - ("qwen/qwen3-coder", 256_000), - ("moonshotai/kimi", 256_000), -]; - -const GROQ_CONTEXT: &[(&str, u32)] = &[ - ("openai/gpt-oss-120b", 131_072), - ("openai/gpt-oss-20b", 131_072), - ("moonshotai/kimi", 131_072), - ("qwen/qwen3", 131_072), - ("llama-3.3", 131_072), -]; - -const DEEPSEEK_CONTEXT: &[(&str, u32)] = - &[("deepseek-chat", 128_000), ("deepseek-reasoner", 128_000)]; - -const XAI_CONTEXT: &[(&str, u32)] = &[("grok-4", 256_000), ("grok-3", 131_072)]; - -const MISTRAL_CONTEXT: &[(&str, u32)] = &[ - ("mistral-large", 131_072), - ("mistral-medium", 131_072), - ("mistral-small", 131_072), - ("devstral-medium", 131_072), - ("codestral", 256_000), -]; - -const ZAI_CONTEXT: &[(&str, u32)] = &[ - ("glm-5.2", 128_000), - ("glm-5.1", 128_000), - ("glm-5", 128_000), - ("glm-4", 128_000), -]; - -const ZAI_CODING_CONTEXT: &[(&str, u32)] = &[ - ("glm-5.2", 1_000_000), - ("glm-5-turbo", 128_000), - ("glm-4.7", 128_000), -]; - -const KIMI_CONTEXT: &[(&str, u32)] = &[ - ("kimi-k2.7", 256_000), - ("kimi-k2.6", 256_000), - ("kimi-k2.5", 256_000), - ("moonshot-v1-128k", 128_000), - ("moonshot-v1-32k", 32_000), - ("moonshot-v1-8k", 8_000), -]; - -const KIMI_CODE_CONTEXT: &[(&str, u32)] = &[ - ("kimi-k2.7", 256_000), - ("kimi-k2.6", 256_000), - ("kimi-k2.5", 256_000), -]; - -const QWEN_CONTEXT: &[(&str, u32)] = &[ - ("qwen-plus", 131_072), - ("qwen-max", 131_072), - ("qwen-turbo", 1_000_000), - ("qwen3-coder", 1_000_000), - ("qwen-vl", 129_024), -]; - -const HOSTS: &[(&str, &str)] = &[ - (OPENROUTER, "openrouter.ai"), - (GROQ, "api.groq.com"), - (DEEPSEEK, "api.deepseek.com"), - (XAI, "api.x.ai"), - (MISTRAL, "api.mistral.ai"), - (ZAI, "api.z.ai"), - (ZAI_CODING, "api.z.ai"), - (KIMI, "api.moonshot.ai"), - (KIMI_CODE, "api.kimi.com"), - (QWEN, "dashscope-us.aliyuncs.com"), -]; - -#[derive(Debug, thiserror::Error)] -pub enum HostedOAuthError { - #[error("http error: {0}")] - Http(#[from] reqwest::Error), - #[error("oauth error: {0}")] - OAuth(String), - #[error("io error: {0}")] - Io(#[from] std::io::Error), -} - -pub fn all(store: &CredentialStore, account: &str) -> Vec { - vec![ - build_openrouter(store, account), - build_groq(store, account), - build_deepseek(store, account), - build_xai(store, account), - build_mistral(store, account), - build_zai(store, account), - build_zai_coding(store, account), - build_kimi(store, account), - build_qwen(store, account), - ] -} - -pub fn build_openrouter(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { - hosted( - OPENROUTER, - "https://openrouter.ai/api/v1", - "OPENROUTER_API_KEY", - store, - account, - ) - .with_catalog(OPENROUTER_CATALOG) - .with_context_windows(OPENROUTER_CONTEXT) - .with_model_filter(openrouter_chat_model) - .with_vision_filter(openrouter_vision_model) - .with_reasoning_effort(false) -} - -pub fn build_groq(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { - hosted( - GROQ, - "https://api.groq.com/openai/v1", - "GROQ_API_KEY", - store, - account, - ) - .with_catalog(GROQ_CATALOG) - .with_context_windows(GROQ_CONTEXT) - .with_model_filter(groq_chat_model) - .with_images(false) - .with_stream_options(false) - .with_reasoning_effort(false) -} - -pub fn build_deepseek(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { - hosted( - DEEPSEEK, - "https://api.deepseek.com", - "DEEPSEEK_API_KEY", - store, - account, - ) - .with_catalog(DEEPSEEK_CATALOG) - .with_context_windows(DEEPSEEK_CONTEXT) - .with_images(false) - .with_reasoning_effort(false) -} - -pub fn build_xai(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { - hosted(XAI, "https://api.x.ai/v1", "XAI_API_KEY", store, account) - .with_catalog(XAI_CATALOG) - .with_context_windows(XAI_CONTEXT) - .with_vision_filter(xai_vision_model) - .with_efforts(no_efforts) - .with_reasoning_effort(false) -} - -pub fn build_mistral(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { - hosted( - MISTRAL, - "https://api.mistral.ai/v1", - "MISTRAL_API_KEY", - store, - account, - ) - .with_catalog(MISTRAL_CATALOG) - .with_context_windows(MISTRAL_CONTEXT) - .with_vision_filter(mistral_vision_model) - .with_efforts(no_efforts) - .with_reasoning_effort(false) -} - -pub fn build_zai(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { - hosted( - ZAI, - "https://api.z.ai/api/paas/v4", - "ZAI_API_KEY", - store, - account, - ) - .with_catalog(ZAI_CATALOG) - .with_context_windows(ZAI_CONTEXT) - .with_vision_filter(zai_vision_model) - .with_efforts(zai_efforts) - .with_effort_wire(zai_effort_wire) - .with_validation(ChatValidation::CatalogOnly) - .with_discovery(ChatDiscovery::CatalogOnly) - .with_metadata(ProviderMetadata { - env_var: Some("ZAI_API_KEY"), - validation: "catalog-only", - endpoint: None, - oauth: Some("not supported by Z.AI API docs"), - login_endpoint: None, - setup: &[], - }) -} - -pub fn build_zai_coding(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { - hosted( - ZAI_CODING, - "https://api.z.ai/api/coding/paas/v4", - "ZAI_CODING_API_KEY", - store, - account, - ) - .with_catalog(ZAI_CODING_CATALOG) - .with_context_windows(ZAI_CODING_CONTEXT) - .with_vision_filter(no_vision) - .with_efforts(zai_efforts) - .with_effort_wire(zai_effort_wire) - .with_validation(ChatValidation::CatalogOnly) - .with_discovery(ChatDiscovery::CatalogOnly) - .with_metadata(ProviderMetadata { - env_var: Some("ZAI_CODING_API_KEY"), - validation: "catalog-only", - endpoint: Some("https://api.z.ai/api/coding/paas/v4"), - oauth: Some("not OAuth; uses Z.AI Coding Plan API key"), - login_endpoint: None, - setup: ZAI_CODING_SETUP, - }) -} - -pub fn build_kimi(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { - hosted( - KIMI, - "https://api.moonshot.ai/v1", - "MOONSHOT_API_KEY", - store, - account, - ) - .with_catalog(KIMI_CATALOG) - .with_context_windows(KIMI_CONTEXT) - .with_vision_filter(no_vision) - .with_efforts(no_efforts) - .with_reasoning_effort(false) - .with_validation(ChatValidation::CatalogOnly) - .with_discovery(ChatDiscovery::CatalogOnly) - .with_metadata(ProviderMetadata { - env_var: Some("MOONSHOT_API_KEY"), - validation: "catalog-only", - endpoint: None, - oauth: Some("Kimi Code OAuth is provider id kimi-code"), - login_endpoint: None, - setup: KIMI_SETUP, - }) -} - -pub fn build_kimi_code(store: &CredentialStore, account: &str) -> KimiCodeProvider { - enforce_host(KIMI_CODE, KIMI_CODE_BASE_URL).expect("kimi-code provider base URL"); - KimiCodeProvider::new(store.clone(), CredentialKey::model(KIMI_CODE, account)) -} - -pub fn build_qwen(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { - let key = CredentialKey::model(QWEN, account); - let stored = store.get(&key); - let endpoint_source = std::env::var("QWEN_BASE_URL").ok().or_else(|| { - stored - .as_ref() - .and_then(goat_auth::Credential::endpoint) - .map(str::to_owned) - }); - let endpoint = match endpoint_source { - Some(raw) => validate_qwen_endpoint(&raw).ok(), - None => Some(QWEN_DEFAULT_ENDPOINT.to_owned()), - }; - let bearer = endpoint.as_ref().and_then(|_| { - store - .resolve(&key, Some("DASHSCOPE_API_KEY")) - .map(|cred| cred.bearer().to_owned()) - }); - OpenAiCompatProvider::new( - ProviderId::from(QWEN), - endpoint.unwrap_or_else(|| QWEN_DEFAULT_ENDPOINT.to_owned()), - bearer, - AuthMethod::ApiKey, - ) - .with_catalog(QWEN_CATALOG) - .with_context_windows(QWEN_CONTEXT) - .with_vision_filter(qwen_vision_model) - .with_efforts(no_efforts) - .with_reasoning_effort(false) - .with_metadata(ProviderMetadata { - env_var: Some("DASHSCOPE_API_KEY"), - validation: "network", - endpoint: Some("required for non-US DashScope workspaces"), - oauth: Some("Qwen OAuth enrollment discontinued"), - login_endpoint: Some(LoginEndpointMetadata { - env_var: Some("QWEN_BASE_URL"), - default: Some(QWEN_DEFAULT_ENDPOINT), - validate: Some(validate_qwen_endpoint), - }), - setup: QWEN_SETUP, - }) -} - -fn hosted( - provider_id: &'static str, - base_url: &'static str, - env_var: &'static str, - store: &CredentialStore, - account: &str, -) -> OpenAiCompatProvider { - enforce_host(provider_id, base_url).expect("hosted provider base URL"); - let key = CredentialKey::model(provider_id, account); - let bearer = store - .resolve(&key, Some(env_var)) - .map(|cred| cred.bearer().to_owned()); - OpenAiCompatProvider::new( - ProviderId::from(provider_id), - base_url, - bearer, - AuthMethod::ApiKey, - ) - .with_metadata(ProviderMetadata { - env_var: Some(env_var), - validation: "network", - endpoint: None, - oauth: Some("not supported"), - login_endpoint: None, - setup: &[], - }) -} - -fn enforce_host(provider_id: &str, base_url: &str) -> Result<(), String> { - let Some(host) = HOSTS - .iter() - .find_map(|(id, host)| (*id == provider_id).then_some(*host)) - else { - return Err(format!("unknown hosted provider: {provider_id}")); - }; - let url = base_url.trim_end_matches('/'); - let rest = url - .strip_prefix("https://") - .ok_or_else(|| "hosted providers require https".to_owned())?; - let actual = rest.split('/').next().unwrap_or_default(); - if actual == host || actual.ends_with(&format!(".{host}")) { - Ok(()) - } else { - Err(format!("invalid hosted provider host: {actual}")) - } -} - -pub struct KimiCodeProvider { - store: CredentialStore, - key: CredentialKey, - client: reqwest::Client, -} - -impl KimiCodeProvider { - pub fn new(store: CredentialStore, key: CredentialKey) -> Self { - Self { - store, - key, - client: oauth_client(), - } - } -} - -impl Provider for KimiCodeProvider { - fn id(&self) -> ProviderId { - ProviderId::from(KIMI_CODE) - } - - fn capabilities(&self) -> Capabilities { - Capabilities { - tools: true, - auth: AuthMethod::OAuth, - images: false, - } - } - - fn metadata(&self) -> ProviderMetadata { - kimi_code_metadata() - } - - fn authenticated(&self) -> bool { - self.store - .get(&self.key) - .is_some_and(|cred| matches!(cred, Credential::OAuth(_))) - } - - fn catalog(&self) -> &'static [&'static str] { - KIMI_CODE_CATALOG - } - - fn efforts(&self, _model: &str) -> Vec { - Vec::new() - } - - fn context_window(&self, model: &str) -> Option { - KIMI_CODE_CONTEXT - .iter() - .find_map(|(prefix, window)| model.starts_with(prefix).then_some(*window)) - } - - fn supports_images(&self, _model: &str) -> bool { - false - } - - fn verifies_credentials(&self) -> bool { - true - } - - fn validate(&self) -> JoinHandle> { - let store = self.store.clone(); - let key = self.key.clone(); - let client = self.client.clone(); - tokio::spawn(async move { - let Some(token) = current_kimi_code_token(&store, &key).await else { - return Err("no credentials".to_owned()); - }; - let response = client - .get(format!("{KIMI_CODE_BASE_URL}/models")) - .bearer_auth(token) - .send() - .await - .map_err(|_| "could not reach provider".to_owned())?; - let status = response.status(); - if status.is_success() { - Ok(()) - } else if status == reqwest::StatusCode::UNAUTHORIZED - || status == reqwest::StatusCode::FORBIDDEN - { - Err("invalid credentials".to_owned()) - } else { - Err(format!("could not reach provider: {status}")) - } - }) - } - - fn stream(&self, req: Request, events: mpsc::Sender) -> JoinHandle<()> { - let store = self.store.clone(); - let key = self.key.clone(); - tokio::spawn(async move { - let Some(token) = current_kimi_code_token(&store, &key).await else { - let _ = events - .send(StreamEvent::Failed { - error: StreamError::auth("no credentials"), - }) - .await; - return; - }; - let provider = OpenAiCompatProvider::new( - ProviderId::from(KIMI_CODE), - KIMI_CODE_BASE_URL, - Some(token), - AuthMethod::OAuth, - ) - .with_catalog(KIMI_CODE_CATALOG) - .with_context_windows(KIMI_CODE_CONTEXT) - .with_vision_filter(no_vision) - .with_efforts(no_efforts) - .with_reasoning_effort(false); - let handle = provider.stream(req, events); - let _ = handle.await; - }) - } - - fn discover(&self, out: mpsc::Sender) -> JoinHandle<()> { - let store = self.store.clone(); - let key = self.key.clone(); - tokio::spawn(async move { - let Some(token) = current_kimi_code_token(&store, &key).await else { - for id in KIMI_CODE_CATALOG { - if out - .send(Model { - id: (*id).to_owned(), - supports_images: false, - }) - .await - .is_err() - { - return; - } - } - return; - }; - let provider = OpenAiCompatProvider::new( - ProviderId::from(KIMI_CODE), - KIMI_CODE_BASE_URL, - Some(token), - AuthMethod::OAuth, - ) - .with_model_filter(kimi_code_chat_model) - .with_vision_filter(no_vision); - let handle = provider.discover(out); - let _ = handle.await; - }) - } - - fn login(&self, status: mpsc::Sender) -> JoinHandle> { - tokio::spawn(async move { - kimi_code_login(&status) - .await - .map_err(|err| err.to_string()) - }) - } - - fn web_search(&self, query: String) -> JoinHandle> { - let _ = query; - tokio::spawn(async { Err(StreamError::other("web search is not supported")) }) - } -} - -fn kimi_code_metadata() -> ProviderMetadata { - ProviderMetadata { - env_var: None, - validation: "network", - endpoint: Some(KIMI_CODE_BASE_URL), - oauth: Some("device code"), - login_endpoint: None, - setup: KIMI_CODE_SETUP, - } -} - -async fn current_kimi_code_token(store: &CredentialStore, key: &CredentialKey) -> Option { - let Credential::OAuth(tokens) = store.get(key)? else { - return None; - }; - let tokens = ensure_valid(tokens, store, key, kimi_code_refresh).await?; - Some(tokens.access_token.expose().to_owned()) -} - -fn oauth_client() -> reqwest::Client { - reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)) - .connect_timeout(std::time::Duration::from_secs(10)) - .redirect(reqwest::redirect::Policy::none()) - .build() - .expect("reqwest client") -} - -#[derive(Deserialize)] -struct DeviceAuthorizationResponse { - user_code: String, - device_code: String, - verification_uri: Option, - verification_uri_complete: String, - expires_in: Option, - interval: Option, -} - -#[derive(Deserialize)] -struct TokenResponse { - access_token: String, - refresh_token: String, - expires_in: i64, - scope: Option, - token_type: Option, -} - -#[derive(Deserialize)] -struct OAuthErrorResponse { - error: Option, - #[serde(rename = "error_description")] - _error_description: Option, -} - -async fn kimi_code_login(status: &mpsc::Sender) -> Result { - let client = oauth_client(); - let device = request_device_authorization(&client).await?; - let url = if device.verification_uri_complete.is_empty() { - device.verification_uri.as_deref().unwrap_or("") - } else { - &device.verification_uri_complete - }; - if !valid_kimi_verification_url(url) { - return Err(HostedOAuthError::OAuth( - "device authorization returned an invalid verification URL".to_owned(), - )); - } - let _ = status - .send(format!("open {url} and enter code: {}", device.user_code)) - .await; - poll_device_token(&client, &device).await -} - -async fn request_device_authorization( - client: &reqwest::Client, -) -> Result { - let response = client - .post(format!( - "{KIMI_CODE_OAUTH_HOST}/api/oauth/device_authorization" - )) - .headers(kimi_headers()) - .form(&[("client_id", KIMI_CODE_CLIENT_ID)]) - .send() - .await?; - let status = response.status(); - if !status.is_success() { - return Err(HostedOAuthError::OAuth(format!( - "device authorization failed: {status}" - ))); - } - let device: DeviceAuthorizationResponse = response.json().await?; - if device.user_code.is_empty() - || device.device_code.is_empty() - || device.verification_uri_complete.is_empty() - { - return Err(HostedOAuthError::OAuth( - "device authorization response is missing required fields".to_owned(), - )); - } - Ok(device) -} - -async fn poll_device_token( - client: &reqwest::Client, - device: &DeviceAuthorizationResponse, -) -> Result { - let mut interval = device.interval.unwrap_or(5).max(1); - let deadline = - goat_auth::now_secs() + i64::try_from(device.expires_in.unwrap_or(900)).unwrap_or(900); - loop { - if goat_auth::now_secs() > deadline { - return Err(HostedOAuthError::OAuth("device login timed out".to_owned())); - } - tokio::time::sleep(std::time::Duration::from_secs(interval)).await; - let response = client - .post(format!("{KIMI_CODE_OAUTH_HOST}/api/oauth/token")) - .headers(kimi_headers()) - .form(&[ - ("client_id", KIMI_CODE_CLIENT_ID), - ("device_code", device.device_code.as_str()), - ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), - ]) - .send() - .await?; - let status = response.status(); - if status.is_success() { - let tokens: TokenResponse = response.json().await?; - return parse_token_response(tokens); - } - let error = response.json::().await.ok(); - match error.as_ref().and_then(|err| err.error.as_deref()) { - Some("authorization_pending") => {} - Some("slow_down") => interval = interval.saturating_add(5).min(30), - Some("expired_token") => { - return Err(HostedOAuthError::OAuth( - "device login code expired".to_owned(), - )); - } - Some("access_denied") => { - return Err(HostedOAuthError::OAuth( - "device login access denied".to_owned(), - )); - } - Some(code) => { - return Err(HostedOAuthError::OAuth(format!( - "device token polling failed: {code}" - ))); - } - None => { - return Err(HostedOAuthError::OAuth(format!( - "device token polling failed: {status}" - ))); - } - } - } -} - -async fn kimi_code_refresh(refresh_token: String) -> Result { - let client = oauth_client(); - let response = client - .post(format!("{KIMI_CODE_OAUTH_HOST}/api/oauth/token")) - .headers(kimi_headers()) - .form(&[ - ("client_id", KIMI_CODE_CLIENT_ID), - ("grant_type", "refresh_token"), - ("refresh_token", refresh_token.as_str()), - ]) - .send() - .await - .map_err(|_| "token refresh request failed".to_owned())?; - let status = response.status(); - if !status.is_success() { - return Err(format!("token refresh failed: {status}")); - } - response - .json::() - .await - .map_err(|_| "token refresh returned invalid JSON".to_owned()) - .and_then(|tokens| parse_token_response(tokens).map_err(|err| err.to_string())) -} - -fn parse_token_response(tokens: TokenResponse) -> Result { - let _ = (&tokens.scope, &tokens.token_type); - if tokens.access_token.is_empty() || tokens.refresh_token.is_empty() || tokens.expires_in <= 0 { - return Err(HostedOAuthError::OAuth( - "token response is missing required fields".to_owned(), - )); - } - Ok(TokenSet::from_parts( - tokens.access_token, - Some(tokens.refresh_token), - Some(tokens.expires_in), - None, - )) -} - -fn kimi_headers() -> HeaderMap { - let mut headers = HeaderMap::new(); - headers.insert(ACCEPT, HeaderValue::from_static("application/json")); - headers.insert( - CONTENT_TYPE, - HeaderValue::from_static("application/x-www-form-urlencoded"), - ); - headers.insert(USER_AGENT, HeaderValue::from_static("goat-code/0.1.14")); - insert_header(&mut headers, "X-Msh-Platform", "kimi_code_cli"); - insert_header(&mut headers, "X-Msh-Version", env!("CARGO_PKG_VERSION")); - insert_header(&mut headers, "X-Msh-Device-Name", &device_name()); - insert_header(&mut headers, "X-Msh-Device-Model", &device_model()); - insert_header(&mut headers, "X-Msh-Os-Version", std::env::consts::OS); - insert_header(&mut headers, "X-Msh-Device-Id", &device_id()); - headers -} - -fn insert_header(headers: &mut HeaderMap, name: &'static str, value: &str) { - if let Ok(value) = HeaderValue::from_str(&ascii_header(value)) { - headers.insert( - HeaderName::from_static(name.to_ascii_lowercase().leak()), - value, - ); - } -} - -fn ascii_header(value: &str) -> String { - let cleaned: String = value - .chars() - .filter(|ch| matches!(*ch as u32, 0x20..=0x7e)) - .collect::() - .trim() - .to_owned(); - if cleaned.is_empty() { - "unknown".to_owned() - } else { - cleaned - } -} - -fn device_name() -> String { - std::env::var("HOSTNAME") - .or_else(|_| std::env::var("COMPUTERNAME")) - .unwrap_or_else(|_| "unknown".to_owned()) -} - -fn device_model() -> String { - format!("{} {}", std::env::consts::OS, std::env::consts::ARCH) -} - -fn device_id() -> String { - let path = device_id_path(); - if let Ok(value) = std::fs::read_to_string(&path) { - let value = value.trim(); - if !value.is_empty() { - return value.to_owned(); - } - } - let id = goat_auth::random_state(); - if let Some(parent) = path.parent() { - let _ = std::fs::create_dir_all(parent); - set_private_dir(parent); - } - if std::fs::write(&path, &id).is_ok() { - set_private_file(&path); - } - id -} - -fn device_id_path() -> PathBuf { - std::env::home_dir().map_or_else( - || PathBuf::from(".goat-code-kimi-code-device-id"), - |home| home.join(".goat-code").join("kimi-code-device-id"), - ) -} - -#[cfg(unix)] -fn set_private_dir(path: &std::path::Path) { - use std::os::unix::fs::PermissionsExt; - let _ = std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o700)); -} - -#[cfg(not(unix))] -fn set_private_dir(_path: &std::path::Path) {} - -#[cfg(unix)] -fn set_private_file(path: &std::path::Path) { - use std::os::unix::fs::PermissionsExt; - let _ = std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600)); -} - -#[cfg(not(unix))] -fn set_private_file(_path: &std::path::Path) {} - -fn valid_kimi_verification_url(url: &str) -> bool { - reqwest::Url::parse(url) - .ok() - .and_then(|url| { - (url.scheme() == "https") - .then(|| url.host_str().is_some_and(|host| host == "auth.kimi.com")) - }) - .unwrap_or(false) -} - -fn openrouter_chat_model(id: &str) -> bool { - let id = id.to_ascii_lowercase(); - !(id.contains("embedding") - || id.contains("moderation") - || id.contains("image") - || id.contains("tts") - || id.contains("whisper")) -} - -fn groq_chat_model(id: &str) -> bool { - let id = id.to_ascii_lowercase(); - !(id.contains("whisper") || id.contains("tts") || id.contains("embedding")) -} - -fn kimi_code_chat_model(id: &str) -> bool { - let id = id.to_ascii_lowercase(); - !id.contains("embedding") && !id.contains("image") && !id.contains("video") -} - -fn openrouter_vision_model(id: &str) -> bool { - let id = id.to_ascii_lowercase(); - goat_provider_openai_compat::known_openai_compatible_vision_model(&id) - || id.contains("claude") - || id.contains("gemini") - || id.contains("grok-4") -} - -fn xai_vision_model(id: &str) -> bool { - let id = id.to_ascii_lowercase(); - id.starts_with("grok-4") || id.contains("vision") -} - -fn mistral_vision_model(id: &str) -> bool { - let id = id.to_ascii_lowercase(); - id.contains("pixtral") -} - -fn zai_vision_model(id: &str) -> bool { - let id = id.to_ascii_lowercase(); - id.contains("glm-4v") || id.contains("vision") -} - -fn qwen_vision_model(id: &str) -> bool { - let id = id.to_ascii_lowercase(); - id.contains("qwen-vl") || id.contains("qwen2-vl") || id.contains("qwen2.5-vl") -} - -pub fn validate_qwen_endpoint(endpoint: &str) -> Result { - let trimmed = endpoint.trim().trim_end_matches('/'); - let url = reqwest::Url::parse(trimmed).map_err(|err| err.to_string())?; - if url.scheme() != "https" { - return Err("qwen endpoint must use https".to_owned()); - } - if !url.username().is_empty() || url.password().is_some() { - return Err("qwen endpoint must not include userinfo".to_owned()); - } - let Some(host) = url.host_str() else { - return Err("qwen endpoint must include a host".to_owned()); - }; - if host.ends_with('.') { - return Err("qwen endpoint host must not end with a dot".to_owned()); - } - let allowed_static = [ - "dashscope.aliyuncs.com", - "dashscope-intl.aliyuncs.com", - "dashscope-us.aliyuncs.com", - ]; - let allowed_regions = [ - "cn-beijing.maas.aliyuncs.com", - "ap-southeast-1.maas.aliyuncs.com", - "ap-northeast-1.maas.aliyuncs.com", - ]; - let allowed = allowed_static.contains(&host) - || allowed_regions.iter().any(|region| { - host.strip_suffix(region) - .and_then(|prefix| prefix.strip_suffix('.')) - .is_some_and(valid_workspace_id) - }); - if !allowed { - return Err("qwen endpoint host is not an allowed Alibaba Model Studio host".to_owned()); - } - if url.port().is_some() { - return Err("qwen endpoint must not include a custom port".to_owned()); - } - if url.path() != "/compatible-mode/v1" { - return Err("qwen endpoint path must be /compatible-mode/v1".to_owned()); - } - if url.query().is_some() || url.fragment().is_some() { - return Err("qwen endpoint must not include query or fragment".to_owned()); - } - Ok(trimmed.to_owned()) -} - -fn valid_workspace_id(value: &str) -> bool { - !value.is_empty() - && value - .bytes() - .all(|b| b.is_ascii_alphanumeric() || b == b'-') -} - -fn no_vision(_id: &str) -> bool { - false -} - -fn no_efforts(_model: &str) -> Vec { - Vec::new() -} - -fn zai_efforts(model: &str) -> Vec { - if model == "glm-5.2" { - vec![ - Effort::Off, - Effort::Low, - Effort::Medium, - Effort::High, - Effort::Xhigh, - Effort::Max, - ] - } else { - Vec::new() - } -} - -fn zai_effort_wire(effort: Effort) -> Option<&'static str> { - let wire = match effort { - Effort::Off => "none", - Effort::Low => "low", - Effort::Medium => "medium", - Effort::High => "high", - Effort::Xhigh => "xhigh", - Effort::Max => "max", - }; - (!wire.is_empty()).then_some(wire) -} - -#[cfg(test)] -mod tests { - use goat_auth::{Credential, SecretString}; - use goat_provider::{AuthMethod, Effort, Provider}; - - use super::*; - - fn store(name: &str) -> CredentialStore { - let _ = std::fs::remove_file(std::env::temp_dir().join(name)); - CredentialStore::new(std::env::temp_dir().join(name)) - } - - #[test] - fn enforces_https_and_provider_owned_hosts() { - assert!(enforce_host(OPENROUTER, "https://openrouter.ai/api/v1/").is_ok()); - assert!(enforce_host(OPENROUTER, "http://openrouter.ai/api/v1").is_err()); - assert!(enforce_host(OPENROUTER, "https://example.com/api/v1").is_err()); - assert!(enforce_host(ZAI_CODING, "https://api.z.ai/api/coding/paas/v4").is_ok()); - assert!(enforce_host(KIMI_CODE, KIMI_CODE_BASE_URL).is_ok()); - } - - #[test] - fn resolves_stored_credential() { - let store = store("goat-provider-hosted-resolves.json"); - store - .store( - &CredentialKey::model(OPENROUTER, "default"), - Credential::ApiKey(SecretString::from("key".to_owned())), - ) - .unwrap(); - let provider = build_openrouter(&store, "default"); - assert!(provider.authenticated()); - assert_eq!(provider.capabilities().auth, AuthMethod::ApiKey); - } - - #[test] - fn metadata_is_exposed() { - let store = store("goat-provider-hosted-metadata.json"); - let provider = build_zai(&store, "default"); - assert_eq!(provider.catalog(), ZAI_CATALOG); - assert_eq!(provider.context_window("glm-5.2"), Some(128_000)); - assert_eq!( - provider.efforts("glm-5.2"), - vec![ - Effort::Off, - Effort::Low, - Effort::Medium, - Effort::High, - Effort::Xhigh, - Effort::Max - ] - ); - assert!(!provider.verifies_credentials()); - } - - #[test] - fn zai_coding_is_distinct_api_key_provider() { - let store = store("goat-provider-hosted-zai-coding.json"); - let provider = build_zai_coding(&store, "default"); - assert_eq!(provider.capabilities().auth, AuthMethod::ApiKey); - assert_eq!(provider.metadata().env_var, Some("ZAI_CODING_API_KEY")); - assert_eq!( - provider.metadata().endpoint, - Some("https://api.z.ai/api/coding/paas/v4") - ); - assert_eq!(provider.catalog(), ZAI_CODING_CATALOG); - assert_eq!(provider.context_window("glm-5.2"), Some(1_000_000)); - } - - #[test] - fn kimi_code_is_oauth_provider() { - let store = store("goat-provider-hosted-kimi-code.json"); - let provider = build_kimi_code(&store, "default"); - assert_eq!(provider.capabilities().auth, AuthMethod::OAuth); - assert_eq!(provider.metadata().oauth, Some("device code")); - assert!(!provider.authenticated()); - assert_eq!(provider.catalog(), KIMI_CODE_CATALOG); - assert!(valid_kimi_verification_url( - "https://auth.kimi.com/device?code=abc" - )); - assert!(!valid_kimi_verification_url( - "https://example.com/device?code=abc" - )); - } - - #[test] - fn parses_kimi_token_response_without_leaking_secrets() { - let token = parse_token_response(TokenResponse { - access_token: "access-secret".to_owned(), - refresh_token: "refresh-secret".to_owned(), - expires_in: 3600, - scope: Some("scope".to_owned()), - token_type: Some("Bearer".to_owned()), - }) - .unwrap(); - assert_eq!(token.access_token.expose(), "access-secret"); - assert_eq!(token.refresh_token.unwrap().expose(), "refresh-secret"); - let error = parse_token_response(TokenResponse { - access_token: String::new(), - refresh_token: "refresh-secret".to_owned(), - expires_in: 3600, - scope: None, - token_type: None, - }) - .unwrap_err() - .to_string(); - assert!(!error.contains("refresh-secret")); - assert!(!error.contains("access-secret")); - } - - #[test] - fn validates_qwen_endpoints() { - for endpoint in [ - "https://dashscope-us.aliyuncs.com/compatible-mode/v1", - "https://dashscope.aliyuncs.com/compatible-mode/v1", - "https://workspace-1.cn-beijing.maas.aliyuncs.com/compatible-mode/v1", - "https://abc123.ap-southeast-1.maas.aliyuncs.com/compatible-mode/v1/", - ] { - assert_eq!( - validate_qwen_endpoint(endpoint).unwrap(), - endpoint.trim_end_matches('/') - ); - } - for endpoint in [ - "http://dashscope-us.aliyuncs.com/compatible-mode/v1", - "https://dashscope-us.aliyuncs.com.evil.test/compatible-mode/v1", - "https://user@dashscope-us.aliyuncs.com/compatible-mode/v1", - "https://dashscope-us.aliyuncs.com:444/compatible-mode/v1", - "https://dashscope-us.aliyuncs.com/v1", - "https://dashscope-us.aliyuncs.com/compatible-mode/v1?x=1", - "https://workspace_1.cn-beijing.maas.aliyuncs.com/compatible-mode/v1", - "https://workspace-1.cn-hangzhou.maas.aliyuncs.com/compatible-mode/v1", - ] { - assert!( - validate_qwen_endpoint(endpoint).is_err(), - "expected rejection for {endpoint}" - ); - } - } - - #[test] - fn invalid_qwen_endpoint_does_not_authenticate() { - let store = store("goat-provider-hosted-qwen-invalid.json"); - store - .store( - &CredentialKey::model(QWEN, "default"), - Credential::ApiKeyWithEndpoint { - secret: SecretString::from("key".to_owned()), - endpoint: "https://example.com/compatible-mode/v1".to_owned(), - }, - ) - .unwrap(); - let provider = build_qwen(&store, "default"); - assert!(!provider.authenticated()); - } - - #[test] - fn qwen_endpoint_credential_authenticates() { - let store = store("goat-provider-hosted-qwen-valid.json"); - store - .store( - &CredentialKey::model(QWEN, "default"), - Credential::ApiKeyWithEndpoint { - secret: SecretString::from("key".to_owned()), - endpoint: "https://dashscope-us.aliyuncs.com/compatible-mode/v1".to_owned(), - }, - ) - .unwrap(); - let provider = build_qwen(&store, "default"); - assert!(provider.authenticated()); - assert_eq!( - provider.base_url(), - "https://dashscope-us.aliyuncs.com/compatible-mode/v1" - ); - } - - #[test] - fn local_no_auth_behavior_is_not_host_checked() { - let store = store("goat-provider-hosted-local.json"); - let provider = build_deepseek(&store, "default"); - assert_eq!(provider.capabilities().auth, AuthMethod::ApiKey); - } -} diff --git a/crates/goat-provider-hosted/Cargo.toml b/crates/goat-provider-kimi-code/Cargo.toml similarity index 84% rename from crates/goat-provider-hosted/Cargo.toml rename to crates/goat-provider-kimi-code/Cargo.toml index aed83ad..24b7510 100644 --- a/crates/goat-provider-hosted/Cargo.toml +++ b/crates/goat-provider-kimi-code/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "goat-provider-hosted" +name = "goat-provider-kimi-code" version.workspace = true edition.workspace = true rust-version.workspace = true @@ -13,9 +13,8 @@ goat-provider = { workspace = true } goat-provider-openai-compat = { workspace = true } reqwest = { workspace = true } serde = { workspace = true } -serde_json = { workspace = true } tokio = { workspace = true } thiserror = { workspace = true } [lints] -workspace = true +workspace = true \ No newline at end of file diff --git a/crates/goat-provider-kimi-code/src/lib.rs b/crates/goat-provider-kimi-code/src/lib.rs new file mode 100644 index 0000000..a951689 --- /dev/null +++ b/crates/goat-provider-kimi-code/src/lib.rs @@ -0,0 +1,239 @@ +mod oauth; + +use goat_auth::{Credential, CredentialKey, CredentialStore, TokenSet}; +use goat_provider::{ + AuthMethod, Capabilities, Effort, Model, Provider, ProviderId, ProviderMetadata, Request, + StreamError, StreamEvent, WebSearchOutput, +}; +use goat_provider_openai_compat::{ + OpenAiCompatProvider, enforce_https_host, no_efforts, no_vision, +}; +use tokio::{sync::mpsc, task::JoinHandle}; + +pub const PROVIDER_ID: &str = "kimi-code"; + +const BASE_URL: &str = "https://api.kimi.com/coding/v1"; +const ALLOWED_HOST: &str = "api.kimi.com"; + +const SETUP: &[&str] = &[ + "Kimi Code OAuth device-code login.", + "Run `goat provider login kimi-code`, open the URL, and enter the code.", +]; + +const CATALOG: &[&str] = &[ + "kimi-k2.7-code", + "kimi-k2.7-code-highspeed", + "kimi-k2.6", + "kimi-k2.5", +]; + +const CONTEXT_WINDOWS: &[(&str, u32)] = &[ + ("kimi-k2.7", 256_000), + ("kimi-k2.6", 256_000), + ("kimi-k2.5", 256_000), +]; + +pub fn build(store: &CredentialStore, account: &str) -> KimiCodeProvider { + enforce_https_host(BASE_URL, ALLOWED_HOST).expect("kimi-code provider base URL"); + KimiCodeProvider::new(store.clone(), CredentialKey::model(PROVIDER_ID, account)) +} + +pub struct KimiCodeProvider { + store: CredentialStore, + key: CredentialKey, + client: reqwest::Client, +} + +impl KimiCodeProvider { + pub fn new(store: CredentialStore, key: CredentialKey) -> Self { + Self { + store, + key, + client: oauth::oauth_client(), + } + } +} + +impl Provider for KimiCodeProvider { + fn id(&self) -> ProviderId { + ProviderId::from(PROVIDER_ID) + } + + fn capabilities(&self) -> Capabilities { + Capabilities { + tools: true, + auth: AuthMethod::OAuth, + images: false, + } + } + + fn metadata(&self) -> ProviderMetadata { + ProviderMetadata { + env_var: None, + validation: "network", + endpoint: Some(BASE_URL), + oauth: Some("device code"), + login_endpoint: None, + setup: SETUP, + } + } + + fn authenticated(&self) -> bool { + self.store + .get(&self.key) + .is_some_and(|cred| matches!(cred, Credential::OAuth(_))) + } + + fn catalog(&self) -> &'static [&'static str] { + CATALOG + } + + fn efforts(&self, _model: &str) -> Vec { + Vec::new() + } + + fn context_window(&self, model: &str) -> Option { + CONTEXT_WINDOWS + .iter() + .find_map(|(prefix, window)| model.starts_with(prefix).then_some(*window)) + } + + fn supports_images(&self, _model: &str) -> bool { + false + } + + fn verifies_credentials(&self) -> bool { + true + } + + fn validate(&self) -> JoinHandle> { + let store = self.store.clone(); + let key = self.key.clone(); + let client = self.client.clone(); + tokio::spawn(async move { + let Some(token) = oauth::current_token(&store, &key).await else { + return Err("no credentials".to_owned()); + }; + let response = client + .get(format!("{BASE_URL}/models")) + .bearer_auth(token) + .send() + .await + .map_err(|_| "could not reach provider".to_owned())?; + let status = response.status(); + if status.is_success() { + Ok(()) + } else if status == reqwest::StatusCode::UNAUTHORIZED + || status == reqwest::StatusCode::FORBIDDEN + { + Err("invalid credentials".to_owned()) + } else { + Err(format!("could not reach provider: {status}")) + } + }) + } + + fn stream(&self, req: Request, events: mpsc::Sender) -> JoinHandle<()> { + let store = self.store.clone(); + let key = self.key.clone(); + tokio::spawn(async move { + let Some(token) = oauth::current_token(&store, &key).await else { + let _ = events + .send(StreamEvent::Failed { + error: StreamError::auth("no credentials"), + }) + .await; + return; + }; + let provider = OpenAiCompatProvider::new( + ProviderId::from(PROVIDER_ID), + BASE_URL, + Some(token), + AuthMethod::OAuth, + ) + .with_catalog(CATALOG) + .with_context_windows(CONTEXT_WINDOWS) + .with_vision_filter(no_vision) + .with_efforts(no_efforts) + .with_reasoning_effort(false); + let handle = provider.stream(req, events); + let _ = handle.await; + }) + } + + fn discover(&self, out: mpsc::Sender) -> JoinHandle<()> { + let store = self.store.clone(); + let key = self.key.clone(); + tokio::spawn(async move { + let Some(token) = oauth::current_token(&store, &key).await else { + for id in CATALOG { + if out + .send(Model { + id: (*id).to_owned(), + supports_images: false, + }) + .await + .is_err() + { + return; + } + } + return; + }; + let provider = OpenAiCompatProvider::new( + ProviderId::from(PROVIDER_ID), + BASE_URL, + Some(token), + AuthMethod::OAuth, + ) + .with_model_filter(chat_model) + .with_vision_filter(no_vision); + let handle = provider.discover(out); + let _ = handle.await; + }) + } + + fn login(&self, status: mpsc::Sender) -> JoinHandle> { + tokio::spawn(async move { oauth::login(&status).await.map_err(|err| err.to_string()) }) + } + + fn web_search(&self, query: String) -> JoinHandle> { + let _ = query; + tokio::spawn(async { Err(StreamError::other("web search is not supported")) }) + } +} + +fn chat_model(id: &str) -> bool { + let id = id.to_ascii_lowercase(); + !id.contains("embedding") && !id.contains("image") && !id.contains("video") +} + +#[cfg(test)] +mod tests { + use goat_auth::CredentialStore; + use goat_provider::{AuthMethod, Provider}; + + use super::*; + use crate::oauth::valid_kimi_verification_url; + + fn store(name: &str) -> CredentialStore { + let _ = std::fs::remove_file(std::env::temp_dir().join(name)); + CredentialStore::new(std::env::temp_dir().join(name)) + } + + #[test] + fn kimi_code_is_oauth_provider() { + let store = store("goat-provider-kimi-code.json"); + let provider = build(&store, "default"); + assert_eq!(provider.capabilities().auth, AuthMethod::OAuth); + assert_eq!(provider.metadata().oauth, Some("device code")); + assert!(!provider.authenticated()); + assert_eq!(provider.catalog(), CATALOG); + assert!(valid_kimi_verification_url( + "https://auth.kimi.com/device?code=abc" + )); + assert!(!valid_kimi_verification_url( + "https://example.com/device?code=abc" + )); + } +} diff --git a/crates/goat-provider-kimi-code/src/oauth.rs b/crates/goat-provider-kimi-code/src/oauth.rs new file mode 100644 index 0000000..41fbfbf --- /dev/null +++ b/crates/goat-provider-kimi-code/src/oauth.rs @@ -0,0 +1,337 @@ +use std::path::PathBuf; + +use goat_auth::{Credential, CredentialKey, CredentialStore, TokenSet, ensure_valid, now_secs}; +use reqwest::header::{ACCEPT, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue, USER_AGENT}; +use serde::Deserialize; +use tokio::sync::mpsc; + +const OAUTH_HOST: &str = "https://auth.kimi.com"; +const CLIENT_ID: &str = "17e5f671-d194-4dfb-9706-5516cb48c098"; + +#[derive(Debug, thiserror::Error)] +pub enum KimiCodeOAuthError { + #[error("http error: {0}")] + Http(#[from] reqwest::Error), + #[error("oauth error: {0}")] + OAuth(String), + #[error("io error: {0}")] + Io(#[from] std::io::Error), +} + +#[derive(Deserialize)] +struct DeviceAuthorizationResponse { + user_code: String, + device_code: String, + verification_uri: Option, + verification_uri_complete: String, + expires_in: Option, + interval: Option, +} + +#[derive(Deserialize)] +pub(crate) struct TokenResponse { + access_token: String, + refresh_token: String, + expires_in: i64, + scope: Option, + token_type: Option, +} + +#[derive(Deserialize)] +struct OAuthErrorResponse { + error: Option, + #[serde(rename = "error_description")] + _error_description: Option, +} + +pub fn oauth_client() -> reqwest::Client { + reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .connect_timeout(std::time::Duration::from_secs(10)) + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("reqwest client") +} + +pub async fn login(status: &mpsc::Sender) -> Result { + let client = oauth_client(); + let device = request_device_authorization(&client).await?; + let url = if device.verification_uri_complete.is_empty() { + device.verification_uri.as_deref().unwrap_or("") + } else { + &device.verification_uri_complete + }; + if !valid_kimi_verification_url(url) { + return Err(KimiCodeOAuthError::OAuth( + "device authorization returned an invalid verification URL".to_owned(), + )); + } + let _ = status + .send(format!("open {url} and enter code: {}", device.user_code)) + .await; + poll_device_token(&client, &device).await +} + +pub async fn current_token(store: &CredentialStore, key: &CredentialKey) -> Option { + let Credential::OAuth(tokens) = store.get(key)? else { + return None; + }; + let tokens = ensure_valid(tokens, store, key, refresh).await?; + Some(tokens.access_token.expose().to_owned()) +} + +async fn request_device_authorization( + client: &reqwest::Client, +) -> Result { + let response = client + .post(format!("{OAUTH_HOST}/api/oauth/device_authorization")) + .headers(kimi_headers()) + .form(&[("client_id", CLIENT_ID)]) + .send() + .await?; + let status = response.status(); + if !status.is_success() { + return Err(KimiCodeOAuthError::OAuth(format!( + "device authorization failed: {status}" + ))); + } + let device: DeviceAuthorizationResponse = response.json().await?; + if device.user_code.is_empty() + || device.device_code.is_empty() + || device.verification_uri_complete.is_empty() + { + return Err(KimiCodeOAuthError::OAuth( + "device authorization response is missing required fields".to_owned(), + )); + } + Ok(device) +} + +async fn poll_device_token( + client: &reqwest::Client, + device: &DeviceAuthorizationResponse, +) -> Result { + let mut interval = device.interval.unwrap_or(5).max(1); + let deadline = now_secs() + i64::try_from(device.expires_in.unwrap_or(900)).unwrap_or(900); + loop { + if now_secs() > deadline { + return Err(KimiCodeOAuthError::OAuth( + "device login timed out".to_owned(), + )); + } + tokio::time::sleep(std::time::Duration::from_secs(interval)).await; + let response = client + .post(format!("{OAUTH_HOST}/api/oauth/token")) + .headers(kimi_headers()) + .form(&[ + ("client_id", CLIENT_ID), + ("device_code", device.device_code.as_str()), + ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), + ]) + .send() + .await?; + let status = response.status(); + if status.is_success() { + let tokens: TokenResponse = response.json().await?; + return parse_token_response(tokens); + } + let error = response.json::().await.ok(); + match error.as_ref().and_then(|err| err.error.as_deref()) { + Some("authorization_pending") => {} + Some("slow_down") => interval = interval.saturating_add(5).min(30), + Some("expired_token") => { + return Err(KimiCodeOAuthError::OAuth( + "device login code expired".to_owned(), + )); + } + Some("access_denied") => { + return Err(KimiCodeOAuthError::OAuth( + "device login access denied".to_owned(), + )); + } + Some(code) => { + return Err(KimiCodeOAuthError::OAuth(format!( + "device token polling failed: {code}" + ))); + } + None => { + return Err(KimiCodeOAuthError::OAuth(format!( + "device token polling failed: {status}" + ))); + } + } + } +} + +async fn refresh(refresh_token: String) -> Result { + let client = oauth_client(); + let response = client + .post(format!("{OAUTH_HOST}/api/oauth/token")) + .headers(kimi_headers()) + .form(&[ + ("client_id", CLIENT_ID), + ("grant_type", "refresh_token"), + ("refresh_token", refresh_token.as_str()), + ]) + .send() + .await + .map_err(|_| "token refresh request failed".to_owned())?; + let status = response.status(); + if !status.is_success() { + return Err(format!("token refresh failed: {status}")); + } + response + .json::() + .await + .map_err(|_| "token refresh returned invalid JSON".to_owned()) + .and_then(|tokens| parse_token_response(tokens).map_err(|err| err.to_string())) +} + +pub(crate) fn parse_token_response(tokens: TokenResponse) -> Result { + let _ = (&tokens.scope, &tokens.token_type); + if tokens.access_token.is_empty() || tokens.refresh_token.is_empty() || tokens.expires_in <= 0 { + return Err(KimiCodeOAuthError::OAuth( + "token response is missing required fields".to_owned(), + )); + } + Ok(TokenSet::from_parts( + tokens.access_token, + Some(tokens.refresh_token), + Some(tokens.expires_in), + None, + )) +} + +fn kimi_headers() -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.insert(ACCEPT, HeaderValue::from_static("application/json")); + headers.insert( + CONTENT_TYPE, + HeaderValue::from_static("application/x-www-form-urlencoded"), + ); + headers.insert(USER_AGENT, HeaderValue::from_static("goat-code/0.1.14")); + insert_header(&mut headers, "X-Msh-Platform", "kimi_code_cli"); + insert_header(&mut headers, "X-Msh-Version", env!("CARGO_PKG_VERSION")); + insert_header(&mut headers, "X-Msh-Device-Name", &device_name()); + insert_header(&mut headers, "X-Msh-Device-Model", &device_model()); + insert_header(&mut headers, "X-Msh-Os-Version", std::env::consts::OS); + insert_header(&mut headers, "X-Msh-Device-Id", &device_id()); + headers +} + +fn insert_header(headers: &mut HeaderMap, name: &'static str, value: &str) { + if let Ok(value) = HeaderValue::from_str(&ascii_header(value)) { + headers.insert( + HeaderName::from_static(name.to_ascii_lowercase().leak()), + value, + ); + } +} + +fn ascii_header(value: &str) -> String { + let cleaned: String = value + .chars() + .filter(|ch| matches!(*ch as u32, 0x20..=0x7e)) + .collect::() + .trim() + .to_owned(); + if cleaned.is_empty() { + "unknown".to_owned() + } else { + cleaned + } +} + +fn device_name() -> String { + std::env::var("HOSTNAME") + .or_else(|_| std::env::var("COMPUTERNAME")) + .unwrap_or_else(|_| "unknown".to_owned()) +} + +fn device_model() -> String { + format!("{} {}", std::env::consts::OS, std::env::consts::ARCH) +} + +fn device_id() -> String { + let path = device_id_path(); + if let Ok(value) = std::fs::read_to_string(&path) { + let value = value.trim(); + if !value.is_empty() { + return value.to_owned(); + } + } + let id = goat_auth::random_state(); + if let Some(parent) = path.parent() { + let _ = std::fs::create_dir_all(parent); + set_private_dir(parent); + } + if std::fs::write(&path, &id).is_ok() { + set_private_file(&path); + } + id +} + +fn device_id_path() -> PathBuf { + std::env::home_dir().map_or_else( + || PathBuf::from(".goat-code-kimi-code-device-id"), + |home| home.join(".goat-code").join("kimi-code-device-id"), + ) +} + +#[cfg(unix)] +fn set_private_dir(path: &std::path::Path) { + use std::os::unix::fs::PermissionsExt; + let _ = std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o700)); +} + +#[cfg(not(unix))] +fn set_private_dir(_path: &std::path::Path) {} + +#[cfg(unix)] +fn set_private_file(path: &std::path::Path) { + use std::os::unix::fs::PermissionsExt; + let _ = std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600)); +} + +#[cfg(not(unix))] +fn set_private_file(_path: &std::path::Path) {} + +pub fn valid_kimi_verification_url(url: &str) -> bool { + reqwest::Url::parse(url) + .ok() + .and_then(|url| { + (url.scheme() == "https") + .then(|| url.host_str().is_some_and(|host| host == "auth.kimi.com")) + }) + .unwrap_or(false) +} + +#[cfg(test)] +mod tests { + use super::{TokenResponse, parse_token_response}; + + #[test] + fn parses_kimi_token_response_without_leaking_secrets() { + let token = parse_token_response(TokenResponse { + access_token: "access-secret".to_owned(), + refresh_token: "refresh-secret".to_owned(), + expires_in: 3600, + scope: Some("scope".to_owned()), + token_type: Some("Bearer".to_owned()), + }) + .unwrap(); + assert_eq!(token.access_token.expose(), "access-secret"); + assert_eq!(token.refresh_token.unwrap().expose(), "refresh-secret"); + let error = parse_token_response(TokenResponse { + access_token: String::new(), + refresh_token: "refresh-secret".to_owned(), + expires_in: 3600, + scope: None, + token_type: None, + }) + .unwrap_err() + .to_string(); + assert!(!error.contains("refresh-secret")); + assert!(!error.contains("access-secret")); + } +} diff --git a/crates/goat-provider-kimi/Cargo.toml b/crates/goat-provider-kimi/Cargo.toml new file mode 100644 index 0000000..d2de9f9 --- /dev/null +++ b/crates/goat-provider-kimi/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "goat-provider-kimi" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true +publish = false + +[dependencies] +goat-auth = { workspace = true } +goat-provider = { workspace = true } +goat-provider-openai-compat = { workspace = true } + +[lints] +workspace = true \ No newline at end of file diff --git a/crates/goat-provider-kimi/src/lib.rs b/crates/goat-provider-kimi/src/lib.rs new file mode 100644 index 0000000..47882e3 --- /dev/null +++ b/crates/goat-provider-kimi/src/lib.rs @@ -0,0 +1,55 @@ +use goat_auth::CredentialStore; +use goat_provider::ProviderMetadata; +use goat_provider_openai_compat::{ + ChatDiscovery, ChatValidation, OpenAiCompatProvider, api_key, no_efforts, no_vision, +}; + +pub const PROVIDER_ID: &str = "kimi"; + +const BASE_URL: &str = "https://api.moonshot.ai/v1"; +const HOST: &str = "api.moonshot.ai"; +const ENV_VAR: &str = "MOONSHOT_API_KEY"; + +const KIMI_SETUP: &[&str] = &[ + "Kimi Platform API key provider.", + "For Kimi Code OAuth, use `goat provider login kimi-code`.", + "API-key setup: `goat provider login kimi --key sk-...`.", +]; + +const CATALOG: &[&str] = &[ + "kimi-k2.7-code", + "kimi-k2.7-code-highspeed", + "kimi-k2.6", + "kimi-k2.5", + "moonshot-v1-128k", + "moonshot-v1-32k", + "moonshot-v1-8k", +]; + +const CONTEXT: &[(&str, u32)] = &[ + ("kimi-k2.7", 256_000), + ("kimi-k2.6", 256_000), + ("kimi-k2.5", 256_000), + ("moonshot-v1-128k", 128_000), + ("moonshot-v1-32k", 32_000), + ("moonshot-v1-8k", 8_000), +]; + +pub fn build(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { + api_key(store, account, PROVIDER_ID, BASE_URL, HOST, ENV_VAR) + .with_catalog(CATALOG) + .with_context_windows(CONTEXT) + .with_vision_filter(no_vision) + .with_efforts(no_efforts) + .with_reasoning_effort(false) + .with_validation(ChatValidation::CatalogOnly) + .with_discovery(ChatDiscovery::CatalogOnly) + .with_metadata(ProviderMetadata { + env_var: Some(ENV_VAR), + validation: "catalog-only", + endpoint: None, + oauth: Some("Kimi Code OAuth is provider id kimi-code"), + login_endpoint: None, + setup: KIMI_SETUP, + }) +} diff --git a/crates/goat-provider-mistral/Cargo.toml b/crates/goat-provider-mistral/Cargo.toml new file mode 100644 index 0000000..e8900ad --- /dev/null +++ b/crates/goat-provider-mistral/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "goat-provider-mistral" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true +publish = false + +[dependencies] +goat-provider = { workspace = true } +goat-provider-openai-compat = { workspace = true } +goat-auth = { workspace = true } + +[lints] +workspace = true \ No newline at end of file diff --git a/crates/goat-provider-mistral/src/lib.rs b/crates/goat-provider-mistral/src/lib.rs new file mode 100644 index 0000000..acd6f8b --- /dev/null +++ b/crates/goat-provider-mistral/src/lib.rs @@ -0,0 +1,37 @@ +use goat_auth::CredentialStore; +use goat_provider_openai_compat::{OpenAiCompatProvider, api_key, no_efforts}; + +pub const PROVIDER_ID: &str = "mistral"; +const BASE_URL: &str = "https://api.mistral.ai/v1"; +const HOST: &str = "api.mistral.ai"; +const ENV_VAR: &str = "MISTRAL_API_KEY"; + +const CATALOG: &[&str] = &[ + "mistral-large-latest", + "mistral-medium-latest", + "mistral-small-latest", + "devstral-medium-latest", + "codestral-latest", + "ministral-8b-latest", +]; + +const CONTEXT_WINDOWS: &[(&str, u32)] = &[ + ("mistral-large", 131_072), + ("mistral-medium", 131_072), + ("mistral-small", 131_072), + ("devstral-medium", 131_072), + ("codestral", 256_000), +]; + +fn is_vision_model(id: &str) -> bool { + id.to_ascii_lowercase().contains("pixtral") +} + +pub fn build(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { + api_key(store, account, PROVIDER_ID, BASE_URL, HOST, ENV_VAR) + .with_catalog(CATALOG) + .with_context_windows(CONTEXT_WINDOWS) + .with_vision_filter(is_vision_model) + .with_efforts(no_efforts) + .with_reasoning_effort(false) +} diff --git a/crates/goat-provider-openai-compat/Cargo.toml b/crates/goat-provider-openai-compat/Cargo.toml index 6a6389f..6a50dfb 100644 --- a/crates/goat-provider-openai-compat/Cargo.toml +++ b/crates/goat-provider-openai-compat/Cargo.toml @@ -9,6 +9,7 @@ publish = false [dependencies] goat-provider = { workspace = true } +goat-auth = { workspace = true } reqwest = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/crates/goat-provider-openai-compat/src/hosted.rs b/crates/goat-provider-openai-compat/src/hosted.rs new file mode 100644 index 0000000..f9a42c4 --- /dev/null +++ b/crates/goat-provider-openai-compat/src/hosted.rs @@ -0,0 +1,67 @@ +use goat_auth::{CredentialKey, CredentialStore}; +use goat_provider::{AuthMethod, Effort, ProviderId, ProviderMetadata}; + +use crate::OpenAiCompatProvider; + +pub fn enforce_https_host(base_url: &str, allowed_host: &str) -> Result<(), String> { + let url = base_url.trim_end_matches('/'); + let rest = url + .strip_prefix("https://") + .ok_or_else(|| "hosted providers require https".to_owned())?; + let actual = rest.split('/').next().unwrap_or_default(); + if actual == allowed_host || actual.ends_with(&format!(".{allowed_host}")) { + Ok(()) + } else { + Err(format!("invalid hosted provider host: {actual}")) + } +} + +pub fn api_key( + store: &CredentialStore, + account: &str, + provider_id: &'static str, + base_url: &'static str, + allowed_host: &'static str, + env_var: &'static str, +) -> OpenAiCompatProvider { + enforce_https_host(base_url, allowed_host).expect("hosted provider base URL"); + let key = CredentialKey::model(provider_id, account); + let bearer = store + .resolve(&key, Some(env_var)) + .map(|cred| cred.bearer().to_owned()); + OpenAiCompatProvider::new( + ProviderId::from(provider_id), + base_url, + bearer, + AuthMethod::ApiKey, + ) + .with_metadata(ProviderMetadata { + env_var: Some(env_var), + validation: "network", + endpoint: None, + oauth: Some("not supported"), + login_endpoint: None, + setup: &[], + }) +} + +pub fn no_vision(_id: &str) -> bool { + false +} + +pub fn no_efforts(_model: &str) -> Vec { + Vec::new() +} + +#[cfg(test)] +mod tests { + use super::enforce_https_host; + + #[test] + fn enforces_https_and_allowed_host() { + assert!(enforce_https_host("https://openrouter.ai/api/v1/", "openrouter.ai").is_ok()); + assert!(enforce_https_host("http://openrouter.ai/api/v1", "openrouter.ai").is_err()); + assert!(enforce_https_host("https://example.com/api/v1", "openrouter.ai").is_err()); + assert!(enforce_https_host("https://api.z.ai/api/coding/paas/v4", "api.z.ai").is_ok()); + } +} diff --git a/crates/goat-provider-openai-compat/src/lib.rs b/crates/goat-provider-openai-compat/src/lib.rs index d9a3a2b..d0bd271 100644 --- a/crates/goat-provider-openai-compat/src/lib.rs +++ b/crates/goat-provider-openai-compat/src/lib.rs @@ -1,11 +1,13 @@ pub mod chat; pub mod common; pub mod headers; +pub mod hosted; pub mod responses; pub mod vision; pub use chat::{ChatDiscovery, ChatValidation, OpenAiCompatProvider}; pub use headers::parse_codex_ratelimits; +pub use hosted::{api_key, enforce_https_host, no_efforts, no_vision}; pub use responses::{ ResponsesProvider, build_body, responses_efforts, run_request, run_web_search, }; diff --git a/crates/goat-provider-openrouter/Cargo.toml b/crates/goat-provider-openrouter/Cargo.toml new file mode 100644 index 0000000..ebede01 --- /dev/null +++ b/crates/goat-provider-openrouter/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "goat-provider-openrouter" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true +publish = false + +[dependencies] +goat-provider = { workspace = true } +goat-provider-openai-compat = { workspace = true } +goat-auth = { workspace = true } + +[lints] +workspace = true \ No newline at end of file diff --git a/crates/goat-provider-openrouter/src/lib.rs b/crates/goat-provider-openrouter/src/lib.rs new file mode 100644 index 0000000..170a524 --- /dev/null +++ b/crates/goat-provider-openrouter/src/lib.rs @@ -0,0 +1,53 @@ +use goat_auth::CredentialStore; +use goat_provider_openai_compat::{ + OpenAiCompatProvider, api_key, known_openai_compatible_vision_model, +}; + +pub const PROVIDER_ID: &str = "openrouter"; +const BASE_URL: &str = "https://openrouter.ai/api/v1"; +const HOST: &str = "openrouter.ai"; +const ENV_VAR: &str = "OPENROUTER_API_KEY"; + +const CATALOG: &[&str] = &[ + "anthropic/claude-sonnet-4.5", + "openai/gpt-5.1", + "google/gemini-2.5-pro", + "deepseek/deepseek-chat-v3.1", + "qwen/qwen3-coder", + "moonshotai/kimi-k2", +]; + +const CONTEXT_WINDOWS: &[(&str, u32)] = &[ + ("anthropic/claude-sonnet-4.5", 200_000), + ("openai/gpt-5", 400_000), + ("google/gemini-2.5", 1_000_000), + ("deepseek/deepseek", 128_000), + ("qwen/qwen3-coder", 256_000), + ("moonshotai/kimi", 256_000), +]; + +fn is_chat_model(id: &str) -> bool { + let id = id.to_ascii_lowercase(); + !(id.contains("embedding") + || id.contains("moderation") + || id.contains("image") + || id.contains("tts") + || id.contains("whisper")) +} + +fn is_vision_model(id: &str) -> bool { + let id = id.to_ascii_lowercase(); + known_openai_compatible_vision_model(&id) + || id.contains("claude") + || id.contains("gemini") + || id.contains("grok-4") +} + +pub fn build(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { + api_key(store, account, PROVIDER_ID, BASE_URL, HOST, ENV_VAR) + .with_catalog(CATALOG) + .with_context_windows(CONTEXT_WINDOWS) + .with_model_filter(is_chat_model) + .with_vision_filter(is_vision_model) + .with_reasoning_effort(false) +} diff --git a/crates/goat-provider-qwen/Cargo.toml b/crates/goat-provider-qwen/Cargo.toml new file mode 100644 index 0000000..a9ac0b4 --- /dev/null +++ b/crates/goat-provider-qwen/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "goat-provider-qwen" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true +publish = false + +[dependencies] +goat-auth = { workspace = true } +goat-provider = { workspace = true } +goat-provider-openai-compat = { workspace = true } +reqwest = { workspace = true } + +[lints] +workspace = true \ No newline at end of file diff --git a/crates/goat-provider-qwen/src/lib.rs b/crates/goat-provider-qwen/src/lib.rs new file mode 100644 index 0000000..c1451d8 --- /dev/null +++ b/crates/goat-provider-qwen/src/lib.rs @@ -0,0 +1,211 @@ +use goat_auth::{Credential, CredentialKey, CredentialStore}; +use goat_provider::{AuthMethod, LoginEndpointMetadata, ProviderId, ProviderMetadata}; +use goat_provider_openai_compat::{OpenAiCompatProvider, no_efforts}; + +pub const PROVIDER_ID: &str = "qwen"; + +const QWEN_DEFAULT_ENDPOINT: &str = "https://dashscope-us.aliyuncs.com/compatible-mode/v1"; + +const QWEN_SETUP: &[&str] = &[ + "Qwen DashScope API-key provider.", + "Default endpoint: https://dashscope-us.aliyuncs.com/compatible-mode/v1", + "Non-US workspaces: `goat provider login qwen --endpoint --key sk-...`.", + "Qwen OAuth enrollment is discontinued upstream.", +]; + +const CATALOG: &[&str] = &[ + "qwen-plus", + "qwen-max", + "qwen-turbo", + "qwen3-coder-plus", + "qwen3-coder-flash", + "qwen-vl-plus", +]; + +const CONTEXT: &[(&str, u32)] = &[ + ("qwen-plus", 131_072), + ("qwen-max", 131_072), + ("qwen-turbo", 1_000_000), + ("qwen3-coder", 1_000_000), + ("qwen-vl", 129_024), +]; + +pub fn build(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { + let key = CredentialKey::model(PROVIDER_ID, account); + let stored = store.get(&key); + let endpoint_source = std::env::var("QWEN_BASE_URL").ok().or_else(|| { + stored + .as_ref() + .and_then(Credential::endpoint) + .map(str::to_owned) + }); + let endpoint = match endpoint_source { + Some(raw) => validate_qwen_endpoint(&raw).ok(), + None => Some(QWEN_DEFAULT_ENDPOINT.to_owned()), + }; + let bearer = endpoint.as_ref().and_then(|_| { + store + .resolve(&key, Some("DASHSCOPE_API_KEY")) + .map(|cred| cred.bearer().to_owned()) + }); + OpenAiCompatProvider::new( + ProviderId::from(PROVIDER_ID), + endpoint.unwrap_or_else(|| QWEN_DEFAULT_ENDPOINT.to_owned()), + bearer, + AuthMethod::ApiKey, + ) + .with_catalog(CATALOG) + .with_context_windows(CONTEXT) + .with_vision_filter(qwen_vision_model) + .with_efforts(no_efforts) + .with_reasoning_effort(false) + .with_metadata(ProviderMetadata { + env_var: Some("DASHSCOPE_API_KEY"), + validation: "network", + endpoint: Some("required for non-US DashScope workspaces"), + oauth: Some("Qwen OAuth enrollment discontinued"), + login_endpoint: Some(LoginEndpointMetadata { + env_var: Some("QWEN_BASE_URL"), + default: Some(QWEN_DEFAULT_ENDPOINT), + validate: Some(validate_qwen_endpoint), + }), + setup: QWEN_SETUP, + }) +} + +pub fn validate_qwen_endpoint(endpoint: &str) -> Result { + let trimmed = endpoint.trim().trim_end_matches('/'); + let url = reqwest::Url::parse(trimmed).map_err(|err| err.to_string())?; + if url.scheme() != "https" { + return Err("qwen endpoint must use https".to_owned()); + } + if !url.username().is_empty() || url.password().is_some() { + return Err("qwen endpoint must not include userinfo".to_owned()); + } + let Some(host) = url.host_str() else { + return Err("qwen endpoint must include a host".to_owned()); + }; + if host.ends_with('.') { + return Err("qwen endpoint host must not end with a dot".to_owned()); + } + let allowed_static = [ + "dashscope.aliyuncs.com", + "dashscope-intl.aliyuncs.com", + "dashscope-us.aliyuncs.com", + ]; + let allowed_regions = [ + "cn-beijing.maas.aliyuncs.com", + "ap-southeast-1.maas.aliyuncs.com", + "ap-northeast-1.maas.aliyuncs.com", + ]; + let allowed = allowed_static.contains(&host) + || allowed_regions.iter().any(|region| { + host.strip_suffix(region) + .and_then(|prefix| prefix.strip_suffix('.')) + .is_some_and(valid_workspace_id) + }); + if !allowed { + return Err("qwen endpoint host is not an allowed Alibaba Model Studio host".to_owned()); + } + if url.port().is_some() { + return Err("qwen endpoint must not include a custom port".to_owned()); + } + if url.path() != "/compatible-mode/v1" { + return Err("qwen endpoint path must be /compatible-mode/v1".to_owned()); + } + if url.query().is_some() || url.fragment().is_some() { + return Err("qwen endpoint must not include query or fragment".to_owned()); + } + Ok(trimmed.to_owned()) +} + +fn valid_workspace_id(value: &str) -> bool { + !value.is_empty() + && value + .bytes() + .all(|b| b.is_ascii_alphanumeric() || b == b'-') +} + +fn qwen_vision_model(id: &str) -> bool { + let id = id.to_ascii_lowercase(); + id.contains("qwen-vl") || id.contains("qwen2-vl") || id.contains("qwen2.5-vl") +} + +#[cfg(test)] +mod tests { + use goat_auth::{Credential, CredentialStore, SecretString}; + use goat_provider::Provider; + + use super::*; + + fn store(name: &str) -> CredentialStore { + let _ = std::fs::remove_file(std::env::temp_dir().join(name)); + CredentialStore::new(std::env::temp_dir().join(name)) + } + + #[test] + fn validates_qwen_endpoints() { + for endpoint in [ + "https://dashscope-us.aliyuncs.com/compatible-mode/v1", + "https://dashscope.aliyuncs.com/compatible-mode/v1", + "https://workspace-1.cn-beijing.maas.aliyuncs.com/compatible-mode/v1", + "https://abc123.ap-southeast-1.maas.aliyuncs.com/compatible-mode/v1/", + ] { + assert_eq!( + validate_qwen_endpoint(endpoint).unwrap(), + endpoint.trim_end_matches('/') + ); + } + for endpoint in [ + "http://dashscope-us.aliyuncs.com/compatible-mode/v1", + "https://dashscope-us.aliyuncs.com.evil.test/compatible-mode/v1", + "https://user@dashscope-us.aliyuncs.com/compatible-mode/v1", + "https://dashscope-us.aliyuncs.com:444/compatible-mode/v1", + "https://dashscope-us.aliyuncs.com/v1", + "https://dashscope-us.aliyuncs.com/compatible-mode/v1?x=1", + "https://workspace_1.cn-beijing.maas.aliyuncs.com/compatible-mode/v1", + "https://workspace-1.cn-hangzhou.maas.aliyuncs.com/compatible-mode/v1", + ] { + assert!( + validate_qwen_endpoint(endpoint).is_err(), + "expected rejection for {endpoint}" + ); + } + } + + #[test] + fn invalid_qwen_endpoint_does_not_authenticate() { + let store = store("goat-provider-qwen-invalid.json"); + store + .store( + &CredentialKey::model(PROVIDER_ID, "default"), + Credential::ApiKeyWithEndpoint { + secret: SecretString::from("key".to_owned()), + endpoint: "https://example.com/compatible-mode/v1".to_owned(), + }, + ) + .unwrap(); + let provider = build(&store, "default"); + assert!(!provider.authenticated()); + } + + #[test] + fn qwen_endpoint_credential_authenticates() { + let store = store("goat-provider-qwen-valid.json"); + store + .store( + &CredentialKey::model(PROVIDER_ID, "default"), + Credential::ApiKeyWithEndpoint { + secret: SecretString::from("key".to_owned()), + endpoint: "https://dashscope-us.aliyuncs.com/compatible-mode/v1".to_owned(), + }, + ) + .unwrap(); + let provider = build(&store, "default"); + assert!(provider.authenticated()); + assert_eq!( + provider.base_url(), + "https://dashscope-us.aliyuncs.com/compatible-mode/v1" + ); + } +} diff --git a/crates/goat-provider-xai/Cargo.toml b/crates/goat-provider-xai/Cargo.toml new file mode 100644 index 0000000..e4e6baa --- /dev/null +++ b/crates/goat-provider-xai/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "goat-provider-xai" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true +publish = false + +[dependencies] +goat-auth = { workspace = true } +goat-provider = { workspace = true } +goat-provider-openai-compat = { workspace = true } +reqwest = { workspace = true } +serde = { workspace = true } +tokio = { workspace = true } +thiserror = { workspace = true } +open = { workspace = true } + +[lints] +workspace = true \ No newline at end of file diff --git a/crates/goat-provider-xai/src/lib.rs b/crates/goat-provider-xai/src/lib.rs new file mode 100644 index 0000000..af3b2d0 --- /dev/null +++ b/crates/goat-provider-xai/src/lib.rs @@ -0,0 +1,336 @@ +mod oauth; + +use goat_auth::{Credential, CredentialKey, CredentialStore, TokenSet}; +use goat_provider::{ + AuthMethod, Capabilities, Effort, Model, Provider, ProviderId, ProviderMetadata, Request, + StreamError, StreamEvent, WebSearchOutput, +}; +use goat_provider_openai_compat::{ + OpenAiCompatProvider, ResponsesProvider, enforce_https_host, no_efforts, +}; +use tokio::{sync::mpsc, task::JoinHandle}; + +pub const PROVIDER_ID: &str = "xai"; + +const BASE_URL: &str = "https://api.x.ai/v1"; +const ALLOWED_HOST: &str = "api.x.ai"; + +const SETUP: &[&str] = &[ + "xAI Grok provider (API key or SuperGrok / X Premium+ OAuth).", + "API key: `goat provider login xai --key xai-...` or `XAI_API_KEY`.", + "OAuth: `goat provider login xai` (browser or device code; no API key).", +]; + +const OAUTH_CATALOG: &[&str] = &[ + "grok-4.3", + "grok-build-0.1", + "grok-4.20-beta-latest-reasoning", + "grok-4.20-beta-latest-non-reasoning", +]; + +const API_KEY_CATALOG: &[&str] = &[ + "grok-4", + "grok-4-fast-reasoning", + "grok-4-fast-non-reasoning", + "grok-3", + "grok-3-fast", + "grok-3-mini", +]; + +const CATALOG: &[&str] = &[ + "grok-4.3", + "grok-build-0.1", + "grok-4.20-beta-latest-reasoning", + "grok-4.20-beta-latest-non-reasoning", + "grok-4", + "grok-4-fast-reasoning", + "grok-4-fast-non-reasoning", + "grok-3", + "grok-3-fast", + "grok-3-mini", +]; + +const OAUTH_CONTEXT: &[(&str, u32)] = &[ + ("grok-4.3", 1_000_000), + ("grok-build", 512_000), + ("grok-4.20", 2_000_000), +]; + +const API_KEY_CONTEXT: &[(&str, u32)] = &[("grok-4", 256_000), ("grok-3", 131_072)]; + +pub fn build(store: &CredentialStore, account: &str) -> XaiProvider { + enforce_https_host(BASE_URL, ALLOWED_HOST).expect("xai provider base URL"); + XaiProvider::new(store.clone(), CredentialKey::model(PROVIDER_ID, account)) +} + +enum XaiAuth { + ApiKey(String), + OAuth(String), +} + +pub struct XaiProvider { + store: CredentialStore, + key: CredentialKey, +} + +impl XaiProvider { + pub fn new(store: CredentialStore, key: CredentialKey) -> Self { + Self { store, key } + } + + async fn resolve_auth(&self) -> Option { + let cred = self.store.resolve(&self.key, Some("XAI_API_KEY"))?; + match cred { + Credential::ApiKey(secret) | Credential::ApiKeyWithEndpoint { secret, .. } => { + Some(XaiAuth::ApiKey(secret.expose().to_owned())) + } + Credential::OAuth(_) => oauth::current_oauth_token(&self.store, &self.key) + .await + .map(XaiAuth::OAuth), + } + } + + fn is_oauth_model(model: &str) -> bool { + OAUTH_CATALOG + .iter() + .any(|id| *id == model || model.starts_with(id)) + } + + fn chat_provider(bearer: String) -> OpenAiCompatProvider { + OpenAiCompatProvider::new( + ProviderId::from(PROVIDER_ID), + BASE_URL, + Some(bearer), + AuthMethod::ApiKeyOrOAuth, + ) + .with_catalog(API_KEY_CATALOG) + .with_context_windows(API_KEY_CONTEXT) + .with_vision_filter(vision_model) + .with_efforts(no_efforts) + .with_reasoning_effort(false) + } + + fn responses_provider(bearer: String) -> ResponsesProvider { + ResponsesProvider::new( + ProviderId::from(PROVIDER_ID), + BASE_URL, + Some(bearer), + AuthMethod::ApiKeyOrOAuth, + ) + .with_catalog(OAUTH_CATALOG) + .with_context_windows(OAUTH_CONTEXT) + .with_vision_filter(vision_model) + .with_model_filter(oauth_chat_model) + } +} + +impl Provider for XaiProvider { + fn id(&self) -> ProviderId { + ProviderId::from(PROVIDER_ID) + } + + fn capabilities(&self) -> Capabilities { + Capabilities { + tools: true, + auth: AuthMethod::ApiKeyOrOAuth, + images: true, + } + } + + fn metadata(&self) -> ProviderMetadata { + ProviderMetadata { + env_var: Some("XAI_API_KEY"), + validation: "network", + endpoint: None, + oauth: Some("browser or device code (SuperGrok / X Premium+)"), + login_endpoint: None, + setup: SETUP, + } + } + + fn authenticated(&self) -> bool { + self.store.resolve(&self.key, Some("XAI_API_KEY")).is_some() + } + + fn catalog(&self) -> &'static [&'static str] { + CATALOG + } + + fn efforts(&self, model: &str) -> Vec { + if Self::is_oauth_model(model) { + oauth_efforts(model) + } else { + no_efforts(model) + } + } + + fn context_window(&self, model: &str) -> Option { + if Self::is_oauth_model(model) { + OAUTH_CONTEXT + .iter() + .find_map(|(prefix, window)| model.starts_with(prefix).then_some(*window)) + } else { + API_KEY_CONTEXT + .iter() + .find_map(|(prefix, window)| model.starts_with(prefix).then_some(*window)) + } + } + + fn supports_images(&self, model: &str) -> bool { + vision_model(model) + } + + fn verifies_credentials(&self) -> bool { + true + } + + fn validate(&self) -> JoinHandle> { + let store = self.store.clone(); + let key = self.key.clone(); + tokio::spawn(async move { + let provider = XaiProvider { store, key }; + let Some(auth) = provider.resolve_auth().await else { + return Err("no credentials".to_owned()); + }; + match auth { + XaiAuth::ApiKey(bearer) => { + let handle = XaiProvider::chat_provider(bearer).validate(); + handle.await.unwrap_or(Err("validation failed".to_owned())) + } + XaiAuth::OAuth(bearer) => { + let handle = XaiProvider::responses_provider(bearer).validate(); + handle.await.unwrap_or(Err("validation failed".to_owned())) + } + } + }) + } + + fn stream(&self, req: Request, events: mpsc::Sender) -> JoinHandle<()> { + let store = self.store.clone(); + let key = self.key.clone(); + let model = req.model.clone(); + tokio::spawn(async move { + let provider = XaiProvider { store, key }; + let Some(auth) = provider.resolve_auth().await else { + let _ = events + .send(StreamEvent::Failed { + error: StreamError::auth("no credentials"), + }) + .await; + return; + }; + let handle = match auth { + XaiAuth::ApiKey(bearer) => XaiProvider::chat_provider(bearer).stream(req, events), + XaiAuth::OAuth(bearer) => { + if !XaiProvider::is_oauth_model(&model) + && API_KEY_CATALOG.contains(&model.as_str()) + { + let _ = events + .send(StreamEvent::Failed { + error: StreamError::invalid_request(format!( + "model {model} requires an xAI API key; OAuth supports {}", + OAUTH_CATALOG.join(", ") + )), + }) + .await; + return; + } + XaiProvider::responses_provider(bearer).stream(req, events) + } + }; + let _ = handle.await; + }) + } + + fn discover(&self, out: mpsc::Sender) -> JoinHandle<()> { + let store = self.store.clone(); + let key = self.key.clone(); + tokio::spawn(async move { + let provider = XaiProvider { store, key }; + let Some(auth) = provider.resolve_auth().await else { + for id in CATALOG { + if out + .send(Model { + id: (*id).to_owned(), + supports_images: vision_model(id), + }) + .await + .is_err() + { + return; + } + } + return; + }; + let handle = match auth { + XaiAuth::ApiKey(bearer) => XaiProvider::chat_provider(bearer).discover(out), + XaiAuth::OAuth(bearer) => XaiProvider::responses_provider(bearer).discover(out), + }; + let _ = handle.await; + }) + } + + fn login(&self, status: mpsc::Sender) -> JoinHandle> { + tokio::spawn(async move { oauth::login(&status).await.map_err(|err| err.to_string()) }) + } + + fn web_search(&self, query: String) -> JoinHandle> { + let _ = query; + tokio::spawn(async { Err(StreamError::other("web search is not supported")) }) + } +} + +fn oauth_chat_model(id: &str) -> bool { + let id = id.to_ascii_lowercase(); + !(id.contains("embedding") + || id.contains("tts") + || id.contains("whisper") + || id.contains("image") + || id.contains("video")) +} + +fn oauth_efforts(model: &str) -> Vec { + let id = model.to_ascii_lowercase(); + if id.starts_with("grok-4.3") { + vec![Effort::Low, Effort::Medium, Effort::High] + } else { + Vec::new() + } +} + +fn vision_model(id: &str) -> bool { + let id = id.to_ascii_lowercase(); + id.starts_with("grok-4") || id.contains("vision") +} + +#[cfg(test)] +mod tests { + use goat_auth::CredentialStore; + use goat_provider::{AuthMethod, Effort, Provider}; + + use super::*; + + fn store(name: &str) -> CredentialStore { + let _ = std::fs::remove_file(std::env::temp_dir().join(name)); + CredentialStore::new(std::env::temp_dir().join(name)) + } + + #[test] + fn xai_supports_api_key_and_oauth() { + let store = store("goat-provider-xai.json"); + let provider = build(&store, "default"); + assert_eq!(provider.capabilities().auth, AuthMethod::ApiKeyOrOAuth); + assert_eq!( + provider.metadata().oauth, + Some("browser or device code (SuperGrok / X Premium+)") + ); + assert!(!provider.authenticated()); + assert_eq!(provider.catalog(), CATALOG); + assert_eq!(provider.context_window("grok-4.3"), Some(1_000_000)); + assert_eq!(provider.context_window("grok-4"), Some(256_000)); + assert_eq!( + provider.efforts("grok-4.3"), + vec![Effort::Low, Effort::Medium, Effort::High] + ); + } +} diff --git a/crates/goat-provider-xai/src/oauth.rs b/crates/goat-provider-xai/src/oauth.rs new file mode 100644 index 0000000..5455424 --- /dev/null +++ b/crates/goat-provider-xai/src/oauth.rs @@ -0,0 +1,523 @@ +use std::time::Duration; + +use goat_auth::{ + Credential, CredentialKey, CredentialStore, Pkce, TokenSet, capture_loopback_code, + ensure_valid, now_secs, random_state, +}; +use reqwest::header::{ACCEPT, CONTENT_TYPE, HeaderMap, HeaderValue, USER_AGENT}; +use serde::Deserialize; +use tokio::sync::mpsc; + +const CLIENT_ID: &str = "b1a00492-073a-47ea-816f-4c329264a828"; +const SCOPE: &str = "openid profile email offline_access grok-cli:access api:access"; +const DISCOVERY_URL: &str = "https://auth.x.ai/.well-known/openid-configuration"; +const CALLBACK_PORT: u16 = 56121; +const REDIRECT_URI: &str = "http://127.0.0.1:56121/callback"; +const DEVICE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:device_code"; +const LOGIN_TIMEOUT_SECS: i64 = 300; + +#[derive(Debug, thiserror::Error)] +pub enum XaiOAuthError { + #[error("http error: {0}")] + Http(#[from] reqwest::Error), + #[error("auth error: {0}")] + Auth(#[from] goat_auth::AuthError), + #[error("oauth error: {0}")] + OAuth(String), + #[error("no browser available")] + NoBrowser, +} + +struct OAuthDiscovery { + authorization_endpoint: String, + token_endpoint: String, +} + +struct DeviceDiscovery { + device_authorization_endpoint: String, + token_endpoint: String, +} + +#[derive(Deserialize)] +struct DiscoveryDocument { + #[serde(rename = "authorization_endpoint")] + authorization: Option, + #[serde(rename = "token_endpoint")] + token: Option, + #[serde(rename = "device_authorization_endpoint")] + device_authorization: Option, +} + +#[derive(Deserialize)] +struct DeviceAuthorizationResponse { + device_code: String, + user_code: String, + verification_uri: Option, + verification_uri_complete: Option, + expires_in: Option, + interval: Option, +} + +#[derive(Deserialize)] +struct TokenResponse { + access_token: String, + refresh_token: Option, + expires_in: Option, +} + +#[derive(Deserialize)] +struct OAuthErrorResponse { + error: Option, +} + +pub fn trusted_xai_host(endpoint: &str) -> bool { + let Ok(url) = reqwest::Url::parse(endpoint) else { + return false; + }; + if url.scheme() != "https" { + return false; + } + let Some(host) = url.host_str() else { + return false; + }; + host == "x.ai" || host.ends_with(".x.ai") +} + +fn require_trusted_endpoint(endpoint: &str, label: &str) -> Result { + if trusted_xai_host(endpoint) { + Ok(endpoint.to_owned()) + } else { + Err(XaiOAuthError::OAuth(format!( + "xAI OAuth discovery returned untrusted {label}" + ))) + } +} + +fn oauth_client() -> reqwest::Client { + reqwest::Client::builder() + .timeout(Duration::from_secs(30)) + .connect_timeout(Duration::from_secs(10)) + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("reqwest client") +} + +fn oauth_headers() -> HeaderMap { + let mut headers = HeaderMap::new(); + headers.insert(ACCEPT, HeaderValue::from_static("application/json")); + headers.insert( + CONTENT_TYPE, + HeaderValue::from_static("application/x-www-form-urlencoded"), + ); + let ua = format!("goat-code/{}", env!("CARGO_PKG_VERSION")); + if let Ok(value) = HeaderValue::from_str(&ua) { + headers.insert(USER_AGENT, value); + } + headers +} + +async fn fetch_discovery() -> Result { + let client = oauth_client(); + let response = client + .get(DISCOVERY_URL) + .headers(oauth_headers()) + .send() + .await?; + let status = response.status(); + if !status.is_success() { + return Err(XaiOAuthError::OAuth(format!( + "OAuth discovery failed: {status}" + ))); + } + let doc: DiscoveryDocument = response.json().await?; + let authorization_endpoint = doc + .authorization + .ok_or_else(|| XaiOAuthError::OAuth("missing authorization_endpoint".to_owned()))?; + let token_endpoint = doc + .token + .ok_or_else(|| XaiOAuthError::OAuth("missing token_endpoint".to_owned()))?; + Ok(OAuthDiscovery { + authorization_endpoint: require_trusted_endpoint( + &authorization_endpoint, + "authorization endpoint", + )?, + token_endpoint: require_trusted_endpoint(&token_endpoint, "token endpoint")?, + }) +} + +async fn fetch_device_discovery() -> Result { + let client = oauth_client(); + let response = client + .get(DISCOVERY_URL) + .headers(oauth_headers()) + .send() + .await?; + let status = response.status(); + if !status.is_success() { + return Err(XaiOAuthError::OAuth(format!( + "OAuth discovery failed: {status}" + ))); + } + let doc: DiscoveryDocument = response.json().await?; + let device_authorization_endpoint = doc + .device_authorization + .ok_or_else(|| XaiOAuthError::OAuth("missing device_authorization_endpoint".to_owned()))?; + let token_endpoint = doc + .token + .ok_or_else(|| XaiOAuthError::OAuth("missing token_endpoint".to_owned()))?; + Ok(DeviceDiscovery { + device_authorization_endpoint: require_trusted_endpoint( + &device_authorization_endpoint, + "device authorization endpoint", + )?, + token_endpoint: require_trusted_endpoint(&token_endpoint, "token endpoint")?, + }) +} + +pub fn build_authorize_url( + authorization_endpoint: &str, + challenge: &str, + state: &str, + nonce: &str, +) -> Result { + let endpoint = require_trusted_endpoint(authorization_endpoint, "authorization endpoint")?; + reqwest::Url::parse_with_params( + &endpoint, + &[ + ("response_type", "code"), + ("client_id", CLIENT_ID), + ("redirect_uri", REDIRECT_URI), + ("scope", SCOPE), + ("state", state), + ("nonce", nonce), + ("code_challenge", challenge), + ("code_challenge_method", "S256"), + ("plan", "generic"), + ("referrer", "goat-code"), + ], + ) + .map(|url| url.to_string()) + .map_err(|err| XaiOAuthError::OAuth(err.to_string())) +} + +fn random_nonce() -> String { + random_state() +} + +fn parse_token_response( + tokens: TokenResponse, + require_refresh: bool, +) -> Result { + if tokens.access_token.is_empty() { + return Err(XaiOAuthError::OAuth( + "token response is missing access_token".to_owned(), + )); + } + if require_refresh && tokens.refresh_token.as_deref().is_none_or(str::is_empty) { + return Err(XaiOAuthError::OAuth( + "token response is missing refresh_token".to_owned(), + )); + } + Ok(TokenSet::from_parts( + tokens.access_token, + tokens.refresh_token, + tokens.expires_in, + None, + )) +} + +async fn exchange_authorization_code( + token_endpoint: &str, + code: &str, + pkce: &Pkce, +) -> Result { + let endpoint = require_trusted_endpoint(token_endpoint, "token endpoint")?; + let client = oauth_client(); + let response = client + .post(endpoint) + .headers(oauth_headers()) + .form(&[ + ("grant_type", "authorization_code"), + ("code", code), + ("redirect_uri", REDIRECT_URI), + ("client_id", CLIENT_ID), + ("code_verifier", pkce.verifier.as_str()), + ("code_challenge", pkce.challenge.as_str()), + ("code_challenge_method", "S256"), + ]) + .send() + .await?; + let status = response.status(); + if !status.is_success() { + return Err(XaiOAuthError::OAuth(format!( + "token exchange failed: {status}" + ))); + } + let tokens: TokenResponse = response.json().await?; + parse_token_response(tokens, true) +} + +pub async fn refresh_token(refresh_token: String) -> Result { + let discovery = fetch_discovery().await.map_err(|err| err.to_string())?; + let client = oauth_client(); + let response = client + .post(&discovery.token_endpoint) + .headers(oauth_headers()) + .form(&[ + ("grant_type", "refresh_token"), + ("client_id", CLIENT_ID), + ("refresh_token", refresh_token.as_str()), + ]) + .send() + .await + .map_err(|err| err.to_string())?; + let status = response.status(); + if !status.is_success() { + return Err(format!("token refresh failed: {status}")); + } + let tokens: TokenResponse = response.json().await.map_err(|err| err.to_string())?; + parse_token_response(tokens, false).map_err(|err| err.to_string()) +} + +pub async fn current_oauth_token(store: &CredentialStore, key: &CredentialKey) -> Option { + let Credential::OAuth(tokens) = store.get(key)? else { + return None; + }; + let tokens = ensure_valid(tokens, store, key, refresh_token).await?; + Some(tokens.access_token.expose().to_owned()) +} + +fn browser_available() -> bool { + if cfg!(any(target_os = "macos", target_os = "windows")) { + return true; + } + std::env::var_os("DISPLAY").is_some() || std::env::var_os("WAYLAND_DISPLAY").is_some() +} + +async fn login_browser(status: &mpsc::Sender) -> Result { + let discovery = fetch_discovery().await?; + let pkce = Pkce::generate(); + let state = random_state(); + let nonce = random_nonce(); + let url = build_authorize_url( + &discovery.authorization_endpoint, + &pkce.challenge, + &state, + &nonce, + )?; + let _ = status + .send(format!( + "opening browser to sign in\u{2026} if it does not open, visit:\n{url}" + )) + .await; + if open::that(&url).is_err() { + return Err(XaiOAuthError::NoBrowser); + } + let code = capture_loopback_code(CALLBACK_PORT, &state).await?; + exchange_authorization_code(&discovery.token_endpoint, &code, &pkce).await +} + +async fn request_device_authorization( + device_authorization_endpoint: &str, +) -> Result { + let endpoint = require_trusted_endpoint( + device_authorization_endpoint, + "device authorization endpoint", + )?; + let client = oauth_client(); + let response = client + .post(endpoint) + .headers(oauth_headers()) + .form(&[("client_id", CLIENT_ID), ("scope", SCOPE)]) + .send() + .await?; + let status = response.status(); + if !status.is_success() { + return Err(XaiOAuthError::OAuth(format!( + "device authorization failed: {status}" + ))); + } + let device: DeviceAuthorizationResponse = response.json().await?; + if device.user_code.is_empty() || device.device_code.is_empty() { + return Err(XaiOAuthError::OAuth( + "device authorization response is missing required fields".to_owned(), + )); + } + Ok(device) +} + +pub fn valid_device_verification_url(url: &str) -> bool { + trusted_xai_host(url) +} + +async fn poll_device_token( + token_endpoint: &str, + device_code: &str, + expires_in: u64, + interval: u64, +) -> Result { + let endpoint = require_trusted_endpoint(token_endpoint, "token endpoint")?; + let client = oauth_client(); + let mut interval_secs = interval.max(1); + let deadline = now_secs() + i64::try_from(expires_in).unwrap_or(LOGIN_TIMEOUT_SECS); + loop { + if now_secs() > deadline { + return Err(XaiOAuthError::OAuth("device login timed out".to_owned())); + } + tokio::time::sleep(Duration::from_secs(interval_secs)).await; + let response = client + .post(&endpoint) + .headers(oauth_headers()) + .form(&[ + ("grant_type", DEVICE_GRANT_TYPE), + ("client_id", CLIENT_ID), + ("device_code", device_code), + ]) + .send() + .await?; + let status = response.status(); + if status.is_success() { + let tokens: TokenResponse = response.json().await?; + return parse_token_response(tokens, true); + } + let body: OAuthErrorResponse = response + .json() + .await + .unwrap_or(OAuthErrorResponse { error: None }); + match body.error.as_deref() { + Some("authorization_pending") => {} + Some("slow_down") => interval_secs = interval_secs.saturating_add(5), + Some("access_denied" | "authorization_denied") => { + return Err(XaiOAuthError::OAuth( + "device login access denied".to_owned(), + )); + } + Some("expired_token") => { + return Err(XaiOAuthError::OAuth("device login code expired".to_owned())); + } + Some(code) => { + return Err(XaiOAuthError::OAuth(format!( + "device token polling failed: {code}" + ))); + } + None => { + return Err(XaiOAuthError::OAuth(format!( + "device token polling failed: {status}" + ))); + } + } + } +} + +async fn login_device(status: &mpsc::Sender) -> Result { + let discovery = fetch_device_discovery().await?; + let device = request_device_authorization(&discovery.device_authorization_endpoint).await?; + let url = device + .verification_uri_complete + .as_deref() + .or(device.verification_uri.as_deref()) + .unwrap_or_default(); + if !valid_device_verification_url(url) { + return Err(XaiOAuthError::OAuth( + "device authorization returned an invalid verification URL".to_owned(), + )); + } + let _ = open::that(url); + let _ = status + .send(format!("open {url} and enter code: {}", device.user_code)) + .await; + poll_device_token( + &discovery.token_endpoint, + &device.device_code, + device.expires_in.unwrap_or(900), + device.interval.unwrap_or(5), + ) + .await +} + +pub async fn login(status: &mpsc::Sender) -> Result { + if browser_available() { + match login_browser(status).await { + Err(XaiOAuthError::NoBrowser) => login_device(status).await, + other => other, + } + } else { + login_device(status).await + } +} + +#[cfg(test)] +mod tests { + use super::{ + CLIENT_ID, REDIRECT_URI, SCOPE, build_authorize_url, parse_token_response, + trusted_xai_host, valid_device_verification_url, + }; + + #[test] + fn authorize_url_contains_required_params() { + let url = build_authorize_url( + "https://auth.x.ai/oauth2/authorize", + "challenge", + "state-value", + "nonce-value", + ) + .unwrap(); + let parsed = reqwest::Url::parse(&url).unwrap(); + assert_eq!(parsed.origin().ascii_serialization(), "https://auth.x.ai"); + let pairs: std::collections::HashMap<_, _> = parsed.query_pairs().collect(); + let value = |key: &str| pairs.get(key).map(|value| value.as_ref().to_owned()); + assert_eq!(value("client_id"), Some(CLIENT_ID.to_owned())); + assert_eq!(value("redirect_uri"), Some(REDIRECT_URI.to_owned())); + assert_eq!(value("scope"), Some(SCOPE.to_owned())); + assert_eq!(value("code_challenge"), Some("challenge".to_owned())); + assert_eq!(value("state"), Some("state-value".to_owned())); + assert_eq!(value("nonce"), Some("nonce-value".to_owned())); + assert_eq!(value("referrer"), Some("goat-code".to_owned())); + } + + #[test] + fn rejects_untrusted_authorize_host() { + let err = build_authorize_url( + "https://evil.example/oauth2/authorize", + "challenge", + "state", + "nonce", + ) + .unwrap_err() + .to_string(); + assert!(err.contains("untrusted")); + } + + #[test] + fn trusted_xai_hosts() { + assert!(trusted_xai_host("https://auth.x.ai/oauth2/authorize")); + assert!(trusted_xai_host("https://accounts.x.ai/sign-in")); + assert!(!trusted_xai_host("https://evil.example/oauth")); + assert!(!trusted_xai_host("http://auth.x.ai/oauth")); + } + + #[test] + fn validates_device_verification_url() { + assert!(valid_device_verification_url( + "https://accounts.x.ai/device?code=abc" + )); + assert!(!valid_device_verification_url( + "https://example.com/device?code=abc" + )); + } + + #[test] + fn token_parse_does_not_leak_secrets() { + let err = parse_token_response( + super::TokenResponse { + access_token: String::new(), + refresh_token: Some("refresh-secret".to_owned()), + expires_in: Some(3600), + }, + true, + ) + .unwrap_err() + .to_string(); + assert!(!err.contains("refresh-secret")); + } +} diff --git a/crates/goat-provider-zai-coding/Cargo.toml b/crates/goat-provider-zai-coding/Cargo.toml new file mode 100644 index 0000000..8036e2f --- /dev/null +++ b/crates/goat-provider-zai-coding/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "goat-provider-zai-coding" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true +publish = false + +[dependencies] +goat-auth = { workspace = true } +goat-provider = { workspace = true } +goat-provider-openai-compat = { workspace = true } + +[lints] +workspace = true \ No newline at end of file diff --git a/crates/goat-provider-zai-coding/src/lib.rs b/crates/goat-provider-zai-coding/src/lib.rs new file mode 100644 index 0000000..7a7c2da --- /dev/null +++ b/crates/goat-provider-zai-coding/src/lib.rs @@ -0,0 +1,95 @@ +use goat_auth::CredentialStore; +use goat_provider::{Effort, ProviderMetadata}; +use goat_provider_openai_compat::{ + ChatDiscovery, ChatValidation, OpenAiCompatProvider, api_key, no_vision, +}; + +pub const PROVIDER_ID: &str = "zai-coding"; + +const BASE_URL: &str = "https://api.z.ai/api/coding/paas/v4"; +const HOST: &str = "api.z.ai"; +const ENV_VAR: &str = "ZAI_CODING_API_KEY"; + +const ZAI_CODING_SETUP: &[&str] = &[ + "Z.AI Coding Plan API-key provider.", + "Use `ZAI_CODING_API_KEY` or `goat provider login zai-coding --key sk-...`.", + "This is not OAuth and does not reuse the standard `zai` credential.", +]; + +const CATALOG: &[&str] = &["glm-5.2", "glm-5-turbo", "glm-4.7"]; + +const CONTEXT: &[(&str, u32)] = &[ + ("glm-5.2", 1_000_000), + ("glm-5-turbo", 128_000), + ("glm-4.7", 128_000), +]; + +pub fn build(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { + api_key(store, account, PROVIDER_ID, BASE_URL, HOST, ENV_VAR) + .with_catalog(CATALOG) + .with_context_windows(CONTEXT) + .with_vision_filter(no_vision) + .with_efforts(zai_efforts) + .with_effort_wire(zai_effort_wire) + .with_validation(ChatValidation::CatalogOnly) + .with_discovery(ChatDiscovery::CatalogOnly) + .with_metadata(ProviderMetadata { + env_var: Some(ENV_VAR), + validation: "catalog-only", + endpoint: Some(BASE_URL), + oauth: Some("not OAuth; uses Z.AI Coding Plan API key"), + login_endpoint: None, + setup: ZAI_CODING_SETUP, + }) +} + +fn zai_efforts(model: &str) -> Vec { + if model == "glm-5.2" { + vec![ + Effort::Off, + Effort::Low, + Effort::Medium, + Effort::High, + Effort::Xhigh, + Effort::Max, + ] + } else { + Vec::new() + } +} + +fn zai_effort_wire(effort: Effort) -> Option<&'static str> { + let wire = match effort { + Effort::Off => "none", + Effort::Low => "low", + Effort::Medium => "medium", + Effort::High => "high", + Effort::Xhigh => "xhigh", + Effort::Max => "max", + }; + (!wire.is_empty()).then_some(wire) +} + +#[cfg(test)] +mod tests { + use goat_auth::CredentialStore; + use goat_provider::{AuthMethod, Provider}; + + use super::*; + + fn store(name: &str) -> CredentialStore { + let _ = std::fs::remove_file(std::env::temp_dir().join(name)); + CredentialStore::new(std::env::temp_dir().join(name)) + } + + #[test] + fn zai_coding_is_distinct_api_key_provider() { + let store = store("goat-provider-zai-coding.json"); + let provider = build(&store, "default"); + assert_eq!(provider.capabilities().auth, AuthMethod::ApiKey); + assert_eq!(provider.metadata().env_var, Some("ZAI_CODING_API_KEY")); + assert_eq!(provider.metadata().endpoint, Some(BASE_URL)); + assert_eq!(provider.catalog(), CATALOG); + assert_eq!(provider.context_window("glm-5.2"), Some(1_000_000)); + } +} diff --git a/crates/goat-provider-zai/Cargo.toml b/crates/goat-provider-zai/Cargo.toml new file mode 100644 index 0000000..ef541ea --- /dev/null +++ b/crates/goat-provider-zai/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "goat-provider-zai" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true +publish = false + +[dependencies] +goat-auth = { workspace = true } +goat-provider = { workspace = true } +goat-provider-openai-compat = { workspace = true } + +[lints] +workspace = true \ No newline at end of file diff --git a/crates/goat-provider-zai/src/lib.rs b/crates/goat-provider-zai/src/lib.rs new file mode 100644 index 0000000..59ecc72 --- /dev/null +++ b/crates/goat-provider-zai/src/lib.rs @@ -0,0 +1,111 @@ +use goat_auth::CredentialStore; +use goat_provider::{Effort, ProviderMetadata}; +use goat_provider_openai_compat::{ChatDiscovery, ChatValidation, OpenAiCompatProvider, api_key}; + +pub const PROVIDER_ID: &str = "zai"; + +const BASE_URL: &str = "https://api.z.ai/api/paas/v4"; +const HOST: &str = "api.z.ai"; +const ENV_VAR: &str = "ZAI_API_KEY"; + +const CATALOG: &[&str] = &[ + "glm-5.2", + "glm-5.1", + "glm-5-turbo", + "glm-5", + "glm-4.7", + "glm-4.6", + "glm-4.5", + "glm-4-32b-0414-128k", +]; + +const CONTEXT: &[(&str, u32)] = &[ + ("glm-5.2", 128_000), + ("glm-5.1", 128_000), + ("glm-5", 128_000), + ("glm-4", 128_000), +]; + +pub fn build(store: &CredentialStore, account: &str) -> OpenAiCompatProvider { + api_key(store, account, PROVIDER_ID, BASE_URL, HOST, ENV_VAR) + .with_catalog(CATALOG) + .with_context_windows(CONTEXT) + .with_vision_filter(zai_vision_model) + .with_efforts(zai_efforts) + .with_effort_wire(zai_effort_wire) + .with_validation(ChatValidation::CatalogOnly) + .with_discovery(ChatDiscovery::CatalogOnly) + .with_metadata(ProviderMetadata { + env_var: Some(ENV_VAR), + validation: "catalog-only", + endpoint: None, + oauth: Some("not supported by Z.AI API docs"), + login_endpoint: None, + setup: &[], + }) +} + +fn zai_vision_model(id: &str) -> bool { + let id = id.to_ascii_lowercase(); + id.contains("glm-4v") || id.contains("vision") +} + +fn zai_efforts(model: &str) -> Vec { + if model == "glm-5.2" { + vec![ + Effort::Off, + Effort::Low, + Effort::Medium, + Effort::High, + Effort::Xhigh, + Effort::Max, + ] + } else { + Vec::new() + } +} + +fn zai_effort_wire(effort: Effort) -> Option<&'static str> { + let wire = match effort { + Effort::Off => "none", + Effort::Low => "low", + Effort::Medium => "medium", + Effort::High => "high", + Effort::Xhigh => "xhigh", + Effort::Max => "max", + }; + (!wire.is_empty()).then_some(wire) +} + +#[cfg(test)] +mod tests { + use goat_auth::CredentialStore; + use goat_provider::{Effort, Provider}; + + use super::*; + + fn store(name: &str) -> CredentialStore { + let _ = std::fs::remove_file(std::env::temp_dir().join(name)); + CredentialStore::new(std::env::temp_dir().join(name)) + } + + #[test] + fn metadata_is_exposed() { + let store = store("goat-provider-zai-metadata.json"); + let provider = build(&store, "default"); + assert_eq!(provider.catalog(), CATALOG); + assert_eq!(provider.context_window("glm-5.2"), Some(128_000)); + assert_eq!( + provider.efforts("glm-5.2"), + vec![ + Effort::Off, + Effort::Low, + Effort::Medium, + Effort::High, + Effort::Xhigh, + Effort::Max + ] + ); + assert!(!provider.verifies_credentials()); + } +} diff --git a/crates/goat-providers/Cargo.toml b/crates/goat-providers/Cargo.toml index d4a3898..5e6b488 100644 --- a/crates/goat-providers/Cargo.toml +++ b/crates/goat-providers/Cargo.toml @@ -16,7 +16,16 @@ goat-provider-openai-codex = { workspace = true } goat-provider-anthropic = { workspace = true } goat-provider-gemini = { workspace = true } goat-provider-local = { workspace = true } -goat-provider-hosted = { workspace = true } +goat-provider-openrouter = { workspace = true } +goat-provider-groq = { workspace = true } +goat-provider-deepseek = { workspace = true } +goat-provider-mistral = { workspace = true } +goat-provider-zai = { workspace = true } +goat-provider-zai-coding = { workspace = true } +goat-provider-kimi = { workspace = true } +goat-provider-kimi-code = { workspace = true } +goat-provider-qwen = { workspace = true } +goat-provider-xai = { workspace = true } [lints] -workspace = true +workspace = true \ No newline at end of file diff --git a/crates/goat-providers/src/lib.rs b/crates/goat-providers/src/lib.rs index c53dd23..eafe770 100644 --- a/crates/goat-providers/src/lib.rs +++ b/crates/goat-providers/src/lib.rs @@ -15,25 +15,25 @@ impl Registry { } pub fn load(store: &CredentialStore, account: &str) -> Self { - let mut providers: Vec> = vec![ + let providers: Vec> = vec![ Arc::new(goat_provider_openai::build(store, account)), Arc::new(goat_provider_openai_codex::build(store, account)), Arc::new(goat_provider_anthropic::build(store, account)), Arc::new(goat_provider_gemini::build(store, account)), + Arc::new(goat_provider_openrouter::build(store, account)), + Arc::new(goat_provider_groq::build(store, account)), + Arc::new(goat_provider_deepseek::build(store, account)), + Arc::new(goat_provider_xai::build(store, account)), + Arc::new(goat_provider_mistral::build(store, account)), + Arc::new(goat_provider_zai::build(store, account)), + Arc::new(goat_provider_zai_coding::build(store, account)), + Arc::new(goat_provider_kimi::build(store, account)), + Arc::new(goat_provider_kimi_code::build(store, account)), + Arc::new(goat_provider_qwen::build(store, account)), + Arc::new(goat_provider_local::ollama()), + Arc::new(goat_provider_local::lmstudio()), + Arc::new(goat_provider_local::llama_cpp()), ]; - providers.extend( - goat_provider_hosted::all(store, account) - .into_iter() - .map(|provider| Arc::new(provider) as Arc), - ); - providers.push(Arc::new(goat_provider_hosted::build_kimi_code( - store, account, - ))); - providers.extend([ - Arc::new(goat_provider_local::ollama()) as Arc, - Arc::new(goat_provider_local::lmstudio()) as Arc, - Arc::new(goat_provider_local::llama_cpp()) as Arc, - ]); Self { providers } } @@ -65,7 +65,7 @@ impl Registry { #[cfg(test)] mod tests { - use goat_provider::ProviderId; + use goat_provider::{AuthMethod, ProviderId}; use super::Registry; @@ -80,7 +80,14 @@ mod tests { assert!(registry.get(&ProviderId::from("openrouter")).is_some()); assert!(registry.get(&ProviderId::from("groq")).is_some()); assert!(registry.get(&ProviderId::from("deepseek")).is_some()); - assert!(registry.get(&ProviderId::from("xai")).is_some()); + let xai = registry + .get(&ProviderId::from("xai")) + .expect("xai provider"); + assert_eq!(xai.capabilities().auth, AuthMethod::ApiKeyOrOAuth); + assert_eq!( + xai.metadata().oauth, + Some("browser or device code (SuperGrok / X Premium+)") + ); assert!(registry.get(&ProviderId::from("mistral")).is_some()); assert!(registry.get(&ProviderId::from("zai")).is_some()); assert!(registry.get(&ProviderId::from("zai-coding")).is_some());