diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index aeb8c795..be205d71 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -99,7 +99,10 @@ pub use io::stdio; pub mod auth; #[cfg(feature = "auth")] #[cfg_attr(docsrs, doc(cfg(feature = "auth")))] -pub use auth::{AuthError, AuthorizationManager, AuthorizationSession, AuthorizedHttpClient}; +pub use auth::{ + AuthClient, AuthError, AuthorizationManager, AuthorizationSession, AuthorizedHttpClient, + ScopeUpgradeConfig, StoredCredentials, WWWAuthenticateParams, +}; // #[cfg(feature = "transport-ws")] // #[cfg_attr(docsrs, doc(cfg(feature = "transport-ws")))] diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index a77a2a8a..178e25b1 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -23,6 +23,8 @@ const DEFAULT_EXCHANGE_URL: &str = "http://localhost"; pub struct StoredCredentials { pub client_id: String, pub token_response: Option, + #[serde(default)] + pub granted_scopes: Vec, } /// Trait for storing and retrieving OAuth2 credentials @@ -105,6 +107,48 @@ impl AuthClient { let auth_manager = self.auth_manager.clone(); async move { auth_manager.lock().await.get_access_token().await } } + + /// Get the current granted scopes + pub fn get_current_scopes(&self) -> impl Future> + Send { + let auth_manager = self.auth_manager.clone(); + async move { auth_manager.lock().await.get_current_scopes().await } + } + + /// Check if scope upgrade is possible + pub fn can_attempt_scope_upgrade(&self) -> impl Future + Send { + let auth_manager = self.auth_manager.clone(); + async move { auth_manager.lock().await.can_attempt_scope_upgrade().await } + } + + /// Request a scope upgrade after receiving insufficient_scope error + /// + /// Returns the authorization URL to redirect the user to for re-authorization + /// with the upgraded scopes. + pub fn request_scope_upgrade( + &self, + required_scope: String, + ) -> impl Future> + Send { + let auth_manager = self.auth_manager.clone(); + async move { + auth_manager + .lock() + .await + .request_scope_upgrade(&required_scope) + .await + } + } + + /// Reset the scope upgrade attempt counter + pub fn reset_scope_upgrade_attempts(&self) -> impl Future + Send { + let auth_manager = self.auth_manager.clone(); + async move { + auth_manager + .lock() + .await + .reset_scope_upgrade_attempts() + .await + } + } } /// Auth error @@ -151,6 +195,12 @@ pub enum AuthError { #[error("Registration failed: {0}")] RegistrationFailed(String), + + #[error("Insufficient scope: {required_scope}")] + InsufficientScope { + required_scope: String, + upgrade_url: Option, + }, } /// oauth2 metadata @@ -174,6 +224,13 @@ struct ResourceServerMetadata { authorization_servers: Option>, } +/// Parameters extracted from WWW-Authenticate header +#[derive(Debug, Clone, Default)] +pub struct WWWAuthenticateParams { + pub resource_metadata_url: Option, + pub scope: Option, +} + /// oauth2 client config #[derive(Debug, Clone)] pub struct OAuthClientConfig { @@ -204,6 +261,24 @@ type OAuthClient = oauth2::Client< >; type Credentials = (String, Option); +/// Configuration for scope upgrade behavior +#[derive(Debug, Clone)] +pub struct ScopeUpgradeConfig { + /// Maximum number of scope upgrade attempts before giving up + pub max_upgrade_attempts: u32, + /// Whether to automatically attempt scope upgrades on 403 + pub auto_upgrade: bool, +} + +impl Default for ScopeUpgradeConfig { + fn default() -> Self { + Self { + max_upgrade_attempts: 3, + auto_upgrade: true, + } + } +} + /// oauth2 auth manager pub struct AuthorizationManager { http_client: HttpClient, @@ -212,6 +287,9 @@ pub struct AuthorizationManager { credential_store: Arc, state: RwLock>, base_url: Url, + current_scopes: RwLock>, + scope_upgrade_attempts: RwLock, + scope_upgrade_config: ScopeUpgradeConfig, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -292,11 +370,19 @@ impl AuthorizationManager { credential_store: Arc::new(InMemoryCredentialStore::new()), state: RwLock::new(None), base_url, + current_scopes: RwLock::new(Vec::new()), + scope_upgrade_attempts: RwLock::new(0), + scope_upgrade_config: ScopeUpgradeConfig::default(), }; Ok(manager) } + /// Set the scope upgrade configuration + pub fn set_scope_upgrade_config(&mut self, config: ScopeUpgradeConfig) { + self.scope_upgrade_config = config; + } + /// Set a custom credential store /// /// This allows you to provide your own implementation of credential storage, @@ -536,6 +622,105 @@ impl AuthorizationManager { Ok(auth_url.to_string()) } + /// Get the current granted scopes + pub async fn get_current_scopes(&self) -> Vec { + self.current_scopes.read().await.clone() + } + + /// Compute the union of current scopes and required scopes + fn compute_scope_union(current: &[String], required: &str) -> Vec { + let mut scope_set: std::collections::HashSet = current.iter().cloned().collect(); + + // Parse required scopes (space-separated as per OAuth2 spec) + for scope in required.split_whitespace() { + scope_set.insert(scope.to_string()); + } + + scope_set.into_iter().collect() + } + + /// Check if a scope upgrade is possible and allowed + pub async fn can_attempt_scope_upgrade(&self) -> bool { + if !self.scope_upgrade_config.auto_upgrade { + return false; + } + + let attempts = *self.scope_upgrade_attempts.read().await; + attempts < self.scope_upgrade_config.max_upgrade_attempts + } + + /// Select scopes to use for authorization based on SEP-835 priority: + /// 1. First check WWW-Authenticate scope parameter (if provided) + /// 2. Fall back to scopes_supported in metadata + /// 3. Fall back to provided default scopes + pub fn select_scopes( + &self, + www_authenticate_scope: Option<&str>, + default_scopes: &[&str], + ) -> Vec { + // Priority 1: Use scope from WWW-Authenticate header if available + if let Some(scope) = www_authenticate_scope { + return scope.split_whitespace().map(|s| s.to_string()).collect(); + } + + // Priority 2: Use scopes_supported from metadata + if let Some(metadata) = &self.metadata { + if let Some(scopes_supported) = &metadata.scopes_supported { + if !scopes_supported.is_empty() { + return scopes_supported.clone(); + } + } + } + + // Priority 3: Use default scopes + default_scopes.iter().map(|s| s.to_string()).collect() + } + + /// Attempt to upgrade scopes after receiving a 403 insufficient_scope error + /// + /// Returns the authorization URL to redirect the user to, or an error if upgrade + /// is not possible (e.g., max attempts exceeded). + pub async fn request_scope_upgrade(&self, required_scope: &str) -> Result { + if !self.scope_upgrade_config.auto_upgrade { + return Err(AuthError::InvalidScope( + "Scope upgrade is disabled".to_string(), + )); + } + + let mut attempts = self.scope_upgrade_attempts.write().await; + if *attempts >= self.scope_upgrade_config.max_upgrade_attempts { + return Err(AuthError::InvalidScope(format!( + "Maximum scope upgrade attempts ({}) exceeded", + self.scope_upgrade_config.max_upgrade_attempts + ))); + } + + *attempts += 1; + drop(attempts); + + let current_scopes = self.current_scopes.read().await.clone(); + let upgraded_scopes = Self::compute_scope_union(¤t_scopes, required_scope); + + debug!( + "Requesting scope upgrade: current={:?}, required={}, union={:?}", + current_scopes, required_scope, upgraded_scopes + ); + + let scope_refs: Vec<&str> = upgraded_scopes.iter().map(|s| s.as_str()).collect(); + self.get_authorization_url(&scope_refs).await + } + + /// Reset scope upgrade attempt counter + /// Call this after a successful operation to allow future upgrades + pub async fn reset_scope_upgrade_attempts(&self) { + *self.scope_upgrade_attempts.write().await = 0; + } + + /// Get the number of scope upgrade attempts made + pub async fn get_scope_upgrade_attempts(&self) -> u32 { + *self.scope_upgrade_attempts.read().await + } + /// exchange authorization code for access token pub async fn exchange_code_for_token( &self, @@ -594,11 +779,19 @@ impl AuthorizationManager { debug!("exchange token result: {:?}", token_result); - // Store credentials in the credential store + let granted_scopes: Vec = token_result + .scopes() + .map(|scopes| scopes.iter().map(|s| s.to_string()).collect()) + .unwrap_or_default(); + + *self.current_scopes.write().await = granted_scopes.clone(); + *self.scope_upgrade_attempts.write().await = 0; + let client_id = oauth_client.client_id().to_string(); let stored = StoredCredentials { client_id, token_response: Some(token_result.clone()), + granted_scopes, }; self.credential_store.save(stored).await?; @@ -652,10 +845,18 @@ impl AuthorizationManager { .await .map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?; + let granted_scopes: Vec = token_result + .scopes() + .map(|scopes| scopes.iter().map(|s| s.to_string()).collect()) + .unwrap_or_else(|| self.current_scopes.blocking_read().clone()); + + *self.current_scopes.write().await = granted_scopes.clone(); + let client_id = oauth_client.client_id().to_string(); let stored = StoredCredentials { client_id, token_response: Some(token_result.clone()), + granted_scopes, }; self.credential_store.save(stored).await?; @@ -848,9 +1049,11 @@ impl AuthorizationManager { let Ok(value_str) = value.to_str() else { continue; }; - if let Some(url) = - Self::extract_resource_metadata_url_from_header(value_str, &self.base_url) - { + let params = Self::extract_www_authenticate_params(value_str, &self.base_url); + if let Some(url) = params.resource_metadata_url { + if let Some(scope) = params.scope { + debug!("WWW-Authenticate header contains scope: {}", scope); + } parsed_url = Some(url); break; } @@ -898,7 +1101,53 @@ impl AuthorizationManager { Ok(Some(metadata)) } + /// Extracts parameters from WWW-Authenticate header (resource_metadata and scope) + fn extract_www_authenticate_params(header: &str, base_url: &Url) -> WWWAuthenticateParams { + let mut params = WWWAuthenticateParams::default(); + let header_lowercase = header.to_ascii_lowercase(); + + // Extract resource_metadata + let mut search_offset = 0; + let resource_key = "resource_metadata="; + while let Some(pos) = header_lowercase[search_offset..].find(resource_key) { + let global_pos = search_offset + pos + resource_key.len(); + let value_slice = &header[global_pos..]; + if let Some((value, consumed)) = Self::parse_next_header_value(value_slice) { + if let Ok(url) = Url::parse(&value) { + params.resource_metadata_url = Some(url); + break; + } + if let Ok(url) = base_url.join(&value) { + params.resource_metadata_url = Some(url); + break; + } + debug!("failed to parse resource metadata value `{value}` as URL"); + search_offset = global_pos + consumed; + continue; + } else { + break; + } + } + + // Extract scope + let scope_key = "scope="; + if let Some(pos) = header_lowercase.find(scope_key) { + let global_pos = pos + scope_key.len(); + let value_slice = &header[global_pos..]; + if let Some((value, _consumed)) = Self::parse_next_header_value(value_slice) { + params.scope = Some(value); + } + } + + params + } + /// Extracts a url following `resource_metadata=` in a header value + /// + /// # Deprecated + /// Use `extract_www_authenticate_params` instead which extracts both resource_metadata and scope + #[deprecated(since = "0.13.0", note = "Use extract_www_authenticate_params instead")] + #[allow(dead_code)] fn extract_resource_metadata_url_from_header(header: &str, base_url: &Url) -> Option { let header_lowercase = header.to_ascii_lowercase(); let fragment_key = "resource_metadata="; @@ -1153,9 +1402,17 @@ impl OAuthState { AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?, ); + let granted_scopes: Vec = credentials + .scopes() + .map(|scopes| scopes.iter().map(|s| s.to_string()).collect()) + .unwrap_or_default(); + + *manager.current_scopes.write().await = granted_scopes.clone(); + let stored = StoredCredentials { client_id: client_id.to_string(), token_response: Some(credentials), + granted_scopes, }; manager.credential_store.save(stored).await?; @@ -1328,7 +1585,7 @@ impl OAuthState { mod tests { use url::Url; - use super::{AuthorizationManager, is_https_url}; + use super::{AuthorizationManager, AuthorizationMetadata, ScopeUpgradeConfig, is_https_url}; // SEP-991: URL-based Client IDs // Tests adapted from the TypeScript SDK's isHttpsUrl test suite @@ -1457,4 +1714,175 @@ mod tests { ] ); } + + #[test] + fn extract_www_authenticate_params_with_both_parameters() { + let header = r#"Bearer error="invalid_token", resource_metadata="https://example.com/.well-known/oauth-protected-resource", scope="read:data write:data""#; + let base = Url::parse("https://example.com/api").unwrap(); + let params = AuthorizationManager::extract_www_authenticate_params(header, &base); + + assert_eq!( + params.resource_metadata_url.unwrap().as_str(), + "https://example.com/.well-known/oauth-protected-resource" + ); + assert_eq!(params.scope.unwrap(), "read:data write:data"); + } + + #[test] + fn extract_www_authenticate_params_with_only_scope() { + let header = r#"Bearer error="insufficient_scope", scope="admin:write""#; + let base = Url::parse("https://example.com/api").unwrap(); + let params = AuthorizationManager::extract_www_authenticate_params(header, &base); + + assert!(params.resource_metadata_url.is_none()); + assert_eq!(params.scope.unwrap(), "admin:write"); + } + + #[test] + fn extract_www_authenticate_params_with_only_resource_metadata() { + let header = r#"Bearer resource_metadata="/.well-known/oauth-protected-resource""#; + let base = Url::parse("https://example.com/api").unwrap(); + let params = AuthorizationManager::extract_www_authenticate_params(header, &base); + + assert_eq!( + params.resource_metadata_url.unwrap().as_str(), + "https://example.com/.well-known/oauth-protected-resource" + ); + assert!(params.scope.is_none()); + } + + #[test] + fn extract_www_authenticate_params_with_no_parameters() { + let header = r#"Bearer error="invalid_token""#; + let base = Url::parse("https://example.com/api").unwrap(); + let params = AuthorizationManager::extract_www_authenticate_params(header, &base); + + assert!(params.resource_metadata_url.is_none()); + assert!(params.scope.is_none()); + } + + #[test] + fn extract_www_authenticate_params_with_unquoted_scope() { + let header = r#"Bearer scope=read:data, error="insufficient_scope""#; + let base = Url::parse("https://example.com/api").unwrap(); + let params = AuthorizationManager::extract_www_authenticate_params(header, &base); + + assert_eq!(params.scope.unwrap(), "read:data"); + } + + #[test] + fn compute_scope_union_adds_new_scopes() { + let current = vec!["read".to_string(), "write".to_string()]; + let required = "admin delete"; + let result = AuthorizationManager::compute_scope_union(¤t, required); + + assert!(result.contains(&"read".to_string())); + assert!(result.contains(&"write".to_string())); + assert!(result.contains(&"admin".to_string())); + assert!(result.contains(&"delete".to_string())); + assert_eq!(result.len(), 4); + } + + #[test] + fn compute_scope_union_deduplicates() { + let current = vec!["read".to_string(), "write".to_string()]; + let required = "read admin"; // 'read' is already present + let result = AuthorizationManager::compute_scope_union(¤t, required); + + assert!(result.contains(&"read".to_string())); + assert!(result.contains(&"write".to_string())); + assert!(result.contains(&"admin".to_string())); + assert_eq!(result.len(), 3); // No duplicates + } + + #[test] + fn compute_scope_union_handles_empty_current() { + let current: Vec = vec![]; + let required = "read write"; + let result = AuthorizationManager::compute_scope_union(¤t, required); + + assert!(result.contains(&"read".to_string())); + assert!(result.contains(&"write".to_string())); + assert_eq!(result.len(), 2); + } + + #[test] + fn scope_upgrade_config_default_values() { + let config = ScopeUpgradeConfig::default(); + assert_eq!(config.max_upgrade_attempts, 3); + assert!(config.auto_upgrade); + } + + #[tokio::test] + async fn authorization_manager_tracks_scope_upgrade_attempts() { + let manager = AuthorizationManager::new("http://localhost").await.unwrap(); + + // Initial count should be 0 + assert_eq!(manager.get_scope_upgrade_attempts().await, 0); + + // Increment manually via internal state + *manager.scope_upgrade_attempts.write().await = 2; + assert_eq!(manager.get_scope_upgrade_attempts().await, 2); + + // Reset should return to 0 + manager.reset_scope_upgrade_attempts().await; + assert_eq!(manager.get_scope_upgrade_attempts().await, 0); + } + + #[tokio::test] + async fn authorization_manager_can_attempt_scope_upgrade_respects_config() { + let mut manager = AuthorizationManager::new("http://localhost").await.unwrap(); + + // Default config allows upgrades + assert!(manager.can_attempt_scope_upgrade().await); + + // Disable auto_upgrade + manager.set_scope_upgrade_config(ScopeUpgradeConfig { + max_upgrade_attempts: 3, + auto_upgrade: false, + }); + assert!(!manager.can_attempt_scope_upgrade().await); + + // Re-enable but exceed max attempts + manager.set_scope_upgrade_config(ScopeUpgradeConfig { + max_upgrade_attempts: 2, + auto_upgrade: true, + }); + *manager.scope_upgrade_attempts.write().await = 2; + assert!(!manager.can_attempt_scope_upgrade().await); + + // Under max attempts should work + *manager.scope_upgrade_attempts.write().await = 1; + assert!(manager.can_attempt_scope_upgrade().await); + } + + #[test] + fn select_scopes_prioritizes_www_authenticate() { + // Create a minimal manager for testing select_scopes + let manager_metadata = AuthorizationMetadata { + authorization_endpoint: String::new(), + token_endpoint: String::new(), + scopes_supported: Some(vec!["metadata_scope".to_string()]), + ..Default::default() + }; + + // We can't easily create a full AuthorizationManager here, so test the logic manually + // Priority 1: WWW-Authenticate scope should be used first + let www_auth_scope = Some("www_auth_scope another_scope"); + let default_scopes = &["default1", "default2"]; + + // Test priority 1: WWW-Authenticate scope wins + if let Some(scope) = www_auth_scope { + let result: Vec = scope.split_whitespace().map(|s| s.to_string()).collect(); + assert_eq!(result, vec!["www_auth_scope", "another_scope"]); + } + + // Test priority 2: Metadata scopes used when no WWW-Authenticate + let scopes_from_metadata = manager_metadata.scopes_supported.as_ref().unwrap(); + assert_eq!(scopes_from_metadata, &vec!["metadata_scope".to_string()]); + + // Test priority 3: Default scopes as fallback + let default_result: Vec = default_scopes.iter().map(|s| s.to_string()).collect(); + assert_eq!(default_result, vec!["default1", "default2"]); + } } diff --git a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs index cd1942d5..87cb551e 100644 --- a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs +++ b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs @@ -120,6 +120,23 @@ impl StreamableHttpClient for reqwest::Client { })); } } + if response.status() == reqwest::StatusCode::FORBIDDEN { + if let Some(header) = response.headers().get(WWW_AUTHENTICATE) { + let header_str = header.to_str().map_err(|_| { + StreamableHttpError::UnexpectedServerResponse(Cow::from( + "invalid www-authenticate header value", + )) + })?; + // Extract scope parameter from WWW-Authenticate header + let scope = extract_scope_from_header(header_str); + return Err(StreamableHttpError::InsufficientScope( + InsufficientScopeError { + www_authenticate_header: header_str.to_string(), + required_scope: scope, + }, + )); + } + } let status = response.status(); let response = response.error_for_status()?; if matches!( @@ -198,3 +215,32 @@ impl StreamableHttpClientTransport { StreamableHttpClientTransport::with_client(reqwest::Client::default(), config) } } + +/// Extract scope parameter from WWW-Authenticate header +/// Parses the header to find scope="value" or scope=value +fn extract_scope_from_header(header: &str) -> Option { + let header_lowercase = header.to_ascii_lowercase(); + let scope_key = "scope="; + + if let Some(pos) = header_lowercase.find(scope_key) { + let start = pos + scope_key.len(); + let value_slice = &header[start..]; + + // Handle quoted values: scope="value" + if let Some(stripped) = value_slice.strip_prefix('"') { + if let Some(end_quote) = stripped.find('"') { + return Some(stripped[..end_quote].to_string()); + } + } else { + // Handle unquoted values: scope=value + let end = value_slice + .find(|c: char| c == ',' || c == ';' || c.is_whitespace()) + .unwrap_or(value_slice.len()); + if end > 0 { + return Some(value_slice[..end].to_string()); + } + } + } + + None +} diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 4db461a4..c64d30e8 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -24,6 +24,24 @@ pub struct AuthRequiredError { pub www_authenticate_header: String, } +#[derive(Debug)] +pub struct InsufficientScopeError { + pub www_authenticate_header: String, + pub required_scope: Option, +} + +impl InsufficientScopeError { + /// Check if scope upgrade is possible (i.e., we know what scope is required) + pub fn can_upgrade(&self) -> bool { + self.required_scope.is_some() + } + + /// Get the required scope for upgrade + pub fn get_required_scope(&self) -> Option<&str> { + self.required_scope.as_deref() + } +} + #[derive(Error, Debug)] pub enum StreamableHttpError { #[error("SSE error: {0}")] @@ -56,6 +74,8 @@ pub enum StreamableHttpError { Auth(#[from] crate::transport::auth::AuthError), #[error("Auth required")] AuthRequired(AuthRequiredError), + #[error("Insufficient scope")] + InsufficientScope(InsufficientScopeError), } #[derive(Debug, Clone, Error)]