Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion src/__tests__/mcp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@ import { JsonBlock, type TextBlock, type ToolResultBlock } from '../types/messag
import type { AgentData } from '../types/agent.js'
import type { ToolContext } from '../tools/tool.js'

vi.mock('@modelcontextprotocol/sdk/types.js', () => ({
ElicitRequestSchema: { method: 'elicitation/create' },
}))

vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({
Client: vi.fn(function () {
return {
connect: vi.fn(),
close: vi.fn(),
listTools: vi.fn(),
callTool: vi.fn(),
setRequestHandler: vi.fn(),
}
}),
}))
Expand Down Expand Up @@ -70,7 +75,7 @@ describe('MCP Integration', () => {
})

it('initializes SDK client with correct configuration', () => {
expect(Client).toHaveBeenCalledWith({ name: 'TestApp', version: '0.0.1' })
expect(Client).toHaveBeenCalledWith({ name: 'TestApp', version: '0.0.1' }, undefined)
})

it('manages connection state lazily', async () => {
Expand Down Expand Up @@ -125,6 +130,95 @@ describe('MCP Integration', () => {
expect(sdkClientMock.close).toHaveBeenCalled()
expect(mockTransport.close).toHaveBeenCalled()
})

describe('elicitation callback', () => {
it('registers callback when provided', async () => {
const callback = vi.fn()
const clientWithCallback = new McpClient({
applicationName: 'TestApp',
transport: mockTransport,
elicitationCallback: callback,
})
const sdkClient = vi.mocked(Client).mock.results[1]!.value

await clientWithCallback.connect()

expect(sdkClient.setRequestHandler).toHaveBeenCalled()
})

it('does not register callback when not provided', async () => {
await client.connect()

expect(sdkClientMock.setRequestHandler).not.toHaveBeenCalled()
})

it('invokes callback and returns all action types correctly', async () => {
const callback = vi.fn()
const clientWithCallback = new McpClient({
applicationName: 'TestApp',
transport: mockTransport,
elicitationCallback: callback,
})
const sdkClient = vi.mocked(Client).mock.results[vi.mocked(Client).mock.results.length - 1]!.value

await clientWithCallback.connect()

const handler = sdkClient.setRequestHandler.mock.calls[0]![1]
const mockExtra = { sessionId: 'test-session' }

// Test accept action
callback.mockResolvedValueOnce({ action: 'accept', content: { response: 'yes' } })
const acceptResult = await handler(
{
params: {
message: 'Do you want to continue?',
requestedSchema: { type: 'object' },
},
},
mockExtra
)

expect(callback).toHaveBeenCalledWith(mockExtra, {
message: 'Do you want to continue?',
requestedSchema: { type: 'object' },
})
expect(acceptResult).toEqual({
action: 'accept',
content: { response: 'yes' },
})

// Test decline action
callback.mockResolvedValueOnce({ action: 'decline' })
const declineResult = await handler({ params: { message: 'Proceed?' } }, mockExtra)
expect(declineResult).toEqual({ action: 'decline', content: undefined })

// Test cancel action
callback.mockResolvedValueOnce({ action: 'cancel' })
const cancelResult = await handler({ params: { message: 'Cancel operation?' } }, mockExtra)
expect(cancelResult).toEqual({ action: 'cancel', content: undefined })
})

it('handles callback errors gracefully', async () => {
const callback = vi.fn().mockRejectedValue(new Error('User cancelled'))

const clientWithCallback = new McpClient({
applicationName: 'TestApp',
transport: mockTransport,
elicitationCallback: callback,
})
const sdkClient = vi.mocked(Client).mock.results[vi.mocked(Client).mock.results.length - 1]!.value

await clientWithCallback.connect()

const handler = sdkClient.setRequestHandler.mock.calls[0]![1]
const mockRequest = {
params: { message: 'Continue?' },
}
const mockExtra = { sessionId: 'test-session' }

await expect(handler(mockRequest, mockExtra)).rejects.toThrow('User cancelled')
})
})
})

