Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions crates/goat-agent/src/accounts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,15 @@ pub(crate) async fn discover_ready(
model_list_entries(&providers, credentials).await
}

pub(crate) async fn refresh_model_list(
events: &mpsc::Sender<Event>,
registry: &Registry,
credentials: &CredentialStore,
) {
let entries = discover_ready(registry, credentials).await;
let _ = events.send(Event::ModelListChanged { entries }).await;
}

pub(crate) fn clear_account_registries(cache: &std::sync::Mutex<HashMap<String, Arc<Registry>>>) {
cache
.lock()
Expand Down
33 changes: 25 additions & 8 deletions crates/goat-agent/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ mod websearch;

pub use agent::{AgentRegistry, AgentSpec, ToolSelection};

pub async fn model_list_entries(credentials: &CredentialStore) -> Vec<goat_protocol::ModelEntry> {
let registry = Registry::new(credentials);
accounts::discover_ready(&registry, credentials).await
}

const CHILD_ID_BASE: u64 = 1 << 32;

pub struct GoatAgent {
Expand Down Expand Up @@ -387,6 +392,7 @@ async fn run(agent: GoatAgent, mut ops: mpsc::Receiver<Op>, events: mpsc::Sender
&events,
)
.await;
accounts::refresh_model_list(&events, &registry, &credentials).await;
}
Op::ResumeLatest {} => {
threads::handle_resume_latest(
Expand All @@ -400,6 +406,7 @@ async fn run(agent: GoatAgent, mut ops: mpsc::Receiver<Op>, events: mpsc::Sender
&events,
)
.await;
accounts::refresh_model_list(&events, &registry, &credentials).await;
}
Op::RenameThread { title } => {
threads::handle_rename(&store, state.thread_id, title, &events).await;
Expand Down Expand Up @@ -1566,17 +1573,27 @@ mod tests {
assert_eq!(messages[0].role, "shell");

ops.send(Op::Resume { thread_id: 1 }).await.unwrap();
let mut restored = false;
let mut refreshed = false;
while let Some(event) = events.recv().await {
if let Event::ConversationRestored { entries, .. } = event {
assert!(entries.iter().any(|entry| matches!(
entry,
goat_protocol::TranscriptEntry::Shell { command, output }
if command == "echo persisted" && output.contains("persisted")
)));
return;
match event {
Event::ConversationRestored { entries, .. } => {
assert!(entries.iter().any(|entry| matches!(
entry,
goat_protocol::TranscriptEntry::Shell { command, output }
if command == "echo persisted" && output.contains("persisted")
)));
restored = true;
}
Event::ModelListChanged { .. } if restored => {
refreshed = true;
break;
}
_ => {}
}
}
panic!("expected ConversationRestored");
assert!(restored, "expected ConversationRestored");
assert!(refreshed, "expected ModelListChanged after resume");
}

#[tokio::test]
Expand Down
20 changes: 20 additions & 0 deletions crates/goat-daemon/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,15 @@ impl Manager {
break;
}
}
let catalog_refresh = {
let inner = live.inner.lock().await;
if inner.snapshot.is_some() {
let credentials = CredentialStore::new(self.inner.auth_path.clone());
Some(goat_agent::model_list_entries(&credentials).await)
} else {
None
}
};
let (backlog, live_rx) = {
let mut inner = live.inner.lock().await;
let mut backlog = Vec::new();
Expand All @@ -292,6 +301,17 @@ impl Manager {
event: event.clone(),
});
}
if let Some(entries) = catalog_refresh {
let event = goat_protocol::Event::ModelListChanged { entries };
let seq = inner.next_seq;
inner.next_seq += 1;
inner.log.push_back((seq, event.clone()));
backlog.push(ServerFrame::Event {
session,
seq,
event,
});
}
let (bridge_tx, bridge_rx) = mpsc::channel(SUBSCRIBER_QUEUE);
crate::session::subscriber_upsert(&mut inner.subscribers, client, bridge_tx);
let clients = inner.presence();
Expand Down
97 changes: 97 additions & 0 deletions crates/goat-provider-openai-compat/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct ChatOptions {
discovery: ChatDiscovery,
model_list_source: Option<ModelListSource>,
metadata: ProviderMetadata,
extra_headers: &'static [(&'static str, &'static str)],
}

impl Default for ChatOptions {
Expand All @@ -59,6 +60,7 @@ impl Default for ChatOptions {
discovery: ChatDiscovery::ModelsEndpoint,
model_list_source: None,
metadata: ProviderMetadata::default(),
extra_headers: &[],
}
}
}
Expand Down Expand Up @@ -193,6 +195,12 @@ impl OpenAiCompatProvider {
self.options.metadata = metadata;
self
}

#[must_use]
pub fn with_extra_headers(mut self, headers: &'static [(&'static str, &'static str)]) -> Self {
self.options.extra_headers = headers;
self
}
}

