From d8e3ff06b683c2c9f4538e610b2a64619b3164b3 Mon Sep 17 00:00:00 2001 From: Konippi Date: Sun, 7 Dec 2025 15:19:04 +0900 Subject: [PATCH] feat: add MCP elicitation support --- src/__tests__/mcp.test.ts | 96 ++++++++++++++++- src/index.ts | 3 + src/mcp.ts | 51 ++++++++- src/types/elicitation.ts | 66 ++++++++++++ test/integ/__fixtures__/test-mcp-server.ts | 58 ++++++++++ test/integ/mcp.test.ts | 119 +++++++++++++++++++++ 6 files changed, 387 insertions(+), 6 deletions(-) create mode 100644 src/types/elicitation.ts diff --git a/src/__tests__/mcp.test.ts b/src/__tests__/mcp.test.ts index eae131f7..7850ca08 100644 --- a/src/__tests__/mcp.test.ts +++ b/src/__tests__/mcp.test.ts @@ -7,6 +7,10 @@ import { JsonBlock, type TextBlock, type ToolResultBlock } from '../types/messag import type { AgentData } from '../types/agent.js' import type { ToolContext } from '../tools/tool.js' +vi.mock('@modelcontextprotocol/sdk/types.js', () => ({ + ElicitRequestSchema: { method: 'elicitation/create' }, +})) + vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({ Client: vi.fn(function () { return { @@ -14,6 +18,7 @@ vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({ close: vi.fn(), listTools: vi.fn(), callTool: vi.fn(), + setRequestHandler: vi.fn(), } }), })) @@ -70,7 +75,7 @@ describe('MCP Integration', () => { }) it('initializes SDK client with correct configuration', () => { - expect(Client).toHaveBeenCalledWith({ name: 'TestApp', version: '0.0.1' }) + expect(Client).toHaveBeenCalledWith({ name: 'TestApp', version: '0.0.1' }, undefined) }) it('manages connection state lazily', async () => { @@ -125,6 +130,95 @@ describe('MCP Integration', () => { expect(sdkClientMock.close).toHaveBeenCalled() expect(mockTransport.close).toHaveBeenCalled() }) + + describe('elicitation callback', () => { + it('registers callback when provided', async () => { + const callback = vi.fn() + const clientWithCallback = new McpClient({ + applicationName: 'TestApp', + transport: mockTransport, + elicitationCallback: callback, + }) + const sdkClient = vi.mocked(Client).mock.results[1]!.value + + await clientWithCallback.connect() + + expect(sdkClient.setRequestHandler).toHaveBeenCalled() + }) + + it('does not register callback when not provided', async () => { + await client.connect() + + expect(sdkClientMock.setRequestHandler).not.toHaveBeenCalled() + }) + + it('invokes callback and returns all action types correctly', async () => { + const callback = vi.fn() + const clientWithCallback = new McpClient({ + applicationName: 'TestApp', + transport: mockTransport, + elicitationCallback: callback, + }) + const sdkClient = vi.mocked(Client).mock.results[vi.mocked(Client).mock.results.length - 1]!.value + + await clientWithCallback.connect() + + const handler = sdkClient.setRequestHandler.mock.calls[0]![1] + const mockExtra = { sessionId: 'test-session' } + + // Test accept action + callback.mockResolvedValueOnce({ action: 'accept', content: { response: 'yes' } }) + const acceptResult = await handler( + { + params: { + message: 'Do you want to continue?', + requestedSchema: { type: 'object' }, + }, + }, + mockExtra + ) + + expect(callback).toHaveBeenCalledWith(mockExtra, { + message: 'Do you want to continue?', + requestedSchema: { type: 'object' }, + }) + expect(acceptResult).toEqual({ + action: 'accept', + content: { response: 'yes' }, + }) + + // Test decline action + callback.mockResolvedValueOnce({ action: 'decline' }) + const declineResult = await handler({ params: { message: 'Proceed?' } }, mockExtra) + expect(declineResult).toEqual({ action: 'decline', content: undefined }) + + // Test cancel action + callback.mockResolvedValueOnce({ action: 'cancel' }) + const cancelResult = await handler({ params: { message: 'Cancel operation?' } }, mockExtra) + expect(cancelResult).toEqual({ action: 'cancel', content: undefined }) + }) + + it('handles callback errors gracefully', async () => { + const callback = vi.fn().mockRejectedValue(new Error('User cancelled')) + + const clientWithCallback = new McpClient({ + applicationName: 'TestApp', + transport: mockTransport, + elicitationCallback: callback, + }) + const sdkClient = vi.mocked(Client).mock.results[vi.mocked(Client).mock.results.length - 1]!.value + + await clientWithCallback.connect() + + const handler = sdkClient.setRequestHandler.mock.calls[0]![1] + const mockRequest = { + params: { message: 'Continue?' }, + } + const mockExtra = { sessionId: 'test-session' } + + await expect(handler(mockRequest, mockExtra)).rejects.toThrow('User cancelled') + }) + }) }) describe('McpTool', () => { diff --git a/src/index.ts b/src/index.ts index 740b1fea..df288a1e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -166,3 +166,6 @@ export type { Logger } from './logging/types.js' // MCP Client types and implementations export { type McpClientConfig, McpClient } from './mcp.js' + +// Elicitation types +export type { ElicitationCallback, ElicitRequestParams, ElicitResult } from './types/elicitation.js' diff --git a/src/mcp.ts b/src/mcp.ts index b7128897..40d6fc83 100644 --- a/src/mcp.ts +++ b/src/mcp.ts @@ -1,6 +1,8 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js' import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' +import { ElicitRequestSchema } from '@modelcontextprotocol/sdk/types.js' import type { JSONSchema, JSONValue } from './types/json.js' +import type { ElicitationCallback, ElicitRequestParams } from './types/elicitation.js' import { McpTool } from './tools/mcp-tool.js' /** Temporary placeholder for RuntimeConfig */ @@ -10,7 +12,10 @@ export interface RuntimeConfig { } /** Arguments for configuring an MCP Client. */ -export type McpClientConfig = RuntimeConfig & { transport: Transport } +export type McpClientConfig = RuntimeConfig & { + transport: Transport + elicitationCallback?: ElicitationCallback +} /** MCP Client for interacting with Model Context Protocol servers. */ export class McpClient { @@ -19,16 +24,33 @@ export class McpClient { private _transport: Transport private _connected: boolean private _client: Client + private _elicitationCallback?: ElicitationCallback constructor(args: McpClientConfig) { this._clientName = args.applicationName || 'strands-agents-ts-sdk' this._clientVersion = args.applicationVersion || '0.0.1' this._transport = args.transport this._connected = false - this._client = new Client({ - name: this._clientName, - version: this._clientVersion, - }) + + if (args.elicitationCallback !== undefined) { + this._elicitationCallback = args.elicitationCallback + } + + const clientOptions = this._elicitationCallback + ? { + capabilities: { + elicitation: { form: {} }, + }, + } + : undefined + + this._client = new Client( + { + name: this._clientName, + version: this._clientVersion, + }, + clientOptions + ) } get client(): Client { @@ -54,6 +76,25 @@ export class McpClient { await this._client.connect(this._transport) + if (this._elicitationCallback) { + const callback = this._elicitationCallback + this._client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + const params: ElicitRequestParams = { + message: request.params.message, + ...(request.params.requestedSchema !== undefined && { + requestedSchema: request.params.requestedSchema as JSONSchema, + }), + } + + const result = await callback(extra, params) + + return { + action: result.action, + content: result.content, + } + }) + } + this._connected = true } diff --git a/src/types/elicitation.ts b/src/types/elicitation.ts new file mode 100644 index 00000000..adecf0fe --- /dev/null +++ b/src/types/elicitation.ts @@ -0,0 +1,66 @@ +import type { JSONSchema, JSONValue } from './json.js' + +/** + * Context information for elicitation requests provided by the MCP SDK. + * + * This object contains session metadata and capabilities from the MCP transport layer: + * - `sessionId`: Session identifier (for HTTP transports) + * - `authInfo`: Authentication information (for OAuth flows) + * - `sendNotification`: Function to send notifications to the server + * - Other transport-specific metadata + */ +export type ElicitationContext = unknown + +/** + * Parameters passed to the elicitation callback when the server requests user input. + */ +export interface ElicitRequestParams { + /** + * Message to display to the user explaining what input is needed. + */ + message: string + + /** + * Optional JSON Schema defining the structure of the expected input. + */ + requestedSchema?: JSONSchema +} + +/** + * Result returned by the elicitation callback to indicate user's response. + */ +export interface ElicitResult { + /** + * Action taken by the user. + * - 'accept': User provided input and wants to continue + * - 'decline': User declined to provide input + * - 'cancel': User wants to cancel the entire operation + */ + action: 'accept' | 'decline' | 'cancel' + + /** + * Optional content provided by the user when action is 'accept'. + */ + content?: Record +} + +/** + * Callback function invoked when an MCP server requests additional input during tool execution. + * + * @param context - Context information about the elicitation request + * @param params - Parameters including the message and optional schema + * @returns A promise that resolves with the user's response + * + * @example + * ```typescript + * const elicitationCallback: ElicitationCallback = async (_context, params) => { + * console.log(`Server is asking: ${params.message}`) + * const userInput = await getUserInput() + * return { + * action: 'accept', + * content: { response: userInput } + * } + * } + * ``` + */ +export type ElicitationCallback = (context: ElicitationContext, params: ElicitRequestParams) => Promise diff --git a/test/integ/__fixtures__/test-mcp-server.ts b/test/integ/__fixtures__/test-mcp-server.ts index 1715908c..47b42e50 100644 --- a/test/integ/__fixtures__/test-mcp-server.ts +++ b/test/integ/__fixtures__/test-mcp-server.ts @@ -25,6 +25,7 @@ function createTestServer(): McpServer { { capabilities: { tools: {}, + elicitation: { form: {} }, }, } ) @@ -124,6 +125,63 @@ function createTestServer(): McpServer { } ) + // Register elicitation tool + server.registerTool( + 'confirm_action', + { + title: 'Confirm Action Tool', + description: 'Requests user confirmation before performing an action', + inputSchema: { + action: z.string(), + }, + outputSchema: { + confirmed: z.boolean(), + action: z.string(), + }, + }, + async ({ action }) => { + // Request user confirmation via elicitation + const result = await server.server.elicitInput({ + message: `Do you want to proceed with: ${action}?`, + requestedSchema: { + type: 'object', + properties: { + confirmed: { + type: 'boolean', + title: 'Confirm action', + description: 'Confirm whether to proceed', + }, + }, + required: ['confirmed'], + }, + }) + + if (result.action === 'accept' && result.content?.confirmed) { + const output = { confirmed: true, action } + return { + content: [ + { + type: 'text', + text: `Action confirmed: ${action}`, + }, + ], + structuredContent: output, + } + } + + const output = { confirmed: false, action } + return { + content: [ + { + type: 'text', + text: `Action declined: ${action}`, + }, + ], + structuredContent: output, + } + } + ) + return server } diff --git a/test/integ/mcp.test.ts b/test/integ/mcp.test.ts index ccbef567..57c8b9ca 100644 --- a/test/integ/mcp.test.ts +++ b/test/integ/mcp.test.ts @@ -116,4 +116,123 @@ describe('MCP Integration Tests', () => { expect(hasErrorResult).toBe(true) }, 30000) }) + + describe('elicitation callback', () => { + it('handles elicitation requests with accept action', async () => { + let elicitationCalled = false + let elicitationMessage = '' + + const client = new McpClient({ + applicationName: 'test-mcp-elicitation', + transport: new StdioClientTransport({ + command: 'npx', + args: ['tsx', serverPath], + }), + elicitationCallback: async (_context, params) => { + elicitationCalled = true + elicitationMessage = params.message + return { + action: 'accept', + content: { confirmed: true }, + } + }, + }) + + const model = new BedrockModel({ maxTokens: 200 }) + + const agent = new Agent({ + systemPrompt: 'You are a helpful assistant. Use the confirm_action tool when asked.', + tools: [client], + model, + }) + + const result = await agent.invoke('Use the confirm_action tool to delete a file.') + + expect(result).toBeDefined() + expect(elicitationCalled).toBe(true) + expect(elicitationMessage).toContain('delete a file') + + // Verify the tool was used and completed + const hasConfirmUse = agent.messages.some((msg) => + msg.content.some((block) => block.type === 'toolUseBlock' && block.name === 'confirm_action') + ) + expect(hasConfirmUse).toBe(true) + + await client.disconnect() + }, 30000) + + it('handles elicitation requests with decline action', async () => { + let elicitationCalled = false + + const client = new McpClient({ + applicationName: 'test-mcp-elicitation-decline', + transport: new StdioClientTransport({ + command: 'npx', + args: ['tsx', serverPath], + }), + elicitationCallback: async (_context, _params) => { + elicitationCalled = true + return { + action: 'decline', + } + }, + }) + + const model = new BedrockModel({ maxTokens: 200 }) + + const agent = new Agent({ + systemPrompt: 'You are a helpful assistant. Use the confirm_action tool when asked.', + tools: [client], + model, + }) + + const result = await agent.invoke('Use the confirm_action tool to update settings.') + + expect(result).toBeDefined() + expect(elicitationCalled).toBe(true) + + // Verify the tool was used + const hasConfirmUse = agent.messages.some((msg) => + msg.content.some((block) => block.type === 'toolUseBlock' && block.name === 'confirm_action') + ) + expect(hasConfirmUse).toBe(true) + + await client.disconnect() + }, 30000) + + it('handles elicitation with requested schema', async () => { + let receivedSchema: unknown + + const client = new McpClient({ + applicationName: 'test-mcp-elicitation-schema', + transport: new StdioClientTransport({ + command: 'npx', + args: ['tsx', serverPath], + }), + elicitationCallback: async (_context, params) => { + receivedSchema = params.requestedSchema + return { + action: 'accept', + content: { confirmed: true }, + } + }, + }) + + const model = new BedrockModel({ maxTokens: 200 }) + + const agent = new Agent({ + systemPrompt: 'You are a helpful assistant. Use the confirm_action tool when asked.', + tools: [client], + model, + }) + + await agent.invoke('Use the confirm_action tool to restart the system.') + + expect(receivedSchema).toBeDefined() + expect(receivedSchema).toHaveProperty('type', 'object') + expect(receivedSchema).toHaveProperty('properties') + + await client.disconnect() + }, 30000) + }) })