diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 847ad4e68d..26e63ab66d 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -1874,7 +1874,7 @@ impl A + Send + Sync> UI { self.on_custom_event(event.into()).await?; } SlashCommand::Model => { - self.on_model_selection().await?; + self.on_model_selection(None).await?; } SlashCommand::Provider => { self.on_provider_selection().await?; @@ -2024,11 +2024,18 @@ impl A + Send + Sync> UI { /// Shows columns: MODEL, PROVIDER, CONTEXT WINDOW, TOOL SUPPORTED, IMAGE /// with a non-selectable header row. /// + /// When `provider_filter` is `Some`, only models belonging to that provider + /// are shown. This is used during onboarding so that after a provider is + /// selected the model list is scoped to that provider only. + /// /// # Returns /// - `Ok(Some(ModelId))` if a model was selected /// - `Ok(None)` if selection was canceled #[async_recursion::async_recursion] - async fn select_model(&mut self) -> Result> { + async fn select_model( + &mut self, + provider_filter: Option, + ) -> Result> { // Check if provider is set otherwise first ask to select a provider if self.api.get_default_provider().await.is_err() { self.on_provider_selection().await?; @@ -2042,11 +2049,19 @@ impl A + Send + Sync> UI { } // Fetch models from ALL configured providers (matches shell plugin's - // `forge list models --porcelain`) + // `forge list models --porcelain`), then optionally filter by provider. self.spinner.start(Some("Loading"))?; let mut all_provider_models = self.api.get_all_provider_models().await?; self.spinner.stop(None)?; + // When a provider filter is specified (e.g. during onboarding after a + // provider was just selected), restrict the list to that provider's + // models so the user cannot accidentally pick a model from a different + // provider. + if let Some(ref filter_id) = provider_filter { + all_provider_models.retain(|pm| &pm.provider_id == filter_id); + } + if all_provider_models.is_empty() { return Ok(None); } @@ -2636,11 +2651,15 @@ impl A + Send + Sync> UI { self.select_provider_from_list(providers, "Provider", current_provider_id) } - // Helper method to handle model selection and update the conversation + // Helper method to handle model selection and update the conversation. + // When `provider_filter` is `Some`, only models from that provider are shown. #[async_recursion::async_recursion] - async fn on_model_selection(&mut self) -> Result> { + async fn on_model_selection( + &mut self, + provider_filter: Option, + ) -> Result> { // Select a model - let model_option = self.select_model().await?; + let model_option = self.select_model(provider_filter).await?; // If no model was selected (user canceled), return early let model = match model_option { @@ -2742,13 +2761,13 @@ impl A + Send + Sync> UI { let model_available = models.iter().any(|m| m.id == current_model); if !model_available { - // Prompt user to select a new model + // Prompt user to select a new model, scoped to the activated provider self.writeln_title(TitleFormat::info("Please select a new model"))?; - self.on_model_selection().await?; + self.on_model_selection(Some(provider.id.clone())).await?; } } else { - // No model set, select one now - self.on_model_selection().await?; + // No model set, select one now scoped to the activated provider + self.on_model_selection(Some(provider.id.clone())).await?; } Ok(()) @@ -2862,7 +2881,7 @@ impl A + Send + Sync> UI { let mut operating_model = self.get_agent_model(active_agent.clone()).await; if operating_model.is_none() { // Use the model returned from selection instead of re-fetching - operating_model = self.on_model_selection().await?; + operating_model = self.on_model_selection(None).await?; } // Validate provider is configured before loading agents