describe('McpTool', () => {
Expand Down
3 changes: 3 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,6 @@ export type { Logger } from './logging/types.js'

// MCP Client types and implementations
export { type McpClientConfig, McpClient } from './mcp.js'

// Elicitation types
export type { ElicitationCallback, ElicitRequestParams, ElicitResult } from './types/elicitation.js'
51 changes: 46 additions & 5 deletions src/mcp.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
import { ElicitRequestSchema } from '@modelcontextprotocol/sdk/types.js'
import type { JSONSchema, JSONValue } from './types/json.js'
import type { ElicitationCallback, ElicitRequestParams } from './types/elicitation.js'
import { McpTool } from './tools/mcp-tool.js'

/** Temporary placeholder for RuntimeConfig */
Expand All @@ -10,7 +12,10 @@ export interface RuntimeConfig {
}

/** Arguments for configuring an MCP Client. */
export type McpClientConfig = RuntimeConfig & { transport: Transport }
export type McpClientConfig = RuntimeConfig & {
transport: Transport
elicitationCallback?: ElicitationCallback
}

/** MCP Client for interacting with Model Context Protocol servers. */
export class McpClient {
Expand All @@ -19,16 +24,33 @@ export class McpClient {
private _transport: Transport
private _connected: boolean
private _client: Client
private _elicitationCallback?: ElicitationCallback

constructor(args: McpClientConfig) {
this._clientName = args.applicationName || 'strands-agents-ts-sdk'
this._clientVersion = args.applicationVersion || '0.0.1'
this._transport = args.transport
this._connected = false
this._client = new Client({
name: this._clientName,
version: this._clientVersion,
})

if (args.elicitationCallback !== undefined) {
this._elicitationCallback = args.elicitationCallback
}

const clientOptions = this._elicitationCallback
? {
capabilities: {
elicitation: { form: {} },
},
}
: undefined

this._client = new Client(
{
name: this._clientName,
version: this._clientVersion,
},
clientOptions
)
}

get client(): Client {
Expand All @@ -54,6 +76,25 @@ export class McpClient {

await this._client.connect(this._transport)

if (this._elicitationCallback) {
const callback = this._elicitationCallback
this._client.setRequestHandler(ElicitRequestSchema, async (request, extra) => {
const params: ElicitRequestParams = {
message: request.params.message,
...(request.params.requestedSchema !== undefined && {
requestedSchema: request.params.requestedSchema as JSONSchema,
}),
}

const result = await callback(extra, params)

return {
action: result.action,
content: result.content,
}
})
}

this._connected = true
}

Expand Down
66 changes: 66 additions & 0 deletions src/types/elicitation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import type { JSONSchema, JSONValue } from './json.js'

/**
* Context information for elicitation requests provided by the MCP SDK.
*
* This object contains session metadata and capabilities from the MCP transport layer:
* - `sessionId`: Session identifier (for HTTP transports)
* - `authInfo`: Authentication information (for OAuth flows)
* - `sendNotification`: Function to send notifications to the server
* - Other transport-specific metadata
*/
export type ElicitationContext = unknown

/**
* Parameters passed to the elicitation callback when the server requests user input.
*/
export interface ElicitRequestParams {
/**
* Message to display to the user explaining what input is needed.
*/
message: string

/**
* Optional JSON Schema defining the structure of the expected input.
*/
requestedSchema?: JSONSchema
}

/**
* Result returned by the elicitation callback to indicate user's response.
*/
export interface ElicitResult {
/**
* Action taken by the user.
* - 'accept': User provided input and wants to continue
* - 'decline': User declined to provide input
* - 'cancel': User wants to cancel the entire operation
*/
action: 'accept' | 'decline' | 'cancel'

/**
* Optional content provided by the user when action is 'accept'.
*/
content?: Record<string, JSONValue>
}

/**
* Callback function invoked when an MCP server requests additional input during tool execution.
*
* @param context - Context information about the elicitation request
* @param params - Parameters including the message and optional schema
* @returns A promise that resolves with the user's response
*
* @example
* ```typescript
* const elicitationCallback: ElicitationCallback = async (_context, params) => {
* console.log(`Server is asking: ${params.message}`)
* const userInput = await getUserInput()
* return {
* action: 'accept',
* content: { response: userInput }
* }
* }
* ```
*/
export type ElicitationCallback = (context: ElicitationContext, params: ElicitRequestParams) => Promise<ElicitResult>
58 changes: 58 additions & 0 deletions test/integ/__fixtures__/test-mcp-server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ function createTestServer(): McpServer {
{
capabilities: {
tools: {},
elicitation: { form: {} },
},
}
)
Expand Down Expand Up @@ -124,6 +125,63 @@ function createTestServer(): McpServer {
}
)

// Register elicitation tool
server.registerTool(
'confirm_action',
{
title: 'Confirm Action Tool',
description: 'Requests user confirmation before performing an action',
inputSchema: {
action: z.string(),
},
outputSchema: {
confirmed: z.boolean(),
action: z.string(),
},
},
async ({ action }) => {
// Request user confirmation via elicitation
const result = await server.server.elicitInput({
message: `Do you want to proceed with: ${action}?`,
requestedSchema: {
type: 'object',
properties: {
confirmed: {
type: 'boolean',
title: 'Confirm action',
description: 'Confirm whether to proceed',
},
},
required: ['confirmed'],
},
})

if (result.action === 'accept' && result.content?.confirmed) {
const output = { confirmed: true, action }
return {
content: [
{
type: 'text',
text: `Action confirmed: ${action}`,
},
],
structuredContent: output,
}
}

const output = { confirmed: false, action }
return {
content: [
{
type: 'text',
text: `Action declined: ${action}`,
},
],
structuredContent: output,
}
}
)

return server
}

Expand Down
Loading