diff --git a/crates/goat-agent/src/accounts.rs b/crates/goat-agent/src/accounts.rs index 4672ad0..b6d136f 100644 --- a/crates/goat-agent/src/accounts.rs +++ b/crates/goat-agent/src/accounts.rs @@ -526,6 +526,15 @@ pub(crate) async fn discover_ready( model_list_entries(&providers, credentials).await } +pub(crate) async fn refresh_model_list( + events: &mpsc::Sender, + registry: &Registry, + credentials: &CredentialStore, +) { + let entries = discover_ready(registry, credentials).await; + let _ = events.send(Event::ModelListChanged { entries }).await; +} + pub(crate) fn clear_account_registries(cache: &std::sync::Mutex>>) { cache .lock() diff --git a/crates/goat-agent/src/lib.rs b/crates/goat-agent/src/lib.rs index 39ee809..b2951a5 100644 --- a/crates/goat-agent/src/lib.rs +++ b/crates/goat-agent/src/lib.rs @@ -39,6 +39,11 @@ mod websearch; pub use agent::{AgentRegistry, AgentSpec, ToolSelection}; +pub async fn model_list_entries(credentials: &CredentialStore) -> Vec { + let registry = Registry::new(credentials); + accounts::discover_ready(®istry, credentials).await +} + const CHILD_ID_BASE: u64 = 1 << 32; pub struct GoatAgent { @@ -387,6 +392,7 @@ async fn run(agent: GoatAgent, mut ops: mpsc::Receiver, events: mpsc::Sender &events, ) .await; + accounts::refresh_model_list(&events, ®istry, &credentials).await; } Op::ResumeLatest {} => { threads::handle_resume_latest( @@ -400,6 +406,7 @@ async fn run(agent: GoatAgent, mut ops: mpsc::Receiver, events: mpsc::Sender &events, ) .await; + accounts::refresh_model_list(&events, ®istry, &credentials).await; } Op::RenameThread { title } => { threads::handle_rename(&store, state.thread_id, title, &events).await; @@ -1566,17 +1573,27 @@ mod tests { assert_eq!(messages[0].role, "shell"); ops.send(Op::Resume { thread_id: 1 }).await.unwrap(); + let mut restored = false; + let mut refreshed = false; while let Some(event) = events.recv().await { - if let Event::ConversationRestored { entries, .. } = event { - assert!(entries.iter().any(|entry| matches!( - entry, - goat_protocol::TranscriptEntry::Shell { command, output } - if command == "echo persisted" && output.contains("persisted") - ))); - return; + match event { + Event::ConversationRestored { entries, .. } => { + assert!(entries.iter().any(|entry| matches!( + entry, + goat_protocol::TranscriptEntry::Shell { command, output } + if command == "echo persisted" && output.contains("persisted") + ))); + restored = true; + } + Event::ModelListChanged { .. } if restored => { + refreshed = true; + break; + } + _ => {} } } - panic!("expected ConversationRestored"); + assert!(restored, "expected ConversationRestored"); + assert!(refreshed, "expected ModelListChanged after resume"); } #[tokio::test] diff --git a/crates/goat-daemon/src/manager.rs b/crates/goat-daemon/src/manager.rs index 900b16e..88719d7 100644 --- a/crates/goat-daemon/src/manager.rs +++ b/crates/goat-daemon/src/manager.rs @@ -267,6 +267,15 @@ impl Manager { break; } } + let catalog_refresh = { + let inner = live.inner.lock().await; + if inner.snapshot.is_some() { + let credentials = CredentialStore::new(self.inner.auth_path.clone()); + Some(goat_agent::model_list_entries(&credentials).await) + } else { + None + } + }; let (backlog, live_rx) = { let mut inner = live.inner.lock().await; let mut backlog = Vec::new(); @@ -292,6 +301,17 @@ impl Manager { event: event.clone(), }); } + if let Some(entries) = catalog_refresh { + let event = goat_protocol::Event::ModelListChanged { entries }; + let seq = inner.next_seq; + inner.next_seq += 1; + inner.log.push_back((seq, event.clone())); + backlog.push(ServerFrame::Event { + session, + seq, + event, + }); + } let (bridge_tx, bridge_rx) = mpsc::channel(SUBSCRIBER_QUEUE); crate::session::subscriber_upsert(&mut inner.subscribers, client, bridge_tx); let clients = inner.presence(); diff --git a/crates/goat-provider-openai-compat/src/chat.rs b/crates/goat-provider-openai-compat/src/chat.rs index 2a8df8a..6c46f78 100644 --- a/crates/goat-provider-openai-compat/src/chat.rs +++ b/crates/goat-provider-openai-compat/src/chat.rs @@ -40,6 +40,7 @@ struct ChatOptions { discovery: ChatDiscovery, model_list_source: Option, metadata: ProviderMetadata, + extra_headers: &'static [(&'static str, &'static str)], } impl Default for ChatOptions { @@ -59,6 +60,7 @@ impl Default for ChatOptions { discovery: ChatDiscovery::ModelsEndpoint, model_list_source: None, metadata: ProviderMetadata::default(), + extra_headers: &[], } } } @@ -193,6 +195,12 @@ impl OpenAiCompatProvider { self.options.metadata = metadata; self } + + #[must_use] + pub fn with_extra_headers(mut self, headers: &'static [(&'static str, &'static str)]) -> Self { + self.options.extra_headers = headers; + self + } } fn normalize_base_url(base_url: &str) -> String { @@ -610,6 +618,9 @@ impl Provider for OpenAiCompatProvider { if let Some(token) = &bearer { builder = builder.bearer_auth(token); } + for (name, value) in options.extra_headers { + builder = builder.header(*name, *value); + } let resp = match builder.send().await { Ok(resp) => resp, Err(err) => { @@ -875,4 +886,90 @@ mod tests { ] ); } + + #[tokio::test] + async fn stream_sends_extra_headers() { + use std::sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }; + + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpListener, + sync::mpsc, + }; + + const HEADERS: &[(&str, &str)] = &[ + ("User-Agent", "xai-grok-cli"), + ("x-grok-client-version", "0.2.82"), + ("x-grok-client-identifier", "xai-grok-cli"), + ]; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let saw_version = Arc::new(AtomicBool::new(false)); + let saw_identifier = Arc::new(AtomicBool::new(false)); + let saw_user_agent = Arc::new(AtomicBool::new(false)); + let saw_version_server = saw_version.clone(); + let saw_identifier_server = saw_identifier.clone(); + let saw_user_agent_server = saw_user_agent.clone(); + + let server = tokio::spawn(async move { + let (mut socket, _) = listener.accept().await.unwrap(); + let mut buf = vec![0u8; 16_384]; + let n = socket.read(&mut buf).await.unwrap(); + let request = String::from_utf8_lossy(&buf[..n]); + if request.contains("x-grok-client-version: 0.2.82") { + saw_version_server.store(true, Ordering::SeqCst); + } + if request.contains("x-grok-client-identifier: xai-grok-cli") { + saw_identifier_server.store(true, Ordering::SeqCst); + } + if request + .to_ascii_lowercase() + .contains("user-agent: xai-grok-cli") + { + saw_user_agent_server.store(true, Ordering::SeqCst); + } + let body = concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"ok\"},\"finish_reason\":null}]}\n\n", + "data: [DONE]\n\n" + ); + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\n\r\n{body}", + body.len() + ); + socket.write_all(response.as_bytes()).await.unwrap(); + }); + + let provider = OpenAiCompatProvider::new( + "test".into(), + format!("http://{addr}/v1"), + None, + AuthMethod::None, + ) + .with_extra_headers(HEADERS); + let (events, mut rx) = mpsc::channel(8); + let handle = provider.stream( + Request { + model: "grok-composer-2.5-fast".to_owned(), + messages: vec![Message::text(MessageRole::User, "hi")], + tools: Vec::new(), + effort: None, + tool_choice: ToolChoice::None, + }, + events, + ); + let _ = handle.await; + server.await.unwrap(); + assert!(saw_version.load(Ordering::SeqCst)); + assert!(saw_identifier.load(Ordering::SeqCst)); + assert!(saw_user_agent.load(Ordering::SeqCst)); + assert!(matches!( + rx.recv().await, + Some(StreamEvent::TextDelta { .. }) + )); + assert!(matches!(rx.recv().await, Some(StreamEvent::Completed))); + } } diff --git a/crates/goat-provider-xai/src/lib.rs b/crates/goat-provider-xai/src/lib.rs index c2d674d..e6a11d9 100644 --- a/crates/goat-provider-xai/src/lib.rs +++ b/crates/goat-provider-xai/src/lib.rs @@ -26,6 +26,16 @@ const SETUP: &[&str] = &[ const COMPOSER_CATALOG: &[&str] = &["grok-composer-2.5-fast"]; +const GROK_CLI_CLIENT_VERSION: &str = "0.2.82"; +const GROK_CLI_CLIENT_IDENTIFIER: &str = "xai-grok-cli"; +const GROK_CLI_USER_AGENT: &str = "xai-grok-cli"; + +const GROK_CLI_HEADERS: &[(&str, &str)] = &[ + ("User-Agent", GROK_CLI_USER_AGENT), + ("x-grok-client-version", GROK_CLI_CLIENT_VERSION), + ("x-grok-client-identifier", GROK_CLI_CLIENT_IDENTIFIER), +]; + const OAUTH_CATALOG: &[&str] = &[ "grok-composer-2.5-fast", "grok-4.3", @@ -91,11 +101,28 @@ impl XaiProvider { async fn resolve_auth(&self) -> Option { let cred = self.store.resolve(&self.key, Some("XAI_API_KEY"))?; + Self::auth_from_credential(&self.store, &self.key, cred).await + } + + async fn resolve_auth_for_model(&self, model: &str) -> Option { + if Self::is_oauth_model(model) + && let Some(cred @ Credential::OAuth(_)) = self.store.get(&self.key) + { + return Self::auth_from_credential(&self.store, &self.key, cred).await; + } + self.resolve_auth().await + } + + async fn auth_from_credential( + store: &CredentialStore, + key: &CredentialKey, + cred: Credential, + ) -> Option { match cred { Credential::ApiKey(secret) | Credential::ApiKeyWithEndpoint { secret, .. } => { Some(XaiAuth::ApiKey(secret.expose().to_owned())) } - Credential::OAuth(_) => oauth::current_oauth_token(&self.store, &self.key) + Credential::OAuth(_) => oauth::current_oauth_token(store, key) .await .map(XaiAuth::OAuth), } @@ -144,6 +171,7 @@ impl XaiProvider { .with_vision_filter(no_vision) .with_efforts(no_efforts) .with_reasoning_effort(false) + .with_extra_headers(GROK_CLI_HEADERS) } async fn emit_models(provider: &XaiProvider, out: &mpsc::Sender) -> bool { @@ -260,7 +288,7 @@ impl Provider for XaiProvider { let model = req.model.clone(); tokio::spawn(async move { let provider = XaiProvider { store, key }; - let Some(auth) = provider.resolve_auth().await else { + let Some(auth) = provider.resolve_auth_for_model(&model).await else { let _ = events .send(StreamEvent::Failed { error: StreamError::auth("no credentials"), @@ -419,4 +447,77 @@ mod tests { .collect::>() ); } + + #[tokio::test] + async fn oauth_model_prefers_stored_oauth_for_composer() { + use goat_auth::{Credential, CredentialKey, TokenSet}; + + let store = store("goat-provider-xai-oauth-pref.json"); + store + .store( + &CredentialKey::model(PROVIDER_ID, "default"), + Credential::OAuth(TokenSet::from_parts( + "oauth-access".to_owned(), + None, + Some(3600), + None, + )), + ) + .unwrap(); + let provider = XaiProvider::new(store, CredentialKey::model(PROVIDER_ID, "default")); + let auth = provider + .resolve_auth_for_model("grok-composer-2.5-fast") + .await + .expect("oauth should win for composer"); + assert!(matches!(auth, XaiAuth::OAuth(token) if token == "oauth-access")); + } + + #[tokio::test] + #[ignore = "live network and OAuth credentials required"] + async fn composer_proxy_passes_grok_cli_version_gate() { + use goat_auth::CredentialKey; + use goat_provider::{Message, MessageRole, Request, StreamEvent, ToolChoice}; + + let auth_path = std::env::var("HOME") + .map(|home| std::path::PathBuf::from(home).join(".goat-code/auth.json")) + .expect("HOME"); + if !auth_path.is_file() { + return; + } + let store = CredentialStore::new(auth_path); + let key = CredentialKey::model(PROVIDER_ID, "default"); + let Some(Credential::OAuth(_)) = store.get(&key) else { + return; + }; + let provider = build(&store, "default"); + let (events, mut rx) = tokio::sync::mpsc::channel(8); + let handle = provider.stream( + Request { + model: "grok-composer-2.5-fast".to_owned(), + messages: vec![Message::text(MessageRole::User, "Reply with exactly: ok")], + tools: Vec::new(), + effort: None, + tool_choice: ToolChoice::None, + }, + events, + ); + while let Some(event) = rx.recv().await { + if let StreamEvent::Failed { error } = event { + let message = error.to_string(); + assert!( + !message.contains("426"), + "composer proxy rejected client version: {message}" + ); + assert!( + !message.contains("outdated"), + "composer proxy rejected client version: {message}" + ); + panic!("composer request failed: {message}"); + } + if matches!(event, StreamEvent::Completed) { + break; + } + } + let _ = handle.await; + } }