fn normalize_base_url(base_url: &str) -> String {
Expand Down Expand Up @@ -610,6 +618,9 @@ impl Provider for OpenAiCompatProvider {
if let Some(token) = &bearer {
builder = builder.bearer_auth(token);
}
for (name, value) in options.extra_headers {
builder = builder.header(*name, *value);
}
let resp = match builder.send().await {
Ok(resp) => resp,
Err(err) => {
Expand Down Expand Up @@ -875,4 +886,90 @@ mod tests {
]
);
}

#[tokio::test]
async fn stream_sends_extra_headers() {
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};

use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpListener,
sync::mpsc,
};

const HEADERS: &[(&str, &str)] = &[
("User-Agent", "xai-grok-cli"),
("x-grok-client-version", "0.2.82"),
("x-grok-client-identifier", "xai-grok-cli"),
];

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let saw_version = Arc::new(AtomicBool::new(false));
let saw_identifier = Arc::new(AtomicBool::new(false));
let saw_user_agent = Arc::new(AtomicBool::new(false));
let saw_version_server = saw_version.clone();
let saw_identifier_server = saw_identifier.clone();
let saw_user_agent_server = saw_user_agent.clone();

let server = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
let mut buf = vec![0u8; 16_384];
let n = socket.read(&mut buf).await.unwrap();
let request = String::from_utf8_lossy(&buf[..n]);
if request.contains("x-grok-client-version: 0.2.82") {
saw_version_server.store(true, Ordering::SeqCst);
}
if request.contains("x-grok-client-identifier: xai-grok-cli") {
saw_identifier_server.store(true, Ordering::SeqCst);
}
if request
.to_ascii_lowercase()
.contains("user-agent: xai-grok-cli")
{
saw_user_agent_server.store(true, Ordering::SeqCst);
}
let body = concat!(
"data: {\"choices\":[{\"delta\":{\"content\":\"ok\"},\"finish_reason\":null}]}\n\n",
"data: [DONE]\n\n"
);
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\n\r\n{body}",
body.len()
);
socket.write_all(response.as_bytes()).await.unwrap();
});

