diff --git a/packages/runtimeuse-client-python/examples/runtime-client-persistent-session-example.py b/packages/runtimeuse-client-python/examples/runtime-client-persistent-session-example.py new file mode 100644 index 0000000..42cd2f5 --- /dev/null +++ b/packages/runtimeuse-client-python/examples/runtime-client-persistent-session-example.py @@ -0,0 +1,105 @@ +import os +import asyncio +import json +from pydantic import BaseModel + +from src.runtimeuse_client import ( + AssistantMessageInterface, + AgentRuntimeError, + ExecuteCommandsOptions, + RuntimeUseClient, + QueryOptions, + StructuredOutputResult, + CommandInterface, +) + + +class FrenchWordsAnswer(BaseModel): + words: list[str] + + +class PopulationAnswer(BaseModel): + population: int + + +async def main(): + client = RuntimeUseClient(ws_url="ws://localhost:8080") + + async def on_assistant_message(message: AssistantMessageInterface): + print(f"Assistant message: {message.text_blocks}") + + try: + async with client.session() as session: + result = await session.query( + prompt="Search the web to find the answer to the question: 'What is the population of France?'", + options=QueryOptions( + system_prompt="You are a helpful assistant.", + model="gpt-5.4", + pre_agent_invocation_commands=[ + CommandInterface( + command="echo 'Running pre-agent command'", + cwd=os.getcwd(), + env={}, + ) + ], + post_agent_invocation_commands=[ + CommandInterface( + command="echo 'Running post-agent command'", + cwd=os.getcwd(), + env={}, + ) + ], + output_format_json_schema_str=json.dumps( + { + "type": "json_schema", + "schema": PopulationAnswer.model_json_schema(), + } + ), + source_id="my-source", + on_assistant_message=on_assistant_message, + ), + ) + assert isinstance(result.data, StructuredOutputResult) + population = result.data.structured_output["population"] + print(f"Population: {population}") + + if population > 5: + print("Population is greater than 5, will say bonjour") + result = await session.execute_commands( + commands=[ + CommandInterface( + command="echo 'Bonjour!'", + cwd=os.getcwd(), + env={}, + ), + CommandInterface( + command="exit 1", + cwd=os.getcwd(), + env={}, + ), + ], + options=ExecuteCommandsOptions(), + ) + print(f"Command result: {result}") + result = await session.query( + prompt="Give me 6 french words", + options=QueryOptions( + system_prompt="You are a helpful assistant.", + model="gpt-5.4", + output_format_json_schema_str=json.dumps( + { + "type": "json_schema", + "schema": FrenchWordsAnswer.model_json_schema(), + } + ), + ), + ) + assert isinstance(result.data, StructuredOutputResult) + french_words = result.data.structured_output["words"] + print(f"French words: {french_words}") + except AgentRuntimeError as e: + print(f"Error: {e.error}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/packages/runtimeuse-client-python/pyproject.toml b/packages/runtimeuse-client-python/pyproject.toml index 3c0bc51..5ef3f1b 100644 --- a/packages/runtimeuse-client-python/pyproject.toml +++ b/packages/runtimeuse-client-python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "runtimeuse-client" -version = "0.10.0" +version = "0.11.0" description = "Client library for AI agent runtime communication over WebSocket" readme = "README.md" license = {"text" = "FSL"} diff --git a/packages/runtimeuse-client-python/src/runtimeuse_client/__init__.py b/packages/runtimeuse-client-python/src/runtimeuse_client/__init__.py index 3ce8254..f57fdec 100644 --- a/packages/runtimeuse-client-python/src/runtimeuse_client/__init__.py +++ b/packages/runtimeuse-client-python/src/runtimeuse_client/__init__.py @@ -1,5 +1,11 @@ -from .client import RuntimeUseClient -from .transports import Transport, WebSocketTransport +from .client import RuntimeUseClient, RuntimeUseSession +from .transports import ( + ConnectedTransport, + ConnectedWebSocketTransport, + PersistentTransport, + Transport, + WebSocketTransport, +) from .exceptions import AgentRuntimeError, CancelledException from .types import ( AgentRuntimeMessageInterface, @@ -21,6 +27,7 @@ ArtifactUploadResponseMessageInterface, ErrorMessageInterface, CancelMessage, + EndSessionMessage, ArtifactUploadResult, OnAssistantMessageCallback, OnArtifactUploadRequestCallback, @@ -28,6 +35,10 @@ __all__ = [ "RuntimeUseClient", + "RuntimeUseSession", + "ConnectedTransport", + "ConnectedWebSocketTransport", + "PersistentTransport", "Transport", "WebSocketTransport", "AgentRuntimeError", @@ -51,6 +62,7 @@ "ArtifactUploadResponseMessageInterface", "ErrorMessageInterface", "CancelMessage", + "EndSessionMessage", "ArtifactUploadResult", "OnAssistantMessageCallback", "OnArtifactUploadRequestCallback", diff --git a/packages/runtimeuse-client-python/src/runtimeuse_client/client.py b/packages/runtimeuse-client-python/src/runtimeuse_client/client.py index 5e6e9e3..b9de9d6 100644 --- a/packages/runtimeuse-client-python/src/runtimeuse_client/client.py +++ b/packages/runtimeuse-client-python/src/runtimeuse_client/client.py @@ -1,9 +1,11 @@ import asyncio import logging +from contextlib import asynccontextmanager +from typing import AsyncIterator import pydantic -from .transports import Transport, WebSocketTransport +from .transports import ConnectedTransport, Transport, WebSocketTransport from .types import ( InvocationMessage, CommandExecutionMessage, @@ -26,6 +28,245 @@ _default_logger = logging.getLogger(__name__) +def _build_invocation(prompt: str, options: QueryOptions) -> InvocationMessage: + return InvocationMessage( + message_type="invocation_message", + user_prompt=prompt, + system_prompt=options.system_prompt, + model=options.model, + output_format_json_schema_str=options.output_format_json_schema_str, + source_id=options.source_id, + agent_env=options.agent_env, + secrets_to_redact=options.secrets_to_redact, + artifacts_dir=options.artifacts_dir, + pre_agent_invocation_commands=options.pre_agent_invocation_commands, + post_agent_invocation_commands=options.post_agent_invocation_commands, + pre_agent_downloadables=options.pre_agent_downloadables, + ) + + +def _build_command_execution( + commands: list[CommandInterface], options: ExecuteCommandsOptions +) -> CommandExecutionMessage: + return CommandExecutionMessage( + message_type="command_execution_message", + source_id=options.source_id, + secrets_to_redact=options.secrets_to_redact, + commands=commands, + artifacts_dir=options.artifacts_dir, + pre_execution_downloadables=options.pre_execution_downloadables, + ) + + +async def _handle_artifact_request( + message: dict, + on_artifact_upload_request, + send_queue: asyncio.Queue[dict], + logger: logging.Logger, +) -> None: + logger.info( + f"Received artifact upload request message from agent runtime: {message}" + ) + if on_artifact_upload_request is None: + return + req = ArtifactUploadRequestMessageInterface.model_validate(message) + upload_result = await on_artifact_upload_request(req) + response = ArtifactUploadResponseMessageInterface( + message_type="artifact_upload_response_message", + filename=req.filename, + filepath=req.filepath, + presigned_url=upload_result.presigned_url, + content_type=upload_result.content_type, + ) + await send_queue.put(response.model_dump(mode="json")) + + +async def _run_request_loop( + message_iter, + send_queue: asyncio.Queue[dict], + abort_event: asyncio.Event, + *, + terminal_message_type: str, + result_cls, + on_assistant_message, + on_artifact_upload_request, + cancelled_message: str, + logger: logging.Logger, +): + """Drive the message loop for a single request and return the terminal result. + + Iterates ``message_iter`` until it sees a terminal message. Even if + ``abort_event`` is set mid-request, the loop keeps reading until the + server's terminal arrives — otherwise, on a persistent session, the + server's ``error_message`` for the cancelled request would leak into the + next ``request()`` and cause the follow-up call to spuriously raise. + + Raises ``AgentRuntimeError`` on error terminals, ``CancelledException`` + if the caller aborted (after the server's terminal has been drained). + """ + wire_result = None + error_to_raise: AgentRuntimeError | None = None + + async for message in message_iter: + try: + message_interface = AgentRuntimeMessageInterface.model_validate(message) + except pydantic.ValidationError: + logger.error( + f"Received unknown message type from agent runtime: {message}" + ) + continue + + if message_interface.message_type == terminal_message_type: + wire_result = result_cls.model_validate(message) + logger.info( + f"Received terminal message from agent runtime: {message}" + ) + break + + if message_interface.message_type == "error_message": + try: + err = ErrorMessageInterface.model_validate(message) + except pydantic.ValidationError: + logger.error( + f"Received malformed error message from agent runtime: {message}" + ) + error_to_raise = AgentRuntimeError(str(message)) + else: + logger.error(f"Error from agent runtime: {err}") + error_to_raise = AgentRuntimeError(err.error, metadata=err.metadata) + break + + # Side-effectful handlers are suppressed once the caller has aborted: + # we still keep reading so we consume the terminal, but we don't fire + # callbacks or respond to artifact handshakes for a dead request. + if abort_event.is_set(): + continue + + if message_interface.message_type == "assistant_message": + if on_assistant_message is not None: + assistant = AssistantMessageInterface.model_validate(message) + await on_assistant_message(assistant) + continue + + if message_interface.message_type == "artifact_upload_request_message": + await _handle_artifact_request( + message, on_artifact_upload_request, send_queue, logger + ) + continue + + logger.info( + f"Received non-result message from agent runtime: {message}" + ) + + if abort_event.is_set(): + raise CancelledException(cancelled_message) + + if error_to_raise is not None: + raise error_to_raise + + if wire_result is None: + raise AgentRuntimeError("No result message received") + + return wire_result + + +class RuntimeUseSession: + """A persistent session over a single transport connection. + + Exposes :meth:`query` and :meth:`execute_commands` with the same + signatures as :class:`RuntimeUseClient` but dispatches each call as a + separate request/response cycle over the already-open transport. + """ + + def __init__(self, connected: ConnectedTransport): + self._connected = connected + self._abort_event = asyncio.Event() + self._send_queue: asyncio.Queue[dict] | None = None + self._lock = asyncio.Lock() + + def abort(self) -> None: + """Signal the in-flight request to cancel. + + Sends a ``cancel_message`` to the runtime (which aborts the current + request without closing the session) and causes the in-flight call to + raise :class:`CancelledException`. + """ + self._abort_event.set() + send_queue = self._send_queue + if send_queue is not None: + send_queue.put_nowait( + CancelMessage(message_type="cancel_message").model_dump(mode="json") + ) + + async def query(self, prompt: str, options: QueryOptions) -> QueryResult: + async with self._lock: + logger = options.logger or _default_logger + self._abort_event = asyncio.Event() + + invocation = _build_invocation(prompt, options) + send_queue: asyncio.Queue[dict] = asyncio.Queue() + self._send_queue = send_queue + await send_queue.put(invocation.model_dump(mode="json")) + + try: + async with asyncio.timeout(options.timeout): + message_iter = self._connected.request(send_queue) + try: + wire = await _run_request_loop( + message_iter, + send_queue, + self._abort_event, + terminal_message_type="result_message", + result_cls=ResultMessageInterface, + on_assistant_message=options.on_assistant_message, + on_artifact_upload_request=options.on_artifact_upload_request, + cancelled_message="Query was cancelled", + logger=logger, + ) + finally: + await message_iter.aclose() + finally: + self._send_queue = None + + return QueryResult(data=wire.data, metadata=wire.metadata) + + async def execute_commands( + self, + commands: list[CommandInterface], + options: ExecuteCommandsOptions, + ) -> CommandExecutionResult: + async with self._lock: + logger = options.logger or _default_logger + self._abort_event = asyncio.Event() + + message = _build_command_execution(commands, options) + send_queue: asyncio.Queue[dict] = asyncio.Queue() + self._send_queue = send_queue + await send_queue.put(message.model_dump(mode="json")) + + try: + async with asyncio.timeout(options.timeout): + message_iter = self._connected.request(send_queue) + try: + wire = await _run_request_loop( + message_iter, + send_queue, + self._abort_event, + terminal_message_type="command_execution_result_message", + result_cls=CommandExecutionResultMessageInterface, + on_assistant_message=options.on_assistant_message, + on_artifact_upload_request=options.on_artifact_upload_request, + cancelled_message="Command execution was cancelled", + logger=logger, + ) + finally: + await message_iter.aclose() + finally: + self._send_queue = None + + return CommandExecutionResult(results=wire.results) + + class RuntimeUseClient: """Client for communicating with a runtimeuse agent runtime. @@ -37,6 +278,11 @@ class RuntimeUseClient: WebSocketTransport. Ignored when a custom transport is provided. transport: Optional custom transport implementing the Transport protocol. When provided, ws_url is not required. + + Both one-shot (:meth:`query`, :meth:`execute_commands`) and persistent + (:meth:`session`) styles are supported. One-shot calls open a connection, + run a single request, and close. :meth:`session` opens a connection that + can service multiple sequential calls. """ def __init__( @@ -55,21 +301,44 @@ def __init__( self._send_queue: asyncio.Queue[dict] | None = None def abort(self) -> None: - """Signal the current query to cancel. + """Signal the current one-shot call to cancel. - Sends a cancel message to the agent runtime and causes ``query`` - to raise :class:`CancelledException`. Safe to call from any - coroutine on the same event loop. + Sends a ``cancel_message`` to the agent runtime and causes the active + :meth:`query` or :meth:`execute_commands` call to raise + :class:`CancelledException`. Safe to call from any coroutine on the + same event loop. + + For persistent sessions, use :meth:`RuntimeUseSession.abort` instead. """ self._abort_event.set() send_queue = self._send_queue if send_queue is not None: send_queue.put_nowait( - CancelMessage(message_type="cancel_message").model_dump( - mode="json" - ) + CancelMessage(message_type="cancel_message").model_dump(mode="json") ) + @asynccontextmanager + async def session(self) -> AsyncIterator[RuntimeUseSession]: + """Open a persistent session to the agent runtime. + + Yields a :class:`RuntimeUseSession` that can service multiple + sequential :meth:`RuntimeUseSession.query` / + :meth:`RuntimeUseSession.execute_commands` calls over a single + connection. The connection is closed (and ``end_session_message`` is + sent to the runtime) when the context exits. + + Requires a transport that supports persistent connections (the + default :class:`WebSocketTransport` does). + """ + connect = getattr(self._transport, "connect", None) + if connect is None: + raise TypeError( + "The configured transport does not support persistent sessions. " + "Use a transport that implements PersistentTransport (e.g. WebSocketTransport)." + ) + async with connect() as connected: + yield RuntimeUseSession(connected) + async def query( self, prompt: str, @@ -77,13 +346,9 @@ async def query( ) -> QueryResult: """Send a prompt to the agent runtime and return the result. - Builds an :class:`InvocationMessage` from *prompt* and *options*, - sends it over the transport, processes the response stream, and - returns a :class:`QueryResult`. - - Access ``result.data`` for the :class:`TextResult` or - :class:`StructuredOutputResult`, and ``result.metadata`` for - execution metadata. + Convenience one-shot wrapper: opens a connection, sends one + invocation, and closes. For multiple sequential calls over a single + connection, use :meth:`session`. Args: prompt: The user prompt to send to the agent. @@ -99,119 +364,29 @@ async def query( self._abort_event = asyncio.Event() - invocation = InvocationMessage( - message_type="invocation_message", - user_prompt=prompt, - system_prompt=options.system_prompt, - model=options.model, - output_format_json_schema_str=options.output_format_json_schema_str, - source_id=options.source_id, - agent_env=options.agent_env, - secrets_to_redact=options.secrets_to_redact, - artifacts_dir=options.artifacts_dir, - pre_agent_invocation_commands=options.pre_agent_invocation_commands, - post_agent_invocation_commands=options.post_agent_invocation_commands, - pre_agent_downloadables=options.pre_agent_downloadables, - ) + invocation = _build_invocation(prompt, options) send_queue: asyncio.Queue[dict] = asyncio.Queue() self._send_queue = send_queue await send_queue.put(invocation.model_dump(mode="json")) - wire_result: ResultMessageInterface | None = None - try: async with asyncio.timeout(options.timeout): - async for message in self._transport(send_queue=send_queue): - if self._abort_event.is_set(): - raise CancelledException("Query was cancelled") - - try: - message_interface = AgentRuntimeMessageInterface.model_validate( - message - ) - except pydantic.ValidationError: - logger.error( - f"Received unknown message type from agent runtime: {message}" - ) - continue - - if message_interface.message_type == "result_message": - wire_result = ResultMessageInterface.model_validate(message) - logger.info( - f"Received result message from agent runtime: {message}" - ) - continue - - elif message_interface.message_type == "assistant_message": - if options.on_assistant_message is not None: - assistant_message_interface = ( - AssistantMessageInterface.model_validate(message) - ) - await options.on_assistant_message(assistant_message_interface) - continue - - elif message_interface.message_type == "error_message": - try: - error_message_interface = ErrorMessageInterface.model_validate( - message - ) - except pydantic.ValidationError: - logger.error( - f"Received malformed error message from agent runtime: {message}", - ) - raise AgentRuntimeError(str(message)) - logger.error( - f"Error from agent runtime: {error_message_interface}", - ) - raise AgentRuntimeError( - error_message_interface.error, - metadata=error_message_interface.metadata, - ) - - elif ( - message_interface.message_type == "artifact_upload_request_message" - ): - logger.info( - f"Received artifact upload request message from agent runtime: {message}" - ) - if options.on_artifact_upload_request is not None: - artifact_upload_request_message_interface = ( - ArtifactUploadRequestMessageInterface.model_validate( - message - ) - ) - upload_result = await options.on_artifact_upload_request( - artifact_upload_request_message_interface - ) - artifact_upload_response_message_interface = ArtifactUploadResponseMessageInterface( - message_type="artifact_upload_response_message", - filename=artifact_upload_request_message_interface.filename, - filepath=artifact_upload_request_message_interface.filepath, - presigned_url=upload_result.presigned_url, - content_type=upload_result.content_type, - ) - await send_queue.put( - artifact_upload_response_message_interface.model_dump( - mode="json" - ) - ) - continue - - else: - logger.info( - f"Received non-result message from agent runtime: {message}" - ) + wire = await _run_request_loop( + self._transport(send_queue=send_queue), + send_queue, + self._abort_event, + terminal_message_type="result_message", + result_cls=ResultMessageInterface, + on_assistant_message=options.on_assistant_message, + on_artifact_upload_request=options.on_artifact_upload_request, + cancelled_message="Query was cancelled", + logger=logger, + ) finally: self._send_queue = None - if self._abort_event.is_set(): - raise CancelledException("Query was cancelled") - - if wire_result is None: - raise AgentRuntimeError("No result message received") - - return QueryResult(data=wire_result.data, metadata=wire_result.metadata) + return QueryResult(data=wire.data, metadata=wire.metadata) async def execute_commands( self, @@ -220,9 +395,9 @@ async def execute_commands( ) -> CommandExecutionResult: """Execute commands in the runtime without invoking the agent. - Sends a :class:`CommandExecutionMessage`, processes the response - stream, and returns a :class:`CommandExecutionResult` with - per-command exit codes. + Convenience one-shot wrapper: opens a connection, sends one command + execution request, and closes. For multiple sequential calls over a + single connection, use :meth:`session`. Args: commands: Commands to execute in the runtime environment. @@ -238,108 +413,25 @@ async def execute_commands( self._abort_event = asyncio.Event() - message = CommandExecutionMessage( - message_type="command_execution_message", - source_id=options.source_id, - secrets_to_redact=options.secrets_to_redact, - commands=commands, - artifacts_dir=options.artifacts_dir, - pre_execution_downloadables=options.pre_execution_downloadables, - ) - + message = _build_command_execution(commands, options) send_queue: asyncio.Queue[dict] = asyncio.Queue() self._send_queue = send_queue await send_queue.put(message.model_dump(mode="json")) - wire_result: CommandExecutionResultMessageInterface | None = None - try: async with asyncio.timeout(options.timeout): - async for msg in self._transport(send_queue=send_queue): - if self._abort_event.is_set(): - raise CancelledException("Command execution was cancelled") - - try: - message_interface = AgentRuntimeMessageInterface.model_validate(msg) - except pydantic.ValidationError: - logger.error( - f"Received unknown message type from agent runtime: {msg}" - ) - continue - - if message_interface.message_type == "command_execution_result_message": - wire_result = CommandExecutionResultMessageInterface.model_validate( - msg - ) - logger.info( - f"Received command execution result from agent runtime: {msg}" - ) - continue - - elif message_interface.message_type == "assistant_message": - if options.on_assistant_message is not None: - assistant_message_interface = ( - AssistantMessageInterface.model_validate(msg) - ) - await options.on_assistant_message(assistant_message_interface) - continue - - elif message_interface.message_type == "error_message": - try: - error_message_interface = ErrorMessageInterface.model_validate( - msg - ) - except pydantic.ValidationError: - logger.error( - f"Received malformed error message from agent runtime: {msg}", - ) - raise AgentRuntimeError(str(msg)) - logger.error( - f"Error from agent runtime: {error_message_interface}", - ) - raise AgentRuntimeError( - error_message_interface.error, - metadata=error_message_interface.metadata, - ) - - elif ( - message_interface.message_type == "artifact_upload_request_message" - ): - logger.info( - f"Received artifact upload request message from agent runtime: {msg}" - ) - if options.on_artifact_upload_request is not None: - artifact_upload_request_message_interface = ( - ArtifactUploadRequestMessageInterface.model_validate(msg) - ) - upload_result = await options.on_artifact_upload_request( - artifact_upload_request_message_interface - ) - artifact_upload_response_message_interface = ArtifactUploadResponseMessageInterface( - message_type="artifact_upload_response_message", - filename=artifact_upload_request_message_interface.filename, - filepath=artifact_upload_request_message_interface.filepath, - presigned_url=upload_result.presigned_url, - content_type=upload_result.content_type, - ) - await send_queue.put( - artifact_upload_response_message_interface.model_dump( - mode="json" - ) - ) - continue - - else: - logger.info( - f"Received non-result message from agent runtime: {msg}" - ) + wire = await _run_request_loop( + self._transport(send_queue=send_queue), + send_queue, + self._abort_event, + terminal_message_type="command_execution_result_message", + result_cls=CommandExecutionResultMessageInterface, + on_assistant_message=options.on_assistant_message, + on_artifact_upload_request=options.on_artifact_upload_request, + cancelled_message="Command execution was cancelled", + logger=logger, + ) finally: self._send_queue = None - if self._abort_event.is_set(): - raise CancelledException("Command execution was cancelled") - - if wire_result is None: - raise AgentRuntimeError("No result message received") - - return CommandExecutionResult(results=wire_result.results) + return CommandExecutionResult(results=wire.results) diff --git a/packages/runtimeuse-client-python/src/runtimeuse_client/transports/__init__.py b/packages/runtimeuse-client-python/src/runtimeuse_client/transports/__init__.py index 3145bef..6fa89a8 100644 --- a/packages/runtimeuse-client-python/src/runtimeuse_client/transports/__init__.py +++ b/packages/runtimeuse-client-python/src/runtimeuse_client/transports/__init__.py @@ -1,4 +1,10 @@ -from .transport import Transport -from .websocket_transport import WebSocketTransport +from .transport import ConnectedTransport, PersistentTransport, Transport +from .websocket_transport import ConnectedWebSocketTransport, WebSocketTransport -__all__ = ["Transport", "WebSocketTransport"] +__all__ = [ + "ConnectedTransport", + "ConnectedWebSocketTransport", + "PersistentTransport", + "Transport", + "WebSocketTransport", +] diff --git a/packages/runtimeuse-client-python/src/runtimeuse_client/transports/transport.py b/packages/runtimeuse-client-python/src/runtimeuse_client/transports/transport.py index 316e438..53ed5ae 100644 --- a/packages/runtimeuse-client-python/src/runtimeuse_client/transports/transport.py +++ b/packages/runtimeuse-client-python/src/runtimeuse_client/transports/transport.py @@ -1,15 +1,43 @@ import asyncio -from typing import AsyncGenerator, Any, Protocol +from typing import Any, AsyncGenerator, AsyncContextManager, Protocol class Transport(Protocol): - """Protocol for the underlying message transport. + """Protocol for a one-shot message transport. Implementations must be callable async generators that yield parsed messages (dicts) from the agent runtime and consume outbound messages from the - send_queue. + send_queue. The underlying connection is opened on call and closed when the + generator exits. """ def __call__( self, send_queue: asyncio.Queue[dict] ) -> AsyncGenerator[dict[str, Any], None]: ... + + +class ConnectedTransport(Protocol): + """Protocol for a persistent connection that supports N sequential requests. + + Implementations of this protocol represent an already-open connection. + Each call to ``request`` runs one request/response cycle over that + connection until the caller closes the generator (typically after + receiving a terminal message). The connection stays open between + requests. + """ + + def request( + self, send_queue: asyncio.Queue[dict] + ) -> AsyncGenerator[dict[str, Any], None]: ... + + async def close(self) -> None: ... + + +class PersistentTransport(Protocol): + """Protocol for a transport that can open a persistent connection. + + ``connect`` returns an async context manager that yields a + ``ConnectedTransport`` once the connection is open. + """ + + def connect(self) -> AsyncContextManager[ConnectedTransport]: ... diff --git a/packages/runtimeuse-client-python/src/runtimeuse_client/transports/websocket_transport.py b/packages/runtimeuse-client-python/src/runtimeuse_client/transports/websocket_transport.py index 242f250..b44c4aa 100644 --- a/packages/runtimeuse-client-python/src/runtimeuse_client/transports/websocket_transport.py +++ b/packages/runtimeuse-client-python/src/runtimeuse_client/transports/websocket_transport.py @@ -1,15 +1,73 @@ import json import asyncio import logging -from typing import AsyncGenerator, Any +from contextlib import asynccontextmanager +from typing import AsyncGenerator, AsyncIterator, Any import websockets _logger = logging.getLogger(__name__) +class ConnectedWebSocketTransport: + """An already-open WebSocket connection supporting N sequential requests.""" + + def __init__(self, ws: "websockets.ClientConnection"): + self._ws = ws + + async def request( + self, send_queue: asyncio.Queue[dict] + ) -> AsyncGenerator[dict[str, Any], None]: + """Run one request/response cycle over the open socket. + + Yields incoming messages from the server and drains ``send_queue`` over + the socket concurrently. The caller is expected to close the generator + (e.g. by breaking out of ``async for``) once a terminal message has + been consumed; the socket stays open for subsequent requests. + """ + sender_task = asyncio.create_task(self._queue_sender(send_queue)) + try: + async for message in self._ws: + try: + yield json.loads(message) + except json.JSONDecodeError: + yield {"raw": message} + except websockets.exceptions.ConnectionClosed: + return + finally: + sender_task.cancel() + try: + await sender_task + except asyncio.CancelledError: + pass + + async def close(self, send_end_message: bool = True) -> None: + """Close the connection, optionally sending end_session_message first.""" + if send_end_message: + try: + await self._ws.send( + json.dumps({"message_type": "end_session_message"}) + ) + except websockets.exceptions.ConnectionClosed: + pass + await self._ws.close() + + async def _queue_sender(self, send_queue: asyncio.Queue[dict]) -> None: + while True: + message = await send_queue.get() + try: + await self._ws.send(json.dumps(message)) + finally: + send_queue.task_done() + + class WebSocketTransport: - """Transport that communicates over a WebSocket connection.""" + """Transport that communicates over a WebSocket connection. + + Supports both one-shot use via ``__call__`` (for :meth:`RuntimeUseClient.query` + and :meth:`RuntimeUseClient.execute_commands`) and persistent sessions via + :meth:`connect` (for :meth:`RuntimeUseClient.session`). + """ def __init__(self, ws_url: str): self.ws_url = ws_url @@ -20,30 +78,24 @@ async def __call__( _logger.info("Connecting to WebSocket at %s", self.ws_url) async with websockets.connect(self.ws_url, open_timeout=60) as ws: - sender_task = asyncio.create_task(self._queue_sender(ws, send_queue)) + connected = ConnectedWebSocketTransport(ws) try: - async for message in ws: - try: - data = json.loads(message) - yield data - except json.JSONDecodeError: - yield {"raw": message} - except websockets.exceptions.ConnectionClosed as e: - e.add_note(f"Send queue is empty: {send_queue.empty()}") + async for message in connected.request(send_queue): + yield message finally: - sender_task.cancel() - try: - await sender_task - except asyncio.CancelledError: - pass _logger.info("Agent runtime connection closed") - async def _queue_sender( - self, ws: websockets.ClientConnection, send_queue: asyncio.Queue[dict] - ) -> None: - while True: - message = await send_queue.get() + @asynccontextmanager + async def connect(self) -> AsyncIterator[ConnectedWebSocketTransport]: + """Open a persistent WebSocket connection for use with a session. + + Sends ``end_session_message`` and closes the socket on exit. + """ + _logger.info("Connecting persistent WebSocket to %s", self.ws_url) + async with websockets.connect(self.ws_url, open_timeout=60) as ws: + connected = ConnectedWebSocketTransport(ws) try: - await ws.send(json.dumps(message)) + yield connected finally: - send_queue.task_done() + await connected.close(send_end_message=True) + _logger.info("Persistent agent runtime connection closed") diff --git a/packages/runtimeuse-client-python/src/runtimeuse_client/types.py b/packages/runtimeuse-client-python/src/runtimeuse_client/types.py index 55704f7..9da1696 100644 --- a/packages/runtimeuse-client-python/src/runtimeuse_client/types.py +++ b/packages/runtimeuse-client-python/src/runtimeuse_client/types.py @@ -102,6 +102,10 @@ class CancelMessage(BaseModel): message_type: Literal["cancel_message"] +class EndSessionMessage(BaseModel): + message_type: Literal["end_session_message"] + + class CommandExecutionMessage(BaseModel): message_type: Literal["command_execution_message"] source_id: str | None = None diff --git a/packages/runtimeuse-client-python/test/conftest.py b/packages/runtimeuse-client-python/test/conftest.py index 5f14606..450a7ba 100644 --- a/packages/runtimeuse-client-python/test/conftest.py +++ b/packages/runtimeuse-client-python/test/conftest.py @@ -13,32 +13,74 @@ class FakeTransport: """In-memory transport for testing. Yields pre-canned messages and captures everything written to the send queue. + Drains the send queue synchronously around each yield so tests can assert on + ``sent`` without waiting for a background drainer. """ def __init__(self, messages: list[dict] | None = None): self.messages = messages or [] self.sent: list[dict] = [] + def _drain(self, send_queue: asyncio.Queue[dict]) -> None: + while not send_queue.empty(): + item = send_queue.get_nowait() + self.sent.append(item) + send_queue.task_done() + async def __call__( self, send_queue: asyncio.Queue[dict] ) -> AsyncGenerator[dict[str, Any], None]: - async def _drain_forever() -> None: - while True: - item = await send_queue.get() - self.sent.append(item) - send_queue.task_done() - - drainer = asyncio.create_task(_drain_forever()) try: for msg in self.messages: + self._drain(send_queue) + yield msg + finally: + self._drain(send_queue) + + +class FakePersistentTransport: + """In-memory transport supporting persistent sessions. + + Each request pulls from a list of pre-canned message batches (one batch per + request). Captures everything written to the send queue across all requests. + """ + + def __init__(self, request_batches: list[list[dict]] | None = None): + self.request_batches = list(request_batches or []) + self.sent: list[dict] = [] + self.closed = False + + def _drain(self, send_queue: asyncio.Queue[dict]) -> None: + while not send_queue.empty(): + item = send_queue.get_nowait() + self.sent.append(item) + send_queue.task_done() + + async def request( + self, send_queue: asyncio.Queue[dict] + ) -> AsyncGenerator[dict[str, Any], None]: + batch = self.request_batches.pop(0) if self.request_batches else [] + try: + for msg in batch: + self._drain(send_queue) yield msg - await send_queue.join() finally: - drainer.cancel() - try: - await drainer - except asyncio.CancelledError: - pass + self._drain(send_queue) + + async def close(self) -> None: + self.closed = True + + def connect(self): + transport = self + + class _Ctx: + async def __aenter__(self): + return transport + + async def __aexit__(self, exc_type, exc, tb): + await transport.close() + + return _Ctx() DEFAULT_PROMPT = "Do something." @@ -85,3 +127,15 @@ def _make_execute_commands_options(**overrides: Any) -> ExecuteCommandsOptions: def make_execute_commands_options(): """Return the _make_execute_commands_options factory for tests.""" return _make_execute_commands_options + + +@pytest.fixture +def fake_persistent_transport(): + """Return a factory that creates a (FakePersistentTransport, RuntimeUseClient) pair.""" + + def _factory(request_batches: list[list[dict]] | None = None): + transport = FakePersistentTransport(request_batches) + client = RuntimeUseClient(transport=transport) + return transport, client + + return _factory diff --git a/packages/runtimeuse-client-python/test/test_client.py b/packages/runtimeuse-client-python/test/test_client.py index feeda37..a521f2a 100644 --- a/packages/runtimeuse-client-python/test/test_client.py +++ b/packages/runtimeuse-client-python/test/test_client.py @@ -778,3 +778,264 @@ async def _dummy_cb(req): with pytest.raises(ValueError, match="must be specified together"): ExecuteCommandsOptions(on_artifact_upload_request=_dummy_cb) + + +# --------------------------------------------------------------------------- +# Persistent session +# --------------------------------------------------------------------------- + + +TEXT_RESULT_MSG_PERSISTENT = { + "message_type": "result_message", + "data": {"type": "text", "text": "persistent hello"}, + "metadata": None, +} + + +class TestPersistentSession: + @pytest.mark.asyncio + async def test_two_sequential_queries_share_one_connection( + self, fake_persistent_transport, make_query_options + ): + transport, client = fake_persistent_transport( + [ + [ + { + "message_type": "result_message", + "data": {"type": "text", "text": "first"}, + "metadata": None, + } + ], + [ + { + "message_type": "result_message", + "data": {"type": "text", "text": "second"}, + "metadata": None, + } + ], + ] + ) + + async with client.session() as session: + first = await session.query( + prompt="one", options=make_query_options(source_id="a") + ) + second = await session.query( + prompt="two", options=make_query_options(source_id="b") + ) + + assert first.data.text == "first" + assert second.data.text == "second" + assert transport.closed is True + + invocation_msgs = [ + m for m in transport.sent if m.get("message_type") == "invocation_message" + ] + assert [m["source_id"] for m in invocation_msgs] == ["a", "b"] + + @pytest.mark.asyncio + async def test_mixed_query_and_execute_commands_in_one_session( + self, fake_persistent_transport, make_query_options, make_execute_commands_options + ): + transport, client = fake_persistent_transport( + [ + [TEXT_RESULT_MSG_PERSISTENT], + [ + { + "message_type": "command_execution_result_message", + "results": [{"command": "echo x", "exit_code": 0}], + } + ], + ] + ) + + async with client.session() as session: + query_result = await session.query( + prompt="hello", options=make_query_options() + ) + cmd_result = await session.execute_commands( + commands=[CommandInterface(command="echo x")], + options=make_execute_commands_options(), + ) + + assert query_result.data.text == "persistent hello" + assert cmd_result.results[0].command == "echo x" + + @pytest.mark.asyncio + async def test_cancel_mid_session_then_another_call( + self, fake_persistent_transport, make_query_options + ): + filler_msg = { + "message_type": "assistant_message", + "text_blocks": ["working..."], + } + + transport, client = fake_persistent_transport( + [ + [filler_msg, filler_msg], + [ + { + "message_type": "result_message", + "data": {"type": "text", "text": "after cancel"}, + "metadata": None, + } + ], + ] + ) + + async with client.session() as session: + async def abort_on_first(_msg): + session.abort() + + with pytest.raises(CancelledException): + await session.query( + prompt="first", + options=make_query_options(on_assistant_message=abort_on_first), + ) + + second = await session.query( + prompt="second", options=make_query_options() + ) + + assert second.data.text == "after cancel" + # The aborted call sent a cancel_message to the server + cancel_msgs = [ + m for m in transport.sent if m.get("message_type") == "cancel_message" + ] + assert len(cancel_msgs) == 1 + + @pytest.mark.asyncio + async def test_per_call_abort_targets_in_flight_request_only( + self, fake_persistent_transport, make_query_options + ): + filler_msg = { + "message_type": "assistant_message", + "text_blocks": ["tick"], + } + + transport, client = fake_persistent_transport( + [ + [TEXT_RESULT_MSG_PERSISTENT], + [filler_msg, filler_msg], + ] + ) + + async with client.session() as session: + first = await session.query( + prompt="one", options=make_query_options() + ) + assert first.data.text == "persistent hello" + + async def abort_on_first(_msg): + session.abort() + + with pytest.raises(CancelledException): + await session.query( + prompt="two", + options=make_query_options(on_assistant_message=abort_on_first), + ) + + cancel_msgs = [ + m for m in transport.sent if m.get("message_type") == "cancel_message" + ] + assert len(cancel_msgs) == 1 + + @pytest.mark.asyncio + async def test_session_closes_transport_on_exit( + self, fake_persistent_transport, make_query_options + ): + transport, client = fake_persistent_transport([[TEXT_RESULT_MSG_PERSISTENT]]) + + async with client.session() as session: + await session.query(prompt="hi", options=make_query_options()) + + assert transport.closed is True + + +class _StaleBufferTransport: + """Simulates a real WebSocket: a single FIFO feeds all request() calls. + + If a request() generator is closed early (e.g., the client aborts), any + messages not yet consumed stay in the buffer and are the first thing the + next request() sees. This mirrors real websockets library behaviour. + """ + + def __init__(self, messages: list[dict]): + self._queue: asyncio.Queue[dict] = asyncio.Queue() + for m in messages: + self._queue.put_nowait(m) + self.sent: list[dict] = [] + self.closed = False + + def _drain_send(self, send_queue: asyncio.Queue[dict]) -> None: + while not send_queue.empty(): + self.sent.append(send_queue.get_nowait()) + send_queue.task_done() + + async def request(self, send_queue: asyncio.Queue[dict]): + try: + while True: + self._drain_send(send_queue) + if self._queue.empty(): + return + yield self._queue.get_nowait() + finally: + self._drain_send(send_queue) + + async def close(self) -> None: + self.closed = True + + def connect(self): + transport = self + + class _Ctx: + async def __aenter__(self): + return transport + + async def __aexit__(self, exc_type, exc, tb): + await transport.close() + + return _Ctx() + + +class TestStaleTerminalDraining: + @pytest.mark.asyncio + async def test_cancelled_request_drains_server_terminal_before_next_call( + self, make_query_options + ): + # Scenario: request 1 is aborted mid-flight. The server processes the + # cancel and sends an error_message terminal. That terminal must be + # consumed by request 1's loop - NOT leak into request 2. + filler = {"message_type": "assistant_message", "text_blocks": ["tick"]} + cancel_terminal = { + "message_type": "error_message", + "error": "Request cancelled", + "metadata": {}, + } + result_after = { + "message_type": "result_message", + "data": {"type": "text", "text": "clean"}, + "metadata": None, + } + + transport = _StaleBufferTransport( + [filler, cancel_terminal, result_after] + ) + client = RuntimeUseClient(transport=transport) + + async with client.session() as session: + async def abort_on_first(_msg): + session.abort() + + with pytest.raises(CancelledException): + await session.query( + prompt="one", + options=make_query_options(on_assistant_message=abort_on_first), + ) + + # Second request must NOT see the stale cancel_terminal. + second = await session.query( + prompt="two", options=make_query_options() + ) + + assert second.data.text == "clean" diff --git a/packages/runtimeuse-client-python/uv.lock b/packages/runtimeuse-client-python/uv.lock index 5aed4d8..68cffa4 100644 --- a/packages/runtimeuse-client-python/uv.lock +++ b/packages/runtimeuse-client-python/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14'", @@ -1618,7 +1618,7 @@ wheels = [ [[package]] name = "runtimeuse-client" -version = "0.10.0" +version = "0.11.0" source = { editable = "." } dependencies = [ { name = "pydantic" }, diff --git a/packages/runtimeuse/package-lock.json b/packages/runtimeuse/package-lock.json index 0a67139..dce635e 100644 --- a/packages/runtimeuse/package-lock.json +++ b/packages/runtimeuse/package-lock.json @@ -1,12 +1,12 @@ { "name": "runtimeuse", - "version": "0.3.0", + "version": "0.11.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "runtimeuse", - "version": "0.3.0", + "version": "0.11.0", "license": "FSL", "dependencies": { "@anthropic-ai/claude-agent-sdk": "^0.2.73", @@ -3146,4 +3146,4 @@ } } } -} \ No newline at end of file +} diff --git a/packages/runtimeuse/package.json b/packages/runtimeuse/package.json index 23c54fe..5df2767 100644 --- a/packages/runtimeuse/package.json +++ b/packages/runtimeuse/package.json @@ -1,6 +1,6 @@ { "name": "runtimeuse", - "version": "0.10.0", + "version": "0.11.0", "description": "AI agent runtime with WebSocket protocol, artifact handling, and secret management", "license": "FSL", "type": "module", diff --git a/packages/runtimeuse/src/artifact-manager.test.ts b/packages/runtimeuse/src/artifact-manager.test.ts index efa64c2..950dc28 100644 --- a/packages/runtimeuse/src/artifact-manager.test.ts +++ b/packages/runtimeuse/src/artifact-manager.test.ts @@ -3,6 +3,7 @@ import { describe, it, expect, vi, beforeEach } from "vitest"; const mockWatcher = { on: vi.fn().mockReturnThis(), close: vi.fn().mockResolvedValue(undefined), + add: vi.fn(), }; vi.mock("chokidar", () => ({ @@ -35,16 +36,19 @@ import { import { UploadTracker } from "./upload-tracker.js"; import { uploadFile } from "./storage.js"; -function createManager(overrides: Partial = {}) { +function createManager( + overrides: Partial = {}, + artifactsDir: string | null = "/tmp/artifacts", +) { const send = vi.fn(); const uploadTracker = new UploadTracker(); const config: ArtifactManagerConfig = { - artifactsDir: "/tmp/artifacts", uploadTracker, send, ...overrides, }; const manager = new ArtifactManager(config); + if (artifactsDir) manager.addDirectory(artifactsDir); return { manager, send, uploadTracker }; } @@ -65,12 +69,13 @@ describe("ArtifactManager", () => { }); describe("constructor", () => { - it("creates a chokidar watcher on the artifacts directory", () => { + it("creates an empty chokidar watcher and registers dirs later", () => { createManager(); - expect(chokidar.watch).toHaveBeenCalledWith("/tmp/artifacts", { + expect(chokidar.watch).toHaveBeenCalledWith([], { awaitWriteFinish: true, alwaysStat: true, }); + expect(mockWatcher.add).toHaveBeenCalledWith("/tmp/artifacts"); }); it("registers add and change handlers", () => { @@ -79,6 +84,23 @@ describe("ArtifactManager", () => { expect(events).toContain("add"); expect(events).toContain("change"); }); + + it("supports multiple directories in one session", () => { + const { manager } = createManager({}, null); + manager.addDirectory("/tmp/run-one"); + manager.addDirectory("/tmp/run-two"); + expect(mockWatcher.add).toHaveBeenCalledWith("/tmp/run-one"); + expect(mockWatcher.add).toHaveBeenCalledWith("/tmp/run-two"); + }); + + it("ignores repeat addDirectory calls for the same path", () => { + const { manager } = createManager({}, null); + manager.addDirectory("/tmp/run-one"); + manager.addDirectory("/tmp/run-one"); + expect( + mockWatcher.add.mock.calls.filter((c) => c[0] === "/tmp/run-one"), + ).toHaveLength(1); + }); }); describe("file event handling", () => { diff --git a/packages/runtimeuse/src/artifact-manager.ts b/packages/runtimeuse/src/artifact-manager.ts index 8131481..97d8436 100644 --- a/packages/runtimeuse/src/artifact-manager.ts +++ b/packages/runtimeuse/src/artifact-manager.ts @@ -12,7 +12,6 @@ import { DEFAULT_ARTIFACT_IGNORE } from "./constants.js"; import { defaultLogger, type Logger } from "./logger.js"; export interface ArtifactManagerConfig { - artifactsDir: string; uploadTracker: UploadTracker; send: (message: ArtifactUploadRequestMessage) => void; } @@ -23,22 +22,17 @@ export class ArtifactManager { string, { promise: Promise; resolve: () => void } >(); - private readonly artifactsDir: string; + private readonly watchedDirs = new Map(); private readonly uploadTracker: UploadTracker; private readonly send: (message: ArtifactUploadRequestMessage) => void; - private ig: Ignore = ignore(); private logger: Logger = defaultLogger; private loggingLevel: "info" | "debug" = "info"; constructor(config: ArtifactManagerConfig) { - this.artifactsDir = config.artifactsDir; this.uploadTracker = config.uploadTracker; this.send = config.send; - fs.mkdirSync(config.artifactsDir, { recursive: true }); - this.reloadIgnorePatterns(); - - this.watcher = chokidar.watch(config.artifactsDir, { + this.watcher = chokidar.watch([], { awaitWriteFinish: true, alwaysStat: true, }); @@ -47,15 +41,28 @@ export class ArtifactManager { this.watcher.on("change", (p, s) => this.onFileEvent(p, s)); } - private reloadIgnorePatterns(): void { - this.ig = ignore(); - const ignorePath = path.join(this.artifactsDir, ".artifactignore"); + /** + * Begin watching an artifacts directory. Safe to call repeatedly across + * requests in the same session — repeat calls for the same directory are + * no-ops. The watcher stays alive until {@link stopWatching} is called at + * session close. + */ + addDirectory(dir: string): void { + if (this.watchedDirs.has(dir)) return; + + fs.mkdirSync(dir, { recursive: true }); + + const ig = ignore(); + const ignorePath = path.join(dir, ".artifactignore"); if (fs.existsSync(ignorePath)) { - this.ig.add(fs.readFileSync(ignorePath, "utf-8")); + ig.add(fs.readFileSync(ignorePath, "utf-8")); this.logger.log(`Loaded .artifactignore from ${ignorePath}`); } else { - this.ig.add(DEFAULT_ARTIFACT_IGNORE); + ig.add(DEFAULT_ARTIFACT_IGNORE); } + + this.watchedDirs.set(dir, ig); + this.watcher.add(dir); } setLogger(logger: Logger): void { @@ -109,6 +116,16 @@ export class ArtifactManager { await this.watcher.close(); } + private findOwningDir(filePath: string): string | null { + for (const dir of this.watchedDirs.keys()) { + const rel = path.relative(dir, filePath); + if (!rel.startsWith("..") && !path.isAbsolute(rel)) { + return dir; + } + } + return null; + } + private onFileEvent(filePath: string, stats?: fs.Stats): void { if (this.loggingLevel === "debug") { this.logger.log( @@ -117,7 +134,8 @@ export class ArtifactManager { } if (path.basename(filePath) === ".artifactignore") { - this.reloadIgnorePatterns(); + const owningDir = this.findOwningDir(filePath); + if (owningDir) this.reloadIgnorePatterns(owningDir); return; } @@ -128,8 +146,17 @@ export class ArtifactManager { return; } - const relativePath = path.relative(this.artifactsDir, filePath); - if (!relativePath.startsWith("..") && this.ig.ignores(relativePath)) { + const owningDir = this.findOwningDir(filePath); + if (!owningDir) { + if (this.loggingLevel === "debug") { + this.logger.debug(`File not in any watched dir: ${filePath}`); + } + return; + } + + const ig = this.watchedDirs.get(owningDir); + const relativePath = path.relative(owningDir, filePath); + if (ig && ig.ignores(relativePath)) { if (this.loggingLevel === "debug") { this.logger.debug(`Skipping ignored artifact: ${relativePath}`); } @@ -139,6 +166,18 @@ export class ArtifactManager { this.requestUpload(filePath); } + private reloadIgnorePatterns(dir: string): void { + const ig = ignore(); + const ignorePath = path.join(dir, ".artifactignore"); + if (fs.existsSync(ignorePath)) { + ig.add(fs.readFileSync(ignorePath, "utf-8")); + this.logger.log(`Reloaded .artifactignore from ${ignorePath}`); + } else { + ig.add(DEFAULT_ARTIFACT_IGNORE); + } + this.watchedDirs.set(dir, ig); + } + private requestUpload(filePath: string): void { const filename = path.basename(filePath); diff --git a/packages/runtimeuse/src/command-handler.ts b/packages/runtimeuse/src/command-handler.ts index 811624b..9c3874b 100644 --- a/packages/runtimeuse/src/command-handler.ts +++ b/packages/runtimeuse/src/command-handler.ts @@ -53,7 +53,9 @@ class CommandHandler { }, (error, stdout, stderr) => { if (error) { - return resolve({ exitCode: error.code as number, error }); + const code = + typeof error.code === "number" ? error.code : -1; + return resolve({ exitCode: code, error }); } }, ); diff --git a/packages/runtimeuse/src/invocation-runner.test.ts b/packages/runtimeuse/src/invocation-runner.test.ts index 109f509..d482de3 100644 --- a/packages/runtimeuse/src/invocation-runner.test.ts +++ b/packages/runtimeuse/src/invocation-runner.test.ts @@ -104,7 +104,7 @@ describe("InvocationRunner", () => { it("calls handler with parsed output format", async () => { const { runner, message, abortController, logger, send } = createRunner(); - await runner.run(message); + const result = await runner.run(message); expect(mockHandlerRun).toHaveBeenCalledWith( expect.objectContaining({ @@ -125,7 +125,7 @@ describe("InvocationRunner", () => { }), ); - expect(send).toHaveBeenCalledWith({ + expect(result).toEqual({ message_type: "result_message", metadata: { duration_ms: 12 }, data: { type: "structured_output", structured_output: { ok: true } }, @@ -167,6 +167,23 @@ describe("InvocationRunner", () => { ]); }); + it("throws when a post-agent command exits non-zero", async () => { + mockHandlerRun.mockResolvedValueOnce({ + type: "text", + text: "agent succeeded", + } as AgentResult); + mockExecute.mockResolvedValueOnce({ exitCode: 7 }); + + const { runner, message } = createRunner({ + output_format_json_schema_str: undefined, + post_agent_invocation_commands: [{ command: "cleanup", cwd: "/app" }], + }); + + await expect(runner.run(message)).rejects.toThrow( + "post-agent command failed with exit code: 7", + ); + }); + it("forwards command stdout and stderr through assistant messages", async () => { mockExecute.mockImplementation(async (options) => { options.onStdout?.("stdout data"); @@ -190,7 +207,7 @@ describe("InvocationRunner", () => { }); }); - it("sends error message and throws when command exits non-zero", async () => { + it("throws when pre-agent command exits non-zero; caller decides wire error", async () => { mockExecute.mockResolvedValueOnce({ exitCode: 2 }); const { runner, message, send, logger } = createRunner({ pre_agent_invocation_commands: [{ command: "false", cwd: "/app" }], @@ -203,11 +220,11 @@ describe("InvocationRunner", () => { expect(logger.error).toHaveBeenCalledWith( "pre-agent command failed with exit code: 2", ); - expect(send).toHaveBeenCalledWith({ - message_type: "error_message", - error: "pre-agent command failed with exit code: 2", - metadata: {}, - }); + // runner no longer emits error_message itself; the caller (session) does + // so that exactly one terminal is sent per request. + expect(send).not.toHaveBeenCalledWith( + expect.objectContaining({ message_type: "error_message" }), + ); expect(mockHandlerRun).not.toHaveBeenCalled(); }); @@ -237,33 +254,33 @@ describe("InvocationRunner", () => { type: "structured_output", structuredOutput: { ok: true }, } as AgentResult); - const { runner, message, send } = createRunner(); + const { runner, message } = createRunner(); - await runner.run(message); + const result = await runner.run(message); - expect(send).toHaveBeenCalledWith({ + expect(result).toEqual({ message_type: "result_message", metadata: {}, data: { type: "structured_output", structured_output: { ok: true } }, }); }); - it("sends text result when handler returns text", async () => { + it("returns text result when handler returns text", async () => { mockHandlerRun.mockResolvedValueOnce({ type: "text", text: "Hello, world!", metadata: { model: "test" }, } as AgentResult); - const { runner, send } = createRunner({ + const { runner } = createRunner({ output_format_json_schema_str: undefined, }); - await runner.run({ + const result = await runner.run({ ...BASE_INVOCATION_MESSAGE, output_format_json_schema_str: undefined, }); - expect(send).toHaveBeenCalledWith({ + expect(result).toEqual({ message_type: "result_message", metadata: { model: "test" }, data: { type: "text", text: "Hello, world!" }, @@ -405,12 +422,12 @@ describe("InvocationRunner.runCommandsOnly", () => { mockExecute.mockResolvedValue({ exitCode: 0 }); }); - it("sends command_execution_result_message on success", async () => { - const { runner, message, send } = createCommandRunner(); + it("returns command_execution_result_message on success", async () => { + const { runner, message } = createCommandRunner(); - await runner.runCommandsOnly(message); + const result = await runner.runCommandsOnly(message); - expect(send).toHaveBeenCalledWith({ + expect(result).toEqual({ message_type: "command_execution_result_message", results: [{ command: "echo hello", exit_code: 0 }], }); @@ -418,16 +435,16 @@ describe("InvocationRunner.runCommandsOnly", () => { }); it("collects results for multiple commands", async () => { - const { runner, message, send } = createCommandRunner({ + const { runner, message } = createCommandRunner({ commands: [ { command: "echo 1", cwd: "/app" }, { command: "echo 2", cwd: "/app" }, ], }); - await runner.runCommandsOnly(message); + const result = await runner.runCommandsOnly(message); - expect(send).toHaveBeenCalledWith({ + expect(result).toEqual({ message_type: "command_execution_result_message", results: [ { command: "echo 1", exit_code: 0 }, @@ -461,9 +478,9 @@ describe("InvocationRunner.runCommandsOnly", () => { mockExecute.mockResolvedValueOnce({ exitCode: 2 }); const { runner, message, send } = createCommandRunner(); - await runner.runCommandsOnly(message); + const result = await runner.runCommandsOnly(message); - expect(send).toHaveBeenCalledWith({ + expect(result).toEqual({ message_type: "command_execution_result_message", results: [{ command: "echo hello", exit_code: 2 }], }); @@ -476,16 +493,16 @@ describe("InvocationRunner.runCommandsOnly", () => { mockExecute .mockResolvedValueOnce({ exitCode: 1 }) .mockResolvedValueOnce({ exitCode: 0 }); - const { runner, message, send } = createCommandRunner({ + const { runner, message } = createCommandRunner({ commands: [ { command: "failing", cwd: "/app" }, { command: "skipped", cwd: "/app" }, ], }); - await runner.runCommandsOnly(message); + const result = await runner.runCommandsOnly(message); - expect(send).toHaveBeenCalledWith({ + expect(result).toEqual({ message_type: "command_execution_result_message", results: [{ command: "failing", exit_code: 1 }], }); diff --git a/packages/runtimeuse/src/invocation-runner.ts b/packages/runtimeuse/src/invocation-runner.ts index 91e32b2..a52cf79 100644 --- a/packages/runtimeuse/src/invocation-runner.ts +++ b/packages/runtimeuse/src/invocation-runner.ts @@ -3,7 +3,9 @@ import type { InvocationMessage, CommandExecutionMessage, CommandExecutionResultItem, + CommandExecutionResultMessage, OutgoingMessage, + ResultMessage, RuntimeEnvironmentDownloadable, Command, } from "./types.js"; @@ -27,8 +29,8 @@ export class InvocationRunner { this.downloadHandler = new DownloadHandler(config.logger); } - async run(message: InvocationMessage): Promise { - const { handler, logger, abortController, send } = this.config; + async run(message: InvocationMessage): Promise { + const { handler, logger, abortController } = this.config; await this.downloadRuntimeEnvironment(message.pre_agent_downloadables); await this.runCommands( @@ -59,7 +61,13 @@ export class InvocationRunner { sender, ); - const resultMessage: OutgoingMessage = { + await this.runCommands( + message.post_agent_invocation_commands, + "post-agent", + message.secrets_to_redact, + ); + + return { message_type: "result_message", metadata: agentResult.metadata ?? {}, data: @@ -70,19 +78,12 @@ export class InvocationRunner { structured_output: agentResult.structuredOutput, }, }; - - logger.log("Sending result message:", JSON.stringify(resultMessage)); - send(resultMessage); - - await this.runCommands( - message.post_agent_invocation_commands, - "post-agent", - message.secrets_to_redact, - ); } - async runCommandsOnly(message: CommandExecutionMessage): Promise { - const { logger, send } = this.config; + async runCommandsOnly( + message: CommandExecutionMessage, + ): Promise { + const { logger } = this.config; await this.downloadRuntimeEnvironment(message.pre_execution_downloadables); @@ -101,15 +102,7 @@ export class InvocationRunner { } } - const resultMessage: OutgoingMessage = { - message_type: "command_execution_result_message", - results, - }; - logger.log( - "Sending command execution result:", - JSON.stringify(resultMessage), - ); - send(resultMessage); + return { message_type: "command_execution_result_message", results }; } private async runCommandAndCollect( @@ -184,7 +177,6 @@ export class InvocationRunner { if (result.exitCode !== 0) { const errorMsg = `${phase} command failed with exit code: ${result.exitCode}`; logger.error(errorMsg); - send({ message_type: "error_message", error: errorMsg, metadata: {} }); throw new Error(errorMsg); } } diff --git a/packages/runtimeuse/src/session.test.ts b/packages/runtimeuse/src/session.test.ts index ab11f06..5c45ceb 100644 --- a/packages/runtimeuse/src/session.test.ts +++ b/packages/runtimeuse/src/session.test.ts @@ -4,6 +4,7 @@ import { WebSocket } from "ws"; const mockArtifactManager = { setLogger: vi.fn(), + addDirectory: vi.fn(), handleUploadResponse: vi.fn().mockResolvedValue(undefined), waitForPendingRequests: vi.fn().mockResolvedValue(undefined), stopWatching: vi.fn().mockResolvedValue(undefined), @@ -84,6 +85,7 @@ function createSession() { const config: SessionConfig = { handler: mockHandler, uploadTracker, + postInvocationDelayMs: 0, }; const session = new WebSocketSession(ws as any, config); return { session, ws, config, uploadTracker }; @@ -93,6 +95,33 @@ function sendMessage(ws: MockWebSocket, message: unknown) { ws.emit("message", Buffer.from(JSON.stringify(message))); } +async function waitForSentCount( + ws: MockWebSocket, + predicate: (m: any) => boolean, + count = 1, +): Promise { + for (let i = 0; i < 200; i++) { + const sent = parseSentMessages(ws); + if (sent.filter(predicate).length >= count) return; + await new Promise((r) => setTimeout(r, 5)); + } + throw new Error("Expected message never sent"); +} + +async function waitForTerminal(ws: MockWebSocket, count = 1): Promise { + const terminals = new Set([ + "result_message", + "command_execution_result_message", + "error_message", + ]); + await waitForSentCount(ws, (m) => terminals.has(m.message_type), count); +} + +async function endSession(ws: MockWebSocket, done: Promise): Promise { + sendMessage(ws, { message_type: "end_session_message" }); + await done; +} + const INVOCATION_MSG = { message_type: "invocation_message" as const, source_id: "test-source-id", @@ -110,6 +139,8 @@ const INVOCATION_MSG = { describe("WebSocketSession", () => { beforeEach(() => { vi.clearAllMocks(); + mockHandlerRun.mockReset(); + mockCommandExecute.mockReset(); mockHandlerRun.mockResolvedValue({ type: "structured_output", @@ -125,46 +156,69 @@ describe("WebSocketSession", () => { }); describe("lifecycle", () => { - it("resolves when invocation finishes", async () => { + it("sends terminal after invocation finishes", async () => { + const { ws } = createSession().session && createSession(); + // Fresh session for this test + }); + + it("resolves when end_session_message is received", async () => { + const { session, ws } = createSession(); + const done = session.run(); + sendMessage(ws, INVOCATION_MSG); + await waitForTerminal(ws); + await endSession(ws, done); + }); + + it("resolves when the socket closes", async () => { const { session, ws } = createSession(); const done = session.run(); sendMessage(ws, INVOCATION_MSG); + await waitForTerminal(ws); + ws.close(); await done; }); }); describe("message routing", () => { - it("rejects non-invocation messages before invocation", async () => { + it("ignores cancel_message when no request is in flight", async () => { const { session, ws } = createSession(); - session.run(); + const done = session.run(); sendMessage(ws, { message_type: "cancel_message" }); await tick(); - expectSentError(ws, "non-invocation message before invocation"); + const sent = parseSentMessages(ws); + expect(sent.filter((m) => m.message_type === "error_message")).toHaveLength(0); + + await endSession(ws, done); }); - it("rejects duplicate invocation messages", async () => { + it("rejects a second request while one is in flight", async () => { let resolveAgent!: () => void; mockHandlerRun.mockImplementation( () => new Promise((r) => { resolveAgent = () => - r({ type: "structured_output", structuredOutput: { success: true } } as AgentResult); + r({ + type: "structured_output", + structuredOutput: { success: true }, + } as AgentResult); }), ); const { session, ws } = createSession(); - session.run(); + const done = session.run(); sendMessage(ws, INVOCATION_MSG); await tick(); sendMessage(ws, INVOCATION_MSG); await tick(); - expectSentError(ws, "multiple invocation messages"); + expectSentError(ws, "another is in flight"); resolveAgent(); + await waitForTerminal(ws); + await endSession(ws, done); }); it("delegates artifact upload responses to ArtifactManager", async () => { @@ -173,7 +227,10 @@ describe("WebSocketSession", () => { () => new Promise((r) => { resolveAgent = () => - r({ type: "structured_output", structuredOutput: { success: true } } as AgentResult); + r({ + type: "structured_output", + structuredOutput: { success: true }, + } as AgentResult); }), ); @@ -197,16 +254,17 @@ describe("WebSocketSession", () => { ); resolveAgent(); - await done; + await waitForTerminal(ws); + await endSession(ws, done); }); - it("aborts and closes on cancel message", async () => { - let resolveAgent!: () => void; + it("aborts in-flight request on cancel message without closing session", async () => { mockHandlerRun.mockImplementation( - () => - new Promise((r) => { - resolveAgent = () => - r({ type: "structured_output", structuredOutput: {} } as AgentResult); + (_inv, _sender) => + new Promise((_resolve, reject) => { + _inv.signal.addEventListener("abort", () => { + reject(new Error("aborted")); + }); }), ); @@ -217,11 +275,23 @@ describe("WebSocketSession", () => { await tick(); sendMessage(ws, { message_type: "cancel_message" }); - await tick(); + await waitForTerminal(ws); - expect(ws.close).toHaveBeenCalled(); + // WS should still be open — session continues + expect(ws.close).not.toHaveBeenCalled(); - resolveAgent(); + await endSession(ws, done); + }); + + it("closes the websocket on end_session_message", async () => { + const { session, ws } = createSession(); + const done = session.run(); + sendMessage(ws, INVOCATION_MSG); + await waitForTerminal(ws); + sendMessage(ws, { message_type: "end_session_message" }); + await done; + + expect(ws.close).toHaveBeenCalled(); }); }); @@ -230,7 +300,8 @@ describe("WebSocketSession", () => { const { session, ws } = createSession(); const done = session.run(); sendMessage(ws, INVOCATION_MSG); - await done; + await waitForTerminal(ws); + await endSession(ws, done); expect(mockHandlerRun).toHaveBeenCalledWith( expect.objectContaining({ @@ -257,7 +328,8 @@ describe("WebSocketSession", () => { const { session, ws } = createSession(); const done = session.run(); sendMessage(ws, INVOCATION_MSG); - await done; + await waitForTerminal(ws); + await endSession(ws, done); const sent = parseSentMessages(ws); const result = sent.find((m) => m.message_type === "result_message"); @@ -272,7 +344,8 @@ describe("WebSocketSession", () => { const { session, ws } = createSession(); const done = session.run(); sendMessage(ws, INVOCATION_MSG); - await done; + await waitForTerminal(ws); + await endSession(ws, done); expectSentError(ws, "agent crashed"); }); @@ -291,7 +364,8 @@ describe("WebSocketSession", () => { const { session, ws } = createSession(); const done = session.run(); sendMessage(ws, INVOCATION_MSG); - await done; + await waitForTerminal(ws); + await endSession(ws, done); const sent = parseSentMessages(ws); const runtimeError = sent.find((m) => m.message_type === "error_message"); @@ -308,33 +382,45 @@ describe("WebSocketSession", () => { }); describe("finalization", () => { - it("stops the artifact watcher", async () => { + it("stops the artifact watcher on session close, not per request", async () => { const { session, ws } = createSession(); const done = session.run(); + sendMessage(ws, INVOCATION_MSG); - await done; + await waitForTerminal(ws); + + // Watcher must still be running between requests. + expect(mockArtifactManager.stopWatching).not.toHaveBeenCalled(); + + await endSession(ws, done); expect(mockArtifactManager.stopWatching).toHaveBeenCalled(); }); - it("waits for pending artifact requests", async () => { + it("waits for pending artifact requests on session close", async () => { const { session, ws } = createSession(); const done = session.run(); sendMessage(ws, INVOCATION_MSG); - await done; + await waitForTerminal(ws); + await endSession(ws, done); expect(mockArtifactManager.waitForPendingRequests).toHaveBeenCalledWith( 60_000, ); }); - it("closes the websocket after finalization", async () => { + it("registers each request's artifacts_dir on the shared watcher", async () => { const { session, ws } = createSession(); const done = session.run(); - sendMessage(ws, INVOCATION_MSG); - await done; - expect(ws.close).toHaveBeenCalled(); + sendMessage(ws, { ...INVOCATION_MSG, artifacts_dir: "/tmp/first" }); + await waitForTerminal(ws, 1); + sendMessage(ws, { ...INVOCATION_MSG, artifacts_dir: "/tmp/second" }); + await waitForTerminal(ws, 2); + await endSession(ws, done); + + expect(mockArtifactManager.addDirectory).toHaveBeenCalledWith("/tmp/first"); + expect(mockArtifactManager.addDirectory).toHaveBeenCalledWith("/tmp/second"); }); }); @@ -355,7 +441,8 @@ describe("WebSocketSession", () => { }, ], }); - await done; + await waitForTerminal(ws); + await endSession(ws, done); expect(mockDownloadHandler.download).toHaveBeenCalledTimes(2); expect(mockDownloadHandler.download).toHaveBeenCalledWith( @@ -372,7 +459,8 @@ describe("WebSocketSession", () => { const { session, ws } = createSession(); const done = session.run(); sendMessage(ws, INVOCATION_MSG); - await done; + await waitForTerminal(ws); + await endSession(ws, done); expect(mockDownloadHandler.download).not.toHaveBeenCalled(); }); @@ -389,7 +477,8 @@ describe("WebSocketSession", () => { const { session, ws } = createSession(); const done = session.run(); sendMessage(ws, INVOCATION_MSG); - await done; + await waitForTerminal(ws); + await endSession(ws, done); const sent = parseSentMessages(ws); const result = sent.find((m) => m.message_type === "result_message"); @@ -407,7 +496,8 @@ describe("WebSocketSession", () => { const { session, ws } = createSession(); const done = session.run(); sendMessage(ws, INVOCATION_MSG); - await done; + await waitForTerminal(ws); + await endSession(ws, done); const sent = parseSentMessages(ws); const assistant = sent.find( @@ -428,7 +518,8 @@ describe("WebSocketSession", () => { const { session, ws } = createSession(); const done = session.run(); sendMessage(ws, INVOCATION_MSG); - await done; + await waitForTerminal(ws); + await endSession(ws, done); const sent = parseSentMessages(ws); const runtimeError = sent.find((m) => m.message_type === "error_message"); @@ -448,20 +539,13 @@ describe("WebSocketSession", () => { commands: [{ command: "echo hello", cwd: "/app" }], }; - it("resolves when command execution finishes", async () => { - mockCommandExecute.mockResolvedValueOnce({ exitCode: 0 }); - const { session, ws } = createSession(); - const done = session.run(); - sendMessage(ws, COMMAND_EXEC_MSG); - await done; - }); - it("sends command_execution_result_message on success", async () => { mockCommandExecute.mockResolvedValueOnce({ exitCode: 0 }); const { session, ws } = createSession(); const done = session.run(); sendMessage(ws, COMMAND_EXEC_MSG); - await done; + await waitForTerminal(ws); + await endSession(ws, done); const sent = parseSentMessages(ws); const result = sent.find( @@ -478,7 +562,8 @@ describe("WebSocketSession", () => { const { session, ws } = createSession(); const done = session.run(); sendMessage(ws, COMMAND_EXEC_MSG); - await done; + await waitForTerminal(ws); + await endSession(ws, done); expect(mockHandlerRun).not.toHaveBeenCalled(); }); @@ -488,7 +573,8 @@ describe("WebSocketSession", () => { const { session, ws } = createSession(); const done = session.run(); sendMessage(ws, COMMAND_EXEC_MSG); - await done; + await waitForTerminal(ws); + await endSession(ws, done); const sent = parseSentMessages(ws); const result = sent.find( @@ -502,28 +588,6 @@ describe("WebSocketSession", () => { expect(errors).toHaveLength(0); }); - it("rejects duplicate command execution messages", async () => { - let resolveCmd!: () => void; - mockCommandExecute.mockImplementation( - () => - new Promise((r) => { - resolveCmd = () => r({ exitCode: 0 }); - }), - ); - - const { session, ws } = createSession(); - session.run(); - - sendMessage(ws, COMMAND_EXEC_MSG); - await tick(); - sendMessage(ws, COMMAND_EXEC_MSG); - await tick(); - - expectSentError(ws, "multiple invocation messages"); - - resolveCmd(); - }); - it("redacts secrets from command output", async () => { mockCommandExecute.mockImplementation(async function (this: any) { return { exitCode: 0 }; @@ -531,13 +595,45 @@ describe("WebSocketSession", () => { const { session, ws } = createSession(); const done = session.run(); sendMessage(ws, COMMAND_EXEC_MSG); - await done; + await waitForTerminal(ws); + await endSession(ws, done); const sent = parseSentMessages(ws); for (const msg of sent) { expect(JSON.stringify(msg)).not.toContain("secret123"); } }); + + it("handles two sequential command_execution_messages on one socket", async () => { + mockCommandExecute + .mockResolvedValueOnce({ exitCode: 0 }) + .mockResolvedValueOnce({ exitCode: 0 }); + + const { session, ws } = createSession(); + const done = session.run(); + + sendMessage(ws, { + ...COMMAND_EXEC_MSG, + commands: [{ command: "echo first" }], + }); + await waitForTerminal(ws, 1); + + sendMessage(ws, { + ...COMMAND_EXEC_MSG, + commands: [{ command: "echo second" }], + }); + await waitForTerminal(ws, 2); + + const sent = parseSentMessages(ws); + const results = sent.filter( + (m) => m.message_type === "command_execution_result_message", + ); + expect(results).toHaveLength(2); + expect(results[0].results[0].command).toBe("echo first"); + expect(results[1].results[0].command).toBe("echo second"); + + await endSession(ws, done); + }); }); describe("pre-agent invocation commands", () => { @@ -550,7 +646,8 @@ describe("WebSocketSession", () => { ...INVOCATION_MSG, pre_agent_invocation_commands: [{ command: "npm test", cwd: "/app" }], }); - await done; + await waitForTerminal(ws); + await endSession(ws, done); expect(mockHandlerRun).toHaveBeenCalled(); }); @@ -567,7 +664,8 @@ describe("WebSocketSession", () => { ...INVOCATION_MSG, pre_agent_invocation_commands: [{ command: "npm test", cwd: "/app" }], }); - await done; + await waitForTerminal(ws); + await endSession(ws, done); const sent = parseSentMessages(ws); const error = sent.find((m) => m.message_type === "error_message"); @@ -584,7 +682,8 @@ describe("WebSocketSession", () => { ...INVOCATION_MSG, pre_agent_invocation_commands: [{ command: "bad-cmd" }], }); - await done; + await waitForTerminal(ws); + await endSession(ws, done); const sent = parseSentMessages(ws); const errors = sent.filter((m) => m.message_type === "error_message"); @@ -599,7 +698,8 @@ describe("WebSocketSession", () => { const { session, ws } = createSession(); const done = session.run(); sendMessage(ws, INVOCATION_MSG); - await done; + await waitForTerminal(ws); + await endSession(ws, done); expect(CommandHandler).not.toHaveBeenCalled(); expect(mockHandlerRun).toHaveBeenCalled(); diff --git a/packages/runtimeuse/src/session.ts b/packages/runtimeuse/src/session.ts index b81204c..d0e0fa2 100644 --- a/packages/runtimeuse/src/session.ts +++ b/packages/runtimeuse/src/session.ts @@ -3,10 +3,20 @@ import { WebSocket } from "ws"; import type { AgentHandler } from "./agent-handler.js"; import { ArtifactManager } from "./artifact-manager.js"; import type { UploadTracker } from "./upload-tracker.js"; -import type { InvocationMessage, CommandExecutionMessage, IncomingMessage, OutgoingMessage } from "./types.js"; +import type { + InvocationMessage, + CommandExecutionMessage, + IncomingMessage, + OutgoingMessage, +} from "./types.js"; import { getErrorMessage, serializeErrorMetadata } from "./error-utils.js"; import { redactSecrets, sleep } from "./utils.js"; -import { createLogger, createRedactingLogger, defaultLogger, type Logger } from "./logger.js"; +import { + createLogger, + createRedactingLogger, + defaultLogger, + type Logger, +} from "./logger.js"; import { InvocationRunner } from "./invocation-runner.js"; export interface SessionConfig { @@ -21,11 +31,9 @@ export interface SessionConfig { export class WebSocketSession { private readonly ws: WebSocket; private readonly config: SessionConfig; - private readonly abortController = new AbortController(); + private currentAbortController: AbortController | null = null; private artifactManager: ArtifactManager | null = null; - private invocationReceived = false; - private finalized = false; - private cancelled = false; + private requestInFlight = false; private secrets: string[] = []; private logger: Logger; @@ -41,7 +49,7 @@ export class WebSocketSession { this.logger.log("Received new WS message"); try { const message: IncomingMessage = JSON.parse(rawData.toString()); - await this.handleMessage(message, resolve); + await this.handleMessage(message); } catch (error) { this.logger.error("Error processing message:", error); this.send({ @@ -53,14 +61,26 @@ export class WebSocketSession { }); this.ws.on("close", async (code, reason) => { - if (!this.finalized && !this.cancelled) { + if (this.requestInFlight) { this.logger.warn( - `WebSocket closed unexpectedly (code=${code}, reason=${reason?.toString() ?? ""}). Artifacts may not have been fully uploaded.`, + `WebSocket closed unexpectedly mid-request (code=${code}, reason=${reason?.toString() ?? ""}). Artifacts may not have been fully uploaded.`, ); } this.logger.log("WebSocket connection closed"); - this.abortController.abort(); + this.currentAbortController?.abort(); + + // Give chokidar time to observe files the agent wrote right before + // the session ended. Without this, late `add` events would not fire + // before we stop the watcher, and those artifacts would be lost. + const delayMs = this.config.postInvocationDelayMs ?? 3_000; + if (this.artifactManager && delayMs > 0) { + this.logger.log(`Waiting ${delayMs}ms for artifact watcher to drain...`); + await sleep(delayMs); + } await this.artifactManager?.stopWatching(); + await this.artifactManager?.waitForPendingRequests( + this.config.artifactWaitMs ?? 60_000, + ); await this.config.uploadTracker.waitForAll( this.config.uploadTimeoutMs ?? 30_000, ); @@ -73,22 +93,20 @@ export class WebSocketSession { }); } - private async handleMessage( - message: IncomingMessage, - resolve: () => void, - ): Promise { - if ( - !this.invocationReceived && - message.message_type !== "invocation_message" && - message.message_type !== "command_execution_message" - ) { - throw new Error( - "Received non-invocation message before invocation message! Received: " + - JSON.stringify(message), - ); - } - + private async handleMessage(message: IncomingMessage): Promise { switch (message.message_type) { + case "end_session_message": + this.logger.log("Received end_session_message. Closing session."); + this.ws.close(); + return; + + case "cancel_message": + this.logger.log( + "Received cancel message. Aborting in-flight request...", + ); + this.currentAbortController?.abort(); + return; + case "artifact_upload_response_message": try { await this.artifactManager?.handleUploadResponse(message); @@ -100,140 +118,111 @@ export class WebSocketSession { metadata: serializeErrorMetadata(error), }); } - break; - - case "cancel_message": - this.logger.log("Received cancel message. Aborting agent execution..."); - this.cancelled = true; - this.abortController.abort(); - this.ws.close(); - break; + return; case "invocation_message": - if (this.invocationReceived) { - throw new Error("Received multiple invocation messages!"); + if (this.requestInFlight) { + throw new Error("Received request while another is in flight!"); } - this.invocationReceived = true; - await this.executeInvocation(message); - const hasArtifacts = this.artifactManager !== null; - if (process.env.NODE_ENV !== "test" || hasArtifacts) { - this.logger.log("Waiting for post-invocation delay..."); - await sleep(this.config.postInvocationDelayMs ?? 3_000); - } - await this.finalize(); - resolve(); - break; + await this.handleRequest((runner) => runner.run(message), message); + return; case "command_execution_message": - if (this.invocationReceived) { - throw new Error("Received multiple invocation messages!"); - } - this.invocationReceived = true; - await this.executeCommandsOnly(message); - const hasCommandArtifacts = this.artifactManager !== null; - if (process.env.NODE_ENV !== "test" || hasCommandArtifacts) { - this.logger.log("Waiting for post-invocation delay..."); - await sleep(this.config.postInvocationDelayMs ?? 3_000); + if (this.requestInFlight) { + throw new Error("Received request while another is in flight!"); } - await this.finalize(); - resolve(); - break; - } - } - - private async executeInvocation(message: InvocationMessage): Promise { - const sourceId = message.source_id ?? crypto.randomUUID(); - this.secrets = message.secrets_to_redact ?? []; - this.logger = createRedactingLogger(createLogger(sourceId), this.secrets); - this.config.uploadTracker.setLogger(this.logger); - this.logger.log(`Received invocation: model=${message.model}`); - - this.initArtifactManager(message.artifacts_dir); - - const runner = new InvocationRunner({ - handler: this.config.handler, - logger: this.logger, - abortController: this.abortController, - send: (msg) => this.send(msg), - }); - - try { - await runner.run(message); - } catch (error) { - if (this.abortController.signal.aborted) { - this.ws.close(); - this.logger.log("Agent execution aborted."); + await this.handleRequest( + (runner) => runner.runCommandsOnly(message), + message, + ); return; - } - this.logger.error("Error in agent execution:", error); - this.send({ - message_type: "error_message", - error: getErrorMessage(error), - metadata: serializeErrorMetadata(error), - }); } } - private async executeCommandsOnly(message: CommandExecutionMessage): Promise { + private async handleRequest( + runFn: (runner: InvocationRunner) => Promise, + message: InvocationMessage | CommandExecutionMessage, + ): Promise { + const abortController = new AbortController(); + this.currentAbortController = abortController; + const sourceId = message.source_id ?? crypto.randomUUID(); this.secrets = message.secrets_to_redact ?? []; this.logger = createRedactingLogger(createLogger(sourceId), this.secrets); this.config.uploadTracker.setLogger(this.logger); - this.logger.log("Received command execution request"); + this.logger.log("Handling new request"); - this.initArtifactManager(message.artifacts_dir); + if (message.artifacts_dir) { + this.ensureArtifactManager(); + this.artifactManager!.addDirectory(message.artifacts_dir); + } const runner = new InvocationRunner({ handler: this.config.handler, logger: this.logger, - abortController: this.abortController, + abortController, send: (msg) => this.send(msg), }); + let terminal: OutgoingMessage; try { - await runner.runCommandsOnly(message); - } catch (error) { - if (this.abortController.signal.aborted) { - this.ws.close(); - this.logger.log("Command execution aborted."); - return; + this.requestInFlight = true; + try { + terminal = await runFn(runner); + if (abortController.signal.aborted) { + // Runner may return a partial result after abort (e.g., a command + // exits with a non-numeric code). Replace with a cancel terminal + // so the client sees exactly one error_message for the cancelled + // request, consistent with the throw path below. + this.logger.log("Request aborted (runner returned); emitting cancel terminal."); + terminal = { + message_type: "error_message", + error: "Request cancelled", + metadata: {}, + }; + } + } catch (error) { + if (abortController.signal.aborted) { + this.logger.log("Request aborted."); + terminal = { + message_type: "error_message", + error: "Request cancelled", + metadata: {}, + }; + } else { + this.logger.error("Error in request execution:", error); + terminal = { + message_type: "error_message", + error: getErrorMessage(error), + metadata: serializeErrorMetadata(error), + }; + } } - this.logger.error("Error in command execution:", error); - this.send({ - message_type: "error_message", - error: getErrorMessage(error), - metadata: serializeErrorMetadata(error), - }); + + // The artifact watcher stays alive for the whole session, so we no + // longer block each request on a 3s drain. Artifacts that finish + // writing after the terminal will still fire chokidar events and be + // uploaded; on session close we do a single drain for any that were + // written right before the ws closed. + this.send(terminal); + } finally { + this.currentAbortController = null; + this.requestInFlight = false; } } - private initArtifactManager(artifactsDir?: string): void { - if (!artifactsDir) return; + private ensureArtifactManager(): void { + if (this.artifactManager) { + this.artifactManager.setLogger(this.logger); + return; + } this.artifactManager = new ArtifactManager({ - artifactsDir, uploadTracker: this.config.uploadTracker, send: (msg) => this.send(msg), }); this.artifactManager.setLogger(this.logger); } - private async finalize(): Promise { - await this.artifactManager?.stopWatching(); - - if (!this.abortController.signal.aborted) { - await this.artifactManager?.waitForPendingRequests( - this.config.artifactWaitMs ?? 60_000, - ); - } - - await this.config.uploadTracker.waitForAll( - this.config.uploadTimeoutMs ?? 30_000, - ); - this.logger.log("All artifacts uploaded."); - this.finalized = true; - this.ws.close(); - } - private send(data: OutgoingMessage): void { if (this.ws.readyState === WebSocket.OPEN) { this.ws.send(JSON.stringify(redactSecrets(data, this.secrets))); diff --git a/packages/runtimeuse/src/types.ts b/packages/runtimeuse/src/types.ts index 8fe81b7..0f90c44 100644 --- a/packages/runtimeuse/src/types.ts +++ b/packages/runtimeuse/src/types.ts @@ -47,6 +47,10 @@ interface CancelMessage { message_type: "cancel_message"; } +interface EndSessionMessage { + message_type: "end_session_message"; +} + interface ResultMessage { message_type: "result_message"; metadata?: Record; @@ -92,7 +96,8 @@ type IncomingMessage = | InvocationMessage | CommandExecutionMessage | ArtifactUploadResponseMessage - | CancelMessage; + | CancelMessage + | EndSessionMessage; export type { IncomingMessage, @@ -102,6 +107,7 @@ export type { CommandExecutionResultMessage, CommandExecutionResultItem, CancelMessage, + EndSessionMessage, ResultMessage, AssistantMessage, ArtifactUploadRequestMessage,