Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 156 additions & 148 deletions packages/runtimeuse-client-python/src/runtimeuse_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
raise ValueError("Either ws_url or transport must be provided")

self._abort_event = asyncio.Event()
self._send_queue: asyncio.Queue[dict] | None = None

def abort(self) -> None:
"""Signal the current query to cancel.
Expand All @@ -61,6 +62,13 @@ def abort(self) -> None:
coroutine on the same event loop.
"""
self._abort_event.set()
send_queue = self._send_queue
if send_queue is not None:
send_queue.put_nowait(
CancelMessage(message_type="cancel_message").model_dump(
mode="json"
)
)

async def query(
self,
Expand Down Expand Up @@ -107,98 +115,98 @@ async def query(
)

send_queue: asyncio.Queue[dict] = asyncio.Queue()
self._send_queue = send_queue
await send_queue.put(invocation.model_dump(mode="json"))

wire_result: ResultMessageInterface | None = None

async with asyncio.timeout(options.timeout):
async for message in self._transport(send_queue=send_queue):
if self._abort_event.is_set():
logger.info("Query cancelled by caller")
await send_queue.put(
CancelMessage(message_type="cancel_message").model_dump(
mode="json"
)
)
await send_queue.join()
raise CancelledException("Query was cancelled")

try:
message_interface = AgentRuntimeMessageInterface.model_validate(
message
)
except pydantic.ValidationError:
logger.error(
f"Received unknown message type from agent runtime: {message}"
)
continue

if message_interface.message_type == "result_message":
wire_result = ResultMessageInterface.model_validate(message)
logger.info(
f"Received result message from agent runtime: {message}"
)
continue

elif message_interface.message_type == "assistant_message":
if options.on_assistant_message is not None:
assistant_message_interface = (
AssistantMessageInterface.model_validate(message)
)
await options.on_assistant_message(assistant_message_interface)
continue
try:
async with asyncio.timeout(options.timeout):
async for message in self._transport(send_queue=send_queue):
if self._abort_event.is_set():
raise CancelledException("Query was cancelled")

elif message_interface.message_type == "error_message":
try:
error_message_interface = ErrorMessageInterface.model_validate(
message_interface = AgentRuntimeMessageInterface.model_validate(
message
)
except pydantic.ValidationError:
logger.error(
f"Received malformed error message from agent runtime: {message}",
f"Received unknown message type from agent runtime: {message}"
)
raise AgentRuntimeError(str(message))
logger.error(
f"Error from agent runtime: {error_message_interface}",
)
raise AgentRuntimeError(
error_message_interface.error,
metadata=error_message_interface.metadata,
)

elif (
message_interface.message_type == "artifact_upload_request_message"
):
logger.info(
f"Received artifact upload request message from agent runtime: {message}"
)
if options.on_artifact_upload_request is not None:
artifact_upload_request_message_interface = (
ArtifactUploadRequestMessageInterface.model_validate(
continue

if message_interface.message_type == "result_message":
wire_result = ResultMessageInterface.model_validate(message)
logger.info(
f"Received result message from agent runtime: {message}"
)
continue

elif message_interface.message_type == "assistant_message":
if options.on_assistant_message is not None:
assistant_message_interface = (
AssistantMessageInterface.model_validate(message)
)
await options.on_assistant_message(assistant_message_interface)
continue

elif message_interface.message_type == "error_message":
try:
error_message_interface = ErrorMessageInterface.model_validate(
message
)
except pydantic.ValidationError:
logger.error(
f"Received malformed error message from agent runtime: {message}",
)
raise AgentRuntimeError(str(message))
logger.error(
f"Error from agent runtime: {error_message_interface}",
)
upload_result = await options.on_artifact_upload_request(
artifact_upload_request_message_interface
raise AgentRuntimeError(
error_message_interface.error,
metadata=error_message_interface.metadata,
)
artifact_upload_response_message_interface = ArtifactUploadResponseMessageInterface(
message_type="artifact_upload_response_message",
filename=artifact_upload_request_message_interface.filename,
filepath=artifact_upload_request_message_interface.filepath,
presigned_url=upload_result.presigned_url,
content_type=upload_result.content_type,

elif (
message_interface.message_type == "artifact_upload_request_message"
):
logger.info(
f"Received artifact upload request message from agent runtime: {message}"
)
await send_queue.put(
artifact_upload_response_message_interface.model_dump(
mode="json"
if options.on_artifact_upload_request is not None:
artifact_upload_request_message_interface = (
ArtifactUploadRequestMessageInterface.model_validate(
message
)
)
upload_result = await options.on_artifact_upload_request(
artifact_upload_request_message_interface
)
artifact_upload_response_message_interface = ArtifactUploadResponseMessageInterface(
message_type="artifact_upload_response_message",
filename=artifact_upload_request_message_interface.filename,
filepath=artifact_upload_request_message_interface.filepath,
presigned_url=upload_result.presigned_url,
content_type=upload_result.content_type,
)
await send_queue.put(
artifact_upload_response_message_interface.model_dump(
mode="json"
)
)
continue

else:
logger.info(
f"Received non-result message from agent runtime: {message}"
)
continue
finally:
self._send_queue = None

else:
logger.info(
f"Received non-result message from agent runtime: {message}"
)
if self._abort_event.is_set():
raise CancelledException("Query was cancelled")
Comment thread
cursor[bot] marked this conversation as resolved.

if wire_result is None:
raise AgentRuntimeError("No result message received")
Expand Down Expand Up @@ -240,96 +248,96 @@ async def execute_commands(
)

send_queue: asyncio.Queue[dict] = asyncio.Queue()
self._send_queue = send_queue
await send_queue.put(message.model_dump(mode="json"))

wire_result: CommandExecutionResultMessageInterface | None = None

async with asyncio.timeout(options.timeout):
async for msg in self._transport(send_queue=send_queue):
if self._abort_event.is_set():
logger.info("Command execution cancelled by caller")
await send_queue.put(
CancelMessage(message_type="cancel_message").model_dump(
mode="json"
)
)
await send_queue.join()
raise CancelledException("Command execution was cancelled")

try:
message_interface = AgentRuntimeMessageInterface.model_validate(msg)
except pydantic.ValidationError:
logger.error(
f"Received unknown message type from agent runtime: {msg}"
)
continue

if message_interface.message_type == "command_execution_result_message":
wire_result = CommandExecutionResultMessageInterface.model_validate(
msg
)
logger.info(
f"Received command execution result from agent runtime: {msg}"
)
continue

elif message_interface.message_type == "assistant_message":
if options.on_assistant_message is not None:
assistant_message_interface = (
AssistantMessageInterface.model_validate(msg)
)
await options.on_assistant_message(assistant_message_interface)
continue
try:
async with asyncio.timeout(options.timeout):
async for msg in self._transport(send_queue=send_queue):
if self._abort_event.is_set():
raise CancelledException("Command execution was cancelled")

elif message_interface.message_type == "error_message":
try:
error_message_interface = ErrorMessageInterface.model_validate(
msg
)
message_interface = AgentRuntimeMessageInterface.model_validate(msg)
except pydantic.ValidationError:
logger.error(
f"Received malformed error message from agent runtime: {msg}",
f"Received unknown message type from agent runtime: {msg}"
)
continue

if message_interface.message_type == "command_execution_result_message":
wire_result = CommandExecutionResultMessageInterface.model_validate(
msg
)
raise AgentRuntimeError(str(msg))
logger.error(
f"Error from agent runtime: {error_message_interface}",
)
raise AgentRuntimeError(
error_message_interface.error,
metadata=error_message_interface.metadata,
)

elif (
message_interface.message_type == "artifact_upload_request_message"
):
logger.info(
f"Received artifact upload request message from agent runtime: {msg}"
)
if options.on_artifact_upload_request is not None:
artifact_upload_request_message_interface = (
ArtifactUploadRequestMessageInterface.model_validate(msg)
logger.info(
f"Received command execution result from agent runtime: {msg}"
)
continue

elif message_interface.message_type == "assistant_message":
if options.on_assistant_message is not None:
assistant_message_interface = (
AssistantMessageInterface.model_validate(msg)
)
await options.on_assistant_message(assistant_message_interface)
continue

elif message_interface.message_type == "error_message":
try:
error_message_interface = ErrorMessageInterface.model_validate(
msg
)
except pydantic.ValidationError:
logger.error(
f"Received malformed error message from agent runtime: {msg}",
)
raise AgentRuntimeError(str(msg))
logger.error(
f"Error from agent runtime: {error_message_interface}",
)
upload_result = await options.on_artifact_upload_request(
artifact_upload_request_message_interface
raise AgentRuntimeError(
error_message_interface.error,
metadata=error_message_interface.metadata,
)
artifact_upload_response_message_interface = ArtifactUploadResponseMessageInterface(
message_type="artifact_upload_response_message",
filename=artifact_upload_request_message_interface.filename,
filepath=artifact_upload_request_message_interface.filepath,
presigned_url=upload_result.presigned_url,
content_type=upload_result.content_type,

elif (
message_interface.message_type == "artifact_upload_request_message"
):
logger.info(
f"Received artifact upload request message from agent runtime: {msg}"
)
await send_queue.put(
artifact_upload_response_message_interface.model_dump(
mode="json"
if options.on_artifact_upload_request is not None:
artifact_upload_request_message_interface = (
ArtifactUploadRequestMessageInterface.model_validate(msg)
)
upload_result = await options.on_artifact_upload_request(
artifact_upload_request_message_interface
)
artifact_upload_response_message_interface = ArtifactUploadResponseMessageInterface(
message_type="artifact_upload_response_message",
filename=artifact_upload_request_message_interface.filename,
filepath=artifact_upload_request_message_interface.filepath,
presigned_url=upload_result.presigned_url,
content_type=upload_result.content_type,
)
await send_queue.put(
artifact_upload_response_message_interface.model_dump(
mode="json"
)
)
continue

else:
logger.info(
f"Received non-result message from agent runtime: {msg}"
)
continue
finally:
self._send_queue = None

else:
logger.info(
f"Received non-result message from agent runtime: {msg}"
)
if self._abort_event.is_set():
raise CancelledException("Command execution was cancelled")

if wire_result is None:
raise AgentRuntimeError("No result message received")
Expand Down