diff --git a/buildSrc/src/main/kotlin/mcp.dokka.gradle.kts b/buildSrc/src/main/kotlin/mcp.dokka.gradle.kts index 7ad493134..4863e98da 100644 --- a/buildSrc/src/main/kotlin/mcp.dokka.gradle.kts +++ b/buildSrc/src/main/kotlin/mcp.dokka.gradle.kts @@ -21,8 +21,8 @@ dokka { documentedVisibilities(VisibilityModifier.Public) - externalDocumentationLinks.register("ktor-client") { - url("https://api.ktor.io/ktor-client/") + externalDocumentationLinks.register("ktor") { + url("https://api.ktor.io/") packageListUrl("https://api.ktor.io/package-list") } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 08bf61eea..ccb649ddd 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -49,6 +49,7 @@ ktor-client-logging = { group = "io.ktor", name = "ktor-client-logging", version ktor-server-content-negotiation = { group = "io.ktor", name = "ktor-server-content-negotiation", version.ref = "ktor" } ktor-client-content-negotiation = { group = "io.ktor", name = "ktor-client-content-negotiation", version.ref = "ktor" } ktor-serialization = { group = "io.ktor", name = "ktor-serialization-kotlinx-json", version.ref = "ktor" } +ktor-server-auth = { group = "io.ktor", name = "ktor-server-auth", version.ref = "ktor" } ktor-server-core = { group = "io.ktor", name = "ktor-server-core", version.ref = "ktor" } ktor-server-sse = { group = "io.ktor", name = "ktor-server-sse", version.ref = "ktor" } ktor-server-websockets = { group = "io.ktor", name = "ktor-server-websockets", version.ref = "ktor" } diff --git a/integration-test/build.gradle.kts b/integration-test/build.gradle.kts index 82078e7c6..30afe9e8a 100644 --- a/integration-test/build.gradle.kts +++ b/integration-test/build.gradle.kts @@ -22,6 +22,7 @@ kotlin { implementation(libs.ktor.server.content.negotiation) implementation(libs.ktor.serialization) implementation(libs.ktor.server.websockets) + implementation(libs.ktor.server.auth) implementation(libs.ktor.server.test.host) implementation(libs.ktor.server.content.negotiation) implementation(libs.ktor.serialization) diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt new file mode 100644 index 000000000..27d70bb42 --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt @@ -0,0 +1,184 @@ +package io.modelcontextprotocol.kotlin.sdk.integration + +import io.kotest.matchers.shouldBe +import io.ktor.client.HttpClient +import io.ktor.client.request.basicAuth +import io.ktor.client.request.get +import io.ktor.http.HttpStatusCode +import io.ktor.serialization.kotlinx.json.json +import io.ktor.server.application.Application +import io.ktor.server.application.ApplicationCall +import io.ktor.server.application.install +import io.ktor.server.auth.Authentication +import io.ktor.server.auth.UserIdPrincipal +import io.ktor.server.auth.authenticate +import io.ktor.server.auth.basic +import io.ktor.server.auth.principal +import io.ktor.server.engine.embeddedServer +import io.ktor.server.plugins.contentnegotiation.ContentNegotiation +import io.ktor.server.routing.Route +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.shared.Transport +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.McpJson +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.TextResourceContents +import io.modelcontextprotocol.kotlin.test.utils.actualPort +import kotlinx.coroutines.runBlocking +import java.util.UUID +import kotlin.test.Test +import io.ktor.client.engine.cio.CIO as ClientCIO +import io.ktor.server.cio.CIO as ServerCIO +import io.ktor.server.sse.SSE as ServerSSE + +/** + * Base class for MCP authentication integration tests. + */ +abstract class AbstractAuthenticationTest { + + protected companion object { + const val HOST = "127.0.0.1" + const val AUTH_REALM = "mcp-auth" + const val WHOAMI_URI = "whoami://me" + } + + protected val validUser: String = "user-${UUID.randomUUID().toString().take(8)}" + protected val validPassword: String = UUID.randomUUID().toString() + protected val invalidUser: String = "user-${UUID.randomUUID().toString().take(8)}" + protected val invalidPassword: String = UUID.randomUUID().toString() + + /** + * Installs Ktor plugins required by the transport under test. + */ + protected open fun Application.configurePlugins() { + install(ServerSSE) + // ContentNegotiation is required by the StreamableHttp transport for JSON body handling. + // Installing it for SSE tests as well is harmless. + install(ContentNegotiation) { json(McpJson) } + } + + /** + * Registers the MCP server on the given route. + */ + abstract fun Route.registerMcpServer(serverFactory: ApplicationCall.() -> Server) + + /** + * Creates a client transport configured with the given credentials. + */ + abstract fun createClientTransport(baseUrl: String, user: String, pass: String): Transport + + @Test + fun `mcp behind basic auth rejects unauthenticated requests with 401`(): Unit = runBlocking { + val server = startAuthenticatedServer() + + val httpClient = HttpClient(ClientCIO) + try { + httpClient.get("http://$HOST:${server.actualPort()}").status shouldBe HttpStatusCode.Unauthorized + } finally { + httpClient.close() + server.stopSuspend(500, 1000) + } + } + + @Test + fun `mcp rejects requests with invalid credentials`(): Unit = runBlocking { + val server = startAuthenticatedServer() + + val httpClient = HttpClient(ClientCIO) { + expectSuccess = false + } + try { + httpClient.get("http://$HOST:${server.actualPort()}") { + basicAuth(invalidUser, invalidPassword) + }.status shouldBe HttpStatusCode.Unauthorized + } finally { + httpClient.close() + server.stopSuspend(500, 1000) + } + } + + @Test + fun `authenticated mcp client can read resource scoped to principal`(): Unit = runBlocking { + val server = startAuthenticatedServer() + + val baseUrl = "http://$HOST:${server.actualPort()}" + var mcpClient: Client? = null + try { + mcpClient = Client(Implementation(name = "test-client", version = "1.0.0")) + mcpClient.connect(createClientTransport(baseUrl, validUser, validPassword)) + + val result = mcpClient.readResource( + ReadResourceRequest(ReadResourceRequestParams(uri = WHOAMI_URI)), + ) + + result.contents shouldBe listOf( + TextResourceContents( + text = validUser, + uri = WHOAMI_URI, + mimeType = "text/plain", + ), + ) + } finally { + mcpClient?.close() + server.stopSuspend(500, 1000) + } + } + + private suspend fun startAuthenticatedServer() = embeddedServer(ServerCIO, host = HOST, port = 0) { + configurePlugins() + installBasicAuth() + routing { + authenticate(AUTH_REALM) { + registerMcpServer { + createMcpServer { principal()?.name } + } + } + } + }.startSuspend(wait = false) + + private fun Application.installBasicAuth() { + install(Authentication) { + basic(AUTH_REALM) { + validate { credentials -> + if (credentials.name == validUser && credentials.password == validPassword) { + UserIdPrincipal(credentials.name) + } else { + null + } + } + } + } + } + + protected fun createMcpServer(principalProvider: () -> String?): Server = Server( + serverInfo = Implementation(name = "test-server", version = "1.0.0"), + options = ServerOptions( + capabilities = ServerCapabilities( + resources = ServerCapabilities.Resources(), + ), + ), + ).apply { + addResource( + uri = WHOAMI_URI, + name = "Current User", + description = "Returns the name of the authenticated user", + mimeType = "text/plain", + ) { + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = principalProvider() ?: "anonymous", + uri = WHOAMI_URI, + mimeType = "text/plain", + ), + ), + ) + } + } +} diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/SseAuthenticationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/SseAuthenticationTest.kt new file mode 100644 index 000000000..0b5cf8fb5 --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/sse/SseAuthenticationTest.kt @@ -0,0 +1,28 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.sse + +import io.ktor.client.HttpClient +import io.ktor.client.plugins.sse.SSE +import io.ktor.client.request.basicAuth +import io.ktor.server.application.ApplicationCall +import io.ktor.server.routing.Route +import io.modelcontextprotocol.kotlin.sdk.client.SseClientTransport +import io.modelcontextprotocol.kotlin.sdk.integration.AbstractAuthenticationTest +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.mcp +import io.modelcontextprotocol.kotlin.sdk.shared.Transport +import io.ktor.client.engine.cio.CIO as ClientCIO + +class SseAuthenticationTest : AbstractAuthenticationTest() { + + override fun Route.registerMcpServer(serverFactory: ApplicationCall.() -> Server) { + mcp { + serverFactory(call) + } + } + + override fun createClientTransport(baseUrl: String, user: String, pass: String): Transport = SseClientTransport( + client = HttpClient(ClientCIO) { install(SSE) }, + urlString = baseUrl, + requestBuilder = { basicAuth(user, pass) }, + ) +} diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/streamablehttp/StreamableHttpAuthenticationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/streamablehttp/StreamableHttpAuthenticationTest.kt new file mode 100644 index 000000000..46491a6bd --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/streamablehttp/StreamableHttpAuthenticationTest.kt @@ -0,0 +1,29 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.streamablehttp + +import io.ktor.client.HttpClient +import io.ktor.client.plugins.sse.SSE +import io.ktor.client.request.basicAuth +import io.ktor.server.application.ApplicationCall +import io.ktor.server.routing.Route +import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport +import io.modelcontextprotocol.kotlin.sdk.integration.AbstractAuthenticationTest +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.mcpStreamableHttp +import io.modelcontextprotocol.kotlin.sdk.shared.Transport +import io.ktor.client.engine.cio.CIO as ClientCIO + +class StreamableHttpAuthenticationTest : AbstractAuthenticationTest() { + + override fun Route.registerMcpServer(serverFactory: ApplicationCall.() -> Server) { + mcpStreamableHttp { + serverFactory(call) + } + } + + override fun createClientTransport(baseUrl: String, user: String, pass: String): Transport = + StreamableHttpClientTransport( + client = HttpClient(ClientCIO) { install(SSE) }, + url = baseUrl, + requestBuilder = { basicAuth(user, pass) }, + ) +} diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api index f97845bb9..a7d39662c 100644 --- a/kotlin-sdk-server/api/kotlin-sdk-server.api +++ b/kotlin-sdk-server/api/kotlin-sdk-server.api @@ -8,10 +8,30 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/KtorServerKt { public static final fun mcp (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function1;)V public static final fun mcp (Lio/ktor/server/routing/Route;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V public static final fun mcp (Lio/ktor/server/routing/Route;Lkotlin/jvm/functions/Function1;)V - public static final fun mcpStatelessStreamableHttp (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;)V - public static synthetic fun mcpStatelessStreamableHttp$default (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V - public static final fun mcpStreamableHttp (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;)V - public static synthetic fun mcpStreamableHttp$default (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static final fun mcpStatelessStreamableHttp (Lio/ktor/server/application/Application;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)V + public static final fun mcpStatelessStreamableHttp (Lio/ktor/server/routing/Route;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)V + public static final fun mcpStatelessStreamableHttp (Lio/ktor/server/routing/Route;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun mcpStatelessStreamableHttp$default (Lio/ktor/server/application/Application;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static synthetic fun mcpStatelessStreamableHttp$default (Lio/ktor/server/routing/Route;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static synthetic fun mcpStatelessStreamableHttp$default (Lio/ktor/server/routing/Route;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static final fun mcpStreamableHttp (Lio/ktor/server/application/Application;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)V + public static final fun mcpStreamableHttp (Lio/ktor/server/routing/Route;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)V + public static final fun mcpStreamableHttp (Lio/ktor/server/routing/Route;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun mcpStreamableHttp$default (Lio/ktor/server/application/Application;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static synthetic fun mcpStreamableHttp$default (Lio/ktor/server/routing/Route;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static synthetic fun mcpStreamableHttp$default (Lio/ktor/server/routing/Route;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V +} + +public final class io/modelcontextprotocol/kotlin/sdk/server/McpStreamableHttpConfig { + public fun ()V + public final fun getAllowedHosts ()Ljava/util/List; + public final fun getAllowedOrigins ()Ljava/util/List; + public final fun getEnableDnsRebindingProtection ()Z + public final fun getEventStore ()Lio/modelcontextprotocol/kotlin/sdk/server/EventStore; + public final fun setAllowedHosts (Ljava/util/List;)V + public final fun setAllowedOrigins (Ljava/util/List;)V + public final fun setEnableDnsRebindingProtection (Z)V + public final fun setEventStore (Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;)V } public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt : io/modelcontextprotocol/kotlin/sdk/server/Feature { diff --git a/kotlin-sdk-server/detekt-baseline-commonMainSourceSet.xml b/kotlin-sdk-server/detekt-baseline-commonMainSourceSet.xml index 67d5e7307..97532e092 100644 --- a/kotlin-sdk-server/detekt-baseline-commonMainSourceSet.xml +++ b/kotlin-sdk-server/detekt-baseline-commonMainSourceSet.xml @@ -7,14 +7,12 @@ MaxLineLength:SSEServerTransport.kt:SseServerTransport$* MaxLineLength:ServerSession.kt:ServerSession$"Client requested unsupported protocol version $requestedVersion, falling back to $LATEST_PROTOCOL_VERSION" MaxLineLength:ServerSession.kt:ServerSession$"Creating message with ${params.params.messages.size} messages, maxTokens=${params.params.maxTokens}, temperature=${params.params.temperature}, systemPrompt=${if (params.params.systemPrompt != null) "present" else "absent"}" - ReturnCount:KtorServer.kt:private suspend fun existingStreamableTransport: StreamableHttpServerTransport? ThrowsCount:ServerSession.kt:ServerSession$override fun assertCapabilityForMethod ThrowsCount:ServerSession.kt:ServerSession$override fun assertNotificationCapability ThrowsCount:ServerSession.kt:ServerSession$override fun assertRequestHandlerCapability TooGenericExceptionCaught:SSEServerTransport.kt:SseServerTransport$e: Exception TooGenericExceptionCaught:Server.kt:Server$e: Exception TooGenericExceptionCaught:StdioServerTransport.kt:StdioServerTransport$e: Throwable - TooManyFunctions:KtorServer.kt:io.modelcontextprotocol.kotlin.sdk.server.KtorServer.kt TooManyFunctions:Server.kt:Server TooManyFunctions:ServerSession.kt:ServerSession : Protocol diff --git a/kotlin-sdk-server/detekt-baseline-main.xml b/kotlin-sdk-server/detekt-baseline-main.xml index c016bed02..eceff6b42 100644 --- a/kotlin-sdk-server/detekt-baseline-main.xml +++ b/kotlin-sdk-server/detekt-baseline-main.xml @@ -3,7 +3,6 @@ InjectDispatcher:FeatureNotificationService.kt:FeatureNotificationService$Default - LongParameterList:KtorServer.kt:private suspend fun RoutingContext.streamableTransport: StreamableHttpServerTransport? LongParameterList:Server.kt:Server$public fun addTool MagicNumber:StdioServerTransport.kt:StdioServerTransport$8192 MaxLineLength:SSEServerTransport.kt:SseServerTransport$"SSEServerTransport already started! If using Server class, note that connect() calls start() automatically." @@ -12,14 +11,12 @@ MaxLineLength:ServerSession.kt:ServerSession$"Creating message with ${params.params.messages.size} messages, maxTokens=${params.params.maxTokens}, temperature=${params.params.temperature}, systemPrompt=${if (params.params.systemPrompt != null) "present" else "absent"}" NoNameShadowing:FeatureNotificationService.kt:FeatureNotificationService${ it.remove(job) } RedundantSuspendModifier:ServerSession.kt:ServerSession$suspend - ReturnCount:KtorServer.kt:private suspend fun existingStreamableTransport: StreamableHttpServerTransport? ThrowsCount:ServerSession.kt:ServerSession$override fun assertCapabilityForMethod ThrowsCount:ServerSession.kt:ServerSession$override fun assertNotificationCapability ThrowsCount:ServerSession.kt:ServerSession$override fun assertRequestHandlerCapability TooGenericExceptionCaught:SSEServerTransport.kt:SseServerTransport$e: Exception TooGenericExceptionCaught:Server.kt:Server$e: Exception TooGenericExceptionCaught:StdioServerTransport.kt:StdioServerTransport$e: Throwable - TooManyFunctions:KtorServer.kt:io.modelcontextprotocol.kotlin.sdk.server.KtorServer.kt TooManyFunctions:Server.kt:Server TooManyFunctions:ServerSession.kt:ServerSession : Protocol UnsafeCallOnNullableType:StreamableHttpServerTransport.kt:StreamableHttpServerTransport$responseRequestId!! diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt index 5a750a5d6..31702d9ca 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt @@ -1,3 +1,5 @@ +@file:Suppress("TooManyFunctions") + package io.modelcontextprotocol.kotlin.sdk.server import io.github.oshai.kotlinlogging.KotlinLogging @@ -6,11 +8,13 @@ import io.ktor.server.application.Application import io.ktor.server.application.ApplicationCall import io.ktor.server.application.MissingApplicationPluginException import io.ktor.server.application.install +import io.ktor.server.application.plugin import io.ktor.server.request.ApplicationRequest import io.ktor.server.request.header import io.ktor.server.response.respond import io.ktor.server.routing.Route import io.ktor.server.routing.RoutingContext +import io.ktor.server.routing.application import io.ktor.server.routing.delete import io.ktor.server.routing.get import io.ktor.server.routing.post @@ -20,31 +24,27 @@ import io.ktor.server.sse.SSE import io.ktor.server.sse.ServerSSESession import io.ktor.server.sse.sse import io.ktor.utils.io.KtorDsl -import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.types.RPCError -import kotlinx.atomicfu.AtomicRef -import kotlinx.atomicfu.atomic -import kotlinx.atomicfu.update -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.toPersistentMap import kotlinx.coroutines.awaitCancellation private val logger = KotlinLogging.logger {} -internal class TransportManager(transports: Map = emptyMap()) { - private val transports: AtomicRef> = atomic(transports.toPersistentMap()) - - fun hasTransport(sessionId: String): Boolean = transports.value.containsKey(sessionId) - - fun getTransport(sessionId: String): AbstractTransport? = transports.value[sessionId] - - fun addTransport(sessionId: String, transport: AbstractTransport) { - transports.update { it.put(sessionId, transport) } - } - - fun removeTransport(sessionId: String) { - transports.update { it.remove(sessionId) } - } +/** + * Configuration for Streamable HTTP MCP endpoints. + * + * @property enableDnsRebindingProtection Whether to enable DNS rebinding protection by + * validating the `Host` header against [allowedHosts]. + * @property allowedHosts Allowed hosts for DNS rebinding validation; only consulted when + * [enableDnsRebindingProtection] is `true`. + * @property allowedOrigins Allowed origins for cross-origin request validation. + * @property eventStore An optional [EventStore] for persistent, resumable sessions. + */ +@KtorDsl +public class McpStreamableHttpConfig { + public var enableDnsRebindingProtection: Boolean = false + public var allowedHosts: List? = null + public var allowedOrigins: List? = null + public var eventStore: EventStore? = null } /** @@ -81,7 +81,7 @@ public fun Route.mcp(block: ServerSSESession.() -> Server) { ) } - val transportManager = TransportManager() + val transportManager = TransportManager() sse { mcpSseEndpoint("", transportManager, block) @@ -101,93 +101,207 @@ public fun Application.mcp(block: ServerSSESession.() -> Server) { } } +/** + * Registers Streamable HTTP MCP endpoints at the specified [path] as a [Route] extension. + * + * This allows placing the endpoints inside an [Route.authenticate] block. + * + * **Precondition:** the [SSE] plugin must be installed on the application before calling this function. + * Use [Application.mcpStreamableHttp] if you want SSE to be installed automatically. + * + * @param path the URL path to register the routes. + * @param config optional configuration for DNS rebinding protection, CORS, and event store. + * @throws IllegalStateException if the [SSE] plugin is not installed. + */ +@KtorDsl +public fun Route.mcpStreamableHttp( + path: String, + config: McpStreamableHttpConfig.() -> Unit = {}, + serverFactory: RoutingContext.() -> Server, +) { + route(path) { + mcpStreamableHttp( + config = config, + serverFactory = serverFactory, + ) + } +} + +/** + * Registers Streamable HTTP MCP endpoints on the current route. + * + * This allows placing the endpoints inside an [Route.authenticate] block. + * Each call creates its own session namespace; registering this endpoint twice on the same + * route tree produces two independent session spaces. + * + * **Precondition:** the [SSE] plugin must be installed on the application before calling this function. + * Use [Application.mcpStreamableHttp] if you want SSE to be installed automatically. + * + * @param config optional configuration for DNS rebinding protection, CORS, and event store. + * @throws IllegalStateException if the [SSE] plugin is not installed. + */ +@KtorDsl +public fun Route.mcpStreamableHttp( + config: McpStreamableHttpConfig.() -> Unit = {}, + serverFactory: RoutingContext.() -> Server, +) { + try { + application.plugin(SSE) + } catch (e: MissingApplicationPluginException) { + throw IllegalStateException( + "The SSE plugin must be installed before registering MCP routes. " + + "Add `install(SSE)` to your application configuration, " + + "or use Application.mcpStreamableHttp() which installs it automatically.", + e, + ) + } + + val mcpConfig = McpStreamableHttpConfig().apply(config) + val transportManager = TransportManager() + + sse { + val transport = call.resolveStreamableTransport(transportManager) ?: return@sse + transport.handleRequest(this, call) + } + + post { + val transport = streamableTransport( + transportManager = transportManager, + config = mcpConfig, + serverFactory = serverFactory, + ) + ?: return@post + + transport.handleRequest(null, call) + } + + delete { + val transport = call.resolveStreamableTransport(transportManager) ?: return@delete + transport.handleRequest(null, call) + } +} + @KtorDsl -@Suppress("LongParameterList") public fun Application.mcpStreamableHttp( path: String = "/mcp", - enableDnsRebindingProtection: Boolean = false, - allowedHosts: List? = null, - allowedOrigins: List? = null, - eventStore: EventStore? = null, - block: RoutingContext.() -> Server, + config: McpStreamableHttpConfig.() -> Unit = {}, + serverFactory: RoutingContext.() -> Server, ) { install(SSE) - val transportManager = TransportManager() - routing { - route(path) { - sse { - val transport = existingStreamableTransport(call, transportManager) ?: return@sse - transport.handleRequest(this, call) - } - - post { - val transport = streamableTransport( - transportManager = transportManager, - enableDnsRebindingProtection = enableDnsRebindingProtection, - allowedHosts = allowedHosts, - allowedOrigins = allowedOrigins, - eventStore = eventStore, - block = block, - ) - ?: return@post - - transport.handleRequest(null, call) - } - - delete { - val transport = existingStreamableTransport(call, transportManager) ?: return@delete - transport.handleRequest(null, call) - } - } + mcpStreamableHttp( + path = path, + config = config, + serverFactory = serverFactory, + ) + } +} + +/** + * Registers stateless Streamable HTTP MCP endpoints at the specified [path] as a [Route] extension. + * + * This allows placing the endpoints inside an [Route.authenticate] block. + * Unlike [mcpStreamableHttp], each request creates a fresh server instance with no session + * persistence between calls. + * + * **Precondition:** the [SSE] plugin must be installed on the application before calling this function. + * Use [Application.mcpStatelessStreamableHttp] if you want SSE to be installed automatically. + * + * @param path the URL path to register the routes. + * @param config optional configuration for DNS rebinding protection, CORS, and event store. + * @throws IllegalStateException if the [SSE] plugin is not installed. + */ +@KtorDsl +public fun Route.mcpStatelessStreamableHttp( + path: String, + config: McpStreamableHttpConfig.() -> Unit = {}, + serverFactory: RoutingContext.() -> Server, +) { + route(path) { + mcpStatelessStreamableHttp( + config = config, + serverFactory = serverFactory, + ) + } +} + +/** + * Registers stateless Streamable HTTP MCP endpoints on the current route. + * + * This allows placing the endpoints inside an [Route.authenticate] block. + * Unlike [mcpStreamableHttp], each request creates a fresh server instance with no session + * persistence between calls. + * + * **Precondition:** the [SSE] plugin must be installed on the application before calling this function. + * Use [Application.mcpStatelessStreamableHttp] if you want SSE to be installed automatically. + * + * @param config optional configuration for DNS rebinding protection, CORS, and event store. + * @throws IllegalStateException if the [SSE] plugin is not installed. + */ +@KtorDsl +public fun Route.mcpStatelessStreamableHttp( + config: McpStreamableHttpConfig.() -> Unit = {}, + serverFactory: RoutingContext.() -> Server, +) { + try { + application.plugin(SSE) + } catch (e: MissingApplicationPluginException) { + throw IllegalStateException( + "The SSE plugin must be installed before registering MCP routes. " + + "Add `install(SSE)` to your application configuration, " + + "or use Application.mcpStatelessStreamableHttp() which installs it automatically.", + e, + ) + } + + val mcpConfig = McpStreamableHttpConfig().apply(config) + + post { + mcpStatelessStreamableHttpEndpoint( + enableDnsRebindingProtection = mcpConfig.enableDnsRebindingProtection, + allowedHosts = mcpConfig.allowedHosts, + allowedOrigins = mcpConfig.allowedOrigins, + eventStore = mcpConfig.eventStore, + serverFactory = serverFactory, + ) + } + get { + call.reject( + HttpStatusCode.MethodNotAllowed, + RPCError.ErrorCode.CONNECTION_CLOSED, + "Method not allowed.", + ) + } + delete { + call.reject( + HttpStatusCode.MethodNotAllowed, + RPCError.ErrorCode.CONNECTION_CLOSED, + "Method not allowed.", + ) } } @KtorDsl -@Suppress("LongParameterList") public fun Application.mcpStatelessStreamableHttp( path: String = "/mcp", - enableDnsRebindingProtection: Boolean = false, - allowedHosts: List? = null, - allowedOrigins: List? = null, - eventStore: EventStore? = null, - block: RoutingContext.() -> Server, + config: McpStreamableHttpConfig.() -> Unit = {}, + serverFactory: RoutingContext.() -> Server, ) { install(SSE) routing { - route(path) { - post { - mcpStatelessStreamableHttpEndpoint( - enableDnsRebindingProtection = enableDnsRebindingProtection, - allowedHosts = allowedHosts, - allowedOrigins = allowedOrigins, - eventStore = eventStore, - block = block, - ) - } - get { - call.reject( - HttpStatusCode.MethodNotAllowed, - RPCError.ErrorCode.CONNECTION_CLOSED, - "Method not allowed.", - ) - } - delete { - call.reject( - HttpStatusCode.MethodNotAllowed, - RPCError.ErrorCode.CONNECTION_CLOSED, - "Method not allowed.", - ) - } - } + mcpStatelessStreamableHttp( + path = path, + config = config, + serverFactory = serverFactory, + ) } } private suspend fun ServerSSESession.mcpSseEndpoint( postEndpoint: String, - transportManager: TransportManager, + transportManager: TransportManager, block: ServerSSESession.() -> Server, ) { val transport = mcpSseTransport(postEndpoint, transportManager) @@ -208,7 +322,7 @@ private suspend fun ServerSSESession.mcpSseEndpoint( private fun ServerSSESession.mcpSseTransport( postEndpoint: String, - transportManager: TransportManager, + transportManager: TransportManager, ): SseServerTransport { val transport = SseServerTransport(postEndpoint, this) transportManager.addTransport(transport.sessionId, transport) @@ -222,7 +336,7 @@ private suspend fun RoutingContext.mcpStatelessStreamableHttpEndpoint( allowedHosts: List? = null, allowedOrigins: List? = null, eventStore: EventStore? = null, - block: RoutingContext.() -> Server, + serverFactory: RoutingContext.() -> Server, ) { val transport = StreamableHttpServerTransport( enableDnsRebindingProtection = enableDnsRebindingProtection, @@ -234,7 +348,7 @@ private suspend fun RoutingContext.mcpStatelessStreamableHttpEndpoint( logger.info { "New stateless StreamableHttp connection established without sessionId" } - val server = block() + val server = serverFactory() server.onClose { logger.info { "Server connection closed without sessionId" } } server.createSession(transport) @@ -242,7 +356,7 @@ private suspend fun RoutingContext.mcpStatelessStreamableHttpEndpoint( logger.debug { "Server connected to transport without sessionId" } } -private suspend fun RoutingContext.mcpPostEndpoint(transportManager: TransportManager) { +private suspend fun RoutingContext.mcpPostEndpoint(transportManager: TransportManager) { val sessionId: String = call.request.queryParameters["sessionId"] ?: run { call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided") return @@ -250,7 +364,7 @@ private suspend fun RoutingContext.mcpPostEndpoint(transportManager: TransportMa logger.debug { "Received message for sessionId: $sessionId" } - val transport = transportManager.getTransport(sessionId) as SseServerTransport? + val transport = transportManager.getTransport(sessionId) if (transport == null) { logger.warn { "Session not found for sessionId: $sessionId" } call.respond(HttpStatusCode.NotFound, "Session not found") @@ -263,13 +377,12 @@ private suspend fun RoutingContext.mcpPostEndpoint(transportManager: TransportMa private fun ApplicationRequest.sessionId(): String? = header(MCP_SESSION_ID_HEADER) -private suspend fun existingStreamableTransport( - call: ApplicationCall, - transportManager: TransportManager, +private suspend fun ApplicationCall.resolveStreamableTransport( + transportManager: TransportManager, ): StreamableHttpServerTransport? { - val sessionId = call.request.sessionId() + val sessionId = request.sessionId() if (sessionId.isNullOrEmpty()) { - call.reject( + reject( HttpStatusCode.BadRequest, RPCError.ErrorCode.CONNECTION_CLOSED, "Bad Request: No valid session ID provided", @@ -277,38 +390,32 @@ private suspend fun existingStreamableTransport( return null } - val transport = transportManager.getTransport(sessionId) as? StreamableHttpServerTransport - if (transport == null) { - call.reject( + return transportManager.getTransport(sessionId) ?: run { + reject( HttpStatusCode.NotFound, RPCError.ErrorCode.CONNECTION_CLOSED, "Session not found", ) - return null + null } - - return transport } private suspend fun RoutingContext.streamableTransport( - transportManager: TransportManager, - enableDnsRebindingProtection: Boolean, - allowedHosts: List?, - allowedOrigins: List?, - eventStore: EventStore?, - block: RoutingContext.() -> Server, + transportManager: TransportManager, + config: McpStreamableHttpConfig, + serverFactory: RoutingContext.() -> Server, ): StreamableHttpServerTransport? { val sessionId = call.request.sessionId() if (sessionId != null) { - val transport = transportManager.getTransport(sessionId) as? StreamableHttpServerTransport - return transport ?: existingStreamableTransport(call, transportManager) + val transport = transportManager.getTransport(sessionId) + return transport ?: call.resolveStreamableTransport(transportManager) } val transport = StreamableHttpServerTransport( - enableDnsRebindingProtection = enableDnsRebindingProtection, - allowedHosts = allowedHosts, - allowedOrigins = allowedOrigins, - eventStore = eventStore, + enableDnsRebindingProtection = config.enableDnsRebindingProtection, + allowedHosts = config.allowedHosts, + allowedOrigins = config.allowedOrigins, + eventStore = config.eventStore, enableJsonResponse = true, ) @@ -322,7 +429,7 @@ private suspend fun RoutingContext.streamableTransport( logger.info { "Closed StreamableHttp connection and removed sessionId: $closedSession" } } - val server = block() + val server = serverFactory() server.onClose { transport.sessionId?.let { transportManager.removeTransport(it) } logger.info { "Server connection closed for sessionId: ${transport.sessionId}" } diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index 1894cdc45..406272c32 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -12,6 +12,7 @@ import io.ktor.server.request.receiveText import io.ktor.server.response.header import io.ktor.server.response.respond import io.ktor.server.response.respondNullable +import io.ktor.server.response.respondText import io.ktor.server.sse.ServerSSESession import io.ktor.util.collections.ConcurrentMap import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport @@ -666,6 +667,6 @@ public class StreamableHttpServerTransport( } internal suspend fun ApplicationCall.reject(status: HttpStatusCode, code: Int, message: String) { - this.response.status(status) - this.respond(JSONRPCError(id = null, error = RPCError(code = code, message = message))) + val body = McpJson.encodeToString(JSONRPCError(id = null, error = RPCError(code = code, message = message))) + respondText(body, ContentType.Application.Json, status) } diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/TransportManager.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/TransportManager.kt new file mode 100644 index 000000000..8ff98e128 --- /dev/null +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/TransportManager.kt @@ -0,0 +1,30 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport +import kotlinx.atomicfu.AtomicRef +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.update +import kotlinx.collections.immutable.PersistentMap +import kotlinx.collections.immutable.toPersistentMap + +/** + * Manages active transports keyed by session ID. + * + * Each invocation of [mcpStreamableHttp] or [mcpStatelessStreamableHttp] creates its own + * [TransportManager] with an independent session namespace. Registering the same endpoint + * function twice on the same route tree results in two disjoint session spaces — sessions + * established through one registration are invisible to the other. + */ +internal class TransportManager(transports: Map = emptyMap()) { + private val transports: AtomicRef> = atomic(transports.toPersistentMap()) + + fun getTransport(sessionId: String): T? = transports.value[sessionId] + + fun addTransport(sessionId: String, transport: T) { + transports.update { it.put(sessionId, transport) } + } + + fun removeTransport(sessionId: String) { + transports.update { it.remove(sessionId) } + } +} diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/AbstractKtorExtensionsTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/AbstractKtorExtensionsTest.kt deleted file mode 100644 index 28e21b9c5..000000000 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/AbstractKtorExtensionsTest.kt +++ /dev/null @@ -1,68 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.server - -import io.kotest.assertions.ktor.client.shouldHaveContentType -import io.kotest.assertions.ktor.client.shouldHaveStatus -import io.ktor.client.HttpClient -import io.ktor.client.request.post -import io.ktor.client.request.prepareGet -import io.ktor.client.request.setBody -import io.ktor.client.statement.bodyAsChannel -import io.ktor.http.ContentType -import io.ktor.http.HttpStatusCode -import io.ktor.http.contentType -import io.ktor.utils.io.readUTF8Line -import io.modelcontextprotocol.kotlin.sdk.types.Implementation -import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities - -@Suppress("AbstractClassCanBeConcreteClass") -abstract class AbstractKtorExtensionsTest { - - protected val sseContentType = ContentType("text", "event-stream") - - protected fun testServer() = Server( - serverInfo = Implementation(name = "test-server", version = "1.0.0"), - options = ServerOptions(capabilities = ServerCapabilities()), - ) - - /** - * Asserts that both MCP transport endpoints are registered at [path]: - * - GET returns 200 with `text/event-stream` content type (SSE endpoint) - * - POST with a valid MCP payload and session returns 202 Accepted - * - POST without a sessionId returns 400 Bad Request - */ - protected suspend fun HttpClient.assertMcpEndpointsAt(path: String) { - prepareGet(path).execute { response -> - response.shouldHaveStatus(HttpStatusCode.OK) - response.shouldHaveContentType(sseContentType) - - // Extract sessionId from the SSE "endpoint" event - val channel = response.bodyAsChannel() - var eventName: String? = null - var sessionId: String? = null - - while (sessionId == null && !channel.isClosedForRead) { - val line = channel.readUTF8Line() ?: break - when { - line.startsWith("event:") -> eventName = line.substringAfter("event:").trim() - - line.startsWith("data:") && eventName == "endpoint" -> { - val data = line.substringAfter("data:").trim() - sessionId = data.substringAfter("sessionId=").ifEmpty { null } - } - } - } - - requireNotNull(sessionId) { "sessionId not found in SSE endpoint event" } - - // POST a valid JSON-RPC ping while the SSE connection is alive - val postResponse = post("$path?sessionId=$sessionId") { - contentType(ContentType.Application.Json) - setBody("""{"jsonrpc":"2.0","id":1,"method":"ping"}""") - } - postResponse.shouldHaveStatus(HttpStatusCode.Accepted) - } - - // POST without sessionId returns 400 Bad Request - post(path).shouldHaveStatus(HttpStatusCode.BadRequest) - } -} diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorApplicationExtensionsTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorApplicationExtensionsTest.kt index 05200e59b..357f74c31 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorApplicationExtensionsTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorApplicationExtensionsTest.kt @@ -12,19 +12,15 @@ import io.ktor.server.testing.testApplication import kotlin.test.Test /** - * Integration tests for Ktor Application.mcp() extension. + * Integration tests for [Application.mcp] extension. * - * Verifies that Application.mcp() installs SSE automatically and registers - * MCP endpoints at the application root, without requiring explicit install(SSE). + * Verifies that [Application.mcp] installs the SSE plugin automatically and registers + * MCP endpoints at the application root, without requiring an explicit `install(SSE)` call. */ -class KtorApplicationExtensionsTest : AbstractKtorExtensionsTest() { +class KtorApplicationExtensionsTest { - /** - * Verifies that Application.mcp() does not interfere with other routes - * added to the same application. - */ @Test - fun `Application mcp should installs SSE and coexist with other routes`() = testApplication { + fun `Application mcp should coexist with other routes`() = testApplication { application { mcp { testServer() } diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorRouteExtensionsTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorRouteExtensionsTest.kt index 5ba4448f6..94a794422 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorRouteExtensionsTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorRouteExtensionsTest.kt @@ -1,6 +1,7 @@ package io.modelcontextprotocol.kotlin.sdk.server import io.kotest.assertions.ktor.client.shouldHaveStatus +import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContain import io.ktor.client.request.get @@ -24,7 +25,7 @@ import kotlin.test.assertFailsWith * The key issue was that Routing.mcp() registered at top-level, preventing use on subpaths. * Now Route.mcp() allows registration on any route path. */ -class KtorRouteExtensionsTest : AbstractKtorExtensionsTest() { +class KtorRouteExtensionsTest { /** * Verifies that Route.mcp() throws immediately at route registration time @@ -43,8 +44,10 @@ class KtorRouteExtensionsTest : AbstractKtorExtensionsTest() { client.get("/") } } - exception.message shouldContain "SSE" - exception.message shouldContain "install" + exception.message shouldNotBeNull { + shouldContain("SSE") + shouldContain("install") + } } /** diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorStatelessStreamableHttpRouteExtensionsTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorStatelessStreamableHttpRouteExtensionsTest.kt new file mode 100644 index 000000000..62a08fb1a --- /dev/null +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorStatelessStreamableHttpRouteExtensionsTest.kt @@ -0,0 +1,112 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.kotest.assertions.ktor.client.shouldHaveStatus +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain +import io.ktor.client.request.get +import io.ktor.client.request.post +import io.ktor.client.statement.bodyAsText +import io.ktor.http.HttpStatusCode +import io.ktor.server.application.install +import io.ktor.server.response.respondText +import io.ktor.server.routing.get +import io.ktor.server.routing.route +import io.ktor.server.routing.routing +import io.ktor.server.sse.SSE +import io.ktor.server.testing.testApplication +import kotlin.test.Test +import kotlin.test.assertFailsWith + +class KtorStatelessStreamableHttpRouteExtensionsTest { + + @Test + fun `Route mcpStatelessStreamableHttp should throw at registration time if SSE plugin is not installed`() { + val exception = assertFailsWith { + testApplication { + application { + routing { + mcpStatelessStreamableHttp { testServer() } + } + } + client.get("/") + } + } + exception.message shouldNotBeNull { + shouldContain("SSE") + shouldContain("install") + } + } + + @Test + fun `Route mcpStatelessStreamableHttp GET and DELETE should return 405 Method Not Allowed`() = testApplication { + application { + install(SSE) + routing { + route("/mcp") { + mcpStatelessStreamableHttp { testServer() } + } + } + } + + client.assertStatelessStreamableHttpEndpointsAt("/mcp") + } + + @Test + fun `Route mcpStatelessStreamableHttp should register endpoints at the full nested path`() = testApplication { + application { + install(SSE) + routing { + route("/v1") { + route("/mcp") { + mcpStatelessStreamableHttp { testServer() } + } + } + } + } + + client.assertStatelessStreamableHttpEndpointsAt("/v1/mcp") + } + + @Test + fun `Route mcpStatelessStreamableHttp with path should register endpoints at the resolved subpath`() = + testApplication { + application { + install(SSE) + routing { + route("/api") { + mcpStatelessStreamableHttp("/mcp") { testServer() } + } + } + } + + client.assertStatelessStreamableHttpEndpointsAt("/api/mcp") + + // The parent route /api is not an MCP endpoint + client.post("/api").shouldHaveStatus(HttpStatusCode.NotFound) + } + + @Test + fun `Route mcpStatelessStreamableHttp should not interfere with sibling routes`() = testApplication { + application { + install(SSE) + routing { + get("/health") { call.respondText("ok") } + route("/mcp") { + get("/docs") { call.respondText("docs") } + mcpStatelessStreamableHttp { testServer() } + } + } + } + + val healthResponse = client.get("/health") + healthResponse.shouldHaveStatus(HttpStatusCode.OK) + healthResponse.bodyAsText() shouldBe "ok" + + val docsResponse = client.get("/mcp/docs") + docsResponse.shouldHaveStatus(HttpStatusCode.OK) + docsResponse.bodyAsText() shouldBe "docs" + + client.assertStatelessStreamableHttpEndpointsAt("/mcp") + } +} diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorStreamableHttpApplicationExtensionsTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorStreamableHttpApplicationExtensionsTest.kt new file mode 100644 index 000000000..556cc6ff8 --- /dev/null +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorStreamableHttpApplicationExtensionsTest.kt @@ -0,0 +1,74 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.kotest.assertions.ktor.client.shouldHaveStatus +import io.kotest.matchers.shouldBe +import io.ktor.client.request.get +import io.ktor.client.statement.bodyAsText +import io.ktor.http.HttpStatusCode +import io.ktor.server.response.respondText +import io.ktor.server.routing.get +import io.ktor.server.routing.routing +import io.ktor.server.testing.testApplication +import kotlin.test.Test + +class KtorStreamableHttpApplicationExtensionsTest { + + @Test + fun `Application mcpStreamableHttp should install SSE and register endpoints at default path`() = testApplication { + application { + mcpStreamableHttp { testServer() } + } + + client.assertStreamableHttpEndpointsAt("/mcp") + } + + @Test + fun `Application mcpStreamableHttp should register endpoints at a custom path`() = testApplication { + application { + mcpStreamableHttp(path = "/api/v1/mcp") { testServer() } + } + + client.assertStreamableHttpEndpointsAt("/api/v1/mcp") + + // Default path is not registered + client.get("/mcp").shouldHaveStatus(HttpStatusCode.NotFound) + } + + @Test + fun `Application mcpStreamableHttp should coexist with other routes`() = testApplication { + application { + mcpStreamableHttp { testServer() } + routing { + get("/health") { call.respondText("healthy") } + } + } + + val healthResponse = client.get("/health") + healthResponse.shouldHaveStatus(HttpStatusCode.OK) + healthResponse.bodyAsText() shouldBe "healthy" + + client.assertStreamableHttpEndpointsAt("/mcp") + } + + @Test + fun `Application mcpStatelessStreamableHttp should install SSE and register endpoints at default path`() = + testApplication { + application { + mcpStatelessStreamableHttp { testServer() } + } + + client.assertStatelessStreamableHttpEndpointsAt("/mcp") + } + + @Test + fun `Application mcpStatelessStreamableHttp should register endpoints at a custom path`() = testApplication { + application { + mcpStatelessStreamableHttp(path = "/api/v1/mcp") { testServer() } + } + + client.assertStatelessStreamableHttpEndpointsAt("/api/v1/mcp") + + // Default path is not registered + client.get("/mcp").shouldHaveStatus(HttpStatusCode.NotFound) + } +} diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorStreamableHttpRouteExtensionsTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorStreamableHttpRouteExtensionsTest.kt new file mode 100644 index 000000000..947c6e500 --- /dev/null +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorStreamableHttpRouteExtensionsTest.kt @@ -0,0 +1,121 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.kotest.assertions.ktor.client.shouldHaveStatus +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldContain +import io.ktor.client.request.get +import io.ktor.client.request.post +import io.ktor.client.statement.bodyAsText +import io.ktor.http.HttpStatusCode +import io.ktor.server.application.install +import io.ktor.server.response.respondText +import io.ktor.server.routing.get +import io.ktor.server.routing.route +import io.ktor.server.routing.routing +import io.ktor.server.sse.SSE +import io.ktor.server.testing.testApplication +import kotlin.test.Test +import kotlin.test.assertFailsWith + +/** + * Integration tests for [Route.mcpStreamableHttp] and [Route.mcpStatelessStreamableHttp]. + * + * Route-level tests focus on routing correctness (correct paths registered, fail-fast + * on missing SSE plugin, sibling routes unaffected). + */ +class KtorStreamableHttpRouteExtensionsTest { + + @Test + fun `Route mcpStreamableHttp should throw at registration time if SSE plugin is not installed`() { + val exception = assertFailsWith { + testApplication { + application { + // Intentionally omit install(SSE) + routing { + mcpStreamableHttp { testServer() } + } + } + client.get("/") + } + } + exception.message shouldNotBeNull { + shouldContain("SSE") + shouldContain("install") + } + } + + @Test + fun `Route mcpStreamableHttp should register GET DELETE and POST endpoints at the current route`() = + testApplication { + application { + install(SSE) + routing { + route("/mcp") { + mcpStreamableHttp { testServer() } + } + } + } + + client.assertStreamableHttpEndpointsAt("/mcp") + } + + @Test + fun `Route mcpStreamableHttp should register endpoints at the full nested path`() = testApplication { + application { + install(SSE) + routing { + route("/v1") { + route("/services") { + route("/mcp") { + mcpStreamableHttp { testServer() } + } + } + } + } + } + + client.assertStreamableHttpEndpointsAt("/v1/services/mcp") + } + + @Test + fun `Route mcpStreamableHttp with path should register endpoints at the resolved subpath`() = testApplication { + application { + install(SSE) + routing { + route("/api") { + mcpStreamableHttp("/mcp-endpoint") { testServer() } + } + } + } + + client.assertStreamableHttpEndpointsAt("/api/mcp-endpoint") + + // The parent route /api is not an MCP endpoint + client.post("/api").shouldHaveStatus(HttpStatusCode.NotFound) + } + + @Test + fun `Route mcpStreamableHttp should not interfere with sibling routes`() = testApplication { + application { + install(SSE) + routing { + get("/health") { call.respondText("ok") } + route("/mcp") { + get("/docs") { call.respondText("docs") } + mcpStreamableHttp { testServer() } + } + } + } + + val healthResponse = client.get("/health") + healthResponse.shouldHaveStatus(HttpStatusCode.OK) + healthResponse.bodyAsText() shouldBe "ok" + + val docsResponse = client.get("/mcp/docs") + docsResponse.shouldHaveStatus(HttpStatusCode.OK) + docsResponse.bodyAsText() shouldBe "docs" + + client.assertStreamableHttpEndpointsAt("/mcp") + } +} diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/TestHelpers.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/TestHelpers.kt new file mode 100644 index 000000000..b903218f7 --- /dev/null +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/TestHelpers.kt @@ -0,0 +1,126 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.kotest.assertions.ktor.client.shouldHaveContentType +import io.kotest.assertions.ktor.client.shouldHaveStatus +import io.kotest.assertions.withClue +import io.kotest.matchers.nulls.shouldNotBeNull +import io.ktor.client.HttpClient +import io.ktor.client.request.HttpRequestBuilder +import io.ktor.client.request.delete +import io.ktor.client.request.get +import io.ktor.client.request.header +import io.ktor.client.request.post +import io.ktor.client.request.prepareGet +import io.ktor.client.request.setBody +import io.ktor.client.statement.bodyAsChannel +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.contentType +import io.ktor.utils.io.readUTF8Line +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities + +internal fun testServer() = Server( + serverInfo = Implementation(name = "test-server", version = "1.0.0"), + options = ServerOptions(capabilities = ServerCapabilities()), +) + +/** + * Asserts that stateless Streamable HTTP MCP endpoints are registered at [path]: + * - GET returns 405 Method Not Allowed (explicitly rejected by the stateless routing layer) + * - DELETE returns 405 Method Not Allowed (same) + * - POST is routed to the transport (returns 406 for a deliberately wrong Accept, confirming the route exists) + * + * Use [configureRequest] to add headers (e.g. `basicAuth(...)`) to every request. + */ +internal suspend fun HttpClient.assertStatelessStreamableHttpEndpointsAt( + path: String, + configureRequest: HttpRequestBuilder.() -> Unit = {}, +) { + get(path) { configureRequest() }.shouldHaveStatus(HttpStatusCode.MethodNotAllowed) + delete(path) { configureRequest() }.shouldHaveStatus(HttpStatusCode.MethodNotAllowed) + + post(path) { + contentType(ContentType.Application.Json) + header(HttpHeaders.Accept, ContentType.Text.Plain.toString()) + configureRequest() + }.shouldHaveStatus(HttpStatusCode.NotAcceptable) +} + +/** + * Asserts that stateful Streamable HTTP MCP endpoints are registered at [path]: + * - GET opens an SSE connection (200 OK); session validation inside the SSE body cannot change + * the already-committed status, so the connection closes immediately without a session + * - DELETE without a session ID returns 400 Bad Request + * - POST is routed to the transport (returns 406 for a deliberately wrong Accept, confirming the route exists) + * + * Use [configureRequest] to add headers (e.g. `basicAuth(...)`) to every request. + */ +internal suspend fun HttpClient.assertStreamableHttpEndpointsAt( + path: String, + configureRequest: HttpRequestBuilder.() -> Unit = {}, +) { + // GET starts an SSE handshake — 200 is committed before the body runs + get(path) { configureRequest() }.shouldHaveStatus(HttpStatusCode.OK) + + // DELETE without session ID is rejected by the route handler + delete(path) { configureRequest() }.shouldHaveStatus(HttpStatusCode.BadRequest) + + // POST reaches the transport: a wrong Accept header triggers 406, not 404 + post(path) { + contentType(ContentType.Application.Json) + header(HttpHeaders.Accept, ContentType.Text.Plain.toString()) + configureRequest() + }.shouldHaveStatus(HttpStatusCode.NotAcceptable) +} + +/** + * Asserts that both MCP transport endpoints are registered at [path]: + * - GET returns 200 with `text/event-stream` content type (SSE endpoint) + * - POST with a valid MCP payload and session returns 202 Accepted + * - POST without a sessionId returns 400 Bad Request + * + * Use [configureRequest] to add headers (e.g. `basicAuth(...)`) to every request. + */ +internal suspend fun HttpClient.assertMcpEndpointsAt( + path: String, + configureRequest: HttpRequestBuilder.() -> Unit = {}, +) { + prepareGet(path) { configureRequest() }.execute { response -> + response.shouldHaveStatus(HttpStatusCode.OK) + response.shouldHaveContentType(ContentType("text", "event-stream")) + + // Extract sessionId from the SSE "endpoint" event + val channel = response.bodyAsChannel() + var eventName: String? = null + var sessionId: String? = null + + while (sessionId == null && !channel.isClosedForRead) { + val line = channel.readUTF8Line() ?: break + when { + line.startsWith("event:") -> eventName = line.substringAfter("event:").trim() + + line.startsWith("data:") && eventName == "endpoint" -> { + val data = line.substringAfter("data:").trim() + sessionId = data.substringAfter("sessionId=").ifEmpty { null } + } + } + } + + val resolvedSessionId = withClue("sessionId not found in SSE endpoint event") { + sessionId.shouldNotBeNull() + } + + // POST a valid JSON-RPC ping while the SSE connection is alive + val postResponse = post("$path?sessionId=$resolvedSessionId") { + contentType(ContentType.Application.Json) + setBody("""{"jsonrpc":"2.0","id":1,"method":"ping"}""") + configureRequest() + } + postResponse.shouldHaveStatus(HttpStatusCode.Accepted) + } + + // POST without sessionId returns 400 Bad Request + post(path) { configureRequest() }.shouldHaveStatus(HttpStatusCode.BadRequest) +} diff --git a/kotlin-sdk-server/src/jvmTest/resources/junit-platform.properties b/kotlin-sdk-server/src/jvmTest/resources/junit-platform.properties new file mode 100644 index 000000000..f14d10085 --- /dev/null +++ b/kotlin-sdk-server/src/jvmTest/resources/junit-platform.properties @@ -0,0 +1,5 @@ +## https://docs.junit.org/5.3.0-M1/user-guide/index.html#writing-tests-parallel-execution +junit.jupiter.execution.parallel.enabled=true +junit.jupiter.execution.parallel.config.strategy=dynamic +junit.jupiter.execution.parallel.mode.default=concurrent +junit.jupiter.execution.parallel.mode.classes.default=concurrent