diff --git a/packages/runtimeuse-client-python/src/runtimeuse_client/client.py b/packages/runtimeuse-client-python/src/runtimeuse_client/client.py index 032129e..5e6e9e3 100644 --- a/packages/runtimeuse-client-python/src/runtimeuse_client/client.py +++ b/packages/runtimeuse-client-python/src/runtimeuse_client/client.py @@ -52,6 +52,7 @@ def __init__( raise ValueError("Either ws_url or transport must be provided") self._abort_event = asyncio.Event() + self._send_queue: asyncio.Queue[dict] | None = None def abort(self) -> None: """Signal the current query to cancel. @@ -61,6 +62,13 @@ def abort(self) -> None: coroutine on the same event loop. """ self._abort_event.set() + send_queue = self._send_queue + if send_queue is not None: + send_queue.put_nowait( + CancelMessage(message_type="cancel_message").model_dump( + mode="json" + ) + ) async def query( self, @@ -107,98 +115,98 @@ async def query( ) send_queue: asyncio.Queue[dict] = asyncio.Queue() + self._send_queue = send_queue await send_queue.put(invocation.model_dump(mode="json")) wire_result: ResultMessageInterface | None = None - async with asyncio.timeout(options.timeout): - async for message in self._transport(send_queue=send_queue): - if self._abort_event.is_set(): - logger.info("Query cancelled by caller") - await send_queue.put( - CancelMessage(message_type="cancel_message").model_dump( - mode="json" - ) - ) - await send_queue.join() - raise CancelledException("Query was cancelled") - - try: - message_interface = AgentRuntimeMessageInterface.model_validate( - message - ) - except pydantic.ValidationError: - logger.error( - f"Received unknown message type from agent runtime: {message}" - ) - continue - - if message_interface.message_type == "result_message": - wire_result = ResultMessageInterface.model_validate(message) - logger.info( - f"Received result message from agent runtime: {message}" - ) - continue - - elif message_interface.message_type == "assistant_message": - if options.on_assistant_message is not None: - assistant_message_interface = ( - AssistantMessageInterface.model_validate(message) - ) - await options.on_assistant_message(assistant_message_interface) - continue + try: + async with asyncio.timeout(options.timeout): + async for message in self._transport(send_queue=send_queue): + if self._abort_event.is_set(): + raise CancelledException("Query was cancelled") - elif message_interface.message_type == "error_message": try: - error_message_interface = ErrorMessageInterface.model_validate( + message_interface = AgentRuntimeMessageInterface.model_validate( message ) except pydantic.ValidationError: logger.error( - f"Received malformed error message from agent runtime: {message}", + f"Received unknown message type from agent runtime: {message}" ) - raise AgentRuntimeError(str(message)) - logger.error( - f"Error from agent runtime: {error_message_interface}", - ) - raise AgentRuntimeError( - error_message_interface.error, - metadata=error_message_interface.metadata, - ) - - elif ( - message_interface.message_type == "artifact_upload_request_message" - ): - logger.info( - f"Received artifact upload request message from agent runtime: {message}" - ) - if options.on_artifact_upload_request is not None: - artifact_upload_request_message_interface = ( - ArtifactUploadRequestMessageInterface.model_validate( + continue + + if message_interface.message_type == "result_message": + wire_result = ResultMessageInterface.model_validate(message) + logger.info( + f"Received result message from agent runtime: {message}" + ) + continue + + elif message_interface.message_type == "assistant_message": + if options.on_assistant_message is not None: + assistant_message_interface = ( + AssistantMessageInterface.model_validate(message) + ) + await options.on_assistant_message(assistant_message_interface) + continue + + elif message_interface.message_type == "error_message": + try: + error_message_interface = ErrorMessageInterface.model_validate( message ) + except pydantic.ValidationError: + logger.error( + f"Received malformed error message from agent runtime: {message}", + ) + raise AgentRuntimeError(str(message)) + logger.error( + f"Error from agent runtime: {error_message_interface}", ) - upload_result = await options.on_artifact_upload_request( - artifact_upload_request_message_interface + raise AgentRuntimeError( + error_message_interface.error, + metadata=error_message_interface.metadata, ) - artifact_upload_response_message_interface = ArtifactUploadResponseMessageInterface( - message_type="artifact_upload_response_message", - filename=artifact_upload_request_message_interface.filename, - filepath=artifact_upload_request_message_interface.filepath, - presigned_url=upload_result.presigned_url, - content_type=upload_result.content_type, + + elif ( + message_interface.message_type == "artifact_upload_request_message" + ): + logger.info( + f"Received artifact upload request message from agent runtime: {message}" ) - await send_queue.put( - artifact_upload_response_message_interface.model_dump( - mode="json" + if options.on_artifact_upload_request is not None: + artifact_upload_request_message_interface = ( + ArtifactUploadRequestMessageInterface.model_validate( + message + ) + ) + upload_result = await options.on_artifact_upload_request( + artifact_upload_request_message_interface + ) + artifact_upload_response_message_interface = ArtifactUploadResponseMessageInterface( + message_type="artifact_upload_response_message", + filename=artifact_upload_request_message_interface.filename, + filepath=artifact_upload_request_message_interface.filepath, + presigned_url=upload_result.presigned_url, + content_type=upload_result.content_type, + ) + await send_queue.put( + artifact_upload_response_message_interface.model_dump( + mode="json" + ) ) + continue + + else: + logger.info( + f"Received non-result message from agent runtime: {message}" ) - continue + finally: + self._send_queue = None - else: - logger.info( - f"Received non-result message from agent runtime: {message}" - ) + if self._abort_event.is_set(): + raise CancelledException("Query was cancelled") if wire_result is None: raise AgentRuntimeError("No result message received") @@ -240,96 +248,96 @@ async def execute_commands( ) send_queue: asyncio.Queue[dict] = asyncio.Queue() + self._send_queue = send_queue await send_queue.put(message.model_dump(mode="json")) wire_result: CommandExecutionResultMessageInterface | None = None - async with asyncio.timeout(options.timeout): - async for msg in self._transport(send_queue=send_queue): - if self._abort_event.is_set(): - logger.info("Command execution cancelled by caller") - await send_queue.put( - CancelMessage(message_type="cancel_message").model_dump( - mode="json" - ) - ) - await send_queue.join() - raise CancelledException("Command execution was cancelled") - - try: - message_interface = AgentRuntimeMessageInterface.model_validate(msg) - except pydantic.ValidationError: - logger.error( - f"Received unknown message type from agent runtime: {msg}" - ) - continue - - if message_interface.message_type == "command_execution_result_message": - wire_result = CommandExecutionResultMessageInterface.model_validate( - msg - ) - logger.info( - f"Received command execution result from agent runtime: {msg}" - ) - continue - - elif message_interface.message_type == "assistant_message": - if options.on_assistant_message is not None: - assistant_message_interface = ( - AssistantMessageInterface.model_validate(msg) - ) - await options.on_assistant_message(assistant_message_interface) - continue + try: + async with asyncio.timeout(options.timeout): + async for msg in self._transport(send_queue=send_queue): + if self._abort_event.is_set(): + raise CancelledException("Command execution was cancelled") - elif message_interface.message_type == "error_message": try: - error_message_interface = ErrorMessageInterface.model_validate( - msg - ) + message_interface = AgentRuntimeMessageInterface.model_validate(msg) except pydantic.ValidationError: logger.error( - f"Received malformed error message from agent runtime: {msg}", + f"Received unknown message type from agent runtime: {msg}" + ) + continue + + if message_interface.message_type == "command_execution_result_message": + wire_result = CommandExecutionResultMessageInterface.model_validate( + msg ) - raise AgentRuntimeError(str(msg)) - logger.error( - f"Error from agent runtime: {error_message_interface}", - ) - raise AgentRuntimeError( - error_message_interface.error, - metadata=error_message_interface.metadata, - ) - - elif ( - message_interface.message_type == "artifact_upload_request_message" - ): - logger.info( - f"Received artifact upload request message from agent runtime: {msg}" - ) - if options.on_artifact_upload_request is not None: - artifact_upload_request_message_interface = ( - ArtifactUploadRequestMessageInterface.model_validate(msg) + logger.info( + f"Received command execution result from agent runtime: {msg}" + ) + continue + + elif message_interface.message_type == "assistant_message": + if options.on_assistant_message is not None: + assistant_message_interface = ( + AssistantMessageInterface.model_validate(msg) + ) + await options.on_assistant_message(assistant_message_interface) + continue + + elif message_interface.message_type == "error_message": + try: + error_message_interface = ErrorMessageInterface.model_validate( + msg + ) + except pydantic.ValidationError: + logger.error( + f"Received malformed error message from agent runtime: {msg}", + ) + raise AgentRuntimeError(str(msg)) + logger.error( + f"Error from agent runtime: {error_message_interface}", ) - upload_result = await options.on_artifact_upload_request( - artifact_upload_request_message_interface + raise AgentRuntimeError( + error_message_interface.error, + metadata=error_message_interface.metadata, ) - artifact_upload_response_message_interface = ArtifactUploadResponseMessageInterface( - message_type="artifact_upload_response_message", - filename=artifact_upload_request_message_interface.filename, - filepath=artifact_upload_request_message_interface.filepath, - presigned_url=upload_result.presigned_url, - content_type=upload_result.content_type, + + elif ( + message_interface.message_type == "artifact_upload_request_message" + ): + logger.info( + f"Received artifact upload request message from agent runtime: {msg}" ) - await send_queue.put( - artifact_upload_response_message_interface.model_dump( - mode="json" + if options.on_artifact_upload_request is not None: + artifact_upload_request_message_interface = ( + ArtifactUploadRequestMessageInterface.model_validate(msg) ) + upload_result = await options.on_artifact_upload_request( + artifact_upload_request_message_interface + ) + artifact_upload_response_message_interface = ArtifactUploadResponseMessageInterface( + message_type="artifact_upload_response_message", + filename=artifact_upload_request_message_interface.filename, + filepath=artifact_upload_request_message_interface.filepath, + presigned_url=upload_result.presigned_url, + content_type=upload_result.content_type, + ) + await send_queue.put( + artifact_upload_response_message_interface.model_dump( + mode="json" + ) + ) + continue + + else: + logger.info( + f"Received non-result message from agent runtime: {msg}" ) - continue + finally: + self._send_queue = None - else: - logger.info( - f"Received non-result message from agent runtime: {msg}" - ) + if self._abort_event.is_set(): + raise CancelledException("Command execution was cancelled") if wire_result is None: raise AgentRuntimeError("No result message received")