Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@ import io.kotest.matchers.string.shouldContain
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequest
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult
import io.modelcontextprotocol.kotlin.sdk.types.ListPromptsRequest
import io.modelcontextprotocol.kotlin.sdk.types.ListPromptsResult
import io.modelcontextprotocol.kotlin.sdk.types.McpException
import io.modelcontextprotocol.kotlin.sdk.types.Method
import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.PromptArgument
import io.modelcontextprotocol.kotlin.sdk.types.PromptMessage
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
import io.modelcontextprotocol.kotlin.sdk.types.Role
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.TextContent
Expand All @@ -19,6 +24,7 @@ import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertNotNull
import kotlin.test.assertTrue

Expand Down Expand Up @@ -157,8 +163,8 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() {
// validate required arguments
val requiredArgs = listOf("arg1", "arg2", "arg3")
for (argName in requiredArgs) {
if (request.params.arguments?.get(argName) == null) {
throw IllegalArgumentException("Missing required argument: $argName")
requireNotNull(request.params.arguments?.get(argName)) {
"Missing required argument: $argName"
}
}

Expand Down Expand Up @@ -697,4 +703,62 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() {
exception.message shouldBe expectedMessage
}
}

@Test
fun testListPromptsPagination() = runBlocking(Dispatchers.IO) {
val pagePrefix = "paginated-prompt-"
(0 until 5).forEach { i ->
val name = "$pagePrefix$i"
server.addPrompt(name = name, description = "desc", arguments = listOf()) { _ ->
GetPromptResult(
description = "desc",
messages = listOf(PromptMessage(role = Role.Assistant, content = TextContent(text = name))),
)
}
}

server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListPromptsRequest>(Method.Defined.PromptsList) { request, _ ->
val all = server.prompts.values.map { it.prompt }
val cursor = request.cursor?.toIntOrNull() ?: 0
val pageSize = 2
val page = all.drop(cursor).take(pageSize)
val next = if (cursor + page.size < all.size) (cursor + page.size).toString() else null
ListPromptsResult(prompts = page, nextCursor = next)
}
}

val allPrompts = mutableListOf<io.modelcontextprotocol.kotlin.sdk.types.Prompt>()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fqn

var currentCursor: String? = null
do {
val request = if (currentCursor == null) {
ListPromptsRequest()
} else {
ListPromptsRequest(PaginatedRequestParams(cursor = currentCursor))
}
val response = client.listPrompts(request)
allPrompts.addAll(response.prompts)
currentCursor = response.nextCursor
} while (currentCursor != null)

assertTrue(allPrompts.any { it.name.startsWith(pagePrefix) })
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not all?

}

@Test
fun testListPromptsInvalidCursor() = runBlocking(Dispatchers.IO) {
server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListPromptsRequest>(Method.Defined.PromptsList) { request, _ ->
val cursor = requireNotNull(request.cursor?.toIntOrNull()) { "Invalid cursor" }
val all = server.prompts.values.map { it.prompt }
val page = all.drop(cursor).take(2)
ListPromptsResult(prompts = page, nextCursor = null)
}
}

val exception = assertFailsWith<McpException> {
client.listPrompts(ListPromptsRequest(PaginatedRequestParams(cursor = "not-a-number")))
}

assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code)
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package io.modelcontextprotocol.kotlin.sdk.integration.kotlin

import io.modelcontextprotocol.kotlin.sdk.types.BlobResourceContents
import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesRequest
import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesResult
import io.modelcontextprotocol.kotlin.sdk.types.McpException
import io.modelcontextprotocol.kotlin.sdk.types.Method
import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequestParams
Expand Down Expand Up @@ -309,4 +313,69 @@ abstract class AbstractResourceIntegrationTest : KotlinTestBase() {
assertTrue(result.contents.isNotEmpty(), "Result contents should not be empty")
}
}

