2727from langgraph .graph .state import CompiledStateGraph
2828from reactivex .disposable import Disposable
2929
30- from dimos .agents .mcp .tool_stream import ToolStreamEvent
3130from dimos .agents .system_prompt import SYSTEM_PROMPT
3231from dimos .agents .utils import pretty_print_langchain_message
3332from dimos .core .core import rpc
@@ -62,8 +61,6 @@ class McpClient(Module[McpClientConfig]):
6261 _stop_event : Event
6362 _http_client : httpx .Client
6463 _seq_ids : SequentialIds
65- _sse_thread : Thread | None
66- _sse_client : httpx .Client | None
6764
6865 def __init__ (self , ** kwargs : Any ) -> None :
6966 super ().__init__ (** kwargs )
@@ -80,8 +77,6 @@ def __init__(self, **kwargs: Any) -> None:
8077 self ._stop_event = Event ()
8178 self ._http_client = httpx .Client (timeout = 120.0 )
8279 self ._seq_ids = SequentialIds ()
83- self ._sse_thread = None
84- self ._sse_client = None
8580
8681 def __reduce__ (self ) -> Any :
8782 return (self .__class__ , (), {})
@@ -105,6 +100,59 @@ def _mcp_request(self, method: str, params: dict[str, Any] | None = None) -> dic
105100 result : dict [str , Any ] = data .get ("result" )
106101 return result
107102
103+ def _mcp_tool_call (self , name : str , arguments : dict [str , Any ]) -> dict [str , Any ]:
104+ """Execute a tool call, handling both JSON and SSE streaming responses."""
105+ body : dict [str , Any ] = {
106+ "jsonrpc" : "2.0" ,
107+ "id" : self ._seq_ids .next (),
108+ "method" : "tools/call" ,
109+ "params" : {"name" : name , "arguments" : arguments },
110+ }
111+
112+ with self ._http_client .stream (
113+ "POST" ,
114+ self .config .mcp_server_url ,
115+ json = body ,
116+ headers = {"Accept" : "application/json, text/event-stream" },
117+ ) as resp :
118+ resp .raise_for_status ()
119+ content_type = resp .headers .get ("content-type" , "" )
120+
121+ if "text/event-stream" in content_type :
122+ return self ._consume_sse_tool_response (resp , name )
123+
124+ data = json .loads (resp .read ())
125+ if "error" in data :
126+ raise RuntimeError (f"MCP error { data ['error' ]['code' ]} : { data ['error' ]['message' ]} " )
127+ result : dict [str , Any ] = data .get ("result" , {})
128+ return result
129+
130+ def _consume_sse_tool_response (
131+ self , response : httpx .Response , tool_name : str
132+ ) -> dict [str , Any ]:
133+ """Parse an SSE tool response, injecting notifications as HumanMessages."""
134+ result : dict [str , Any ] | None = None
135+ for line in response .iter_lines ():
136+ if not line .startswith ("data: " ):
137+ continue
138+ try :
139+ data = json .loads (line [6 :])
140+ except json .JSONDecodeError :
141+ continue
142+
143+ if data .get ("method" ) == "notifications/message" :
144+ text = data .get ("params" , {}).get ("data" , "" )
145+ if text :
146+ self ._message_queue .put (
147+ HumanMessage (content = f"[Tool stream update from '{ tool_name } ']: { text } " )
148+ )
149+ elif "result" in data :
150+ result = data ["result" ]
151+
152+ if result is None :
153+ return {"content" : [{"type" : "text" , "text" : "Stream ended without result." }]}
154+ return result
155+
108156 def _fetch_tools (self , timeout : float = 60.0 , interval : float = 1.0 ) -> list [StructuredTool ]:
109157 result = self ._try_fetch_tools (timeout = timeout , interval = interval )
110158 if result is None :
@@ -144,7 +192,7 @@ def _mcp_tool_to_langchain(self, mcp_tool: dict[str, Any]) -> StructuredTool:
144192 input_schema = mcp_tool .get ("inputSchema" , {"type" : "object" , "properties" : {}})
145193
146194 def call_tool (** kwargs : Any ) -> str :
147- result = self ._mcp_request ( "tools/call" , { " name" : name , "arguments" : kwargs } )
195+ result = self ._mcp_tool_call ( name , kwargs )
148196 content = result .get ("content" , [])
149197 parts = [c .get ("text" , "" ) for c in content if c .get ("type" ) == "text" ]
150198 text = "\n " .join (parts )
@@ -194,18 +242,11 @@ def on_system_modules(self, _modules: list[RPCClient]) -> None:
194242 )
195243 self ._thread .start ()
196244
197- self ._start_sse_listener ()
198-
199245 @rpc
200246 def stop (self ) -> None :
201247 self ._stop_event .set ()
202- with self ._lock :
203- if self ._sse_client is not None :
204- self ._sse_client .close ()
205248 if self ._thread .is_alive ():
206249 self ._thread .join (timeout = 2.0 )
207- if self ._sse_thread is not None and self ._sse_thread .is_alive ():
208- self ._sse_thread .join (timeout = 2.0 )
209250 self ._http_client .close ()
210251 super ().stop ()
211252
@@ -254,7 +295,7 @@ def dispatch_continuation(
254295 tool_args [key ] = continuation_context [context_key ]
255296
256297 try :
257- result = self ._mcp_request ( "tools/call" , { "name" : tool_name , "arguments" : tool_args } )
298+ result = self ._mcp_tool_call ( tool_name , tool_args )
258299 content = result .get ("content" , [])
259300 parts = [c .get ("text" , "" ) for c in content if c .get ("type" ) == "text" ]
260301 text = "\n " .join (parts )
@@ -302,53 +343,6 @@ def _process_message(
302343 if self ._message_queue .empty ():
303344 self .agent_idle .publish (True )
304345
305- def _start_sse_listener (self ) -> None :
306- """Connect to the MCP server SSE endpoint to receive tool stream updates."""
307- self ._sse_thread = Thread (target = self ._sse_loop , name = "McpClient-SSE" , daemon = True )
308- self ._sse_thread .start ()
309-
310- def _sse_loop (self ) -> None :
311- base_url = self .config .mcp_server_url .rsplit ("/mcp" , 1 )[0 ]
312- sse_url = f"{ base_url } /mcp/streams"
313-
314- while not self ._stop_event .is_set ():
315- try :
316- self ._sse_connect (sse_url )
317- except Exception :
318- if not self ._stop_event .is_set ():
319- # Try reconnecting after a short delay
320- time .sleep (1.0 )
321-
322- def _sse_connect (self , sse_url : str ) -> None :
323- client = httpx .Client (timeout = None )
324- with self ._lock :
325- self ._sse_client = client
326- try :
327- with client .stream ("GET" , sse_url ) as response :
328- self ._sse_consume (response )
329- finally :
330- with self ._lock :
331- self ._sse_client = None
332- client .close ()
333-
334- def _sse_consume (self , response : httpx .Response ) -> None :
335- for line in response .iter_lines ():
336- if self ._stop_event .is_set ():
337- return
338- if not line .startswith ("data: " ):
339- continue
340- try :
341- data = json .loads (line [6 :])
342- except json .JSONDecodeError :
343- continue
344- event = ToolStreamEvent (** data )
345- if event .type == "update" :
346- self ._message_queue .put (
347- HumanMessage (
348- content = f"[Tool stream update from '{ event .tool_name } ']: { event .text } "
349- )
350- )
351-
352346
353347def _append_image_to_history (
354348 mcp_client : McpClient , func_name : str , uuid_ : str , result : Any
0 commit comments