diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/FileJwtRetriever.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/FileJwtRetriever.java index eeaee1cfb53e3..d22e98c267c88 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/FileJwtRetriever.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/FileJwtRetriever.java @@ -31,8 +31,8 @@ import static org.apache.kafka.common.security.oauthbearer.internals.secured.CachedFile.STRING_JSON_VALIDATING_TRANSFORMER; /** - * FileJwtRetriever is an {@link JwtRetriever} that will load the contents - * of a file, interpreting them as a JWT access key in the serialized form. + * FileJwtRetriever is a {@link JwtRetriever} that loads the contents + * of a file, interpreting them as a JWT access token in serialized form. */ public class FileJwtRetriever implements JwtRetriever { diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java index ddbbd1a787a1a..83a3d278014e6 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java @@ -343,10 +343,6 @@ private void identifyExtensions() throws LoginException { extensionsRequiringCommit = EMPTY_EXTENSIONS; log.debug("CallbackHandler {} does not support SASL extensions. No extensions will be added", callbackHandler.getClass().getName()); } - if (extensionsRequiringCommit == null) { - log.error("SASL Extensions cannot be null. Check whether your callback handler is explicitly setting them as null."); - throw new LoginException("Extensions cannot be null."); - } } @Override diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerRefreshingLogin.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerRefreshingLogin.java index 9c8ee63e4f2d7..f562871e61cd7 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerRefreshingLogin.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerRefreshingLogin.java @@ -20,16 +20,11 @@ import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; import org.apache.kafka.common.security.auth.Login; import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; -import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; -import org.apache.kafka.common.security.oauthbearer.internals.expiring.ExpiringCredential; import org.apache.kafka.common.security.oauthbearer.internals.expiring.ExpiringCredentialRefreshConfig; import org.apache.kafka.common.security.oauthbearer.internals.expiring.ExpiringCredentialRefreshingLogin; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.apache.kafka.common.security.oauthbearer.internals.expiring.OAuthBearerExpiringCredentialRefreshingLogin; import java.util.Map; -import java.util.Set; import javax.security.auth.Subject; import javax.security.auth.login.Configuration; @@ -77,7 +72,6 @@ * @see SaslConfigs#SASL_LOGIN_REFRESH_BUFFER_SECONDS_DOC */ public class OAuthBearerRefreshingLogin implements Login { - private static final Logger log = LoggerFactory.getLogger(OAuthBearerRefreshingLogin.class); private ExpiringCredentialRefreshingLogin expiringCredentialRefreshingLogin = null; @Override @@ -92,41 +86,9 @@ public void configure(Map configs, String contextName, Configuration * reasonable. */ Class classToSynchronizeOnPriorToRefresh = OAuthBearerRefreshingLogin.class; - expiringCredentialRefreshingLogin = new ExpiringCredentialRefreshingLogin(contextName, configuration, + expiringCredentialRefreshingLogin = new OAuthBearerExpiringCredentialRefreshingLogin(contextName, configuration, new ExpiringCredentialRefreshConfig(configs, true), loginCallbackHandler, - classToSynchronizeOnPriorToRefresh) { - @Override - public ExpiringCredential expiringCredential() { - Set privateCredentialTokens = expiringCredentialRefreshingLogin.subject() - .getPrivateCredentials(OAuthBearerToken.class); - if (privateCredentialTokens.isEmpty()) - return null; - final OAuthBearerToken token = privateCredentialTokens.iterator().next(); - if (log.isDebugEnabled()) - log.debug("Found expiring credential with principal '{}'.", token.principalName()); - return new ExpiringCredential() { - @Override - public String principalName() { - return token.principalName(); - } - - @Override - public Long startTimeMs() { - return token.startTimeMs(); - } - - @Override - public long expireTimeMs() { - return token.lifetimeMs(); - } - - @Override - public Long absoluteLastRefreshTimeMs() { - return null; - } - }; - } - }; + classToSynchronizeOnPriorToRefresh); } @Override diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLogin.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLogin.java index b12ee7ffe6146..3591ad2b409dd 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLogin.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLogin.java @@ -157,7 +157,7 @@ public ExpiringCredentialRefreshingLogin(String contextName, Configuration confi mandatoryClassToSynchronizeOnPriorToRefresh, new LoginContextFactory(), Time.SYSTEM); } - public ExpiringCredentialRefreshingLogin(String contextName, Configuration configuration, + ExpiringCredentialRefreshingLogin(String contextName, Configuration configuration, ExpiringCredentialRefreshConfig expiringCredentialRefreshConfig, AuthenticateCallbackHandler callbackHandler, Class mandatoryClassToSynchronizeOnPriorToRefresh, LoginContextFactory loginContextFactory, Time time) { diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/OAuthBearerExpiringCredentialRefreshingLogin.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/OAuthBearerExpiringCredentialRefreshingLogin.java new file mode 100644 index 0000000000000..89172ec3ecc2c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/OAuthBearerExpiringCredentialRefreshingLogin.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.internals.expiring; + +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.utils.Time; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Set; + +import javax.security.auth.login.Configuration; + +public class OAuthBearerExpiringCredentialRefreshingLogin extends ExpiringCredentialRefreshingLogin { + + private static final Logger log = LoggerFactory.getLogger(OAuthBearerExpiringCredentialRefreshingLogin.class); + + public OAuthBearerExpiringCredentialRefreshingLogin(String contextName, Configuration configuration, + ExpiringCredentialRefreshConfig expiringCredentialRefreshConfig, + AuthenticateCallbackHandler callbackHandler, + Class mandatoryClassToSynchronizeOnPriorToRefresh) { + super(contextName, configuration, expiringCredentialRefreshConfig, callbackHandler, + mandatoryClassToSynchronizeOnPriorToRefresh); + } + + OAuthBearerExpiringCredentialRefreshingLogin(String contextName, Configuration configuration, + ExpiringCredentialRefreshConfig expiringCredentialRefreshConfig, + AuthenticateCallbackHandler callbackHandler, + Class mandatoryClassToSynchronizeOnPriorToRefresh, + ExpiringCredentialRefreshingLogin.LoginContextFactory loginContextFactory, + Time time) { + super(contextName, configuration, expiringCredentialRefreshConfig, callbackHandler, + mandatoryClassToSynchronizeOnPriorToRefresh, loginContextFactory, time); + } + + @Override + public ExpiringCredential expiringCredential() { + Set privateCredentialTokens = this.subject() + .getPrivateCredentials(OAuthBearerToken.class); + if (privateCredentialTokens.isEmpty()) + return null; + final OAuthBearerToken token = privateCredentialTokens.iterator().next(); + if (log.isDebugEnabled()) + log.debug("Found expiring credential with principal '{}'.", token.principalName()); + return new ExpiringCredential() { + @Override + public String principalName() { + return token.principalName(); + } + + @Override + public Long startTimeMs() { + return token.startTimeMs(); + } + + @Override + public long expireTimeMs() { + return token.lifetimeMs(); + } + + @Override + public Long absoluteLastRefreshTimeMs() { + return null; + } + }; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ConfigurationUtils.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ConfigurationUtils.java index 3eebecf8fde10..40ff05beb0228 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ConfigurationUtils.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ConfigurationUtils.java @@ -221,12 +221,7 @@ public URL validateUrl(String name) { throw new ConfigException(String.format("The OAuth configuration option %s contains a URL (%s) that is malformed: %s", name, value, e.getMessage())); } - String protocol = url.getProtocol(); - - if (protocol == null || protocol.trim().isEmpty()) - throw new ConfigException(String.format("The OAuth configuration option %s contains a URL (%s) that is missing the protocol", name, value)); - - protocol = protocol.toLowerCase(Locale.ROOT); + String protocol = url.getProtocol().toLowerCase(Locale.ROOT); if (!(protocol.equals("http") || protocol.equals("https") || protocol.equals("file"))) throw new ConfigException(String.format("The OAuth configuration option %s contains a URL (%s) that contains an invalid protocol (%s); only \"http\", \"https\", and \"file\" protocol are supported", name, value, protocol)); @@ -414,4 +409,8 @@ private void throwIfResourceIsNotAllowed(String resourceType, throw new ConfigException(configName, configValue, message); } } + + String prefix() { + return prefix; + } } diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwksVerificationKeyResolver.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwksVerificationKeyResolver.java index d6f6a01089419..27b82e4d37481 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwksVerificationKeyResolver.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwksVerificationKeyResolver.java @@ -41,11 +41,11 @@ * RefreshingHttpsJwksVerificationKeyResolver is a * {@link VerificationKeyResolver} implementation that will periodically refresh the * JWKS using its {@link HttpsJwks} instance. - * + *