@Test
fun testListResourcesPagination() = runBlocking(Dispatchers.IO) {
val prefix = "paginated-resource-"
(0 until 6).forEach { i ->
val uri = "test://$prefix$i.txt"
server.addResource(uri = uri, name = "Name-$i", description = "desc", mimeType = "text/plain") { request ->
ReadResourceResult(
contents = listOf(
TextResourceContents(
text = uri,
uri = request.params.uri,
mimeType = "text/plain",
),
),
)
}
}

server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListResourcesRequest>(Method.Defined.ResourcesList) { request, _ ->
val all = server.resources.values.map { it.resource }
val cursor = request.cursor?.toIntOrNull() ?: 0
val pageSize = 3
val page = all.drop(cursor).take(pageSize)
val next = if (cursor + page.size < all.size) (cursor + page.size).toString() else null
ListResourcesResult(resources = page, nextCursor = next)
}
}

val combinedUris = mutableListOf<String>()
var currentCursor: String? = null

do {
val request = if (currentCursor == null) {
ListResourcesRequest()
} else {
ListResourcesRequest(PaginatedRequestParams(cursor = currentCursor))
}

val response = client.listResources(request)
combinedUris += response.resources.map { it.uri }
currentCursor = response.nextCursor
} while (currentCursor != null)

assertTrue(combinedUris.any { it.contains(prefix) })
}

@Test
fun testListResourcesInvalidCursor() = runBlocking(Dispatchers.IO) {
server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListResourcesRequest>(Method.Defined.ResourcesList) { request, _ ->
val cursor = requireNotNull(request.cursor?.toIntOrNull()) { "Invalid cursor" }
val all = server.resources.values.map { it.resource }
val page = all.drop(cursor).take(2)
ListResourcesResult(resources = page, nextCursor = null)
}
}

val exception = kotlin.test.assertFailsWith<McpException> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fqn

client.listResources(ListResourcesRequest(PaginatedRequestParams(cursor = "bad")))
}

assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult
import io.modelcontextprotocol.kotlin.sdk.types.ContentBlock
import io.modelcontextprotocol.kotlin.sdk.types.ImageContent
import io.modelcontextprotocol.kotlin.sdk.types.ListToolsRequest
import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult
import io.modelcontextprotocol.kotlin.sdk.types.McpException
import io.modelcontextprotocol.kotlin.sdk.types.Method
import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.TextContent
import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema
Expand Down Expand Up @@ -791,4 +797,67 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() {
"Error message should indicate the tool was not found",
)
}

@Test
fun testListToolsPagination() = runBlocking(Dispatchers.IO) {
val prefix = "paginated-tool-"
(0 until 5).forEach { i ->
val name = "$prefix$i"
server.addTool(name = name, description = "desc") { request ->
CallToolResult(
content = listOf(TextContent(text = name)),
structuredContent = buildJsonObject { put("name", name) },
)
}
}

server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListToolsRequest>(Method.Defined.ToolsList) { request, _ ->
val all = server.tools.values.map { it.tool }.sortedBy { it.name }
val cursor = request.cursor?.toIntOrNull() ?: 0
val pageSize = 2
val page = all.drop(cursor).take(pageSize)
val next = if (cursor + page.size < all.size) (cursor + page.size).toString() else null
ListToolsResult(tools = page, nextCursor = next)
}
}

val combinedNames = mutableListOf<String>()
var currentCursor: String? = null

do {
val request = if (currentCursor == null) {
ListToolsRequest()
} else {
ListToolsRequest(PaginatedRequestParams(cursor = currentCursor))
}

val response = client.listTools(request)
combinedNames += response.tools.map { it.name }
currentCursor = response.nextCursor
} while (currentCursor != null)

val paginatedNames = combinedNames.filter { it.startsWith(prefix) }
assertEquals(5, paginatedNames.size, "All 5 paginated tools should appear")
assertEquals(combinedNames.size, combinedNames.distinct().size, "No duplicate tools across pages")
assertEquals(server.tools.size, combinedNames.size, "Total tools should match server registry")
}

@Test
fun testListToolsInvalidCursor() = runBlocking(Dispatchers.IO) {
server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListToolsRequest>(Method.Defined.ToolsList) { request, _ ->
val cursor = requireNotNull(request.cursor?.toIntOrNull()) { "Invalid cursor" }
val all = server.tools.values.map { it.tool }
val page = all.drop(cursor).take(2)
ListToolsResult(tools = page)
}
}

