diff --git a/Cargo.lock b/Cargo.lock index f4bc2e4d..3d01356a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2983,12 +2983,14 @@ dependencies = [ "miette", "openshell-core", "openshell-policy", + "openshell-router", "petname", "pin-project-lite", "prost", "prost-types", "rand 0.9.2", "rcgen", + "reqwest", "russh", "rustls", "rustls-pemfile", @@ -3008,6 +3010,7 @@ dependencies = [ "tracing", "tracing-subscriber", "uuid", + "wiremock", ] [[package]] diff --git a/architecture/inference-routing.md b/architecture/inference-routing.md index b8ce8a80..0d3a95af 100644 --- a/architecture/inference-routing.md +++ b/architecture/inference-routing.md @@ -66,8 +66,9 @@ The gateway implements the `Inference` gRPC service defined in `proto/inference. 1. Validates that both fields are non-empty. 2. Fetches the named provider record from the store. 3. Validates the provider by resolving its route (checking that the provider type is supported and has a usable API key). -4. Builds a managed route spec that stores only `provider_name` and `model_id`. The spec intentionally leaves `base_url`, `api_key`, and `protocols` empty -- these are resolved dynamically at bundle time from the provider record. -5. Upserts the route with name `inference.local`. Version starts at 1 and increments monotonically on each update. +4. By default, performs a lightweight provider-shaped probe against the resolved upstream endpoint (for example, a tiny chat/messages request with `max_tokens: 1`) to confirm the endpoint is reachable and accepts the expected auth/request shape. `--no-verify` disables this probe when the endpoint is not up yet. +5. Builds a managed route spec that stores only `provider_name` and `model_id`. The spec intentionally leaves `base_url`, `api_key`, and `protocols` empty -- these are resolved dynamically at bundle time from the provider record. +6. Upserts the route with name `inference.local`. Version starts at 1 and increments monotonically on each update. `GetClusterInference` returns `provider_name`, `model_id`, and `version` for the managed route. Returns `NOT_FOUND` if cluster inference is not configured. @@ -91,7 +92,7 @@ File: `proto/inference.proto` Key messages: -- `SetClusterInferenceRequest` -- `provider_name` + `model_id` +- `SetClusterInferenceRequest` -- `provider_name` + `model_id` + optional `no_verify` override, with verification enabled by default - `SetClusterInferenceResponse` -- `provider_name` + `model_id` + `version` - `GetInferenceBundleResponse` -- `repeated ResolvedRoute routes` + `revision` + `generated_at_ms` - `ResolvedRoute` -- `name`, `base_url`, `protocols`, `api_key`, `model_id`, `provider_type` @@ -296,13 +297,15 @@ The system route is stored as a separate `InferenceRoute` record in the gateway Cluster inference commands: -- `openshell cluster inference set --provider --model ` -- configures user-facing cluster inference -- `openshell cluster inference set --system --provider --model ` -- configures system inference -- `openshell cluster inference get` -- displays both user and system inference configuration -- `openshell cluster inference get --system` -- displays only the system inference configuration +- `openshell inference set --provider --model ` -- configures user-facing cluster inference +- `openshell inference set --system --provider --model ` -- configures system inference +- `openshell inference get` -- displays both user and system inference configuration +- `openshell inference get --system` -- displays only the system inference configuration The `--provider` flag references a provider record name (not a provider type). The provider must already exist in the cluster and have a supported inference type (`openai`, `anthropic`, or `nvidia`). +Inference writes verify by default. `--no-verify` is the explicit opt-out for endpoints that are not up yet. + ## Provider Discovery Files: diff --git a/crates/openshell-bootstrap/src/docker.rs b/crates/openshell-bootstrap/src/docker.rs index 1cb62b7b..3dc832aa 100644 --- a/crates/openshell-bootstrap/src/docker.rs +++ b/crates/openshell-bootstrap/src/docker.rs @@ -526,9 +526,13 @@ pub async fn ensure_container( port_bindings: Some(port_bindings), binds: Some(vec![format!("{}:/var/lib/rancher/k3s", volume_name(name))]), network_mode: Some(network_name(name)), - // Add host.docker.internal mapping for DNS resolution - // This allows the entrypoint script to configure CoreDNS to use the host gateway - extra_hosts: Some(vec!["host.docker.internal:host-gateway".to_string()]), + // Add host gateway aliases for DNS resolution. + // This allows both the entrypoint script and the running gateway + // process to reach services on the Docker host. + extra_hosts: Some(vec![ + "host.docker.internal:host-gateway".to_string(), + "host.openshell.internal:host-gateway".to_string(), + ]), ..Default::default() }; diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index dcca4703..8995c3df 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -906,9 +906,6 @@ enum InferenceCommands { system: bool, /// Skip endpoint verification before saving the route. - /// - /// Accepted now so scripts can opt out explicitly ahead of a future - /// default switch to verification. #[arg(long)] no_verify: bool, }, @@ -929,9 +926,6 @@ enum InferenceCommands { system: bool, /// Skip endpoint verification before saving the route. - /// - /// Accepted now so scripts can opt out explicitly ahead of a future - /// default switch to verification. #[arg(long)] no_verify: bool, }, @@ -1810,17 +1804,19 @@ async fn main() -> Result<()> { provider, model, system, - no_verify: _, + no_verify, } => { let route_name = if system { "sandbox-system" } else { "" }; - run::gateway_inference_set(endpoint, &provider, &model, route_name, &tls) - .await?; + run::gateway_inference_set( + endpoint, &provider, &model, route_name, no_verify, &tls, + ) + .await?; } InferenceCommands::Update { provider, model, system, - no_verify: _, + no_verify, } => { let route_name = if system { "sandbox-system" } else { "" }; run::gateway_inference_update( @@ -1828,6 +1824,7 @@ async fn main() -> Result<()> { provider.as_deref(), model.as_deref(), route_name, + no_verify, &tls, ) .await?; @@ -2559,6 +2556,54 @@ mod tests { )); } + #[test] + fn inference_set_accepts_no_verify_flag() { + let cli = Cli::try_parse_from([ + "openshell", + "inference", + "set", + "--provider", + "openai-dev", + "--model", + "gpt-4.1", + "--no-verify", + ]) + .expect("inference set should parse --no-verify"); + + assert!(matches!( + cli.command, + Some(Commands::Inference { + command: Some(InferenceCommands::Set { + no_verify: true, + .. + }) + }) + )); + } + + #[test] + fn inference_update_accepts_no_verify_flag() { + let cli = Cli::try_parse_from([ + "openshell", + "inference", + "update", + "--provider", + "openai-dev", + "--no-verify", + ]) + .expect("inference update should parse --no-verify"); + + assert!(matches!( + cli.command, + Some(Commands::Inference { + command: Some(InferenceCommands::Update { + no_verify: true, + .. + }) + }) + )); + } + #[test] fn completion_script_uses_openshell_command_name() { let script = normalize_completion_script( @@ -2747,52 +2792,4 @@ mod tests { other => panic!("expected SshProxy, got: {other:?}"), } } - - #[test] - fn inference_set_accepts_no_verify_flag() { - let cli = Cli::try_parse_from([ - "openshell", - "inference", - "set", - "--provider", - "openai-dev", - "--model", - "gpt-4.1", - "--no-verify", - ]) - .expect("inference set should parse --no-verify"); - - assert!(matches!( - cli.command, - Some(Commands::Inference { - command: Some(InferenceCommands::Set { - no_verify: true, - .. - }) - }) - )); - } - - #[test] - fn inference_update_accepts_no_verify_flag() { - let cli = Cli::try_parse_from([ - "openshell", - "inference", - "update", - "--provider", - "openai-dev", - "--no-verify", - ]) - .expect("inference update should parse --no-verify"); - - assert!(matches!( - cli.command, - Some(Commands::Inference { - command: Some(InferenceCommands::Update { - no_verify: true, - .. - }) - }) - )); - } } diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 42ecbbb1..22123c1d 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -40,7 +40,7 @@ use std::io::{IsTerminal, Write}; use std::path::{Path, PathBuf}; use std::process::Command; use std::time::{Duration, Instant}; -use tonic::Code; +use tonic::{Code, Status}; // Re-export SSH functions for backward compatibility pub use crate::ssh::{Editor, print_ssh_config}; @@ -3390,17 +3390,38 @@ pub async fn gateway_inference_set( provider_name: &str, model_id: &str, route_name: &str, + no_verify: bool, tls: &TlsOptions, ) -> Result<()> { + let progress = if std::io::stdout().is_terminal() { + let spinner = ProgressBar::new_spinner(); + spinner.set_style( + ProgressStyle::with_template("{spinner:.cyan} {msg} ({elapsed})") + .unwrap_or_else(|_| ProgressStyle::default_spinner()), + ); + spinner.set_message("Configuring inference..."); + spinner.enable_steady_tick(Duration::from_millis(120)); + Some(spinner) + } else { + None + }; + let mut client = grpc_inference_client(server, tls).await?; let response = client .set_cluster_inference(SetClusterInferenceRequest { provider_name: provider_name.to_string(), model_id: model_id.to_string(), route_name: route_name.to_string(), + verify: false, + no_verify, }) - .await - .into_diagnostic()?; + .await; + + if let Some(progress) = &progress { + progress.finish_and_clear(); + } + + let response = response.map_err(format_inference_status)?; let configured = response.into_inner(); let label = if configured.route_name == "sandbox-system" { @@ -3414,6 +3435,12 @@ pub async fn gateway_inference_set( println!(" {} {}", "Provider:".dimmed(), configured.provider_name); println!(" {} {}", "Model:".dimmed(), configured.model_id); println!(" {} {}", "Version:".dimmed(), configured.version); + if configured.validation_performed { + println!(" {}", "Validated Endpoints:".dimmed()); + for endpoint in configured.validated_endpoints { + println!(" - {} ({})", endpoint.url, endpoint.protocol); + } + } Ok(()) } @@ -3422,6 +3449,7 @@ pub async fn gateway_inference_update( provider_name: Option<&str>, model_id: Option<&str>, route_name: &str, + no_verify: bool, tls: &TlsOptions, ) -> Result<()> { if provider_name.is_none() && model_id.is_none() { @@ -3444,14 +3472,34 @@ pub async fn gateway_inference_update( let provider = provider_name.unwrap_or(¤t.provider_name); let model = model_id.unwrap_or(¤t.model_id); + let progress = if std::io::stdout().is_terminal() { + let spinner = ProgressBar::new_spinner(); + spinner.set_style( + ProgressStyle::with_template("{spinner:.cyan} {msg} ({elapsed})") + .unwrap_or_else(|_| ProgressStyle::default_spinner()), + ); + spinner.set_message("Configuring inference..."); + spinner.enable_steady_tick(Duration::from_millis(120)); + Some(spinner) + } else { + None + }; + let response = client .set_cluster_inference(SetClusterInferenceRequest { provider_name: provider.to_string(), model_id: model.to_string(), route_name: route_name.to_string(), + verify: false, + no_verify, }) - .await - .into_diagnostic()?; + .await; + + if let Some(progress) = &progress { + progress.finish_and_clear(); + } + + let response = response.map_err(format_inference_status)?; let configured = response.into_inner(); let label = if configured.route_name == "sandbox-system" { @@ -3465,6 +3513,12 @@ pub async fn gateway_inference_update( println!(" {} {}", "Provider:".dimmed(), configured.provider_name); println!(" {} {}", "Model:".dimmed(), configured.model_id); println!(" {} {}", "Version:".dimmed(), configured.version); + if configured.validation_performed { + println!(" {}", "Validated Endpoints:".dimmed()); + for endpoint in configured.validated_endpoints { + println!(" - {} ({})", endpoint.url, endpoint.protocol); + } + } Ok(()) } @@ -3536,6 +3590,16 @@ async fn print_inference_route( } } +fn format_inference_status(status: Status) -> miette::Report { + let message = status.message().trim(); + + if message.is_empty() { + return miette::miette!("inference configuration failed ({})", status.code()); + } + + miette::miette!("{message}") +} + pub fn git_repo_root(local_path: &Path) -> Result { let git_dir = if local_path.is_dir() { local_path diff --git a/crates/openshell-router/src/backend.rs b/crates/openshell-router/src/backend.rs index a060d3f9..f9d53800 100644 --- a/crates/openshell-router/src/backend.rs +++ b/crates/openshell-router/src/backend.rs @@ -3,6 +3,35 @@ use crate::RouterError; use crate::config::{AuthHeader, ResolvedRoute}; +use crate::mock; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ValidatedEndpoint { + pub url: String, + pub protocol: String, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ValidationFailureKind { + RequestShape, + Credentials, + RateLimited, + Connectivity, + UpstreamHealth, + Unexpected, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ValidationFailure { + pub kind: ValidationFailureKind, + pub details: String, +} + +struct ValidationProbe { + path: &'static str, + protocol: &'static str, + body: bytes::Bytes, +} /// Response from a proxied HTTP request to a backend (fully buffered). #[derive(Debug)] @@ -128,6 +157,151 @@ async fn send_backend_request( }) } +fn validation_probe(route: &ResolvedRoute) -> Result { + if route + .protocols + .iter() + .any(|protocol| protocol == "openai_chat_completions") + { + return Ok(ValidationProbe { + path: "/v1/chat/completions", + protocol: "openai_chat_completions", + body: bytes::Bytes::from_static( + br#"{"messages":[{"role":"user","content":"ping"}],"max_tokens":1}"#, + ), + }); + } + + if route + .protocols + .iter() + .any(|protocol| protocol == "anthropic_messages") + { + return Ok(ValidationProbe { + path: "/v1/messages", + protocol: "anthropic_messages", + body: bytes::Bytes::from_static( + br#"{"messages":[{"role":"user","content":"ping"}],"max_tokens":1}"#, + ), + }); + } + + if route + .protocols + .iter() + .any(|protocol| protocol == "openai_responses") + { + return Ok(ValidationProbe { + path: "/v1/responses", + protocol: "openai_responses", + body: bytes::Bytes::from_static(br#"{"input":"ping","max_output_tokens":1}"#), + }); + } + + if route + .protocols + .iter() + .any(|protocol| protocol == "openai_completions") + { + return Ok(ValidationProbe { + path: "/v1/completions", + protocol: "openai_completions", + body: bytes::Bytes::from_static(br#"{"prompt":"ping","max_tokens":1}"#), + }); + } + + Err(ValidationFailure { + kind: ValidationFailureKind::RequestShape, + details: format!( + "route '{}' does not expose a writable inference protocol for validation", + route.name + ), + }) +} + +pub async fn verify_backend_endpoint( + client: &reqwest::Client, + route: &ResolvedRoute, +) -> Result { + let probe = validation_probe(route)?; + + if mock::is_mock_route(route) { + return Ok(ValidatedEndpoint { + url: build_backend_url(&route.endpoint, probe.path), + protocol: probe.protocol.to_string(), + }); + } + + let response = send_backend_request(client, route, "POST", probe.path, Vec::new(), probe.body) + .await + .map_err(|err| match err { + RouterError::UpstreamUnavailable(details) => ValidationFailure { + kind: ValidationFailureKind::Connectivity, + details, + }, + RouterError::Internal(details) | RouterError::UpstreamProtocol(details) => { + ValidationFailure { + kind: ValidationFailureKind::Unexpected, + details, + } + } + RouterError::RouteNotFound(details) + | RouterError::NoCompatibleRoute(details) + | RouterError::Unauthorized(details) => ValidationFailure { + kind: ValidationFailureKind::Unexpected, + details, + }, + })?; + let url = build_backend_url(&route.endpoint, probe.path); + + if response.status().is_success() { + return Ok(ValidatedEndpoint { + url, + protocol: probe.protocol.to_string(), + }); + } + + let status = response.status(); + let body = response.text().await.map_err(|e| ValidationFailure { + kind: ValidationFailureKind::Unexpected, + details: format!("failed to read validation response body: {e}"), + })?; + let body = body.trim(); + let body_suffix = if body.is_empty() { + String::new() + } else { + format!( + " Response body: {}", + body.chars().take(200).collect::() + ) + }; + + let details = match status.as_u16() { + 400 | 404 | 405 | 422 => { + format!("upstream rejected the validation request with HTTP {status}.{body_suffix}") + } + 401 | 403 => { + format!("upstream rejected credentials with HTTP {status}.{body_suffix}") + } + 429 => { + format!("upstream rate-limited the validation request with HTTP {status}.{body_suffix}") + } + 500..=599 => format!("upstream returned HTTP {status}.{body_suffix}"), + _ => format!("upstream returned unexpected HTTP {status}.{body_suffix}"), + }; + + Err(ValidationFailure { + kind: match status.as_u16() { + 400 | 404 | 405 | 422 => ValidationFailureKind::RequestShape, + 401 | 403 => ValidationFailureKind::Credentials, + 429 => ValidationFailureKind::RateLimited, + 500..=599 => ValidationFailureKind::UpstreamHealth, + _ => ValidationFailureKind::Unexpected, + }, + details, + }) +} + /// Extract status and headers from a [`reqwest::Response`]. fn extract_response_metadata(response: &reqwest::Response) -> (u16, Vec<(String, String)>) { let status = response.status().as_u16(); @@ -201,7 +375,11 @@ fn build_backend_url(endpoint: &str, path: &str) -> String { #[cfg(test)] mod tests { - use super::build_backend_url; + use super::{build_backend_url, verify_backend_endpoint}; + use crate::config::ResolvedRoute; + use openshell_core::inference::AuthHeader; + use wiremock::matchers::{body_partial_json, header, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; #[test] fn build_backend_url_dedupes_v1_prefix() { @@ -226,4 +404,61 @@ mod tests { "https://api.openai.com/v1" ); } + + fn test_route(endpoint: &str, protocols: &[&str], auth: AuthHeader) -> ResolvedRoute { + ResolvedRoute { + name: "inference.local".to_string(), + endpoint: endpoint.to_string(), + model: "test-model".to_string(), + api_key: "sk-test".to_string(), + protocols: protocols.iter().map(|p| (*p).to_string()).collect(), + auth, + default_headers: vec![("anthropic-version".to_string(), "2023-06-01".to_string())], + } + } + + #[tokio::test] + async fn verify_backend_endpoint_uses_route_auth_and_shape() { + let mock_server = MockServer::start().await; + let route = test_route( + &mock_server.uri(), + &["anthropic_messages"], + AuthHeader::Custom("x-api-key"), + ); + + Mock::given(method("POST")) + .and(path("/v1/messages")) + .and(header("x-api-key", "sk-test")) + .and(header("anthropic-version", "2023-06-01")) + .and(body_partial_json(serde_json::json!({ + "model": "test-model", + "max_tokens": 1, + }))) + .respond_with( + ResponseTemplate::new(200).set_body_json(serde_json::json!({"id": "msg_1"})), + ) + .mount(&mock_server) + .await; + + let client = reqwest::Client::builder().build().unwrap(); + let validated = verify_backend_endpoint(&client, &route).await.unwrap(); + + assert_eq!(validated.protocol, "anthropic_messages"); + assert_eq!(validated.url, format!("{}/v1/messages", mock_server.uri())); + } + + #[tokio::test] + async fn verify_backend_endpoint_accepts_mock_routes() { + let route = test_route( + "mock://test-backend", + &["openai_chat_completions"], + AuthHeader::Bearer, + ); + + let client = reqwest::Client::builder().build().unwrap(); + let validated = verify_backend_endpoint(&client, &route).await.unwrap(); + + assert_eq!(validated.protocol, "openai_chat_completions"); + assert_eq!(validated.url, "mock://test-backend/v1/chat/completions"); + } } diff --git a/crates/openshell-router/src/lib.rs b/crates/openshell-router/src/lib.rs index 4edd4f87..a5712d9a 100644 --- a/crates/openshell-router/src/lib.rs +++ b/crates/openshell-router/src/lib.rs @@ -7,7 +7,10 @@ mod mock; use std::time::Duration; -pub use backend::{ProxyResponse, StreamingProxyResponse}; +pub use backend::{ + ProxyResponse, StreamingProxyResponse, ValidatedEndpoint, ValidationFailure, + ValidationFailureKind, verify_backend_endpoint, +}; use config::{ResolvedRoute, RouterConfig}; use tracing::info; diff --git a/crates/openshell-server/Cargo.toml b/crates/openshell-server/Cargo.toml index 7d53abe6..7bd72113 100644 --- a/crates/openshell-server/Cargo.toml +++ b/crates/openshell-server/Cargo.toml @@ -17,6 +17,7 @@ path = "src/main.rs" [dependencies] openshell-core = { path = "../openshell-core" } openshell-policy = { path = "../openshell-policy" } +openshell-router = { path = "../openshell-router" } # Async runtime tokio = { workspace = true } @@ -61,6 +62,7 @@ serde = { workspace = true } serde_json = { workspace = true } tokio-stream = { workspace = true } sqlx = { workspace = true } +reqwest = { workspace = true } kube = { workspace = true } kube-runtime = { workspace = true } k8s-openapi = { workspace = true } @@ -78,6 +80,7 @@ rcgen = { version = "0.13", features = ["crypto", "pem"] } tempfile = "3" tokio-tungstenite = { workspace = true } futures-util = "0.3" +wiremock = "0.6" [lints] workspace = true diff --git a/crates/openshell-server/src/inference.rs b/crates/openshell-server/src/inference.rs index f7a8427e..c1354e75 100644 --- a/crates/openshell-server/src/inference.rs +++ b/crates/openshell-server/src/inference.rs @@ -4,9 +4,13 @@ use openshell_core::proto::{ ClusterInferenceConfig, GetClusterInferenceRequest, GetClusterInferenceResponse, GetInferenceBundleRequest, GetInferenceBundleResponse, InferenceRoute, Provider, ResolvedRoute, - SetClusterInferenceRequest, SetClusterInferenceResponse, inference_server::Inference, + SetClusterInferenceRequest, SetClusterInferenceResponse, ValidatedEndpoint, + inference_server::Inference, }; +use openshell_router::config::ResolvedRoute as RouterResolvedRoute; +use openshell_router::{ValidationFailureKind, verify_backend_endpoint}; use std::sync::Arc; +use std::time::Duration; use tonic::{Request, Response, Status}; use crate::{ @@ -76,15 +80,18 @@ impl Inference for InferenceService { ) -> Result, Status> { let req = request.into_inner(); let route_name = effective_route_name(&req.route_name)?; + let verify = !req.no_verify; let route = upsert_cluster_inference_route( self.state.store.as_ref(), route_name, &req.provider_name, &req.model_id, + verify, ) .await?; let config = route + .route .config .as_ref() .ok_or_else(|| Status::internal("managed route missing config"))?; @@ -92,8 +99,10 @@ impl Inference for InferenceService { Ok(Response::new(SetClusterInferenceResponse { provider_name: config.provider_name.clone(), model_id: config.model_id.clone(), - version: route.version, + version: route.route.version, route_name: route_name.to_string(), + validation_performed: !route.validation.is_empty(), + validated_endpoints: route.validation, })) } @@ -111,7 +120,7 @@ impl Inference for InferenceService { .map_err(|e| Status::internal(format!("fetch route failed: {e}")))? .ok_or_else(|| { Status::not_found(format!( - "inference route '{route_name}' is not configured; run 'openshell cluster inference set --provider --model '" + "inference route '{route_name}' is not configured; run 'openshell inference set --provider --model '" )) })?; @@ -140,7 +149,8 @@ async fn upsert_cluster_inference_route( route_name: &str, provider_name: &str, model_id: &str, -) -> Result { + verify: bool, +) -> Result { if provider_name.trim().is_empty() { return Err(Status::invalid_argument("provider_name is required")); } @@ -156,9 +166,12 @@ async fn upsert_cluster_inference_route( Status::failed_precondition(format!("provider '{provider_name}' not found")) })?; - // Validate provider shape at set time; endpoint/auth are resolved from the - // provider record when generating sandbox bundles. - let _ = resolve_provider_route(&provider)?; + let resolved = resolve_provider_route(&provider)?; + let validation = if verify { + vec![verify_provider_endpoint(&provider.name, model_id, &resolved).await?] + } else { + Vec::new() + }; let config = build_cluster_inference_config(&provider, model_id); @@ -188,7 +201,7 @@ async fn upsert_cluster_inference_route( .await .map_err(|e| Status::internal(format!("persist route failed: {e}")))?; - Ok(route) + Ok(UpsertedInferenceRoute { route, validation }) } fn build_cluster_inference_config(provider: &Provider, model_id: &str) -> ClusterInferenceConfig { @@ -200,9 +213,13 @@ fn build_cluster_inference_config(provider: &Provider, model_id: &str) -> Cluste struct ResolvedProviderRoute { provider_type: String, - base_url: String, - protocols: Vec, - api_key: String, + route: RouterResolvedRoute, +} + +#[derive(Debug)] +struct UpsertedInferenceRoute { + route: InferenceRoute, + validation: Vec, } fn resolve_provider_route(provider: &Provider) -> Result { @@ -238,12 +255,86 @@ fn resolve_provider_route(provider: &Provider) -> Result Status { + Status::failed_precondition(format!( + "failed to verify inference endpoint for provider '{provider_name}' and model '{model_id}' at '{base_url}': {details}. Next steps: {next_steps}, or retry with '--no-verify' if you want to skip verification" + )) +} + +fn validation_next_steps(kind: ValidationFailureKind) -> &'static str { + match kind { + ValidationFailureKind::Credentials => { + "verify the provider API key and any required auth headers" + } + ValidationFailureKind::RateLimited => { + "retry later or verify quota/limits on the upstream provider" + } + ValidationFailureKind::RequestShape => { + "confirm the provider type, base URL, and model identifier" + } + ValidationFailureKind::Connectivity => { + "check that the service is running, confirm the base URL and protocol, and verify credentials" + } + ValidationFailureKind::UpstreamHealth => { + "check whether the endpoint is healthy and serving requests" + } + ValidationFailureKind::Unexpected => { + "confirm the endpoint URL, protocol, credentials, and model identifier" + } + } +} + +async fn verify_provider_endpoint( + provider_name: &str, + model_id: &str, + route: &ResolvedProviderRoute, +) -> Result { + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(30)) + .build() + .map_err(|err| Status::internal(format!("build validation client failed: {err}")))?; + let mut route = route.route.clone(); + route.model = model_id.to_string(); + + verify_backend_endpoint(&client, &route) + .await + .map(|validated| ValidatedEndpoint { + url: validated.url, + protocol: validated.protocol, + }) + .map_err(|err| { + validation_failure( + provider_name, + model_id, + &route.endpoint, + &err.details, + validation_next_steps(err.kind), + ) + }) +} + fn find_provider_api_key(provider: &Provider, preferred_key_names: &[&str]) -> Option { for key in preferred_key_names { if let Some(value) = provider.credentials.get(*key) @@ -358,10 +449,10 @@ async fn resolve_route_by_name( Ok(Some(ResolvedRoute { name: route_name.to_string(), - base_url: resolved.base_url, + base_url: resolved.route.endpoint, model_id: config.model_id.clone(), - api_key: resolved.api_key, - protocols: resolved.protocols, + api_key: resolved.route.api_key, + protocols: resolved.route.protocols, provider_type: resolved.provider_type, })) } @@ -369,6 +460,8 @@ async fn resolve_route_by_name( #[cfg(test)] mod tests { use super::*; + use wiremock::matchers::{body_partial_json, header, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; fn make_route(name: &str, provider_name: &str, model_id: &str) -> InferenceRoute { InferenceRoute { @@ -392,6 +485,20 @@ mod tests { } } + fn make_provider_with_base_url( + name: &str, + provider_type: &str, + key_name: &str, + key_value: &str, + base_url_key: &str, + base_url: &str, + ) -> Provider { + Provider { + config: std::iter::once((base_url_key.to_string(), base_url.to_string())).collect(), + ..make_provider(name, provider_type, key_name, key_value) + } + } + #[tokio::test] async fn upsert_cluster_route_creates_and_increments_version() { let store = Store::connect("sqlite::memory:?cache=shared") @@ -409,24 +516,26 @@ mod tests { CLUSTER_INFERENCE_ROUTE_NAME, "openai-dev", "gpt-4o", + false, ) .await .expect("first set should succeed"); - assert_eq!(first.name, CLUSTER_INFERENCE_ROUTE_NAME); - assert_eq!(first.version, 1); + assert_eq!(first.route.name, CLUSTER_INFERENCE_ROUTE_NAME); + assert_eq!(first.route.version, 1); let second = upsert_cluster_inference_route( &store, CLUSTER_INFERENCE_ROUTE_NAME, "openai-dev", "gpt-4.1", + false, ) .await .expect("second set should succeed"); - assert_eq!(second.version, 2); - assert_eq!(second.id, first.id); + assert_eq!(second.route.version, 2); + assert_eq!(second.route.id, first.route.id); - let config = second.config.as_ref().expect("config"); + let config = second.route.config.as_ref().expect("config"); assert_eq!(config.provider_name, "openai-dev"); assert_eq!(config.model_id, "gpt-4.1"); } @@ -630,13 +739,14 @@ mod tests { SANDBOX_SYSTEM_ROUTE_NAME, "anthropic-dev", "claude-sonnet-4-20250514", + false, ) .await .expect("should succeed"); - assert_eq!(route.name, SANDBOX_SYSTEM_ROUTE_NAME); - assert_eq!(route.version, 1); - let config = route.config.as_ref().expect("config"); + assert_eq!(route.route.name, SANDBOX_SYSTEM_ROUTE_NAME); + assert_eq!(route.route.version, 1); + let config = route.route.config.as_ref().expect("config"); assert_eq!(config.provider_name, "anthropic-dev"); assert_eq!(config.model_id, "claude-sonnet-4-20250514"); } @@ -715,6 +825,7 @@ mod tests { SANDBOX_SYSTEM_ROUTE_NAME, "openai-dev", "gpt-4o-mini", + false, ) .await .expect("upsert should succeed"); @@ -730,6 +841,141 @@ mod tests { assert_eq!(config.model_id, "gpt-4o-mini"); } + #[tokio::test] + async fn upsert_cluster_route_verifies_endpoint_when_requested() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .expect("store"); + let mock_server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .and(header("authorization", "Bearer sk-test")) + .and(body_partial_json(serde_json::json!({ + "model": "gpt-4o-mini", + "max_tokens": 1, + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "id": "chatcmpl-123", + "object": "chat.completion", + "choices": [{"index": 0, "message": {"role": "assistant", "content": "ok"}, "finish_reason": "stop"}], + "model": "gpt-4o-mini" + }))) + .mount(&mock_server) + .await; + + let provider = make_provider_with_base_url( + "openai-dev", + "openai", + "OPENAI_API_KEY", + "sk-test", + "OPENAI_BASE_URL", + &mock_server.uri(), + ); + store + .put_message(&provider) + .await + .expect("persist provider"); + + let route = upsert_cluster_inference_route( + &store, + CLUSTER_INFERENCE_ROUTE_NAME, + "openai-dev", + "gpt-4o-mini", + true, + ) + .await + .expect("validation should succeed"); + + assert_eq!(route.route.version, 1); + assert_eq!(route.validation.len(), 1); + assert_eq!(route.validation[0].protocol, "openai_chat_completions"); + } + + #[tokio::test] + async fn upsert_cluster_route_rejects_failed_validation() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .expect("store"); + let mock_server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(401).set_body_string("bad key")) + .mount(&mock_server) + .await; + + let provider = make_provider_with_base_url( + "openai-dev", + "openai", + "OPENAI_API_KEY", + "sk-test", + "OPENAI_BASE_URL", + &mock_server.uri(), + ); + store + .put_message(&provider) + .await + .expect("persist provider"); + + let err = upsert_cluster_inference_route( + &store, + CLUSTER_INFERENCE_ROUTE_NAME, + "openai-dev", + "gpt-4o-mini", + true, + ) + .await + .expect_err("validation should fail"); + + assert_eq!(err.code(), tonic::Code::FailedPrecondition); + assert!( + err.message() + .contains("failed to verify inference endpoint") + ); + assert!(err.message().contains("verify the provider API key")); + assert!(err.message().contains("--no-verify")); + + let persisted = store + .get_message_by_name::(CLUSTER_INFERENCE_ROUTE_NAME) + .await + .expect("fetch route") + .is_none(); + assert!(persisted, "route should not persist on failed validation"); + } + + #[tokio::test] + async fn upsert_cluster_route_skips_validation_by_default() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .expect("store"); + let provider = make_provider_with_base_url( + "openai-dev", + "openai", + "OPENAI_API_KEY", + "sk-test", + "OPENAI_BASE_URL", + "http://127.0.0.1:9", + ); + store + .put_message(&provider) + .await + .expect("persist provider"); + + let route = upsert_cluster_inference_route( + &store, + CLUSTER_INFERENCE_ROUTE_NAME, + "openai-dev", + "gpt-4o-mini", + false, + ) + .await + .expect("non-verified route should persist"); + + assert_eq!(route.route.version, 1); + assert!(route.validation.is_empty()); + } + #[test] fn effective_route_name_defaults_empty_to_inference_local() { assert_eq!( diff --git a/deploy/helm/openshell/templates/statefulset.yaml b/deploy/helm/openshell/templates/statefulset.yaml index 175b2606..83ece499 100644 --- a/deploy/helm/openshell/templates/statefulset.yaml +++ b/deploy/helm/openshell/templates/statefulset.yaml @@ -31,6 +31,13 @@ spec: {{- toYaml . | nindent 8 }} {{- end }} serviceAccountName: {{ include "openshell.serviceAccountName" . }} + {{- if .Values.server.hostGatewayIP }} + hostAliases: + - ip: {{ .Values.server.hostGatewayIP | quote }} + hostnames: + - host.docker.internal + - host.openshell.internal + {{- end }} securityContext: {{- toYaml .Values.podSecurityContext | nindent 8 }} containers: diff --git a/docs/inference/configure.md b/docs/inference/configure.md index b4dcd781..bf0103a7 100644 --- a/docs/inference/configure.md +++ b/docs/inference/configure.md @@ -135,9 +135,9 @@ Use this endpoint when inference should stay local to the host for privacy and s ### Verify the Endpoint from a Sandbox -`openshell inference get` confirms the configuration was saved, but does not verify the upstream endpoint is reachable. The CLI also accepts `--no-verify` on `openshell inference set` and `openshell inference update` so automation can opt out explicitly ahead of a future verify-by-default rollout. +`openshell inference set` and `openshell inference update` verify the resolved upstream endpoint by default before saving the configuration. If the endpoint is not live yet, retry with `--no-verify` to persist the route without the probe. -To confirm end-to-end connectivity, connect to a sandbox and run: +`openshell inference get` confirms the current saved configuration. To confirm end-to-end connectivity from a sandbox, run: ```bash curl https://inference.local/v1/responses \ diff --git a/e2e/rust/tests/host_gateway_alias.rs b/e2e/rust/tests/host_gateway_alias.rs index 76d8be57..547a9238 100644 --- a/e2e/rust/tests/host_gateway_alias.rs +++ b/e2e/rust/tests/host_gateway_alias.rs @@ -16,6 +16,7 @@ use tempfile::NamedTempFile; use tokio::time::{interval, timeout}; const INFERENCE_PROVIDER_NAME: &str = "e2e-host-inference"; +const INFERENCE_PROVIDER_UNREACHABLE_NAME: &str = "e2e-host-inference-unreachable"; const TEST_SERVER_IMAGE: &str = "python:3.13-alpine"; static INFERENCE_ROUTE_LOCK: Mutex<()> = Mutex::new(()); @@ -177,6 +178,22 @@ async fn delete_provider(name: &str) { let _ = cmd.status().await; } +async fn create_openai_provider(name: &str, base_url: &str) -> Result { + run_cli(&[ + "provider", + "create", + "--name", + name, + "--type", + "openai", + "--credential", + "OPENAI_API_KEY=dummy", + "--config", + &format!("OPENAI_BASE_URL={base_url}"), + ]) + .await +} + fn write_policy(port: u16) -> Result { let mut file = NamedTempFile::new().map_err(|e| format!("create temp policy file: {e}"))?; let policy = format!( @@ -282,36 +299,33 @@ async fn sandbox_inference_local_routes_to_host_openshell_internal() { delete_provider(INFERENCE_PROVIDER_NAME).await; } - run_cli(&[ - "provider", - "create", - "--name", + let create_output = create_openai_provider( INFERENCE_PROVIDER_NAME, - "--type", - "openai", - "--credential", - "OPENAI_API_KEY=dummy", - "--config", - &format!( - "OPENAI_BASE_URL=http://host.openshell.internal:{}/v1", - server.port - ), - ]) + &format!("http://host.openshell.internal:{}/v1", server.port), + ) .await .expect("create host-backed OpenAI provider"); - run_cli(&[ + let inference_output = run_cli(&[ "inference", "set", "--provider", INFERENCE_PROVIDER_NAME, "--model", "host-echo-model", - "--no-verify", ]) .await .expect("point inference.local at host-backed provider"); + assert!( + inference_output.contains("Validated Endpoints:"), + "expected verification details in output:\n{inference_output}" + ); + assert!( + inference_output.contains("/v1/chat/completions (openai_chat_completions)"), + "expected validated endpoint in output:\n{inference_output}" + ); + let guard = SandboxGuard::create(&[ "--", "curl", @@ -338,4 +352,69 @@ async fn sandbox_inference_local_routes_to_host_openshell_internal() { "expected sandbox to receive echoed inference content:\n{}", guard.create_output ); + + let _ = create_output; +} + +#[tokio::test] +async fn inference_set_supports_no_verify_for_unreachable_endpoint() { + let _inference_lock = INFERENCE_ROUTE_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + + let current_inference = run_cli(&["inference", "get"]) + .await + .expect("read current inference config"); + if !current_inference.contains("Not configured") { + eprintln!("Skipping test: existing inference config would make shared state unsafe"); + return; + } + + if provider_exists(INFERENCE_PROVIDER_UNREACHABLE_NAME).await { + delete_provider(INFERENCE_PROVIDER_UNREACHABLE_NAME).await; + } + + create_openai_provider( + INFERENCE_PROVIDER_UNREACHABLE_NAME, + "http://host.openshell.internal:9/v1", + ) + .await + .expect("create unreachable OpenAI provider"); + + let verify_err = run_cli(&[ + "inference", + "set", + "--provider", + INFERENCE_PROVIDER_UNREACHABLE_NAME, + "--model", + "host-echo-model", + ]) + .await + .expect_err("default verification should fail for unreachable endpoint"); + + assert!( + verify_err.contains("failed to verify inference endpoint"), + "expected verification failure output:\n{verify_err}" + ); + assert!( + verify_err.contains("--no-verify"), + "expected retry hint in failure output:\n{verify_err}" + ); + + let no_verify_output = run_cli(&[ + "inference", + "set", + "--provider", + INFERENCE_PROVIDER_UNREACHABLE_NAME, + "--model", + "host-echo-model", + "--no-verify", + ]) + .await + .expect("no-verify should bypass validation"); + + assert!( + !no_verify_output.contains("Validated Endpoints:"), + "did not expect validation output when bypassing verification:\n{no_verify_output}" + ); } diff --git a/proto/inference.proto b/proto/inference.proto index 11670c4a..a15f4b84 100644 --- a/proto/inference.proto +++ b/proto/inference.proto @@ -56,6 +56,15 @@ message SetClusterInferenceRequest { // Route name to target. Empty string defaults to "inference.local" (user-facing). // Use "sandbox-system" for the sandbox system-level inference route. string route_name = 3; + // Verify the resolved upstream endpoint synchronously before persistence. + bool verify = 4; + // Skip synchronous endpoint validation before persistence. + bool no_verify = 5; +} + +message ValidatedEndpoint { + string url = 1; + string protocol = 2; } message SetClusterInferenceResponse { @@ -64,6 +73,10 @@ message SetClusterInferenceResponse { uint64 version = 3; // Route name that was configured. string route_name = 4; + // Whether endpoint verification ran as part of this request. + bool validation_performed = 5; + // The concrete endpoints that were probed during validation, when available. + repeated ValidatedEndpoint validated_endpoints = 6; } message GetClusterInferenceRequest { diff --git a/python/openshell/sandbox.py b/python/openshell/sandbox.py index 7b48ab3b..19bdcdf6 100644 --- a/python/openshell/sandbox.py +++ b/python/openshell/sandbox.py @@ -398,11 +398,13 @@ def set_cluster( *, provider_name: str, model_id: str, + no_verify: bool = False, ) -> ClusterInferenceConfig: response = self._stub.SetClusterInference( inference_pb2.SetClusterInferenceRequest( provider_name=provider_name, model_id=model_id, + no_verify=no_verify, ), timeout=self._timeout, ) diff --git a/python/openshell/sandbox_test.py b/python/openshell/sandbox_test.py index 4c0eebcd..c0148dcc 100644 --- a/python/openshell/sandbox_test.py +++ b/python/openshell/sandbox_test.py @@ -10,6 +10,7 @@ from openshell.sandbox import ( _PYTHON_CLOUDPICKLE_BOOTSTRAP, _SANDBOX_PYTHON_BIN, + InferenceRouteClient, SandboxClient, ) @@ -33,6 +34,22 @@ def ExecSandbox( ) +class _FakeInferenceStub: + def __init__(self) -> None: + self.request = None + + def SetClusterInference(self, request: Any, timeout: float | None = None) -> Any: + self.request = request + _ = timeout + + class _Response: + provider_name = request.provider_name + model_id = request.model_id + version = 1 + + return _Response() + + def _client_with_fake_stub(stub: _FakeStub) -> SandboxClient: client = cast("SandboxClient", object.__new__(SandboxClient)) client._timeout = 30.0 @@ -120,3 +137,19 @@ def test_from_active_cluster_prefers_openshell_gateway_env( assert client._cluster_name == gateway_name finally: client.close() + + +def test_inference_set_cluster_forwards_no_verify_flag() -> None: + stub = _FakeInferenceStub() + client = cast("InferenceRouteClient", object.__new__(InferenceRouteClient)) + client._timeout = 30.0 + client._stub = cast("Any", stub) + + client.set_cluster( + provider_name="openai-dev", + model_id="gpt-4.1", + no_verify=True, + ) + + assert stub.request is not None + assert stub.request.no_verify is True