* A JWKS (JSON Web Key Set) * is a JSON document provided by the OAuth/OIDC provider that lists the keys used to sign the JWTs * it issues. - * + *

* Here is a sample JWKS JSON document: * *

@@ -76,7 +76,7 @@
  * order to match up the JWT's signing key with the key in the JWKS. During the validation step of
  * the broker, the jose4j OAuth library will use the contents of the appropriate key in the JWKS
  * to validate the signature.
- *
+ * 

* Given that the JWKS is referenced by the JWT, the JWKS must be made available by the * OAuth/OIDC provider so that a JWT can be validated. * diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/FileJwtRetrieverTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/FileJwtRetrieverTest.java new file mode 100644 index 0000000000000..7ca1ace4f2e8b --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/FileJwtRetrieverTest.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer; + +import org.apache.kafka.common.security.oauthbearer.internals.secured.OAuthBearerTest; +import org.apache.kafka.test.TestUtils; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.Collections; +import java.util.Map; + +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL; +import static org.apache.kafka.common.config.internals.BrokerSecurityConfigs.ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG; +import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule.OAUTHBEARER_MECHANISM; + +class FileJwtRetrieverTest extends OAuthBearerTest { + + @AfterEach + void tearDown() { + System.clearProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG); + } + + @Test + public void testRetrieveCalledBeforeConfigure() throws IOException { + try (FileJwtRetriever retriever = new FileJwtRetriever()) { + + Assertions.assertThrows( + IllegalStateException.class, + retriever::retrieve + ); + } + } + + @Test + public void testRetrieveReturnsTokenFromFile() throws Exception { + String jwtFileContent = createJwt("test"); + String jwtFileURI = TestUtils.tempFile(jwtFileContent).toURI().toString(); + System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, jwtFileURI); + + try (FileJwtRetriever retriever = new FileJwtRetriever()) { + retriever.configure( + Map.of(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, jwtFileURI), + OAUTHBEARER_MECHANISM, + Collections.emptyList() + ); + + Assertions.assertEquals(jwtFileContent, retriever.retrieve()); + } + } + + @Test + public void testRetrieveThrowsIfFileIsMissing() throws Exception { + String jwtFileContent = createJwt("test"); + File jwtFile = TestUtils.tempFile(jwtFileContent); + System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, jwtFile.toURI().toString()); + + try (FileJwtRetriever retriever = new FileJwtRetriever()) { + retriever.configure( + Map.of(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, jwtFile.toURI().toString()), + OAUTHBEARER_MECHANISM, + Collections.emptyList() + ); + Files.delete(jwtFile.toPath()); + + Assertions.assertThrows( + Exception.class, + retriever::retrieve + ); + } + } +} \ No newline at end of file diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/JwtBearerJwtRetrieverTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/JwtBearerJwtRetrieverTest.java index 4a4e567dedfdf..c7dd8d2d0be1e 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/JwtBearerJwtRetrieverTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/JwtBearerJwtRetrieverTest.java @@ -148,4 +148,11 @@ public void testConfigureWithInvalidPassphrase() throws Exception { assertInstanceOf(IOException.class, e.getCause()); } } + + @Test + public void testRetrieveCalledBeforeConfigure() throws IOException { + try (JwtBearerJwtRetriever jwtRetriever = new JwtBearerJwtRetriever()) { + assertThrows(IllegalStateException.class, jwtRetriever::retrieve); + } + } } diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModuleTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModuleTest.java index 4efbc21072280..0519624fbef20 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModuleTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModuleTest.java @@ -21,6 +21,7 @@ import org.apache.kafka.common.security.auth.SaslExtensions; import org.apache.kafka.common.security.auth.SaslExtensionsCallback; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import java.io.IOException; @@ -33,15 +34,21 @@ import javax.security.auth.Subject; import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; import javax.security.auth.callback.UnsupportedCallbackException; import javax.security.auth.login.AppConfigurationEntry; import javax.security.auth.login.LoginException; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verifyNoInteractions; @@ -49,6 +56,23 @@ public class OAuthBearerLoginModuleTest { public static final SaslExtensions RAISE_UNSUPPORTED_CB_EXCEPTION_FLAG = null; + private OAuthBearerToken[] tokens; + private SaslExtensions[] extensions; + + @BeforeEach + public void setup() { + tokens = new OAuthBearerToken[] { + mock(OAuthBearerToken.class), + mock(OAuthBearerToken.class), + mock(OAuthBearerToken.class) + }; + extensions = new SaslExtensions[] { + saslExtensions(), + saslExtensions(), + saslExtensions() + }; + } + private static class TestCallbackHandler implements AuthenticateCallbackHandler { private final OAuthBearerToken[] tokens; private int index = 0; @@ -126,10 +150,6 @@ public void login1Commit1Login2Commit2Logout1Login3Commit3Logout2() throws Login Set publicCredentials = subject.getPublicCredentials(); // Create callback handler - OAuthBearerToken[] tokens = new OAuthBearerToken[] {mock(OAuthBearerToken.class), - mock(OAuthBearerToken.class), mock(OAuthBearerToken.class)}; - SaslExtensions[] extensions = new SaslExtensions[] {saslExtensions(), - saslExtensions(), saslExtensions()}; TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions); // Create login modules @@ -222,10 +242,6 @@ public void login1Commit1Logout1Login2Commit2Logout2() throws LoginException { Set publicCredentials = subject.getPublicCredentials(); // Create callback handler - OAuthBearerToken[] tokens = new OAuthBearerToken[] {mock(OAuthBearerToken.class), - mock(OAuthBearerToken.class)}; - SaslExtensions[] extensions = new SaslExtensions[] {saslExtensions(), - saslExtensions()}; TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions); // Create login modules @@ -282,10 +298,6 @@ public void loginAbortLoginCommitLogout() throws LoginException { Set publicCredentials = subject.getPublicCredentials(); // Create callback handler - OAuthBearerToken[] tokens = new OAuthBearerToken[] {mock(OAuthBearerToken.class), - mock(OAuthBearerToken.class)}; - SaslExtensions[] extensions = new SaslExtensions[] {saslExtensions(), - saslExtensions()}; TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions); // Create login module @@ -334,10 +346,6 @@ public void login1Commit1Login2Abort2Login3Commit3Logout3() throws LoginExceptio Set publicCredentials = subject.getPublicCredentials(); // Create callback handler - OAuthBearerToken[] tokens = new OAuthBearerToken[] {mock(OAuthBearerToken.class), - mock(OAuthBearerToken.class), mock(OAuthBearerToken.class)}; - SaslExtensions[] extensions = new SaslExtensions[] {saslExtensions(), saslExtensions(), - saslExtensions()}; TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions); // Create login modules @@ -415,8 +423,6 @@ public void commitDoesNotThrowOnUnsupportedExtensionsCallback() throws LoginExce Subject subject = new Subject(); // Create callback handler - OAuthBearerToken[] tokens = new OAuthBearerToken[] {mock(OAuthBearerToken.class), - mock(OAuthBearerToken.class), mock(OAuthBearerToken.class)}; TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, new SaslExtensions[] {RAISE_UNSUPPORTED_CB_EXCEPTION_FLAG}); // Create login modules @@ -434,11 +440,171 @@ public void commitDoesNotThrowOnUnsupportedExtensionsCallback() throws LoginExce verifyNoInteractions((Object[]) tokens); } + @Test + public void testInitializeThrowsIfCallbackHandlerIsNotInstanceOfAuthenticateCallbackHandler() { + Subject subject = new Subject(); + CallbackHandler nonAuthCallbackHandler = callbacks -> { }; + OAuthBearerLoginModule loginModule = new OAuthBearerLoginModule(); + + assertThrows( + IllegalArgumentException.class, + () -> loginModule.initialize(subject, nonAuthCallbackHandler, Collections.emptyMap(), Collections.emptyMap())); + } + + @Test + public void testLoginThrowsIfAlreadyLoggedInWithToken() throws Exception { + Subject subject = new Subject(); + TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions); + OAuthBearerLoginModule loginModule = new OAuthBearerLoginModule(); + + loginModule.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(), + Collections.emptyMap()); + loginModule.login(); + + assertThrows(IllegalStateException.class, loginModule::login); + } + + @Test + public void testLoginThrowsIfAlreadyLoggedInWithoutToken() throws Exception { + Subject subject = new Subject(); + OAuthBearerLoginModule loginModule = new OAuthBearerLoginModule(); + + AuthenticateCallbackHandler callbackHandler = mock(AuthenticateCallbackHandler.class); + + loginModule.initialize(subject, callbackHandler, Collections.emptyMap(), Collections.emptyMap()); + loginModule.login(); + + assertThrows(IllegalStateException.class, loginModule::login); + } + + @Test + public void testLoginCatchesIOExceptionFromHandleAndThrowsLoginException() throws Exception { + Subject subject = new Subject(); + OAuthBearerLoginModule loginModule = new OAuthBearerLoginModule(); + + AuthenticateCallbackHandler callbackHandler = mock(AuthenticateCallbackHandler.class); + doThrow(IOException.class).when(callbackHandler).handle(any()); + + loginModule.initialize(subject, callbackHandler, Collections.emptyMap(), Collections.emptyMap()); + + assertThrows(LoginException.class, loginModule::login); + } + + @Test + public void testLoginThrowsIfAlreadyCommittedWithToken() throws Exception { + Subject subject = new Subject(); + TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions); + OAuthBearerLoginModule loginModule = new OAuthBearerLoginModule(); + + loginModule.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(), + Collections.emptyMap()); + loginModule.login(); + loginModule.commit(); + + assertThrows(IllegalStateException.class, loginModule::login); + } + + @Test + public void testLoginThrowsIfAlreadyCommittedWithoutToken() throws Exception { + Subject subject = new Subject(); + OAuthBearerLoginModule loginModule = new OAuthBearerLoginModule(); + AuthenticateCallbackHandler callbackHandler = mock(AuthenticateCallbackHandler.class); + + loginModule.initialize(subject, callbackHandler, Collections.emptyMap(), Collections.emptyMap()); + loginModule.login(); + loginModule.commit(); + + assertThrows(IllegalStateException.class, loginModule::login); + } + + @Test + public void testLoginThrowsLoginExceptionWhenCallbackReturnsErrorCode() throws Exception { + Subject subject = new Subject(); + OAuthBearerLoginModule loginModule = new OAuthBearerLoginModule(); + + AuthenticateCallbackHandler callbackHandler = mock(AuthenticateCallbackHandler.class); + doAnswer(invocation -> { + OAuthBearerTokenCallback callback = (OAuthBearerTokenCallback) ((Callback[]) invocation.getArgument(0))[0]; + callback.error("invalid_token", "the token was invalid", "https://example.com/error"); + return null; + }).when(callbackHandler).handle(any()); + + loginModule.initialize(subject, callbackHandler, Collections.emptyMap(), Collections.emptyMap()); + + assertThrows(LoginException.class, loginModule::login); + } + + @Test + public void testLoginThrowsLoginExceptionWhenExtensionCallbackHandlerThrowsIOException() throws Exception { + Subject subject = new Subject(); + OAuthBearerLoginModule loginModule = new OAuthBearerLoginModule(); + + AuthenticateCallbackHandler callbackHandler = mock(AuthenticateCallbackHandler.class); + doAnswer(invocation -> { + OAuthBearerTokenCallback callback = (OAuthBearerTokenCallback) ((Callback[]) invocation.getArgument(0))[0]; + callback.token(mock(OAuthBearerToken.class)); + return null; + }).doThrow(IOException.class).when(callbackHandler).handle(any()); + + loginModule.initialize(subject, callbackHandler, Collections.emptyMap(), Collections.emptyMap()); + + assertThrows(LoginException.class, loginModule::login); + } + + @Test + public void testLogoutThrowsIfLogoutCalledBeforeCommit() throws Exception { + Subject subject = new Subject(); + TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions); + + OAuthBearerLoginModule loginModule = new OAuthBearerLoginModule(); + loginModule.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(), + Collections.emptyMap()); + + loginModule.login(); + + assertThrows(IllegalStateException.class, loginModule::logout); + } + + @Test + public void testLogoutReturnsFalseIfLogoutCalledBeforeLogin() { + Subject subject = new Subject(); + TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions); + + OAuthBearerLoginModule loginModule = new OAuthBearerLoginModule(); + loginModule.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(), + Collections.emptyMap()); + + assertFalse(loginModule::logout); + } + + @Test + public void testCommitReturnsFalseIfCalledBeforeLogin() { + Subject subject = new Subject(); + TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions); + + OAuthBearerLoginModule loginModule = new OAuthBearerLoginModule(); + loginModule.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(), + Collections.emptyMap()); + + assertFalse(loginModule::commit); + } + + @Test + public void testAbortReturnsFalseIfCalledBeforeLogin() { + Subject subject = new Subject(); + TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions); + + OAuthBearerLoginModule loginModule = new OAuthBearerLoginModule(); + loginModule.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(), + Collections.emptyMap()); + + assertFalse(loginModule::abort); + } + /** * We don't want to use mocks for our tests as we need to make sure to test * {@link SaslExtensions}' {@link SaslExtensions#equals(Object)} and * {@link SaslExtensions#hashCode()} methods. - * *

* * We need to make distinct calls to this method (vs. caching the result and reusing it diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallbackHandlerTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallbackHandlerTest.java index adabec6bc958d..d5b2874a75d16 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallbackHandlerTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallbackHandlerTest.java @@ -18,19 +18,23 @@ package org.apache.kafka.common.security.oauthbearer; import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.security.auth.SaslExtensions; import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenBuilder; import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver; import org.apache.kafka.common.security.oauthbearer.internals.secured.OAuthBearerTest; import org.jose4j.jws.AlgorithmIdentifiers; import org.junit.jupiter.api.Test; +import org.mockito.Mockito; import java.io.IOException; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import javax.security.auth.callback.Callback; +import javax.security.auth.callback.UnsupportedCallbackException; import javax.security.auth.login.AppConfigurationEntry; import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE; @@ -45,7 +49,7 @@ public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest { @Test - public void testBasic() throws Exception { + public void testHandleValidatesToken() throws Exception { String expectedAudience = "a"; List allAudiences = Arrays.asList(expectedAudience, "b", "c"); AccessTokenBuilder builder = new AccessTokenBuilder() @@ -81,6 +85,62 @@ public void testBasic() throws Exception { } } + @Test + public void testHandleAcceptsAllInputExtensions() throws Exception { + String expectedAudience = "a"; + List allAudiences = Arrays.asList(expectedAudience, "b", "c"); + OAuthBearerToken token = new OAuthBearerTokenMock(); + + Map configs = getSaslConfigs(SASL_OAUTHBEARER_EXPECTED_AUDIENCE, allAudiences); + OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler(); + CloseableVerificationKeyResolver mockVerificationKeyResolver = Mockito.mock(CloseableVerificationKeyResolver.class); + JwtValidator mockJwtValidator = Mockito.mock(JwtValidator.class); + handler.configure(configs, OAUTHBEARER_MECHANISM, getJaasConfigEntries(), + mockVerificationKeyResolver, mockJwtValidator); + + try { + Map extensions = new HashMap<>(); + extensions.put("test1", "123"); + extensions.put("test2", "123"); + OAuthBearerExtensionsValidatorCallback callback = new OAuthBearerExtensionsValidatorCallback( + token, + new SaslExtensions(extensions) + ); + handler.handle(new Callback[]{callback}); + + assertTrue(callback.validatedExtensions().containsKey("test1")); + assertTrue(callback.validatedExtensions().containsKey("test2")); + } finally { + handler.close(); + } + } + + @Test + public void testHandleThrowsOnUnsupportedCallback() { + Map configs = getSaslConfigs(); + OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler(); + CloseableVerificationKeyResolver mockVerificationKeyResolver = Mockito.mock(CloseableVerificationKeyResolver.class); + JwtValidator mockJwtValidator = Mockito.mock(JwtValidator.class); + handler.configure(configs, OAUTHBEARER_MECHANISM, getJaasConfigEntries(), + mockVerificationKeyResolver, mockJwtValidator); + Callback unknown = new Callback() { + }; + + try { + assertThrows(UnsupportedCallbackException.class, () -> handler.handle(new Callback[]{unknown})); + } finally { + handler.close(); + } + } + + @Test + public void testHandleThrowsIfNotConfigured() { + OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler(); + OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback("some.jwt.token"); + + assertThrows(IllegalStateException.class, () -> handler.handle(new Callback[]{callback})); + } + @Test public void testInvalidAccessToken() throws Exception { // There aren't different error messages for the validation step, so these are all the diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientTest.java index 94f4f1fc8c49e..44e75be0d0f37 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientTest.java @@ -17,15 +17,22 @@ package org.apache.kafka.common.security.oauthbearer.internals; import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.errors.IllegalSaslStateException; import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; import org.apache.kafka.common.security.auth.SaslExtensions; import org.apache.kafka.common.security.auth.SaslExtensionsCallback; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EmptySource; +import org.junit.jupiter.params.provider.NullSource; +import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; @@ -33,12 +40,24 @@ import java.util.Set; import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; import javax.security.auth.callback.UnsupportedCallbackException; import javax.security.auth.login.AppConfigurationEntry; +import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; public class OAuthBearerSaslClientTest { @@ -49,8 +68,9 @@ public class OAuthBearerSaslClientTest { put("Three", "3"); } }; + private static final String ERROR_MESSAGE = "Error as expected!"; + private SaslExtensions testExtensions = new SaslExtensions(TEST_PROPERTIES); - private final String errorMessage = "Error as expected!"; public class ExtensionsCallbackHandler implements AuthenticateCallbackHandler { private final boolean toThrow; @@ -101,7 +121,7 @@ public Long startTimeMs() { }); else if (callback instanceof SaslExtensionsCallback) { if (toThrow) - throw new ConfigException(errorMessage); + throw new ConfigException(ERROR_MESSAGE); else ((SaslExtensionsCallback) callback).extensions(testExtensions); } else @@ -146,8 +166,191 @@ public void testWrapsExtensionsCallbackHandlingErrorInSaslExceptionInFirstClient } catch (SaslException e) { // assert it has caught our expected exception assertEquals(ConfigException.class, e.getCause().getClass()); - assertEquals(errorMessage, e.getCause().getMessage()); + assertEquals(ERROR_MESSAGE, e.getCause().getMessage()); } + } + + @Test + public void testGetMechanismNameReturnsOAuthBearerMechanism() { + OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false)); + + assertEquals(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, client.getMechanismName()); + } + + @Test + public void testHasInitialResponseReturnsTrue() { + OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false)); + + assertTrue(client.hasInitialResponse()); + } + + @Test + public void testEvaluateChallengeThrowsOnFirstMessageWithChallenge() { + OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false)); + + assertThrows(SaslException.class, () -> client.evaluateChallenge("unexpected".getBytes())); + } + + @ParameterizedTest + @EmptySource + @NullSource + public void testEvaluateChallengeReturnsTokenOnEmptyChallenge(String challenge) throws Exception { + OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false)); + byte[] challengeBytes = challenge == null ? null : challenge.getBytes(); + byte[] response = client.evaluateChallenge(challengeBytes); + + assertNotNull(response); + assertFalse(client.isComplete()); + } + + @Test + public void testEvaluateChallengeReturnsControlAOnServerError() throws Exception { + OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false)); + client.evaluateChallenge(new byte[0]); + byte[] response = client.evaluateChallenge("error".getBytes()); + + assertArrayEquals(new byte[]{OAuthBearerSaslClient.BYTE_CONTROL_A}, response); + assertFalse(client.isComplete()); + } + + @ParameterizedTest + @EmptySource + @NullSource + void testEvaluateChallengeCompletesOnEmptyServerChallenge(String challenge) throws Exception { + OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false)); + byte[] challengeBytes = challenge == null ? null : challenge.getBytes(); + client.evaluateChallenge(new byte[0]); + byte[] response = client.evaluateChallenge(challengeBytes); + + assertNull(response); + assertTrue(client.isComplete()); + } + + @Test + void testEvaluateChallengeThrowsOnUnexpectedState() throws Exception { + OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false)); + client.evaluateChallenge(new byte[0]); + client.evaluateChallenge(new byte[0]); + + assertThrows(IllegalSaslStateException.class, + () -> client.evaluateChallenge(new byte[0])); + } + + @Test + void testEvaluateChallengeWrapsIOExceptionInSaslException() throws Exception { + AuthenticateCallbackHandler callbackHandler = mock(AuthenticateCallbackHandler.class); + doThrow(IOException.class).when(callbackHandler).handle(any()); + + OAuthBearerSaslClient client = new OAuthBearerSaslClient(callbackHandler); + SaslException ex = assertThrows(SaslException.class, + () -> client.evaluateChallenge(new byte[0])); + + assertInstanceOf(IOException.class, ex.getCause()); + assertFalse(client.isComplete()); + } + + @Test + public void testUnwrapThrowsIfNotComplete() { + OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false)); + + assertThrows(IllegalStateException.class, () -> client.unwrap(null, 0, 0)); + } + + @Test + public void testUnwrapThrowsIfComplete() throws SaslException { + OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false)); + client.evaluateChallenge(new byte[0]); + client.evaluateChallenge(new byte[0]); + + assertThrows(IllegalStateException.class, () -> client.unwrap(null, 0, 0)); + } + + @Test + public void testWrapThrowsIfNotComplete() { + OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false)); + + assertThrows(IllegalStateException.class, () -> client.wrap(null, 0, 0)); + } + + @Test + public void testWrapThrowsIfComplete() throws SaslException { + OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false)); + client.evaluateChallenge(new byte[0]); + client.evaluateChallenge(new byte[0]); + + assertThrows(IllegalStateException.class, () -> client.wrap(null, 0, 0)); + } + + @Test + public void testGetNegotiatedPropertyThrowsIfNotComplete() { + OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false)); + + assertThrows(IllegalStateException.class, () -> client.getNegotiatedProperty("test")); + } + + @Test + public void testGetNegotiatedPropertyThrowsIfComplete() throws SaslException { + OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false)); + client.evaluateChallenge(new byte[0]); + client.evaluateChallenge(new byte[0]); + + assertNull(client.getNegotiatedProperty("test")); + } + + @Test + public void testCreateSaslClientReturnsClientForSupportedMechanism() { + OAuthBearerSaslClient.OAuthBearerSaslClientFactory factory = new OAuthBearerSaslClient.OAuthBearerSaslClientFactory(); + AuthenticateCallbackHandler callbackHandler = mock(AuthenticateCallbackHandler.class); + SaslClient client = factory.createSaslClient( + new String[]{OAuthBearerLoginModule.OAUTHBEARER_MECHANISM}, + null, "https", "localhost", Collections.emptyMap(), + callbackHandler); + + assertNotNull(client); + assertInstanceOf(OAuthBearerSaslClient.class, client); + } + + @Test + public void testCreateSaslClientReturnsNullForUnsupportedMechanism() { + OAuthBearerSaslClient.OAuthBearerSaslClientFactory factory = new OAuthBearerSaslClient.OAuthBearerSaslClientFactory(); + AuthenticateCallbackHandler callbackHandler = mock(AuthenticateCallbackHandler.class); + SaslClient client = factory.createSaslClient( + new String[]{"PLAIN", "SCRAM-SHA-256"}, + null, "https", "localhost", Collections.emptyMap(), + callbackHandler); + + assertNull(client); + } + + @Test + public void testCreateSaslClientThrowsIfCallbackHandlerIsNotAuthenticateCallbackHandler() { + OAuthBearerSaslClient.OAuthBearerSaslClientFactory factory = new OAuthBearerSaslClient.OAuthBearerSaslClientFactory(); + CallbackHandler nonAuthHandler = callbacks -> { }; + + assertThrows(IllegalArgumentException.class, () -> + factory.createSaslClient( + new String[]{OAuthBearerLoginModule.OAUTHBEARER_MECHANISM}, + null, "https", "localhost", Collections.emptyMap(), + nonAuthHandler)); + } + + @Test + public void testCreateSaslClientThrowsIfCallbackHandlerIsNull() { + OAuthBearerSaslClient.OAuthBearerSaslClientFactory factory = new OAuthBearerSaslClient.OAuthBearerSaslClientFactory(); + + assertThrows(NullPointerException.class, () -> + factory.createSaslClient( + new String[]{OAuthBearerLoginModule.OAUTHBEARER_MECHANISM}, + null, "https", "localhost", Collections.emptyMap(), + null)); + } + + @Test + public void testGetMechanismNamesReturnsOAuthBearer() { + OAuthBearerSaslClient.OAuthBearerSaslClientFactory factory = new OAuthBearerSaslClient.OAuthBearerSaslClientFactory(); + String[] names = factory.getMechanismNames(Collections.emptyMap()); + assertNotNull(names); + assertTrue(Arrays.asList(names).contains(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM)); } } diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java index 581a72a52072b..3f3d1cb1537d5 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java @@ -43,12 +43,19 @@ import java.util.Map; import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.SaslException; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; public class OAuthBearerSaslServerTest { private static final String USER = "user"; @@ -197,6 +204,108 @@ public void illegalToken() throws Exception { assertEquals("{\"status\":\"invalid_token\"}", challenge); } + @Test + public void testConstructorThrowsIfCallbackHandlerIsNotInstanceOfAuthenticatedCallbackHandler() { + CallbackHandler nonAuthCallbackHandler = callbacks -> { }; + assertThrows(IllegalArgumentException.class, () -> new OAuthBearerSaslServer(nonAuthCallbackHandler)); + } + + @Test + public void testEvaluateResponseThrowsIfResponseIsByteControlA() throws IOException, UnsupportedCallbackException { + saslServer.evaluateResponse(clientInitialResponse(null, true, Collections.emptyMap())); + assertThrows( + SaslAuthenticationException.class, + () -> saslServer.evaluateResponse(new byte[]{OAuthBearerSaslClient.BYTE_CONTROL_A})); + } + + @Test + public void testEvaluateResponseThrowsSaslExceptionOnInvalidResponse() { + assertThrows(SaslException.class, + () -> saslServer.evaluateResponse("not a valid response".getBytes())); + } + + @Test + public void testGetAuthorizationIDThrowsIfNotComplete() { + assertThrows(IllegalStateException.class, () -> saslServer.getAuthorizationID()); + } + + @Test + public void testGetAuthorizationIDReturnsUserAfterCompletion() throws Exception { + saslServer.evaluateResponse(clientInitialResponse(USER)); + assertEquals(USER, saslServer.getAuthorizationID()); + } + + @Test + public void testGetMechanismNameReturnsOAuthBearer() { + assertEquals(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, saslServer.getMechanismName()); + } + + @Test + public void testGetNegotiatedPropertyThrowsIfNotComplete() { + assertThrows(IllegalStateException.class, + () -> saslServer.getNegotiatedProperty("any.property")); + } + + @Test + public void testIsCompleteReturnsFalseBeforeAuthentication() { + assertFalse(saslServer.isComplete()); + } + + @Test + public void testIsCompleteReturnsTrueAfterAuthentication() throws Exception { + saslServer.evaluateResponse(clientInitialResponse(USER)); + assertTrue(saslServer.isComplete()); + } + + @Test + public void testUnwrapThrowsIfNotComplete() { + assertThrows(IllegalStateException.class, + () -> saslServer.unwrap(new byte[0], 0, 0)); + } + + @Test + public void testUnwrapThrowsAfterCompletion() throws Exception { + saslServer.evaluateResponse(clientInitialResponse(USER)); + assertThrows(IllegalStateException.class, + () -> saslServer.unwrap(new byte[0], 0, 0)); + } + + @Test + public void testWrapThrowsIfNotComplete() { + assertThrows(IllegalStateException.class, + () -> saslServer.wrap(new byte[0], 0, 0)); + } + + @Test + public void testWrapThrowsAfterCompletion() throws Exception { + saslServer.evaluateResponse(clientInitialResponse(USER)); + assertThrows(IllegalStateException.class, + () -> saslServer.wrap(new byte[0], 0, 0)); + } + + @Test + public void testDisposeResetsState() throws Exception { + saslServer.evaluateResponse(clientInitialResponse(USER)); + assertTrue(saslServer.isComplete()); + + saslServer.dispose(); + + assertFalse(saslServer.isComplete()); + assertThrows(IllegalStateException.class, () -> saslServer.getAuthorizationID()); + assertThrows(IllegalStateException.class, () -> saslServer.getNegotiatedProperty("any.property")); + } + + @Test + public void testEvaluateResponseThrowsSaslExceptionWhenCallbackHandlerThrowsIOException() throws Exception { + AuthenticateCallbackHandler ioExceptionHandler = mock(AuthenticateCallbackHandler.class); + doThrow(IOException.class).when(ioExceptionHandler).handle(any()); + + saslServer = new OAuthBearerSaslServer(ioExceptionHandler); + + assertThrows(SaslException.class, + () -> saslServer.evaluateResponse(clientInitialResponse(null))); + } + private byte[] clientInitialResponse(String authorizationId) throws OAuthBearerConfigException, IOException, UnsupportedCallbackException { return clientInitialResponse(authorizationId, false, Collections.emptyMap()); diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/OAuthBearerExpiringCredentialRefreshingLoginTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/OAuthBearerExpiringCredentialRefreshingLoginTest.java new file mode 100644 index 0000000000000..0ac9e6e223c5e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/OAuthBearerExpiringCredentialRefreshingLoginTest.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.internals.expiring; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerRefreshingLogin; +import org.apache.kafka.common.utils.Time; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.util.Collections; + +import javax.security.auth.Subject; +import javax.security.auth.login.Configuration; +import javax.security.auth.login.LoginContext; +import javax.security.auth.login.LoginException; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +class OAuthBearerExpiringCredentialRefreshingLoginTest { + + private OAuthBearerExpiringCredentialRefreshingLogin login; + private Subject mockSubject; + private OAuthBearerToken mockToken; + + @BeforeEach + public void setup() throws LoginException { + AuthenticateCallbackHandler mockCallbackHandler = mock(AuthenticateCallbackHandler.class); + Configuration mockConfiguration = mock(Configuration.class); + mockToken = mock(OAuthBearerToken.class); + mockSubject = mock(Subject.class); + + login = new OAuthBearerExpiringCredentialRefreshingLogin( + "KafkaClient", + mockConfiguration, + new ExpiringCredentialRefreshConfig( + new ConfigDef().withClientSaslSupport().parse(Collections.emptyMap()), + true), + mockCallbackHandler, + OAuthBearerRefreshingLogin.class, + new ExpiringCredentialRefreshingLogin.LoginContextFactory() { + @Override + public LoginContext createLoginContext(ExpiringCredentialRefreshingLogin expiringCredentialRefreshingLogin) { + LoginContext mockLoginContext = mock(LoginContext.class); + when(mockLoginContext.getSubject()).thenReturn(mockSubject); + return mockLoginContext; + } + }, + Time.SYSTEM + ); + + login.login(); + } + + @Test + public void testExpiringCredentialSubjectContainsNoTokens() { + when(mockSubject.getPrivateCredentials(Mockito.any())).thenReturn(Collections.emptySet()); + + assertNull(login.expiringCredential()); + } + + @Test + public void testExpiringCredentialMapsTokenFieldsCorrectly() { + when(mockToken.principalName()).thenReturn("test-user"); + when(mockToken.startTimeMs()).thenReturn(1000L); + when(mockToken.lifetimeMs()).thenReturn(9000L); + when(mockSubject.getPrivateCredentials(OAuthBearerToken.class)) + .thenReturn(Collections.singleton(mockToken)); + + ExpiringCredential result = login.expiringCredential(); + + assertNotNull(result); + assertEquals("test-user", result.principalName()); + assertEquals(1000L, result.startTimeMs()); + assertEquals(9000L, result.expireTimeMs()); + assertNull(result.absoluteLastRefreshTimeMs()); + } +} \ No newline at end of file diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ConfigurationUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ConfigurationUtilsTest.java index efc41d64b3290..b8ca52d1945c1 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ConfigurationUtilsTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ConfigurationUtilsTest.java @@ -18,19 +18,38 @@ package org.apache.kafka.common.security.oauthbearer.internals.secured; import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenMock; import org.apache.kafka.test.TestUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EmptySource; +import org.junit.jupiter.params.provider.NullSource; import java.io.File; import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Map; +import javax.security.auth.login.AppConfigurationEntry; + import static org.apache.kafka.common.config.internals.BrokerSecurityConfigs.ALLOWED_SASL_OAUTHBEARER_FILES_CONFIG; import static org.apache.kafka.common.config.internals.BrokerSecurityConfigs.ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG; +import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule.OAUTHBEARER_MECHANISM; +import static org.apache.kafka.common.security.oauthbearer.internals.secured.ConfigurationUtils.getConfiguredInstance; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; public class ConfigurationUtilsTest extends OAuthBearerTest { @@ -182,6 +201,307 @@ public void testThrowIfFileIsNotAllowed() { assertDoesNotThrow(() -> cu.throwIfFileIsNotAllowed(FILE_CONFIG_NAME, file2)); } + @Test + public void testConstructorSetsPrefixToSaslMechanism() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of(), OAUTHBEARER_MECHANISM); + assertEquals("oauthbearer.", cu.prefix()); + } + + @ParameterizedTest + @NullSource + @EmptySource + public void testConstructorSetsPrefixToNull(String saslMechanism) { + ConfigurationUtils cu = new ConfigurationUtils(Map.of(), saslMechanism); + assertNull(cu.prefix()); + } + + @Test + public void testContainsKeyReturnsTrueWhenKeyIsPresent() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", "value")); + assertTrue(cu.containsKey("key")); + } + + @Test + public void testContainsKeyReturnsFalseWhenKeyIsNotPresent() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", "value")); + assertFalse(cu.containsKey("key1")); + } + + @Test + public void testValidateIntegerReturnsValueWhenPresent() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", 42)); + assertEquals(42, cu.validateInteger("key", true)); + } + + @Test + public void testValidateIntegerThrowsWhenRequiredAndMissing() { + ConfigurationUtils cu = new ConfigurationUtils(Collections.emptyMap()); + assertThrows(ConfigException.class, + () -> cu.validateInteger("key", true)); + } + + @Test + public void testValidateIntegerReturnsNullWhenNotRequiredAndMissing() { + ConfigurationUtils cu = new ConfigurationUtils(Collections.emptyMap()); + assertNull(cu.validateInteger("key", false)); + } + + @Test + public void testValidateLongThrowsWhenMissing() { + ConfigurationUtils cu = new ConfigurationUtils(Collections.emptyMap()); + assertThrows(ConfigException.class, () -> cu.validateLong("key")); + } + + @Test + public void testValidateLongReturnsValueWhenPresent() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", 42L)); + assertEquals(42L, cu.validateLong("key", true)); + } + + @Test + public void testValidateLongThrowsWhenRequiredAndMissing() { + ConfigurationUtils cu = new ConfigurationUtils(Collections.emptyMap()); + assertThrows(ConfigException.class, () -> cu.validateLong("key", true)); + } + + @Test + public void testValidateLongReturnsNullWhenNotRequiredAndMissing() { + ConfigurationUtils cu = new ConfigurationUtils(Collections.emptyMap()); + assertNull(cu.validateLong("key", false)); + } + + @Test + public void testValidateLongThrowsWhenValueBelowMin() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", 5L)); + assertThrows(ConfigException.class, () -> cu.validateLong("key", true, 10L)); + } + + @Test + public void testValidateLongReturnsValueWhenValueEqualsMin() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", 10L)); + assertEquals(10L, cu.validateLong("key", true, 10L)); + } + + @Test + public void testValidateLongReturnsValueWhenValueAboveMin() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", 15L)); + assertEquals(15L, cu.validateLong("key", true, 10L)); + } + + @Test + public void testValidateLongReturnsValueWhenMinIsNull() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", 42L)); + assertEquals(42L, cu.validateLong("key", true, null)); + } + + @Test + public void testValidatePasswordReturnsValueWhenPresent() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", new Password("secret"))); + assertEquals("secret", cu.validatePassword("key")); + } + + @Test + public void testValidatePasswordThrowsWhenMissing() { + ConfigurationUtils cu = new ConfigurationUtils(Collections.emptyMap()); + assertThrows(ConfigException.class, () -> cu.validatePassword("key")); + } + + @Test + public void testValidatePasswordThrowsWhenBlank() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", new Password(" "))); + assertThrows(ConfigException.class, () -> cu.validatePassword("key")); + } + + @Test + public void testValidatePasswordTrimsWhitespace() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", new Password(" secret "))); + assertEquals("secret", cu.validatePassword("key")); + } + + @Test + public void testValidateStringReturnsValueWhenPresent() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", "value")); + assertEquals("value", cu.validateString("key", true)); + } + + @Test + public void testValidateStringTrimsWhitespace() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", " value ")); + assertEquals("value", cu.validateString("key", true)); + } + + @Test + public void testValidateStringThrowsWhenRequiredAndMissing() { + ConfigurationUtils cu = new ConfigurationUtils(Collections.emptyMap()); + assertThrows(ConfigException.class, () -> cu.validateString("key", true)); + } + + @Test + public void testValidateStringThrowsWhenRequiredAndBlank() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", " ")); + assertThrows(ConfigException.class, () -> cu.validateString("key", true)); + } + + @Test + public void testValidateStringReturnsNullWhenNotRequiredAndMissing() { + ConfigurationUtils cu = new ConfigurationUtils(Collections.emptyMap()); + assertNull(cu.validateString("key", false)); + } + + @Test + public void testValidateStringReturnsNullWhenNotRequiredAndBlank() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", " ")); + assertNull(cu.validateString("key", false)); + } + + @Test + public void testValidateBooleanReturnsTrueWhenPresent() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", true)); + assertEquals(true, cu.validateBoolean("key", true)); + } + + @Test + public void testValidateBooleanReturnsFalseWhenPresent() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", false)); + assertEquals(false, cu.validateBoolean("key", true)); + } + + @Test + public void testValidateBooleanThrowsWhenRequiredAndMissing() { + ConfigurationUtils cu = new ConfigurationUtils(Collections.emptyMap()); + assertThrows(ConfigException.class, () -> cu.validateBoolean("key", true)); + } + + @Test + public void testValidateBooleanReturnsNullWhenNotRequiredAndMissing() { + ConfigurationUtils cu = new ConfigurationUtils(Collections.emptyMap()); + assertNull(cu.validateBoolean("key", false)); + } + + @Test + public void testGetReturnsValueByName() { + ConfigurationUtils cu = new ConfigurationUtils(Map.of("key", "value"), null); + assertEquals("value", cu.get("key")); + } + + @Test + public void testGetReturnsNullWhenMissing() { + ConfigurationUtils cu = new ConfigurationUtils(Collections.emptyMap(), null); + assertNull(cu.get("missing")); + } + + @Test + public void testGetReturnsPrefixedValueOverUnprefixed() { + String prefix = ListenerName.saslMechanismPrefix(OAUTHBEARER_MECHANISM); + ConfigurationUtils cu = new ConfigurationUtils( + Map.of( + prefix + "key", "prefixed-value", + "key", "unprefixed-value" + ), + OAUTHBEARER_MECHANISM); + assertEquals("prefixed-value", cu.get("key")); + } + + @Test + public void testGetReturnsUnprefixedValueWhenPrefixedNotFound() { + ConfigurationUtils cu = new ConfigurationUtils( + Map.of("key", "unprefixed-value"), + OAUTHBEARER_MECHANISM); + assertEquals("unprefixed-value", cu.get("key")); + } + + @Test + public void testGetReturnsNullWhenNeitherPrefixedNorUnprefixedFound() { + ConfigurationUtils cu = new ConfigurationUtils( + Collections.emptyMap(), + OAUTHBEARER_MECHANISM); + assertNull(cu.get("key")); + } + + @Test + public void testGetIgnoresPrefixWhenSaslMechanismIsNull() { + ConfigurationUtils cu = new ConfigurationUtils( + Map.of("key", "value"), + null); + assertEquals("value", cu.get("key")); + } + + @Test + public void testGetIgnoresPrefixWhenSaslMechanismIsBlank() { + ConfigurationUtils cu = new ConfigurationUtils( + Map.of("key", "value"), + " "); + assertEquals("value", cu.get("key")); + } + + @Test + public void testGetConfiguredInstanceFromClassName() { + Map configs = Map.of("config.key", OAuthBearerTokenMock.class.getName()); + OAuthBearerToken result = getConfiguredInstance(configs, OAUTHBEARER_MECHANISM, + getJaasConfigEntries(), "config.key", OAuthBearerToken.class); + assertNotNull(result); + assertInstanceOf(OAuthBearerToken.class, result); + } + + @Test + public void testGetConfiguredInstanceFromClass() { + Map configs = Map.of("config.key", OAuthBearerTokenMock.class); + OAuthBearerToken result = getConfiguredInstance(configs, OAUTHBEARER_MECHANISM, + getJaasConfigEntries(), "config.key", OAuthBearerToken.class); + assertNotNull(result); + assertInstanceOf(OAuthBearerToken.class, result); + } + + @Test + public void testGetConfiguredInstanceThrowsWhenConfigIsNull() { + Map configs = Collections.emptyMap(); + assertThrows(ConfigException.class, () -> getConfiguredInstance(configs, OAUTHBEARER_MECHANISM, + getJaasConfigEntries(), "config.key", OAuthBearerToken.class)); + } + + @Test + public void testGetConfiguredInstanceThrowsWhenConfigIsWrongType() { + Map configs = Map.of("config.key", 42); + assertThrows(ConfigException.class, () -> getConfiguredInstance(configs, OAUTHBEARER_MECHANISM, + getJaasConfigEntries(), "config.key", OAuthBearerToken.class)); + } + + @Test + public void testGetConfiguredInstanceThrowsWhenClassNameIsInvalid() { + Map configs = Map.of("config.key", "com.nonexistent.ClassName"); + assertThrows(ConfigException.class, () -> getConfiguredInstance(configs, OAUTHBEARER_MECHANISM, + getJaasConfigEntries(), "config.key", OAuthBearerToken.class)); + } + + @Test + public void testGetConfiguredInstanceThrowsWhenClassIsWrongType() { + Map configs = Map.of("config.key", String.class); + assertThrows(ConfigException.class, () -> getConfiguredInstance(configs, OAUTHBEARER_MECHANISM, + getJaasConfigEntries(), "config.key", OAuthBearerToken.class)); + } + + @Test + public void testGetConfiguredInstanceCallsConfigureOnOAuthBearerConfigurable() { + Map configs = Map.of("config.key", MyConfigurableImpl.class); + MyConfigurableImpl result = getConfiguredInstance(configs, OAUTHBEARER_MECHANISM, + getJaasConfigEntries(), "config.key", MyConfigurableImpl.class); + assertTrue(result.configureCalled); + } + + @Test + public void testGetConfiguredInstanceThrowsWhenConfigureThrows() { + Map configs = Map.of("config.key", MyFailingConfigurableImpl.class); + assertThrows(ConfigException.class, () -> getConfiguredInstance(configs, OAUTHBEARER_MECHANISM, + getJaasConfigEntries(), "config.key", MyFailingConfigurableImpl.class)); + } + + @Test + public void testGetConfiguredInstanceThrowsWhenClassCannotBeInstantiated() { + Map configs = Map.of("config.key", NoDefaultConstructorImpl.class); + assertThrows(ConfigException.class, () -> getConfiguredInstance(configs, OAUTHBEARER_MECHANISM, + getJaasConfigEntries(), "config.key", NoDefaultConstructorImpl.class)); + } + private void testUrl(String value) { System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, value == null ? "" : value); Map configs = Collections.singletonMap(URL_CONFIG_NAME, value); @@ -202,4 +522,30 @@ private void testFileUrl(String value) { ConfigurationUtils cu = new ConfigurationUtils(configs); cu.validateFileUrl(URL_CONFIG_NAME); } + + public static class MyConfigurableImpl implements OAuthBearerConfigurable { + boolean configureCalled = false; + + @Override + public void configure(Map configs, String saslMechanism, + List jaasConfigEntries) { + configureCalled = true; + } + } + + public static class MyFailingConfigurableImpl implements OAuthBearerConfigurable { + @Override + public void configure(Map configs, String saslMechanism, + List jaasConfigEntries) { + throw new RuntimeException("configure() failed"); + } + } + + public static class NoDefaultConstructorImpl implements OAuthBearerConfigurable { + public NoDefaultConstructorImpl(String arg) { } + + @Override + public void configure(Map configs, String saslMechanism, + List jaasConfigEntries) { } + } } diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/JwtBearerRequestFormatterTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/JwtBearerRequestFormatterTest.java new file mode 100644 index 0000000000000..0250505bd6359 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/JwtBearerRequestFormatterTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.internals.secured; + +import org.junit.jupiter.api.Test; + +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class JwtBearerRequestFormatterTest { + + private static final String ASSERTION = "test.assertion.token"; + + @Test + public void testFormatBodyContainsGrantType() { + JwtBearerRequestFormatter formatter = new JwtBearerRequestFormatter(null, () -> ASSERTION); + assertTrue(formatter.formatBody().contains("grant_type=" + URLEncoder.encode(JwtBearerRequestFormatter.GRANT_TYPE, StandardCharsets.UTF_8))); + } + + @Test + public void testFormatBodyContainsAssertion() { + JwtBearerRequestFormatter formatter = new JwtBearerRequestFormatter(null, () -> ASSERTION); + assertTrue(formatter.formatBody().contains("assertion=" + URLEncoder.encode(ASSERTION, StandardCharsets.UTF_8))); + } + + @Test + public void testFormatBodyIncludesScopeWhenPresent() { + JwtBearerRequestFormatter formatter = new JwtBearerRequestFormatter("my-scope", () -> ASSERTION); + assertTrue(formatter.formatBody().contains("scope=my-scope")); + } + + @Test + public void testFormatBodyExcludesScopeWhenNull() { + JwtBearerRequestFormatter formatter = new JwtBearerRequestFormatter(null, () -> ASSERTION); + assertFalse(formatter.formatBody().contains("scope")); + } + + @Test + public void testFormatBodyExcludesScopeWhenBlank() { + JwtBearerRequestFormatter formatter = new JwtBearerRequestFormatter(" ", () -> ASSERTION); + assertFalse(formatter.formatBody().contains("scope")); + } + + @Test + public void testFormatBodyTrimsScopeWhitespace() { + JwtBearerRequestFormatter formatter = new JwtBearerRequestFormatter(" my-scope ", () -> ASSERTION); + assertTrue(formatter.formatBody().contains("scope=my-scope")); + } + + @Test + public void testFormatHeadersContainsRequiredHeaders() { + JwtBearerRequestFormatter formatter = new JwtBearerRequestFormatter(null, () -> ASSERTION); + Map headers = formatter.formatHeaders(); + assertEquals("application/json", headers.get("Accept")); + assertEquals("no-cache", headers.get("Cache-Control")); + assertEquals("application/x-www-form-urlencoded", headers.get("Content-Type")); + } +} \ No newline at end of file