diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt index d14ae88e8..4f300290d 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt @@ -6,9 +6,14 @@ import io.kotest.matchers.string.shouldContain import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequest import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequestParams import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.types.ListPromptsRequest +import io.modelcontextprotocol.kotlin.sdk.types.ListPromptsResult import io.modelcontextprotocol.kotlin.sdk.types.McpException +import io.modelcontextprotocol.kotlin.sdk.types.Method +import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams import io.modelcontextprotocol.kotlin.sdk.types.PromptArgument import io.modelcontextprotocol.kotlin.sdk.types.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.types.RPCError import io.modelcontextprotocol.kotlin.sdk.types.Role import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.types.TextContent @@ -19,6 +24,7 @@ import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import kotlin.test.assertNotNull import kotlin.test.assertTrue @@ -157,8 +163,8 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() { // validate required arguments val requiredArgs = listOf("arg1", "arg2", "arg3") for (argName in requiredArgs) { - if (request.params.arguments?.get(argName) == null) { - throw IllegalArgumentException("Missing required argument: $argName") + requireNotNull(request.params.arguments?.get(argName)) { + "Missing required argument: $argName" } } @@ -697,4 +703,62 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() { exception.message shouldBe expectedMessage } } + + @Test + fun testListPromptsPagination() = runBlocking(Dispatchers.IO) { + val pagePrefix = "paginated-prompt-" + (0 until 5).forEach { i -> + val name = "$pagePrefix$i" + server.addPrompt(name = name, description = "desc", arguments = listOf()) { _ -> + GetPromptResult( + description = "desc", + messages = listOf(PromptMessage(role = Role.Assistant, content = TextContent(text = name))), + ) + } + } + + server.sessions.forEach { (_, session) -> + session.setRequestHandler(Method.Defined.PromptsList) { request, _ -> + val all = server.prompts.values.map { it.prompt } + val cursor = request.cursor?.toIntOrNull() ?: 0 + val pageSize = 2 + val page = all.drop(cursor).take(pageSize) + val next = if (cursor + page.size < all.size) (cursor + page.size).toString() else null + ListPromptsResult(prompts = page, nextCursor = next) + } + } + + val allPrompts = mutableListOf() + var currentCursor: String? = null + do { + val request = if (currentCursor == null) { + ListPromptsRequest() + } else { + ListPromptsRequest(PaginatedRequestParams(cursor = currentCursor)) + } + val response = client.listPrompts(request) + allPrompts.addAll(response.prompts) + currentCursor = response.nextCursor + } while (currentCursor != null) + + assertTrue(allPrompts.any { it.name.startsWith(pagePrefix) }) + } + + @Test + fun testListPromptsInvalidCursor() = runBlocking(Dispatchers.IO) { + server.sessions.forEach { (_, session) -> + session.setRequestHandler(Method.Defined.PromptsList) { request, _ -> + val cursor = requireNotNull(request.cursor?.toIntOrNull()) { "Invalid cursor" } + val all = server.prompts.values.map { it.prompt } + val page = all.drop(cursor).take(2) + ListPromptsResult(prompts = page, nextCursor = null) + } + } + + val exception = assertFailsWith { + client.listPrompts(ListPromptsRequest(PaginatedRequestParams(cursor = "not-a-number"))) + } + + assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code) + } } diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt index 4fc53a52c..ce4cb6712 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt @@ -1,7 +1,11 @@ package io.modelcontextprotocol.kotlin.sdk.integration.kotlin import io.modelcontextprotocol.kotlin.sdk.types.BlobResourceContents +import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesRequest +import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesResult import io.modelcontextprotocol.kotlin.sdk.types.McpException +import io.modelcontextprotocol.kotlin.sdk.types.Method +import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams import io.modelcontextprotocol.kotlin.sdk.types.RPCError import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequestParams @@ -309,4 +313,69 @@ abstract class AbstractResourceIntegrationTest : KotlinTestBase() { assertTrue(result.contents.isNotEmpty(), "Result contents should not be empty") } } + + @Test + fun testListResourcesPagination() = runBlocking(Dispatchers.IO) { + val prefix = "paginated-resource-" + (0 until 6).forEach { i -> + val uri = "test://$prefix$i.txt" + server.addResource(uri = uri, name = "Name-$i", description = "desc", mimeType = "text/plain") { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = uri, + uri = request.params.uri, + mimeType = "text/plain", + ), + ), + ) + } + } + + server.sessions.forEach { (_, session) -> + session.setRequestHandler(Method.Defined.ResourcesList) { request, _ -> + val all = server.resources.values.map { it.resource } + val cursor = request.cursor?.toIntOrNull() ?: 0 + val pageSize = 3 + val page = all.drop(cursor).take(pageSize) + val next = if (cursor + page.size < all.size) (cursor + page.size).toString() else null + ListResourcesResult(resources = page, nextCursor = next) + } + } + + val combinedUris = mutableListOf() + var currentCursor: String? = null + + do { + val request = if (currentCursor == null) { + ListResourcesRequest() + } else { + ListResourcesRequest(PaginatedRequestParams(cursor = currentCursor)) + } + + val response = client.listResources(request) + combinedUris += response.resources.map { it.uri } + currentCursor = response.nextCursor + } while (currentCursor != null) + + assertTrue(combinedUris.any { it.contains(prefix) }) + } + + @Test + fun testListResourcesInvalidCursor() = runBlocking(Dispatchers.IO) { + server.sessions.forEach { (_, session) -> + session.setRequestHandler(Method.Defined.ResourcesList) { request, _ -> + val cursor = requireNotNull(request.cursor?.toIntOrNull()) { "Invalid cursor" } + val all = server.resources.values.map { it.resource } + val page = all.drop(cursor).take(2) + ListResourcesResult(resources = page, nextCursor = null) + } + } + + val exception = kotlin.test.assertFailsWith { + client.listResources(ListResourcesRequest(PaginatedRequestParams(cursor = "bad"))) + } + + assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code) + } } diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt index 7da82cc3a..661ad5dd7 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt @@ -6,6 +6,12 @@ import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult import io.modelcontextprotocol.kotlin.sdk.types.ContentBlock import io.modelcontextprotocol.kotlin.sdk.types.ImageContent +import io.modelcontextprotocol.kotlin.sdk.types.ListToolsRequest +import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult +import io.modelcontextprotocol.kotlin.sdk.types.McpException +import io.modelcontextprotocol.kotlin.sdk.types.Method +import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.RPCError import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.types.TextContent import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema @@ -791,4 +797,67 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() { "Error message should indicate the tool was not found", ) } + + @Test + fun testListToolsPagination() = runBlocking(Dispatchers.IO) { + val prefix = "paginated-tool-" + (0 until 5).forEach { i -> + val name = "$prefix$i" + server.addTool(name = name, description = "desc") { request -> + CallToolResult( + content = listOf(TextContent(text = name)), + structuredContent = buildJsonObject { put("name", name) }, + ) + } + } + + server.sessions.forEach { (_, session) -> + session.setRequestHandler(Method.Defined.ToolsList) { request, _ -> + val all = server.tools.values.map { it.tool }.sortedBy { it.name } + val cursor = request.cursor?.toIntOrNull() ?: 0 + val pageSize = 2 + val page = all.drop(cursor).take(pageSize) + val next = if (cursor + page.size < all.size) (cursor + page.size).toString() else null + ListToolsResult(tools = page, nextCursor = next) + } + } + + val combinedNames = mutableListOf() + var currentCursor: String? = null + + do { + val request = if (currentCursor == null) { + ListToolsRequest() + } else { + ListToolsRequest(PaginatedRequestParams(cursor = currentCursor)) + } + + val response = client.listTools(request) + combinedNames += response.tools.map { it.name } + currentCursor = response.nextCursor + } while (currentCursor != null) + + val paginatedNames = combinedNames.filter { it.startsWith(prefix) } + assertEquals(5, paginatedNames.size, "All 5 paginated tools should appear") + assertEquals(combinedNames.size, combinedNames.distinct().size, "No duplicate tools across pages") + assertEquals(server.tools.size, combinedNames.size, "Total tools should match server registry") + } + + @Test + fun testListToolsInvalidCursor() = runBlocking(Dispatchers.IO) { + server.sessions.forEach { (_, session) -> + session.setRequestHandler(Method.Defined.ToolsList) { request, _ -> + val cursor = requireNotNull(request.cursor?.toIntOrNull()) { "Invalid cursor" } + val all = server.tools.values.map { it.tool } + val page = all.drop(cursor).take(2) + ListToolsResult(tools = page) + } + } + + val exception = kotlin.test.assertFailsWith { + client.listTools(ListToolsRequest(PaginatedRequestParams(cursor = "bad"))) + } + + assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code) + } } diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/streamablehttp/LoggingIntegrationTestStreamableHttp.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/streamablehttp/LoggingIntegrationTestStreamableHttp.kt new file mode 100644 index 000000000..31dabd94e --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/streamablehttp/LoggingIntegrationTestStreamableHttp.kt @@ -0,0 +1,123 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.streamablehttp + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.KotlinTestBase +import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel +import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotification +import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotificationParams +import io.modelcontextprotocol.kotlin.sdk.types.Method +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.delay +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.parallel.Execution +import org.junit.jupiter.api.parallel.ExecutionMode + +@Execution(ExecutionMode.SAME_THREAD) +class LoggingIntegrationTestStreamableHttp : KotlinTestBase() { + + override val transportKind = TransportKind.STREAMABLE_HTTP + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(listChanged = true), + logging = JsonObject(emptyMap()), + ) + + override fun configureServer() { + server.addTool(name = "test-notification", description = "test") { request -> + notification( + LoggingMessageNotification( + LoggingMessageNotificationParams( + level = LoggingLevel.Info, + data = JsonPrimitive("test-data-sample"), + ), + ), + ) + CallToolResult(listOf(TextContent("ok"))) + } + + server.addTool(name = "test-logging", description = "test") { request -> + sendLoggingMessage( + LoggingMessageNotification( + LoggingMessageNotificationParams( + level = LoggingLevel.Info, + data = JsonObject(mapOf("key" to JsonPrimitive("value"))), + ), + ), + ) + CallToolResult(listOf(TextContent("ok"))) + } + + server.addTool(name = "test-logging-level", description = "test") { request -> + LoggingLevel.entries.forEach { level -> + sendLoggingMessage( + LoggingMessageNotification( + LoggingMessageNotificationParams( + level = level, + data = JsonPrimitive(level.name), + ), + ), + ) + } + CallToolResult(listOf(TextContent("ok"))) + } + } + + @Test + fun `notification should send logging message to client`() = runBlocking { + val notificationReceived = CompletableDeferred() + client.setNotificationHandler(Method.Defined.NotificationsMessage) { + notificationReceived.complete(it) + CompletableDeferred(Unit) + } + + client.callTool(CallToolRequest(CallToolRequestParams("test-notification"))) + val received = notificationReceived.await() + kotlin.test.assertEquals(LoggingLevel.Info, received.params.level) + kotlin.test.assertEquals(JsonPrimitive("test-data-sample"), received.params.data) + } + + @Test + fun `sendLoggingMessage should send message at level`() = runBlocking { + val notificationReceived = CompletableDeferred() + client.setNotificationHandler(Method.Defined.NotificationsMessage) { + notificationReceived.complete(it) + CompletableDeferred(Unit) + } + + client.callTool(CallToolRequest(CallToolRequestParams("test-logging"))) + val received = notificationReceived.await() + kotlin.test.assertEquals(LoggingLevel.Info, received.params.level) + kotlin.test.assertEquals(JsonObject(mapOf("key" to JsonPrimitive("value"))), received.params.data) + } + + @Test + fun `sendLoggingMessage should filter messages below level`() = runBlocking { + val receivedMessages = mutableListOf() + client.setNotificationHandler(Method.Defined.NotificationsMessage) { + receivedMessages.add(it) + CompletableDeferred(Unit) + } + + client.setLoggingLevel(LoggingLevel.Warning) + + client.callTool(CallToolRequest(CallToolRequestParams("test-logging-level"))) + + val expectedLevels = LoggingLevel.entries.filter { it >= LoggingLevel.Warning } + // wait for expected notifications to arrive (transport may deliver asynchronously) + withTimeout(2000) { + while (receivedMessages.size < expectedLevels.size) { + delay(10) + } + } + kotlin.test.assertEquals(expectedLevels.size, receivedMessages.size) + kotlin.test.assertEquals(expectedLevels.toList(), receivedMessages.map { it.params.level }) + } +}