From dd6b3ecf3a6f48cdb0826c639a333d92c91ed27e Mon Sep 17 00:00:00 2001 From: Beon de Nood Date: Sun, 3 May 2026 00:45:16 -0400 Subject: [PATCH] fix: migrate to a2a-sdk v1.0 protobuf types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BREAKING: Updates all code from a2a-sdk v0.x Pydantic models to v1.0 protobuf-generated types: - TextPart/FilePart/DataPart → unified Part with text/raw/url/data fields - MessageSendParams → SendMessageRequest - Role.user/Role.agent → Role.ROLE_USER/Role.ROLE_AGENT - FileWithBytes/FileWithUri → Part(raw=...) / Part(url=...) - new_agent_text_message() → construct Message directly - RequestContext now requires ServerCallContext - Message metadata now uses protobuf Struct - DataPart data field now uses protobuf Value Also updates: - capiscio_sdk executor to handle protobuf Message serialization via MessageToDict (was model_dump() only) - pyproject.toml dependency: a2a-sdk>=1.0.0 - examples/simple_agent for v1.0 compatibility Fixes cross-product E2E test ImportError: cannot import name 'TextPart' from 'a2a.types' --- capiscio_sdk/executor.py | 14 ++- examples/simple_agent/agent_executor.py | 16 ++- examples/simple_agent/test_client.py | 13 +-- pyproject.toml | 2 +- tests/integration/test_real_executor.py | 129 +++++++++++++----------- 5 files changed, 105 insertions(+), 69 deletions(-) diff --git a/capiscio_sdk/executor.py b/capiscio_sdk/executor.py index ecb7075..30b697c 100644 --- a/capiscio_sdk/executor.py +++ b/capiscio_sdk/executor.py @@ -8,6 +8,13 @@ except ImportError: RequestContext = Any # type: ignore[misc,assignment] +try: + from google.protobuf.json_format import MessageToDict + from google.protobuf.message import Message as ProtobufMessage +except ImportError: + MessageToDict = None # type: ignore[assignment,misc] + ProtobufMessage = None # type: ignore[assignment,misc] + from .config import SecurityConfig from .validators import MessageValidator, ProtocolValidator from .infrastructure import ValidationCache, RateLimiter @@ -86,7 +93,12 @@ async def execute(self, context: RequestContext, event_queue: Any) -> None: return # Convert message to dict for validation (our validators expect dict format) - message_dict = message.model_dump() if hasattr(message, 'model_dump') else {} + if hasattr(message, 'model_dump'): + message_dict = message.model_dump() + elif ProtobufMessage is not None and isinstance(message, ProtobufMessage): + message_dict = MessageToDict(message, preserving_proto_field_name=True) + else: + message_dict = {} # Extract identifier for rate limiting identifier = message_dict.get("message_id") or message.message_id diff --git a/examples/simple_agent/agent_executor.py b/examples/simple_agent/agent_executor.py index d108782..dd0e07a 100644 --- a/examples/simple_agent/agent_executor.py +++ b/examples/simple_agent/agent_executor.py @@ -8,7 +8,7 @@ from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events import EventQueue -from a2a.utils import new_agent_text_message +from a2a.types import Message, Part, Role class SimpleAgent: @@ -69,7 +69,12 @@ async def execute( result = await self.agent.invoke(text or "hello") # Send response back to client - await event_queue.enqueue_event(new_agent_text_message(result)) + response_msg = Message( + message_id="response", + role=Role.ROLE_AGENT, + parts=[Part(text=result)], + ) + await event_queue.enqueue_event(response_msg) async def cancel( self, @@ -84,6 +89,9 @@ async def cancel( event_queue: Queue for sending events """ # This simple agent doesn't support cancellation - await event_queue.enqueue_event( - new_agent_text_message("Task cancellation not supported by this simple agent.") + cancel_msg = Message( + message_id="cancel-response", + role=Role.ROLE_AGENT, + parts=[Part(text="Task cancellation not supported by this simple agent.")], ) + await event_queue.enqueue_event(cancel_msg) diff --git a/examples/simple_agent/test_client.py b/examples/simple_agent/test_client.py index 1ecf8d5..c2b7069 100644 --- a/examples/simple_agent/test_client.py +++ b/examples/simple_agent/test_client.py @@ -17,7 +17,8 @@ import httpx import asyncio from datetime import datetime -from a2a.types import Message, TextPart, Role, MessageSendParams +from a2a.types import Message, Part, Role, SendMessageRequest +from google.protobuf.json_format import MessageToDict async def send_message(client: httpx.AsyncClient, text: str, message_id: str = None): @@ -28,16 +29,16 @@ async def send_message(client: httpx.AsyncClient, text: str, message_id: str = N # Create proper A2A message message = Message( message_id=message_id, - role=Role.user, - parts=[TextPart(text=text)] + role=Role.ROLE_USER, + parts=[Part(text=text)] ) - # Wrap in MessageSendParams and serialize - params = MessageSendParams(message=message) + # Wrap in SendMessageRequest and serialize + params = SendMessageRequest(message=message) response = await client.post( "http://localhost:8080/v1/tasks", - json=params.model_dump(mode="json") + json=MessageToDict(params) ) return response diff --git a/pyproject.toml b/pyproject.toml index 22709c6..f3bdddc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ ] dependencies = [ - "a2a-sdk>=0.1.0", + "a2a-sdk>=1.0.0", "httpx>=0.27.0", "pydantic>=2.0.0", "cryptography>=42.0.0", diff --git a/tests/integration/test_real_executor.py b/tests/integration/test_real_executor.py index ab73229..c1f4c6f 100644 --- a/tests/integration/test_real_executor.py +++ b/tests/integration/test_real_executor.py @@ -13,9 +13,11 @@ import pytest import base64 from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.context import ServerCallContext from a2a.server.events import EventQueue -from a2a.types import Message, TextPart, FilePart, DataPart, Role, MessageSendParams, FileWithBytes, FileWithUri -from a2a.utils import new_agent_text_message +from a2a.types import Message, Part, Role, SendMessageRequest +from google.protobuf.json_format import ParseDict +from google.protobuf.struct_pb2 import Struct, Value from capiscio_sdk import secure, SecurityConfig from capiscio_sdk.errors import CapiscioValidationError, CapiscioRateLimitError @@ -52,27 +54,37 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non # Process and respond result = await self.agent.invoke(text) - await event_queue.enqueue_event(new_agent_text_message(result)) + response_msg = Message( + message_id="response-1", + role=Role.ROLE_AGENT, + parts=[Part(text=result)], + ) + await event_queue.enqueue_event(response_msg) async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: """Handle cancellation.""" - await event_queue.enqueue_event(new_agent_text_message("Cancelled")) + cancel_msg = Message( + message_id="cancel-1", + role=Role.ROLE_AGENT, + parts=[Part(text="Cancelled")], + ) + await event_queue.enqueue_event(cancel_msg) def create_valid_message(text="test", message_id="msg-1"): """Helper to create a valid A2A message.""" return Message( message_id=message_id, - role=Role.user, - parts=[TextPart(text=text)] + role=Role.ROLE_USER, + parts=[Part(text=text)] ) def create_request_context(message): """Helper to create a RequestContext.""" - # RequestContext expects MessageSendParams, not Message directly - params = MessageSendParams(message=message) - return RequestContext(request=params) + # RequestContext expects SendMessageRequest, not Message directly + params = SendMessageRequest(message=message) + return RequestContext(call_context=ServerCallContext(), request=params) class SimpleEventQueue: @@ -119,7 +131,7 @@ async def test_integration_invalid_message_blocked(): # Create invalid message (empty message_id) invalid_message = Message( message_id="", # Empty! - role=Role.user, + role=Role.ROLE_USER, parts=[] # Empty! ) @@ -174,7 +186,7 @@ async def test_integration_monitor_mode_allows_invalid(): # Create invalid message invalid_message = Message( message_id="", # Empty! - role=Role.user, + role=Role.ROLE_USER, parts=[] ) @@ -319,12 +331,12 @@ async def test_integration_file_part_with_bytes(): file_content = b"Hello, this is file content!" message = Message( message_id="msg-file-1", - role=Role.user, - parts=[FilePart(file=FileWithBytes( - bytes=base64.b64encode(file_content).decode('utf-8'), + role=Role.ROLE_USER, + parts=[Part( + raw=file_content, media_type="text/plain", - name="test.txt" - ))] + filename="test.txt", + )] ) context = create_request_context(message) @@ -347,12 +359,12 @@ async def test_integration_file_part_with_uri(): # Create message with FilePart (URI) message = Message( message_id="msg-file-2", - role=Role.user, - parts=[FilePart(file=FileWithUri( - uri="https://example.com/document.pdf", + role=Role.ROLE_USER, + parts=[Part( + url="https://example.com/document.pdf", media_type="application/pdf", - name="document.pdf" - ))] + filename="document.pdf", + )] ) context = create_request_context(message) @@ -373,14 +385,14 @@ async def test_integration_data_part_structured(): secured = secure(executor, SecurityConfig.production()) # Create message with DataPart + data_value = ParseDict( + {"query": "SELECT * FROM users", "parameters": {"limit": 10}, "metadata": {"source": "analytics"}}, + Value(), + ) message = Message( message_id="msg-data-1", - role=Role.user, - parts=[DataPart(data={ - "query": "SELECT * FROM users", - "parameters": {"limit": 10}, - "metadata": {"source": "analytics"} - })] + role=Role.ROLE_USER, + parts=[Part(data=data_value)] ) context = create_request_context(message) @@ -401,13 +413,14 @@ async def test_integration_multiple_parts_mixed_types(): secured = secure(executor, SecurityConfig.production()) # Create message with mixed parts + data_value = ParseDict({"query": "search term"}, Value()) message = Message( message_id="msg-mixed-1", - role=Role.user, + role=Role.ROLE_USER, parts=[ - TextPart(text="Here's the query and data:"), - DataPart(data={"query": "search term"}), - TextPart(text="Please process this.") + Part(text="Here's the query and data:"), + Part(data=data_value), + Part(text="Please process this.") ] ) @@ -431,8 +444,8 @@ async def test_integration_agent_role_message(): # Create message from agent message = Message( message_id="msg-agent-1", - role=Role.agent, # Agent sending message - parts=[TextPart(text="Response from upstream agent")] + role=Role.ROLE_AGENT, # Agent sending message + parts=[Part(text="Response from upstream agent")] ) context = create_request_context(message) @@ -455,8 +468,8 @@ async def test_integration_message_with_context_and_task_ids(): # Create message with optional fields message = Message( message_id="msg-ctx-1", - role=Role.user, - parts=[TextPart(text="Continuing task")], + role=Role.ROLE_USER, + parts=[Part(text="Continuing task")], context_id="ctx-123", task_id="task-456" ) @@ -479,11 +492,13 @@ async def test_integration_message_with_metadata(): secured = secure(executor, SecurityConfig.production()) # Create message with metadata + meta = Struct() + meta.update({"priority": "high", "source": "api"}) message = Message( message_id="msg-meta-1", - role=Role.user, - parts=[TextPart(text="Test with metadata")], - metadata={"priority": "high", "source": "api"} + role=Role.ROLE_USER, + parts=[Part(text="Test with metadata")], + metadata=meta, ) context = create_request_context(message) @@ -506,8 +521,8 @@ async def test_integration_empty_text_part(): # Create message with empty text message = Message( message_id="msg-empty-1", - role=Role.user, - parts=[TextPart(text="")] # Empty text + role=Role.ROLE_USER, + parts=[Part(text="")] # Empty text ) context = create_request_context(message) @@ -531,8 +546,8 @@ async def test_integration_very_long_text(): long_text = "A" * 10000 message = Message( message_id="msg-long-1", - role=Role.user, - parts=[TextPart(text=long_text)] + role=Role.ROLE_USER, + parts=[Part(text=long_text)] ) context = create_request_context(message) @@ -555,8 +570,8 @@ async def test_integration_special_characters_unicode(): # Create message with various Unicode characters message = Message( message_id="msg-unicode-1", - role=Role.user, - parts=[TextPart(text="Hello 世界! 🚀 Émojis & Spëcial çhars: <>&\"'")] + role=Role.ROLE_USER, + parts=[Part(text="Hello 世界! 🚀 Émojis & Spëcial çhars: <>&\"'")] ) context = create_request_context(message) @@ -584,8 +599,8 @@ async def test_integration_xss_attempt_in_text(): # Create message with XSS pattern (should be allowed to pass - content validation is app responsibility) message = Message( message_id="msg-xss-1", - role=Role.user, - parts=[TextPart(text="")] + role=Role.ROLE_USER, + parts=[Part(text="")] ) context = create_request_context(message) @@ -609,8 +624,8 @@ async def test_integration_sql_injection_pattern(): # Create message with SQL injection pattern message = Message( message_id="msg-sql-1", - role=Role.user, - parts=[TextPart(text="'; DROP TABLE users; --")] + role=Role.ROLE_USER, + parts=[Part(text="'; DROP TABLE users; --")] ) context = create_request_context(message) @@ -631,10 +646,10 @@ async def test_integration_oversized_message_parts(): secured = secure(executor, SecurityConfig.production()) # Create message with many parts (100 parts) - parts = [TextPart(text=f"Part {i}") for i in range(100)] + parts = [Part(text=f"Part {i}") for i in range(100)] message = Message( message_id="msg-large-1", - role=Role.user, + role=Role.ROLE_USER, parts=parts ) @@ -658,8 +673,8 @@ async def test_integration_null_bytes_in_text(): # Create message with null bytes message = Message( message_id="msg-null-1", - role=Role.user, - parts=[TextPart(text="Hello\x00World")] + role=Role.ROLE_USER, + parts=[Part(text="Hello\x00World")] ) context = create_request_context(message) @@ -680,7 +695,7 @@ async def test_integration_null_bytes_in_text(): @pytest.mark.asyncio async def test_integration_invalid_role_value(): """Test: Invalid enum values in role are caught by SDK type checking.""" - # Note: The A2A SDK's Pydantic models prevent invalid role values + # Note: The A2A SDK's protobuf models prevent invalid role values # at construction time, so this test validates that the SDK itself # provides type safety. Invalid roles cannot reach our validator. @@ -695,7 +710,7 @@ async def test_integration_invalid_role_value(): Message( message_id="msg-bad-1", role="hacker", # Invalid role - SDK validates this - parts=[TextPart(text="Hello")] + parts=[Part(text="Hello")] ) @@ -711,8 +726,8 @@ async def test_integration_missing_message_id(): # Create message with empty messageId message = Message( message_id="", # Empty! - role=Role.user, - parts=[TextPart(text="Test")] + role=Role.ROLE_USER, + parts=[Part(text="Test")] ) context = create_request_context(message) @@ -740,7 +755,7 @@ async def test_integration_empty_parts_array(): # Create message with empty parts message = Message( message_id="msg-empty-parts", - role=Role.user, + role=Role.ROLE_USER, parts=[] # Empty but valid structure )