Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 3 additions & 86 deletions crates/goat-provider-xai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,20 @@ use goat_provider::{
StreamError, StreamEvent, WebSearchOutput,
};
use goat_provider_openai_compat::{
OpenAiCompatProvider, ResponsesProvider, enforce_https_host, no_efforts, no_vision,
OpenAiCompatProvider, ResponsesProvider, enforce_https_host, no_efforts,
};
use tokio::{sync::mpsc, task::JoinHandle};

pub const PROVIDER_ID: &str = "xai";

const BASE_URL: &str = "https://api.x.ai/v1";
const ALLOWED_HOST: &str = "api.x.ai";
const CLI_BASE_URL: &str = "https://cli-chat-proxy.grok.com/v1";
const CLI_ALLOWED_HOST: &str = "cli-chat-proxy.grok.com";

const SETUP: &[&str] = &[
"xAI Grok provider (API key or SuperGrok / X Premium+ OAuth).",
"API key: `goat provider login xai --key xai-...` or `XAI_API_KEY`.",
"OAuth: `goat provider login xai` (browser or device code; no API key).",
"OAuth includes Composer 2.5 (`grok-composer-2.5-fast`) via the Grok CLI proxy.",
];

const COMPOSER_CATALOG: &[&str] = &["grok-composer-2.5-fast"];

const GROK_CLI_CLIENT_VERSION: &str = "0.2.82";
const GROK_CLI_CLIENT_IDENTIFIER: &str = "xai-grok-cli";
const GROK_CLI_USER_AGENT: &str = "xai-grok-cli";

const GROK_CLI_HEADERS: &[(&str, &str)] = &[
("User-Agent", GROK_CLI_USER_AGENT),
("x-grok-client-version", GROK_CLI_CLIENT_VERSION),
("x-grok-client-identifier", GROK_CLI_CLIENT_IDENTIFIER),
"OAuth coding models (Composer, Grok Build) use the same api.x.ai Responses API.",
];

const OAUTH_CATALOG: &[&str] = &[
Expand Down Expand Up @@ -80,7 +66,6 @@ const API_KEY_CONTEXT: &[(&str, u32)] = &[("grok-4", 256_000), ("grok-3", 131_07

pub fn build(store: &CredentialStore, account: &str) -> XaiProvider {
enforce_https_host(BASE_URL, ALLOWED_HOST).expect("xai provider base URL");
enforce_https_host(CLI_BASE_URL, CLI_ALLOWED_HOST).expect("xai composer base URL");
XaiProvider::new(store.clone(), CredentialKey::model(PROVIDER_ID, account))
}

Expand Down Expand Up @@ -159,21 +144,6 @@ impl XaiProvider {
.with_model_filter(oauth_chat_model)
}

fn composer_provider(bearer: String) -> OpenAiCompatProvider {
OpenAiCompatProvider::new(
ProviderId::from(PROVIDER_ID),
CLI_BASE_URL,
Some(bearer),
AuthMethod::ApiKeyOrOAuth,
)
.with_catalog(COMPOSER_CATALOG)
.with_images(false)
.with_vision_filter(no_vision)
.with_efforts(no_efforts)
.with_reasoning_effort(false)
.with_extra_headers(GROK_CLI_HEADERS)
}

async fn emit_models(provider: &XaiProvider, out: &mpsc::Sender<Model>) -> bool {
for id in provider.list_models() {
if out
Expand Down Expand Up @@ -312,11 +282,7 @@ impl Provider for XaiProvider {
.await;
return;
}
if model.starts_with("grok-composer") {
XaiProvider::composer_provider(bearer).stream(req, events)
} else {
XaiProvider::responses_provider(bearer).stream(req, events)
}
XaiProvider::responses_provider(bearer).stream(req, events)
}
};
let _ = handle.await;
Expand Down Expand Up @@ -471,53 +437,4 @@ mod tests {
.expect("oauth should win for composer");
assert!(matches!(auth, XaiAuth::OAuth(token) if token == "oauth-access"));
}

#[tokio::test]
#[ignore = "live network and OAuth credentials required"]
async fn composer_proxy_passes_grok_cli_version_gate() {
use goat_auth::CredentialKey;
use goat_provider::{Message, MessageRole, Request, StreamEvent, ToolChoice};

let auth_path = std::env::var("HOME")
.map(|home| std::path::PathBuf::from(home).join(".goat-code/auth.json"))
.expect("HOME");
if !auth_path.is_file() {
return;
}
let store = CredentialStore::new(auth_path);
let key = CredentialKey::model(PROVIDER_ID, "default");
let Some(Credential::OAuth(_)) = store.get(&key) else {
return;
};
let provider = build(&store, "default");
let (events, mut rx) = tokio::sync::mpsc::channel(8);
let handle = provider.stream(
Request {
model: "grok-composer-2.5-fast".to_owned(),
messages: vec![Message::text(MessageRole::User, "Reply with exactly: ok")],
tools: Vec::new(),
effort: None,
tool_choice: ToolChoice::None,
},
events,
);
while let Some(event) = rx.recv().await {
if let StreamEvent::Failed { error } = event {
let message = error.to_string();
assert!(
!message.contains("426"),
"composer proxy rejected client version: {message}"
);
assert!(
!message.contains("outdated"),
"composer proxy rejected client version: {message}"
);
panic!("composer request failed: {message}");
}
if matches!(event, StreamEvent::Completed) {
break;
}
}
let _ = handle.await;
}
}
Loading