let provider = OpenAiCompatProvider::new(
"test".into(),
format!("http://{addr}/v1"),
None,
AuthMethod::None,
)
.with_extra_headers(HEADERS);
let (events, mut rx) = mpsc::channel(8);
let handle = provider.stream(
Request {
model: "grok-composer-2.5-fast".to_owned(),
messages: vec![Message::text(MessageRole::User, "hi")],
tools: Vec::new(),
effort: None,
tool_choice: ToolChoice::None,
},
events,
);
let _ = handle.await;
server.await.unwrap();
assert!(saw_version.load(Ordering::SeqCst));
assert!(saw_identifier.load(Ordering::SeqCst));
assert!(saw_user_agent.load(Ordering::SeqCst));
assert!(matches!(
rx.recv().await,
Some(StreamEvent::TextDelta { .. })
));
assert!(matches!(rx.recv().await, Some(StreamEvent::Completed)));
}
}
105 changes: 103 additions & 2 deletions crates/goat-provider-xai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ const SETUP: &[&str] = &[

const COMPOSER_CATALOG: &[&str] = &["grok-composer-2.5-fast"];

const GROK_CLI_CLIENT_VERSION: &str = "0.2.82";
const GROK_CLI_CLIENT_IDENTIFIER: &str = "xai-grok-cli";
const GROK_CLI_USER_AGENT: &str = "xai-grok-cli";

const GROK_CLI_HEADERS: &[(&str, &str)] = &[
("User-Agent", GROK_CLI_USER_AGENT),
("x-grok-client-version", GROK_CLI_CLIENT_VERSION),
("x-grok-client-identifier", GROK_CLI_CLIENT_IDENTIFIER),
];

const OAUTH_CATALOG: &[&str] = &[
"grok-composer-2.5-fast",
"grok-4.3",
Expand Down Expand Up @@ -91,11 +101,28 @@ impl XaiProvider {

async fn resolve_auth(&self) -> Option<XaiAuth> {
let cred = self.store.resolve(&self.key, Some("XAI_API_KEY"))?;
Self::auth_from_credential(&self.store, &self.key, cred).await
}

async fn resolve_auth_for_model(&self, model: &str) -> Option<XaiAuth> {
if Self::is_oauth_model(model)
&& let Some(cred @ Credential::OAuth(_)) = self.store.get(&self.key)
{
return Self::auth_from_credential(&self.store, &self.key, cred).await;
}
self.resolve_auth().await
}

async fn auth_from_credential(
store: &CredentialStore,
key: &CredentialKey,
cred: Credential,
) -> Option<XaiAuth> {
match cred {
Credential::ApiKey(secret) | Credential::ApiKeyWithEndpoint { secret, .. } => {
Some(XaiAuth::ApiKey(secret.expose().to_owned()))
}
Credential::OAuth(_) => oauth::current_oauth_token(&self.store, &self.key)
Credential::OAuth(_) => oauth::current_oauth_token(store, key)
.await
.map(XaiAuth::OAuth),
}
Expand Down Expand Up @@ -144,6 +171,7 @@ impl XaiProvider {
.with_vision_filter(no_vision)
.with_efforts(no_efforts)
.with_reasoning_effort(false)
.with_extra_headers(GROK_CLI_HEADERS)
}

async fn emit_models(provider: &XaiProvider, out: &mpsc::Sender<Model>) -> bool {
Expand Down Expand Up @@ -260,7 +288,7 @@ impl Provider for XaiProvider {
let model = req.model.clone();
tokio::spawn(async move {
let provider = XaiProvider { store, key };
let Some(auth) = provider.resolve_auth().await else {
let Some(auth) = provider.resolve_auth_for_model(&model).await else {
let _ = events
.send(StreamEvent::Failed {
error: StreamError::auth("no credentials"),
Expand Down Expand Up @@ -419,4 +447,77 @@ mod tests {
.collect::<Vec<_>>()
);
}

#[tokio::test]
async fn oauth_model_prefers_stored_oauth_for_composer() {
use goat_auth::{Credential, CredentialKey, TokenSet};

let store = store("goat-provider-xai-oauth-pref.json");
store
.store(
&CredentialKey::model(PROVIDER_ID, "default"),
Credential::OAuth(TokenSet::from_parts(
"oauth-access".to_owned(),
None,
Some(3600),
None,
)),
)
.unwrap();
let provider = XaiProvider::new(store, CredentialKey::model(PROVIDER_ID, "default"));
let auth = provider
.resolve_auth_for_model("grok-composer-2.5-fast")
.await
.expect("oauth should win for composer");
assert!(matches!(auth, XaiAuth::OAuth(token) if token == "oauth-access"));
}

#[tokio::test]
#[ignore = "live network and OAuth credentials required"]
async fn composer_proxy_passes_grok_cli_version_gate() {
use goat_auth::CredentialKey;
use goat_provider::{Message, MessageRole, Request, StreamEvent, ToolChoice};

let auth_path = std::env::var("HOME")
.map(|home| std::path::PathBuf::from(home).join(".goat-code/auth.json"))
.expect("HOME");
if !auth_path.is_file() {
return;
}
let store = CredentialStore::new(auth_path);
let key = CredentialKey::model(PROVIDER_ID, "default");
let Some(Credential::OAuth(_)) = store.get(&key) else {
return;
};
let provider = build(&store, "default");
let (events, mut rx) = tokio::sync::mpsc::channel(8);
let handle = provider.stream(
Request {
model: "grok-composer-2.5-fast".to_owned(),
messages: vec![Message::text(MessageRole::User, "Reply with exactly: ok")],
tools: Vec::new(),
effort: None,
tool_choice: ToolChoice::None,
},
events,
);
while let Some(event) = rx.recv().await {
if let StreamEvent::Failed { error } = event {
let message = error.to_string();
assert!(
!message.contains("426"),
"composer proxy rejected client version: {message}"
);
assert!(
!message.contains("outdated"),
"composer proxy rejected client version: {message}"
);
panic!("composer request failed: {message}");
}
if matches!(event, StreamEvent::Completed) {
break;
}
}
let _ = handle.await;
}
}
Loading