Skip to content

Commit 2fc370a

Browse files
committed
rework how streams work
1 parent 888cf3f commit 2fc370a

3 files changed

Lines changed: 192 additions & 128 deletions

File tree

dimos/agents/mcp/mcp_client.py

Lines changed: 55 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from langgraph.graph.state import CompiledStateGraph
2828
from reactivex.disposable import Disposable
2929

30-
from dimos.agents.mcp.tool_stream import ToolStreamEvent
3130
from dimos.agents.system_prompt import SYSTEM_PROMPT
3231
from dimos.agents.utils import pretty_print_langchain_message
3332
from dimos.core.core import rpc
@@ -62,8 +61,6 @@ class McpClient(Module[McpClientConfig]):
6261
_stop_event: Event
6362
_http_client: httpx.Client
6463
_seq_ids: SequentialIds
65-
_sse_thread: Thread | None
66-
_sse_client: httpx.Client | None
6764

6865
def __init__(self, **kwargs: Any) -> None:
6966
super().__init__(**kwargs)
@@ -80,8 +77,6 @@ def __init__(self, **kwargs: Any) -> None:
8077
self._stop_event = Event()
8178
self._http_client = httpx.Client(timeout=120.0)
8279
self._seq_ids = SequentialIds()
83-
self._sse_thread = None
84-
self._sse_client = None
8580

8681
def __reduce__(self) -> Any:
8782
return (self.__class__, (), {})
@@ -105,6 +100,59 @@ def _mcp_request(self, method: str, params: dict[str, Any] | None = None) -> dic
105100
result: dict[str, Any] = data.get("result")
106101
return result
107102

103+
def _mcp_tool_call(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
104+
"""Execute a tool call, handling both JSON and SSE streaming responses."""
105+
body: dict[str, Any] = {
106+
"jsonrpc": "2.0",
107+
"id": self._seq_ids.next(),
108+
"method": "tools/call",
109+
"params": {"name": name, "arguments": arguments},
110+
}
111+
112+
with self._http_client.stream(
113+
"POST",
114+
self.config.mcp_server_url,
115+
json=body,
116+
headers={"Accept": "application/json, text/event-stream"},
117+
) as resp:
118+
resp.raise_for_status()
119+
content_type = resp.headers.get("content-type", "")
120+
121+
if "text/event-stream" in content_type:
122+
return self._consume_sse_tool_response(resp, name)
123+
124+
data = json.loads(resp.read())
125+
if "error" in data:
126+
raise RuntimeError(f"MCP error {data['error']['code']}: {data['error']['message']}")
127+
result: dict[str, Any] = data.get("result", {})
128+
return result
129+
130+
def _consume_sse_tool_response(
131+
self, response: httpx.Response, tool_name: str
132+
) -> dict[str, Any]:
133+
"""Parse an SSE tool response, injecting notifications as HumanMessages."""
134+
result: dict[str, Any] | None = None
135+
for line in response.iter_lines():
136+
if not line.startswith("data: "):
137+
continue
138+
try:
139+
data = json.loads(line[6:])
140+
except json.JSONDecodeError:
141+
continue
142+
143+
if data.get("method") == "notifications/message":
144+
text = data.get("params", {}).get("data", "")
145+
if text:
146+
self._message_queue.put(
147+
HumanMessage(content=f"[Tool stream update from '{tool_name}']: {text}")
148+
)
149+
elif "result" in data:
150+
result = data["result"]
151+
152+
if result is None:
153+
return {"content": [{"type": "text", "text": "Stream ended without result."}]}
154+
return result
155+
108156
def _fetch_tools(self, timeout: float = 60.0, interval: float = 1.0) -> list[StructuredTool]:
109157
result = self._try_fetch_tools(timeout=timeout, interval=interval)
110158
if result is None:
@@ -144,7 +192,7 @@ def _mcp_tool_to_langchain(self, mcp_tool: dict[str, Any]) -> StructuredTool:
144192
input_schema = mcp_tool.get("inputSchema", {"type": "object", "properties": {}})
145193

146194
def call_tool(**kwargs: Any) -> str:
147-
result = self._mcp_request("tools/call", {"name": name, "arguments": kwargs})
195+
result = self._mcp_tool_call(name, kwargs)
148196
content = result.get("content", [])
149197
parts = [c.get("text", "") for c in content if c.get("type") == "text"]
150198
text = "\n".join(parts)
@@ -194,18 +242,11 @@ def on_system_modules(self, _modules: list[RPCClient]) -> None:
194242
)
195243
self._thread.start()
196244

197-
self._start_sse_listener()
198-
199245
@rpc
200246
def stop(self) -> None:
201247
self._stop_event.set()
202-
with self._lock:
203-
if self._sse_client is not None:
204-
self._sse_client.close()
205248
if self._thread.is_alive():
206249
self._thread.join(timeout=2.0)
207-
if self._sse_thread is not None and self._sse_thread.is_alive():
208-
self._sse_thread.join(timeout=2.0)
209250
self._http_client.close()
210251
super().stop()
211252

