Skip to content

Commit 0124539

Browse files
committed
test: add tests for uncovered branches in streamable HTTP tests
1 parent 846e7f7 commit 0124539

File tree

1 file changed

+186
-7
lines changed

1 file changed

+186
-7
lines changed

tests/shared/test_streamable_http.py

Lines changed: 186 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,13 +1020,6 @@ def test_get_validation(basic_server: None, basic_server_url: str):
10201020

10211021

10221022
# Client-specific fixtures
1023-
@pytest.fixture
1024-
async def http_client(basic_server: None, basic_server_url: str):
1025-
"""Create test client matching the SSE test pattern."""
1026-
async with httpx.AsyncClient(base_url=basic_server_url) as client:
1027-
yield client
1028-
1029-
10301023
@pytest.fixture
10311024
async def initialized_client_session(basic_server: None, basic_server_url: str):
10321025
"""Create initialized StreamableHTTP client session."""
@@ -2293,3 +2286,189 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers(
22932286

22942287
assert "content-type" in headers_data
22952288
assert headers_data["content-type"] == "application/json"
2289+
2290+
2291+
@pytest.mark.anyio
2292+
async def test_replay_events_after_nonexistent_event_id():
2293+
"""Test replay_events_after returns None for non-existent event ID."""
2294+
store = SimpleEventStore()
2295+
2296+
# Store some events first
2297+
stream_id = "stream-1"
2298+
await store.store_event(stream_id, types.JSONRPCResponse(jsonrpc="2.0", id="1", result={"key": "value"}))
2299+
2300+
# Try to replay after a non-existent event ID
2301+
callback = MagicMock()
2302+
result = await store.replay_events_after("999", callback)
2303+
assert result is None
2304+
callback.assert_not_called()
2305+
2306+
2307+
@pytest.mark.anyio
2308+
async def test_replay_events_after_replays_messages():
2309+
"""Test replay_events_after correctly replays messages after a given event ID."""
2310+
store = SimpleEventStore()
2311+
2312+
stream_id = "stream-1"
2313+
msg1: types.JSONRPCMessage = types.JSONRPCResponse(jsonrpc="2.0", id="1", result={"first": True})
2314+
msg2: types.JSONRPCMessage = types.JSONRPCResponse(jsonrpc="2.0", id="2", result={"second": True})
2315+
# Store: priming event (None), then two real messages
2316+
eid0 = await store.store_event(stream_id, None)
2317+
eid1 = await store.store_event(stream_id, msg1)
2318+
eid2 = await store.store_event(stream_id, msg2)
2319+
2320+
# Replay after priming event — should get both messages
2321+
replayed: list[EventMessage] = []
2322+
2323+
async def callback(event_msg: EventMessage) -> None:
2324+
replayed.append(event_msg)
2325+
2326+
result = await store.replay_events_after(eid0, callback)
2327+
assert result == stream_id
2328+
assert len(replayed) == 2
2329+
assert replayed[0].event_id == eid1
2330+
assert replayed[1].event_id == eid2
2331+
2332+
# Replay after first message — should get only second
2333+
replayed.clear()
2334+
result = await store.replay_events_after(eid1, callback)
2335+
assert result == stream_id
2336+
assert len(replayed) == 1
2337+
assert replayed[0].event_id == eid2
2338+
2339+
2340+
@pytest.mark.anyio
2341+
async def test_streamable_http_client_slow_resource(initialized_client_session: ClientSession):
2342+
"""Test reading a slow:// resource."""
2343+
result = await initialized_client_session.read_resource("slow://test-host")
2344+
assert len(result.contents) == 1
2345+
assert isinstance(result.contents[0], TextResourceContents)
2346+
assert result.contents[0].text == "Slow response from test-host"
2347+
2348+
2349+
@pytest.mark.anyio
2350+
async def test_streamable_http_client_long_running_with_checkpoints(basic_server: None, basic_server_url: str):
2351+
"""Test calling the long_running_with_checkpoints tool."""
2352+
captured_notifications: list[str] = []
2353+
2354+
async def message_handler(
2355+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
2356+
) -> None:
2357+
if isinstance(message, Exception):
2358+
return # pragma: no cover
2359+
if isinstance(message, types.ServerNotification): # pragma: no branch
2360+
if isinstance(message, types.LoggingMessageNotification): # pragma: no branch
2361+
captured_notifications.append(str(message.params.data))
2362+
2363+
async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream):
2364+
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
2365+
await session.initialize()
2366+
2367+
result = await session.call_tool("long_running_with_checkpoints", {})
2368+
assert len(result.content) == 1
2369+
assert result.content[0].type == "text"
2370+
assert isinstance(result.content[0], TextContent)
2371+
assert result.content[0].text == "Completed!"
2372+
2373+
# Should have received the two log notifications
2374+
assert "Tool started" in captured_notifications
2375+
assert "Tool is almost done" in captured_notifications
2376+
2377+
2378+
@pytest.mark.anyio
2379+
async def test_streamablehttp_server_sampling_non_text_content(basic_server: None, basic_server_url: str):
2380+
"""Test server-initiated sampling where callback returns non-text content."""
2381+
2382+
async def sampling_callback(
2383+
context: RequestContext[ClientSession],
2384+
params: types.CreateMessageRequestParams,
2385+
) -> types.CreateMessageResult:
2386+
return types.CreateMessageResult(
2387+
role="assistant",
2388+
content=types.ImageContent(
2389+
type="image",
2390+
data="base64data",
2391+
mime_type="image/png",
2392+
),
2393+
model="test-model",
2394+
stop_reason="endTurn",
2395+
)
2396+
2397+
async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream):
2398+
async with ClientSession(read_stream, write_stream, sampling_callback=sampling_callback) as session:
2399+
await session.initialize()
2400+
2401+
tool_result = await session.call_tool("test_sampling_tool", {})
2402+
assert len(tool_result.content) == 1
2403+
assert tool_result.content[0].type == "text"
2404+
# Non-text content should be stringified
2405+
assert "Response from sampling:" in tool_result.content[0].text
2406+
2407+
2408+
@pytest.mark.anyio
2409+
async def test_tool_with_multiple_stream_closes(
2410+
event_server: tuple[SimpleEventStore, str],
2411+
) -> None:
2412+
"""Test tool_with_multiple_stream_closes which calls close_sse_stream multiple times."""
2413+
_, server_url = event_server
2414+
captured_notifications: list[str] = []
2415+
2416+
async def message_handler(
2417+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
2418+
) -> None:
2419+
if isinstance(message, Exception):
2420+
return # pragma: no cover
2421+
if isinstance(message, types.ServerNotification): # pragma: no branch
2422+
if isinstance(message, types.LoggingMessageNotification): # pragma: no branch
2423+
captured_notifications.append(str(message.params.data))
2424+
2425+
async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream):
2426+
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
2427+
await session.initialize()
2428+
2429+
result = await session.call_tool(
2430+
"tool_with_multiple_stream_closes",
2431+
{"checkpoints": 3, "sleep_time": 0.2},
2432+
)
2433+
assert result.content[0].type == "text"
2434+
assert isinstance(result.content[0], TextContent)
2435+
assert "Completed 3 checkpoints" in result.content[0].text
2436+
2437+
# All checkpoint notifications should have been received
2438+
for i in range(3):
2439+
assert f"checkpoint_{i}" in captured_notifications
2440+
2441+
2442+
@pytest.mark.anyio
2443+
async def test_tool_with_multiple_stream_closes_no_event_store(
2444+
basic_server: None,
2445+
basic_server_url: str,
2446+
) -> None:
2447+
"""Test multi_close_tool without event store — close_sse_stream is None."""
2448+
async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream):
2449+
async with ClientSession(read_stream, write_stream) as session:
2450+
await session.initialize()
2451+
2452+
result = await session.call_tool(
2453+
"tool_with_multiple_stream_closes",
2454+
{"checkpoints": 2, "sleep_time": 0.1},
2455+
)
2456+
assert result.content[0].type == "text"
2457+
assert isinstance(result.content[0], TextContent)
2458+
assert "Completed 2 checkpoints" in result.content[0].text
2459+
2460+
2461+
@pytest.mark.anyio
2462+
async def test_tool_with_standalone_stream_close_no_event_store(
2463+
basic_server: None,
2464+
basic_server_url: str,
2465+
) -> None:
2466+
"""Test standalone_stream_close without event store — close_standalone_sse_stream is None."""
2467+
async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream):
2468+
async with ClientSession(read_stream, write_stream) as session:
2469+
await session.initialize()
2470+
2471+
result = await session.call_tool("tool_with_standalone_stream_close", {})
2472+
assert result.content[0].type == "text"
2473+
assert isinstance(result.content[0], TextContent)
2474+
assert result.content[0].text == "Standalone stream close test done"

0 commit comments

Comments
 (0)