val exception = kotlin.test.assertFailsWith<McpException> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fqn

client.listTools(ListToolsRequest(PaginatedRequestParams(cursor = "bad")))
}

assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn’t this duplicate ClientConnectionLoggingTest?

Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.streamablehttp

import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.KotlinTestBase
import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequest
import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult
import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel
import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotification
import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotificationParams
import io.modelcontextprotocol.kotlin.sdk.types.Method
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.TextContent
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.delay
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.parallel.Execution
import org.junit.jupiter.api.parallel.ExecutionMode

@Execution(ExecutionMode.SAME_THREAD)
class LoggingIntegrationTestStreamableHttp : KotlinTestBase() {

override val transportKind = TransportKind.STREAMABLE_HTTP

override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities(
tools = ServerCapabilities.Tools(listChanged = true),
logging = JsonObject(emptyMap()),
)

override fun configureServer() {
server.addTool(name = "test-notification", description = "test") { request ->
notification(
LoggingMessageNotification(
LoggingMessageNotificationParams(
level = LoggingLevel.Info,
data = JsonPrimitive("test-data-sample"),
),
),
)
CallToolResult(listOf(TextContent("ok")))
}

server.addTool(name = "test-logging", description = "test") { request ->
sendLoggingMessage(
LoggingMessageNotification(
LoggingMessageNotificationParams(
level = LoggingLevel.Info,
data = JsonObject(mapOf("key" to JsonPrimitive("value"))),
),
),
)
CallToolResult(listOf(TextContent("ok")))
}

server.addTool(name = "test-logging-level", description = "test") { request ->
LoggingLevel.entries.forEach { level ->
sendLoggingMessage(
LoggingMessageNotification(
LoggingMessageNotificationParams(
level = level,
data = JsonPrimitive(level.name),
),
),
)
}
CallToolResult(listOf(TextContent("ok")))
}
}

@Test
fun `notification should send logging message to client`() = runBlocking {
val notificationReceived = CompletableDeferred<LoggingMessageNotification>()
client.setNotificationHandler<LoggingMessageNotification>(Method.Defined.NotificationsMessage) {
notificationReceived.complete(it)
CompletableDeferred(Unit)
}

client.callTool(CallToolRequest(CallToolRequestParams("test-notification")))
val received = notificationReceived.await()
kotlin.test.assertEquals(LoggingLevel.Info, received.params.level)
kotlin.test.assertEquals(JsonPrimitive("test-data-sample"), received.params.data)
Comment on lines +83 to +84
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fqn import

}

@Test
fun `sendLoggingMessage should send message at level`() = runBlocking {
val notificationReceived = CompletableDeferred<LoggingMessageNotification>()
client.setNotificationHandler<LoggingMessageNotification>(Method.Defined.NotificationsMessage) {
notificationReceived.complete(it)
CompletableDeferred(Unit)
}

client.callTool(CallToolRequest(CallToolRequestParams("test-logging")))
val received = notificationReceived.await()
kotlin.test.assertEquals(LoggingLevel.Info, received.params.level)
kotlin.test.assertEquals(JsonObject(mapOf("key" to JsonPrimitive("value"))), received.params.data)
Comment on lines +97 to +98
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fqn

}

@Test
fun `sendLoggingMessage should filter messages below level`() = runBlocking {
val receivedMessages = mutableListOf<LoggingMessageNotification>()
client.setNotificationHandler<LoggingMessageNotification>(Method.Defined.NotificationsMessage) {
receivedMessages.add(it)
CompletableDeferred(Unit)
}

client.setLoggingLevel(LoggingLevel.Warning)

client.callTool(CallToolRequest(CallToolRequestParams("test-logging-level")))

val expectedLevels = LoggingLevel.entries.filter { it >= LoggingLevel.Warning }
// wait for expected notifications to arrive (transport may deliver asynchronously)
withTimeout(2000) {
while (receivedMessages.size < expectedLevels.size) {
delay(10)
}
}
kotlin.test.assertEquals(expectedLevels.size, receivedMessages.size)
kotlin.test.assertEquals(expectedLevels.toList(), receivedMessages.map { it.params.level })
Comment on lines +120 to +121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fqn

}
}
Loading