@@ -254,7 +295,7 @@ def dispatch_continuation(
254295
tool_args[key] = continuation_context[context_key]
255296

256297
try:
257-
result = self._mcp_request("tools/call", {"name": tool_name, "arguments": tool_args})
298+
result = self._mcp_tool_call(tool_name, tool_args)
258299
content = result.get("content", [])
259300
parts = [c.get("text", "") for c in content if c.get("type") == "text"]
260301
text = "\n".join(parts)
@@ -302,53 +343,6 @@ def _process_message(
302343
if self._message_queue.empty():
303344
self.agent_idle.publish(True)
304345

305-
def _start_sse_listener(self) -> None:
306-
"""Connect to the MCP server SSE endpoint to receive tool stream updates."""
307-
self._sse_thread = Thread(target=self._sse_loop, name="McpClient-SSE", daemon=True)
308-
self._sse_thread.start()
309-
310-
def _sse_loop(self) -> None:
311-
base_url = self.config.mcp_server_url.rsplit("/mcp", 1)[0]
312-
sse_url = f"{base_url}/mcp/streams"
313-
314-
while not self._stop_event.is_set():
315-
try:
316-
self._sse_connect(sse_url)
317-
except Exception:
318-
if not self._stop_event.is_set():
319-
# Try reconnecting after a short delay
320-
time.sleep(1.0)
321-
322-
def _sse_connect(self, sse_url: str) -> None:
323-
client = httpx.Client(timeout=None)
324-
with self._lock:
325-
self._sse_client = client
326-
try:
327-
with client.stream("GET", sse_url) as response:
328-
self._sse_consume(response)
329-
finally:
330-
with self._lock:
331-
self._sse_client = None
332-
client.close()
333-
334-
def _sse_consume(self, response: httpx.Response) -> None:
335-
for line in response.iter_lines():
336-
if self._stop_event.is_set():
337-
return
338-
if not line.startswith("data: "):
339-
continue
340-
try:
341-
data = json.loads(line[6:])
342-
except json.JSONDecodeError:
343-
continue
344-
event = ToolStreamEvent(**data)
345-
if event.type == "update":
346-
self._message_queue.put(
347-
HumanMessage(
348-
content=f"[Tool stream update from '{event.tool_name}']: {event.text}"
349-
)
350-
)
351-
352346

353347
def _append_image_to_history(
354348
mcp_client: McpClient, func_name: str, uuid_: str, result: Any

dimos/agents/mcp/mcp_server.py

Lines changed: 98 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import asyncio
1717
from collections.abc import AsyncGenerator
1818
import concurrent.futures
19+
from dataclasses import dataclass
1920
import json
2021
import os
2122
import time
@@ -30,6 +31,7 @@
3031
import uvicorn
3132

3233
from dimos.agents.annotation import skill
34+
from dimos.agents.mcp.tool_stream import ToolStreamEvent
3335
from dimos.core.core import rpc
3436
from dimos.core.module import Module
3537
from dimos.core.rpc_client import RpcCall, RPCClient
@@ -66,12 +68,20 @@ def _jsonrpc_error(req_id: Any, code: int, message: str) -> dict[str, Any]:
6668
return {"jsonrpc": "2.0", "id": req_id, "error": {"code": code, "message": message}}
6769

6870

71+
@dataclass(frozen=True, slots=True)
72+
class _StreamingToolResult:
73+
"""Marker returned when a tool starts a background stream."""
74+
75+
req_id: Any
76+
tool_name: str
77+
78+
6979
def _handle_initialize(req_id: Any) -> dict[str, Any]:
7080
return _jsonrpc_result(
7181
req_id,
7282
{
7383
"protocolVersion": "2025-11-25",
74-
"capabilities": {"tools": {}},
84+
"capabilities": {"tools": {}, "logging": {}},
7585
"serverInfo": {"name": "dimensional", "version": "1.0.0"},
7686
},
7787
)
@@ -94,7 +104,7 @@ def _handle_tools_list(req_id: Any, skills: list[SkillInfo]) -> dict[str, Any]:
94104

95105
async def _handle_tools_call(
96106
req_id: Any, params: dict[str, Any], rpc_calls: dict[str, Any]
97-
) -> dict[str, Any]:
107+
) -> dict[str, Any] | _StreamingToolResult:
98108
name = params.get("name", "")
99109
args: dict[str, Any] = params.get("arguments") or {}
100110

@@ -115,8 +125,8 @@ async def _handle_tools_call(
115125
duration = f"{time.monotonic() - t0:.3f}s"
116126

117127
if result is None:
118-
logger.info("MCP tool done (async)", tool=name, duration=duration)
119-
return _jsonrpc_result_text(req_id, "It has started. You will be updated later.")
128+
logger.info("MCP tool streaming", tool=name, duration=duration)
129+
return _StreamingToolResult(req_id=req_id, tool_name=name)
120130

121131
response = str(result)[:200]
122132
if hasattr(result, "agent_encode"):
@@ -131,7 +141,7 @@ async def handle_request(
131141
request: dict[str, Any],
132142
skills: list[SkillInfo],
133143
rpc_calls: dict[str, Any],
134-
) -> dict[str, Any] | None:
144+
) -> dict[str, Any] | _StreamingToolResult | None:
135145
"""Handle a single MCP JSON-RPC request.
136146
137147
Returns None for JSON-RPC notifications (no ``id``), which must not
@@ -165,27 +175,101 @@ async def mcp_endpoint(request: Request) -> Response:
165175
{"jsonrpc": "2.0", "id": None, "error": {"code": -32700, "message": "Parse error"}},
166176
status_code=400,
167177
)
178+
179+
# Pre-register a queue for tool stream events when the client accepts SSE.
180+
accept = request.headers.get("accept", "")
181+
is_tool_call = body.get("method") == "tools/call"
182+
client_accepts_sse = "text/event-stream" in accept
183+
184+
stream_queue: asyncio.Queue[dict[str, Any]] | None = None
185+
if is_tool_call and client_accepts_sse:
186+
stream_queue = asyncio.Queue()
187+
app.state.sse_queues.append(stream_queue)
188+
168189
result = await handle_request(body, request.app.state.skills, request.app.state.rpc_calls)
190+
191+
# Streaming tool: return SSE response if client supports it.
192+
if isinstance(result, _StreamingToolResult):
193+
if stream_queue is not None:
194+
return _streaming_tool_response(result, stream_queue)
195+
# Client doesn't support SSE — fall back to immediate JSON.
196+
return JSONResponse(
197+
_jsonrpc_result_text(result.req_id, "It has started. You will be updated later.")
198+
)
199+
200+
# Non-streaming: remove the pre-registered queue if any.
201+
if stream_queue is not None:
202+
try:
203+
app.state.sse_queues.remove(stream_queue)
204+
except ValueError:
205+
pass
206+
169207
if result is None:
170208
return Response(status_code=204)
171209
return JSONResponse(result)
172210

173211

174-
@app.get("/mcp/streams")
175-
async def streams_sse_endpoint() -> StreamingResponse:
176-
"""Server-Sent Events endpoint for tool stream updates.
212+
_STREAM_TIMEOUT = 300.0 # seconds
213+
214+
215+
def _sse_event(data: dict[str, Any]) -> str:
216+
"""Format a JSON-RPC message as an SSE ``event: message`` frame."""
217+
return f"event: message\ndata: {json.dumps(data)}\n\n"
218+
219+
220+
def _streaming_tool_response(
221+
streaming: _StreamingToolResult,
222+
queue: asyncio.Queue[dict[str, Any]],
223+
) -> StreamingResponse:
224+
"""Build an SSE response that forwards ToolStream events as MCP log notifications.
177225
178-
Clients subscribe here to receive real-time updates from long-running
179-
skills that use ``ToolStream``.
226+
The response streams ``notifications/message`` for each update and ends
227+
with the JSON-RPC result carrying the accumulated text.
180228
"""
181-
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
182-
app.state.sse_queues.append(queue)
183229

184230
async def event_generator() -> AsyncGenerator[str, None]:
231+
stream_id: str | None = None
232+
collected: list[str] = []
185233
try:
186234
while True:
187-
data = await queue.get()
188-
yield f"data: {json.dumps(data)}\n\n"
235+
try:
236+
data = await asyncio.wait_for(queue.get(), timeout=_STREAM_TIMEOUT)
237+
except asyncio.TimeoutError:
238+
text = "\n".join(collected) if collected else "No updates received."
239+
yield _sse_event(_jsonrpc_result_text(streaming.req_id, text))
240+
return
241+
242+
try:
243+
event = ToolStreamEvent(**data)
244+
except (TypeError, KeyError):
245+
continue
246+
247+
# Filter: match by tool_name, lock onto the first stream_id seen.
248+
if event.tool_name != streaming.tool_name:
249+
continue
250+
if stream_id is None:
251+
stream_id = event.stream_id
252+
elif event.stream_id != stream_id:
253+
continue
254+
255+
if event.type == "update" and event.text:
256+
collected.append(event.text)
257+
yield _sse_event(
258+
{
259+
"jsonrpc": "2.0",
260+
"method": "notifications/message",
261+
"params": {
262+
"level": "info",
263+
"logger": event.tool_name,
264+
"data": event.text,
265+
},
266+
}
267+
)
268+
269+
elif event.type == "close":
270+
text = "\n".join(collected) if collected else "Stream completed."
271+
yield _sse_event(_jsonrpc_result_text(streaming.req_id, text))
272+
return
189273
except asyncio.CancelledError:
190274
pass
191275
finally:

0 commit comments

Comments
 (0)