diff --git a/.changeset/release-sse-reader-locks.md b/.changeset/release-sse-reader-locks.md new file mode 100644 index 0000000000..f313173e67 --- /dev/null +++ b/.changeset/release-sse-reader-locks.md @@ -0,0 +1,5 @@ +--- +'@modelcontextprotocol/sdk': patch +--- + +Release SSE stream reader locks after graceful closes or read errors in the streamable HTTP client. diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 736587973d..6d1649566f 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -328,40 +328,44 @@ export class StreamableHTTPClientTransport implements Transport { ) .getReader(); - while (true) { - const { value: event, done } = await reader.read(); - if (done) { - break; - } + try { + while (true) { + const { value: event, done } = await reader.read(); + if (done) { + break; + } - // Update last event ID if provided - if (event.id) { - lastEventId = event.id; - // Mark that we've received a priming event - stream is now resumable - hasPrimingEvent = true; - onresumptiontoken?.(event.id); - } + // Update last event ID if provided + if (event.id) { + lastEventId = event.id; + // Mark that we've received a priming event - stream is now resumable + hasPrimingEvent = true; + onresumptiontoken?.(event.id); + } - // Skip events with no data (priming events, keep-alives) - if (!event.data) { - continue; - } + // Skip events with no data (priming events, keep-alives) + if (!event.data) { + continue; + } - if (!event.event || event.event === 'message') { - try { - const message = JSONRPCMessageSchema.parse(JSON.parse(event.data)); - if (isJSONRPCResultResponse(message)) { - // Mark that we received a response - no need to reconnect for this request - receivedResponse = true; - if (replayMessageId !== undefined) { - message.id = replayMessageId; + if (!event.event || event.event === 'message') { + try { + const message = JSONRPCMessageSchema.parse(JSON.parse(event.data)); + if (isJSONRPCResultResponse(message)) { + // Mark that we received a response - no need to reconnect for this request + receivedResponse = true; + if (replayMessageId !== undefined) { + message.id = replayMessageId; + } } + this.onmessage?.(message); + } catch (error) { + this.onerror?.(error as Error); } - this.onmessage?.(message); - } catch (error) { - this.onerror?.(error as Error); } } + } finally { + reader.releaseLock(); } // Handle graceful server-side disconnect diff --git a/test/client/streamableHttp.test.ts b/test/client/streamableHttp.test.ts index 52c8f10748..3975d7c42d 100644 --- a/test/client/streamableHttp.test.ts +++ b/test/client/streamableHttp.test.ts @@ -5,9 +5,42 @@ import { InvalidClientError, InvalidGrantError, UnauthorizedClientError } from ' import { type Mock, type Mocked } from 'vitest'; describe('StreamableHTTPClientTransport', () => { + type ParsedSseEvent = { + id?: string; + event?: string; + data?: string; + }; + let transport: StreamableHTTPClientTransport; let mockAuthProvider: Mocked; + const waitForCondition = async (condition: () => boolean) => { + for (let i = 0; i < 20; i++) { + if (condition()) { + return; + } + await new Promise(resolve => setTimeout(resolve, 0)); + } + expect(condition()).toBe(true); + }; + + const handleSseStream = (stream: ReadableStream, options: StartSSEOptions = {}, isReconnectable = false) => { + ( + transport as unknown as { + _handleSseStream: (stream: ReadableStream, options: StartSSEOptions, isReconnectable: boolean) => void; + } + )._handleSseStream(stream, options, isReconnectable); + }; + + const makeSseStreamWithReader = (reader: Pick, 'read' | 'releaseLock'>) => + ({ + pipeThrough: vi.fn(() => ({ + pipeThrough: vi.fn(() => ({ + getReader: vi.fn(() => reader) + })) + })) + }) as unknown as ReadableStream; + beforeEach(() => { mockAuthProvider = { get redirectUrl() { @@ -311,6 +344,48 @@ describe('StreamableHTTPClientTransport', () => { ); }); + it('releases the SSE reader after the stream ends', async () => { + const reader = { + read: vi + .fn<() => Promise>>() + .mockResolvedValueOnce({ + done: false, + value: { + event: 'message', + data: '{"jsonrpc":"2.0","method":"serverNotification","params":{}}' + } + }) + .mockResolvedValueOnce({ done: true, value: undefined }), + releaseLock: vi.fn() + }; + const messageSpy = vi.fn(); + transport.onmessage = messageSpy; + + handleSseStream(makeSseStreamWithReader(reader)); + + await waitForCondition(() => reader.releaseLock.mock.calls.length === 1); + expect(messageSpy).toHaveBeenCalledWith({ + jsonrpc: '2.0', + method: 'serverNotification', + params: {} + }); + }); + + it('releases the SSE reader when stream processing fails', async () => { + const reader = { + read: vi.fn<() => Promise>>().mockRejectedValueOnce(new Error('network disconnect')), + releaseLock: vi.fn() + }; + const errorSpy = vi.fn(); + transport.onerror = errorSpy; + + handleSseStream(makeSseStreamWithReader(reader)); + + await waitForCondition(() => errorSpy.mock.calls.length === 1); + expect(reader.releaseLock).toHaveBeenCalledTimes(1); + expect(errorSpy).toHaveBeenCalledWith(new Error('SSE stream disconnected: Error: network disconnect')); + }); + it('should handle multiple concurrent SSE streams', async () => { // Mock two POST requests that return SSE streams const makeStream = (id: string) => {