Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions crates/forge_main/src/ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1874,7 +1874,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
self.on_custom_event(event.into()).await?;
}
SlashCommand::Model => {
self.on_model_selection().await?;
self.on_model_selection(None).await?;
}
SlashCommand::Provider => {
self.on_provider_selection().await?;
Expand Down Expand Up @@ -2024,11 +2024,18 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
/// Shows columns: MODEL, PROVIDER, CONTEXT WINDOW, TOOL SUPPORTED, IMAGE
/// with a non-selectable header row.
///
/// When `provider_filter` is `Some`, only models belonging to that provider
/// are shown. This is used during onboarding so that after a provider is
/// selected the model list is scoped to that provider only.
///
/// # Returns
/// - `Ok(Some(ModelId))` if a model was selected
/// - `Ok(None)` if selection was canceled
#[async_recursion::async_recursion]
async fn select_model(&mut self) -> Result<Option<ModelId>> {
async fn select_model(
&mut self,
provider_filter: Option<ProviderId>,
) -> Result<Option<ModelId>> {
// Check if provider is set otherwise first ask to select a provider
if self.api.get_default_provider().await.is_err() {
self.on_provider_selection().await?;
Expand All @@ -2042,11 +2049,19 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
}

// Fetch models from ALL configured providers (matches shell plugin's
// `forge list models --porcelain`)
// `forge list models --porcelain`), then optionally filter by provider.
self.spinner.start(Some("Loading"))?;
let mut all_provider_models = self.api.get_all_provider_models().await?;
self.spinner.stop(None)?;

// When a provider filter is specified (e.g. during onboarding after a
// provider was just selected), restrict the list to that provider's
// models so the user cannot accidentally pick a model from a different
// provider.
if let Some(ref filter_id) = provider_filter {
all_provider_models.retain(|pm| &pm.provider_id == filter_id);
}

if all_provider_models.is_empty() {
return Ok(None);
}
Expand Down Expand Up @@ -2636,11 +2651,15 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
self.select_provider_from_list(providers, "Provider", current_provider_id)
}

// Helper method to handle model selection and update the conversation
// Helper method to handle model selection and update the conversation.
// When `provider_filter` is `Some`, only models from that provider are shown.
#[async_recursion::async_recursion]
async fn on_model_selection(&mut self) -> Result<Option<ModelId>> {
async fn on_model_selection(
&mut self,
provider_filter: Option<ProviderId>,
) -> Result<Option<ModelId>> {
// Select a model
let model_option = self.select_model().await?;
let model_option = self.select_model(provider_filter).await?;

// If no model was selected (user canceled), return early
let model = match model_option {
Expand Down Expand Up @@ -2742,13 +2761,13 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
let model_available = models.iter().any(|m| m.id == current_model);

if !model_available {
// Prompt user to select a new model
// Prompt user to select a new model, scoped to the activated provider
self.writeln_title(TitleFormat::info("Please select a new model"))?;
self.on_model_selection().await?;
self.on_model_selection(Some(provider.id.clone())).await?;
}
} else {
// No model set, select one now
self.on_model_selection().await?;
// No model set, select one now scoped to the activated provider
self.on_model_selection(Some(provider.id.clone())).await?;
}

Ok(())
Expand Down Expand Up @@ -2862,7 +2881,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
let mut operating_model = self.get_agent_model(active_agent.clone()).await;
if operating_model.is_none() {
// Use the model returned from selection instead of re-fetching
operating_model = self.on_model_selection().await?;
operating_model = self.on_model_selection(None).await?;
}

// Validate provider is configured before loading agents
Expand Down
Loading