From 1d649fe02072341455459c7ce2c28e8c271e5c61 Mon Sep 17 00:00:00 2001 From: Zhiyong Yang Date: Mon, 9 Mar 2026 22:25:39 -0700 Subject: [PATCH 01/10] [agentserver] Add azure-ai-agentserver v1.0.0b1 with invoke API support Co-Authored-By: Claude Opus 4.6 --- .../agentframework/agent_framework.py | 4 +- ...ramework_output_non_streaming_converter.py | 2 +- .../id_generator/foundry_id_generator.py | 3 +- .../pyproject.toml | 3 +- .../azure-ai-agentserver/CHANGELOG.md | 12 + sdk/agentserver/azure-ai-agentserver/LICENSE | 21 + .../azure-ai-agentserver/MANIFEST.in | 7 + .../azure-ai-agentserver/README.md | 312 +++ .../azure-ai-agentserver/azure/__init__.py | 1 + .../azure-ai-agentserver/azure/ai/__init__.py | 1 + .../azure/ai/agentserver/__init__.py | 9 + .../azure/ai/agentserver/_constants.py | 20 + .../azure/ai/agentserver/_errors.py | 52 + .../azure/ai/agentserver/_logger.py | 18 + .../azure/ai/agentserver/_tracing.py | 307 +++ .../azure/ai/agentserver/_version.py | 5 + .../azure/ai/agentserver/py.typed | 0 .../azure/ai/agentserver/server/__init__.py | 3 + .../azure/ai/agentserver/server/_base.py | 546 +++++ .../azure/ai/agentserver/server/_config.py | 180 ++ .../ai/agentserver/validation/__init__.py | 3 + .../validation/_openapi_validator.py | 695 ++++++ .../azure-ai-agentserver/cspell.json | 36 + .../azure-ai-agentserver/dev_requirements.txt | 9 + .../azure-ai-agentserver/pyproject.toml | 76 + .../azure-ai-agentserver/pyrightconfig.json | 11 + .../agentframework_invoke_agent/.env.sample | 4 + .../agentframework_invoke_agent.py | 81 + .../requirements.txt | 4 + .../async_invoke_agent/async_invoke_agent.py | 168 ++ .../async_invoke_agent/requirements.txt | 1 + .../human_in_the_loop_agent.py | 74 + .../human_in_the_loop_agent/requirements.txt | 1 + .../langgraph_invoke_agent/.env.sample | 5 + .../langgraph_invoke_agent.py | 100 + .../langgraph_invoke_agent/requirements.txt | 4 + .../openapi_validated_agent.py | 117 + .../openapi_validated_agent/requirements.txt | 1 + .../samples/simple_invoke_agent/Dockerfile | 19 + .../simple_invoke_agent/requirements.txt | 1 + .../simple_invoke_agent.py | 38 + .../azure-ai-agentserver/tests/conftest.py | 205 ++ .../tests/test_decorator_pattern.py | 337 +++ .../tests/test_edge_cases.py | 491 ++++ .../tests/test_get_cancel.py | 111 + .../tests/test_graceful_shutdown.py | 438 ++++ .../azure-ai-agentserver/tests/test_health.py | 23 + .../azure-ai-agentserver/tests/test_http2.py | 473 ++++ .../azure-ai-agentserver/tests/test_invoke.py | 119 + .../azure-ai-agentserver/tests/test_logger.py | 22 + .../tests/test_multimodal_protocol.py | 807 ++++++ .../tests/test_openapi_validation.py | 2167 +++++++++++++++++ .../tests/test_request_limits.py | 147 ++ .../tests/test_server_routes.py | 116 + .../tests/test_tracing.py | 523 ++++ sdk/agentserver/ci.yml | 2 + shared_requirements.txt | 1 + 57 files changed, 8931 insertions(+), 5 deletions(-) create mode 100644 sdk/agentserver/azure-ai-agentserver/CHANGELOG.md create mode 100644 sdk/agentserver/azure-ai-agentserver/LICENSE create mode 100644 sdk/agentserver/azure-ai-agentserver/MANIFEST.in create mode 100644 sdk/agentserver/azure-ai-agentserver/README.md create mode 100644 sdk/agentserver/azure-ai-agentserver/azure/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver/azure/ai/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_constants.py create mode 100644 sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_errors.py create mode 100644 sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_logger.py create mode 100644 sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_tracing.py create mode 100644 sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_version.py create mode 100644 sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/py.typed create mode 100644 sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_base.py create mode 100644 sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_config.py create mode 100644 sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/_openapi_validator.py create mode 100644 sdk/agentserver/azure-ai-agentserver/cspell.json create mode 100644 sdk/agentserver/azure-ai-agentserver/dev_requirements.txt create mode 100644 sdk/agentserver/azure-ai-agentserver/pyproject.toml create mode 100644 sdk/agentserver/azure-ai-agentserver/pyrightconfig.json create mode 100644 sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/.env.sample create mode 100644 sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/agentframework_invoke_agent.py create mode 100644 sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/requirements.txt create mode 100644 sdk/agentserver/azure-ai-agentserver/samples/async_invoke_agent/async_invoke_agent.py create mode 100644 sdk/agentserver/azure-ai-agentserver/samples/async_invoke_agent/requirements.txt create mode 100644 sdk/agentserver/azure-ai-agentserver/samples/human_in_the_loop_agent/human_in_the_loop_agent.py create mode 100644 sdk/agentserver/azure-ai-agentserver/samples/human_in_the_loop_agent/requirements.txt create mode 100644 sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/.env.sample create mode 100644 sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/langgraph_invoke_agent.py create mode 100644 sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/requirements.txt create mode 100644 sdk/agentserver/azure-ai-agentserver/samples/openapi_validated_agent/openapi_validated_agent.py create mode 100644 sdk/agentserver/azure-ai-agentserver/samples/openapi_validated_agent/requirements.txt create mode 100644 sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/Dockerfile create mode 100644 sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/requirements.txt create mode 100644 sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/simple_invoke_agent.py create mode 100644 sdk/agentserver/azure-ai-agentserver/tests/conftest.py create mode 100644 sdk/agentserver/azure-ai-agentserver/tests/test_decorator_pattern.py create mode 100644 sdk/agentserver/azure-ai-agentserver/tests/test_edge_cases.py create mode 100644 sdk/agentserver/azure-ai-agentserver/tests/test_get_cancel.py create mode 100644 sdk/agentserver/azure-ai-agentserver/tests/test_graceful_shutdown.py create mode 100644 sdk/agentserver/azure-ai-agentserver/tests/test_health.py create mode 100644 sdk/agentserver/azure-ai-agentserver/tests/test_http2.py create mode 100644 sdk/agentserver/azure-ai-agentserver/tests/test_invoke.py create mode 100644 sdk/agentserver/azure-ai-agentserver/tests/test_logger.py create mode 100644 sdk/agentserver/azure-ai-agentserver/tests/test_multimodal_protocol.py create mode 100644 sdk/agentserver/azure-ai-agentserver/tests/test_openapi_validation.py create mode 100644 sdk/agentserver/azure-ai-agentserver/tests/test_request_limits.py create mode 100644 sdk/agentserver/azure-ai-agentserver/tests/test_server_routes.py create mode 100644 sdk/agentserver/azure-ai-agentserver/tests/test_tracing.py diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/agent_framework.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/agent_framework.py index 7177b522d2a9..09e6904f8a63 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/agent_framework.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/agent_framework.py @@ -77,7 +77,9 @@ def _resolve_stream_timeout(self, request_body: CreateResponse) -> float: def init_tracing(self): exporter = os.environ.get(AdapterConstants.OTEL_EXPORTER_ENDPOINT) - app_insights_conn_str = os.environ.get(AdapterConstants.APPLICATION_INSIGHTS_CONNECTION_STRING) + app_insights_conn_str = os.environ.get( + AdapterConstants.APPLICATION_INSIGHTS_CONNECTION_STRING # pylint: disable=no-member + ) project_endpoint = os.environ.get(AdapterConstants.AZURE_AI_PROJECT_ENDPOINT) if project_endpoint: diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py index 805a5eeb9dec..ce2010566ce0 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py @@ -213,7 +213,7 @@ def _coerce_result_text(self, value: Any) -> str | dict: def _construct_response_data(self, output_items: List[dict]) -> dict: agent_id = AgentIdGenerator.generate(self._context) - response_data = { + response_data: dict[str, Any] = { "object": "response", "metadata": {}, "agent": agent_id, diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/foundry_id_generator.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/foundry_id_generator.py index 910a7c481daa..3f571503620a 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/foundry_id_generator.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/foundry_id_generator.py @@ -34,12 +34,11 @@ def __init__(self, response_id: Optional[str], conversation_id: Optional[str]): def from_request(cls, payload: dict) -> "FoundryIdGenerator": response_id = payload.get("metadata", {}).get("response_id", None) conv_id_raw = payload.get("conversation", None) + conv_id: Optional[str] = None if isinstance(conv_id_raw, str): conv_id = conv_id_raw elif isinstance(conv_id_raw, dict): conv_id = conv_id_raw.get("id", None) - else: - conv_id = None return cls(response_id, conv_id) def generate(self, category: Optional[str] = None) -> str: diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml index 5552ff8233d2..41dc302054a3 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml @@ -64,4 +64,5 @@ pyright = false verifytypes = false # incompatible python version for -core verify_keywords = false mindependency = false # depends on -core package -whl_no_aio = false \ No newline at end of file +whl_no_aio = false +apistub = false # pip 24.0 crashes resolving langchain/langgraph transitive deps during stub generation \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver/CHANGELOG.md new file mode 100644 index 000000000000..1ba538d24413 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/CHANGELOG.md @@ -0,0 +1,12 @@ +# Release History + +## 1.0.0b1 (Unreleased) + +### Features Added + +- Initial release of `azure-ai-agentserver`. +- Generic `AgentServer` base class with pluggable protocol heads. +- `/invoke` protocol head with all 4 operations: create, get, cancel, and OpenAPI spec. +- OpenAPI spec-based request/response validation via `jsonschema`. +- Health check endpoints (`/liveness`, `/readiness`). +- Streaming and non-streaming invocation support. diff --git a/sdk/agentserver/azure-ai-agentserver/LICENSE b/sdk/agentserver/azure-ai-agentserver/LICENSE new file mode 100644 index 000000000000..b2f52a2bad4e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/LICENSE @@ -0,0 +1,21 @@ +Copyright (c) Microsoft Corporation. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/sdk/agentserver/azure-ai-agentserver/MANIFEST.in b/sdk/agentserver/azure-ai-agentserver/MANIFEST.in new file mode 100644 index 000000000000..468601f6166b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/MANIFEST.in @@ -0,0 +1,7 @@ +include *.md +include LICENSE +recursive-include tests *.py +recursive-include samples *.py *.md +include azure/__init__.py +include azure/ai/__init__.py +include azure/ai/agentserver/py.typed diff --git a/sdk/agentserver/azure-ai-agentserver/README.md b/sdk/agentserver/azure-ai-agentserver/README.md new file mode 100644 index 000000000000..bd6a8c031225 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/README.md @@ -0,0 +1,312 @@ +# Azure AI Agent Server client library for Python + +A standalone, **protocol-agnostic agent server** package for Azure AI. Provides a +Starlette-based `AgentServer` class with pluggable protocol heads, OpenAPI spec +serving, optional request validation, optional tracing, and health +endpoints — with **zero framework coupling**. + +## Getting started + +### Install the package + +```bash +pip install azure-ai-agentserver +``` + +**Requires Python >= 3.10.** + +### Quick start + +```python +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver import AgentServer + +server = AgentServer() + + +@server.invoke_handler +async def handle(request: Request) -> Response: + data = await request.json() + greeting = f"Hello, {data['name']}!" + return JSONResponse({"greeting": greeting}) + + +if __name__ == "__main__": + server.run() +``` + +```bash +# Start the agent +python my_agent.py + +# Call it +curl -X POST http://localhost:8088/invocations \ + -H "Content-Type: application/json" \ + -d '{"name": "World"}' +# → {"greeting": "Hello, World!"} +``` + +## Key concepts + +`AgentServer` is a class for building agent endpoints that plug into the +Azure Agent Service. Register handlers with decorators on an `AgentServer` +instance. Multiple protocol heads (`/invoke` today, `/responses` and others +in the future) are supported through a pluggable handler architecture. + +The Azure Agent Service expects specific route paths (`/invocations`, `/liveness`, +`/readiness`, etc.) for deployment. `AgentServer` wires these automatically so +your agent is compatible with the hosting platform — no manual route setup required. + +**Key properties:** + +- **Platform-compatible routes** — automatically registers the exact endpoints the Azure + Agent Service expects, so your agent deploys without configuration changes. +- **Starlette + Hypercorn** — lightweight ASGI server with native HTTP/1.1 and HTTP/2 + support. +- **Raw protocol access** — receive raw Starlette `Request` objects, return raw + Starlette `Response` objects. Full control over content types, streaming, headers, + SSE, and status codes. +- **Automatic invocation ID tracking** — every request gets a unique ID injected into + `request.state` and the `x-agent-invocation-id` response header. +- **OpenAPI spec serving** — pass a spec and it is served at + `GET /invocations/docs/openapi.json` for documentation / tooling. +- **Optional request validation** — opt in to validate incoming request bodies + against the OpenAPI spec before reaching your code. +- **Request timeout and graceful shutdown** — configurable invoke timeouts (504) and graceful shutdown — all via + constructor args or environment variables. +- **Optional OpenTelemetry tracing** — opt-in span instrumentation that covers the full + request lifecycle, including streaming responses. +- **Health endpoints** — `/liveness` and `/readiness` out of the box. + +### Architecture + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ Azure Agent Service (Cloud Infrastructure) │ +│ Protocols: /invoke, /responses, /mcp, /a2a, /activity │ +└────────────────────────────┬────────────────────────────────────────┘ + │ + ▼ + ┌───────────────────────────────┐ + │ azure-ai-agentserver │ + │ AgentServer │ + │ │ + │ Protocol heads: │ + │ • /invoke │ + │ • /responses (planned) │ + │ │ + │ OpenAPI spec serving │ + │ Optional request validation │ + │ Invocation ID tracking │ + │ Invoke timeout / graceful │ + │ shutdown │ + │ Optional OTel tracing │ + │ Health & readiness endpoints │ + └───────────────────────────────┘ + │ + Your integration code: + ┌─────────┼─────────┐ + ▼ ▼ ▼ + LangGraph Agent Semantic + Framework Kernel +``` + +**Single package, multiple protocol heads, no framework coupling.** + +### Handler registration + +Instantiate `AgentServer` and register handlers with decorators: + +| Decorator | Required | Description | +|-----------|----------|-------------| +| `@server.invoke_handler` | **Yes** | Register the invoke handler function. | +| `@server.get_invocation_handler` | No | Register a get-invocation handler. Default returns 404. | +| `@server.cancel_invocation_handler` | No | Register a cancel-invocation handler. Default returns 404. | +| `@server.shutdown_handler` | No | Register a shutdown handler. Default is a no-op. | + +The invocation ID is available via `request.state.invocation_id` (auto-generated for +`invoke`, extracted from the URL path for `get_invocation` / `cancel_invocation`). +The server auto-injects the `x-agent-invocation-id` response header if not already set. + +### Routes + +| Route | Method | Description | +|-------|--------|-------------| +| `/invocations` | POST | Create and process an invocation | +| `/invocations/{id}` | GET | Retrieve a previous invocation result | +| `/invocations/{id}/cancel` | POST | Cancel a running invocation | +| `/invocations/docs/openapi.json` | GET | Return the registered OpenAPI spec | +| `/liveness` | GET | Health check | +| `/readiness` | GET | Readiness check | + +### Configuration + +All settings follow the same resolution order: **constructor argument > environment +variable > default**. Set a value to `0` to disable the corresponding feature. + +| Constructor param | Environment variable | Default | Description | +|---|---|---|---| +| `port` (on `run()`) | `AGENT_SERVER_PORT` | `8088` | Port to bind | +| `graceful_shutdown_timeout` | `AGENT_GRACEFUL_SHUTDOWN_TIMEOUT` | `30` (seconds) | Drain period after SIGTERM before forced exit | +| `request_timeout` | `AGENT_REQUEST_TIMEOUT` | `300` (seconds) | Max time for `invoke()` before 504 | +| `enable_tracing` | `AGENT_ENABLE_TRACING` | `false` | Enable OpenTelemetry tracing | +| `enable_request_validation` | `AGENT_ENABLE_REQUEST_VALIDATION` | `false` | Validate request bodies against `openapi_spec` | +| `log_level` | `AGENT_LOG_LEVEL` | `INFO` | Library log level (`DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`) | +| `debug_errors` | `AGENT_DEBUG_ERRORS` | `false` | Include exception details in error responses | + +```python +server = AgentServer( + request_timeout=60, # 1 minute + graceful_shutdown_timeout=15, # 15 s drain +) +server.run() +``` + +Or configure entirely via environment variables — no code changes needed for deployment +tuning. + +## Examples + +### OpenAPI spec & validation + +Pass an OpenAPI spec to serve it at `/invocations/docs/openapi.json`. +Opt in to runtime request validation with `enable_request_validation=True`: + +```python +spec = { + "openapi": "3.0.0", + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"greeting": {"type": "string"}}, + } + } + } + } + }, + } + } + }, +} + +# Spec served for documentation; validation off (default) +server = AgentServer(openapi_spec=spec) + +# Spec served AND requests validated at runtime +server = AgentServer(openapi_spec=spec, enable_request_validation=True) +server.run() +``` + +- `GET /invocations/docs/openapi.json` serves the registered spec (or 404 if none). +- When validation is enabled, non-conforming **requests** return 400 with details. + +### Tracing + +Tracing is **disabled by default**. Enable it via constructor or environment variable: + +```python +server = AgentServer(enable_tracing=True) +``` + +or: + +```bash +export AGENT_ENABLE_TRACING=true +``` + +Install the tracing extras (includes OpenTelemetry and the Azure Monitor exporter): + +```bash +pip install azure-ai-agentserver[tracing] +``` + +When enabled, spans are created for `invoke`, `get_invocation`, and `cancel_invocation` +endpoints. For streaming responses, the span stays open until the last chunk is sent, +accurately capturing the full transfer duration. Errors during streaming are recorded on +the span. Incoming `traceparent` / `tracestate` headers are propagated via W3C +TraceContext. + +#### Application Insights integration + +When tracing is enabled **and** an Application Insights connection string is available, +traces and logs are automatically exported to Azure Monitor. The connection string is +resolved in the following order: + +1. The `application_insights_connection_string` constructor parameter. +2. The `APPLICATIONINSIGHTS_CONNECTION_STRING` environment variable. + +```python +# Explicit connection string +server = AgentServer( + enable_tracing=True, + application_insights_connection_string="InstrumentationKey=...", +) + +# Or via environment variable (connection string auto-discovered) +# export APPLICATIONINSIGHTS_CONNECTION_STRING="InstrumentationKey=..." +server = AgentServer(enable_tracing=True) +``` + +If no connection string is found, tracing still works — spans are created but not exported +to Azure Monitor (you can bring your own `TracerProvider`). + +### More samples + +| Sample | Description | +|--------|-------------| +| `samples/simple_invoke_agent/` | Minimal from-scratch agent | +| `samples/openapi_validated_agent/` | OpenAPI spec with request/response validation | +| `samples/async_invoke_agent/` | Long-running tasks with get & cancel support | +| `samples/human_in_the_loop_agent/` | Synchronous human-in-the-loop interaction | +| `samples/langgraph_invoke_agent/` | Customer-managed LangGraph adapter | +| `samples/agentframework_invoke_agent/` | Customer-managed Agent Framework adapter | + +## Troubleshooting + +### Reporting issues + +To report an issue with the client library, or request additional features, please open a +GitHub issue [here](https://github.com/Azure/azure-sdk-for-python/issues). + +## Next steps + +Please visit [Samples](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/agentserver/azure-ai-agentserver/samples) +for more usage examples. + +## Contributing + +This project welcomes contributions and suggestions. Most contributions require +you to agree to a Contributor License Agreement (CLA) declaring that you have +the right to, and actually do, grant us the rights to use your contribution. +For details, visit https://cla.microsoft.com. + +When you submit a pull request, a CLA-bot will automatically determine whether +you need to provide a CLA and decorate the PR appropriately (e.g., label, +comment). Simply follow the instructions provided by the bot. You will only +need to do this once across all repos using our CLA. + +This project has adopted the +[Microsoft Open Source Code of Conduct][code_of_conduct]. For more information, +see the Code of Conduct FAQ or contact opencode@microsoft.com with any +additional questions or comments. + +[code_of_conduct]: https://opensource.microsoft.com/codeofconduct/ diff --git a/sdk/agentserver/azure-ai-agentserver/azure/__init__.py b/sdk/agentserver/azure-ai-agentserver/azure/__init__.py new file mode 100644 index 000000000000..d55ccad1f573 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/azure/__init__.py @@ -0,0 +1 @@ +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/__init__.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/__init__.py new file mode 100644 index 000000000000..d55ccad1f573 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/__init__.py @@ -0,0 +1 @@ +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/__init__.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/__init__.py new file mode 100644 index 000000000000..aa692afa00e8 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/__init__.py @@ -0,0 +1,9 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from ._version import VERSION +from .server._base import AgentServer + +__all__ = ["AgentServer"] +__version__ = VERSION diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_constants.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_constants.py new file mode 100644 index 000000000000..d85e143e6dd4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_constants.py @@ -0,0 +1,20 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +class Constants: + """Well-known environment variables and defaults for AgentServer.""" + + AGENT_LOG_LEVEL = "AGENT_LOG_LEVEL" + AGENT_DEBUG_ERRORS = "AGENT_DEBUG_ERRORS" + AGENT_ENABLE_TRACING = "AGENT_ENABLE_TRACING" + AGENT_SERVER_PORT = "AGENT_SERVER_PORT" + AGENT_GRACEFUL_SHUTDOWN_TIMEOUT = "AGENT_GRACEFUL_SHUTDOWN_TIMEOUT" + AGENT_REQUEST_TIMEOUT = "AGENT_REQUEST_TIMEOUT" + AGENT_ENABLE_REQUEST_VALIDATION = "AGENT_ENABLE_REQUEST_VALIDATION" + APPLICATIONINSIGHTS_CONNECTION_STRING = "APPLICATIONINSIGHTS_CONNECTION_STRING" + DEFAULT_PORT = 8088 + DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT = 30 + DEFAULT_REQUEST_TIMEOUT = 300 # 5 minutes + INVOCATION_ID_HEADER = "x-agent-invocation-id" diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_errors.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_errors.py new file mode 100644 index 000000000000..a7774fd09a92 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_errors.py @@ -0,0 +1,52 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Standardised error response builder for AgentServer. + +Every error returned by the framework uses the shape:: + + { + "error": { + "code": "...", // required – machine-readable error code + "message": "...", // required – human-readable description + "details": [ ... ] // optional – child errors + } + } +""" +from __future__ import annotations + +from typing import Any, Optional + +from starlette.responses import JSONResponse + + +def error_response( + code: str, + message: str, + *, + status_code: int, + details: Optional[list[dict[str, Any]]] = None, + headers: Optional[dict[str, str]] = None, +) -> JSONResponse: + """Build a ``JSONResponse`` with the standard error envelope. + + :param code: Machine-readable error code (e.g. ``"internal_error"``). + :type code: str + :param message: Human-readable error message. + :type message: str + :keyword status_code: HTTP status code for the response. + :paramtype status_code: int + :keyword details: Child error objects, each with at least ``code`` and + ``message`` keys. + :paramtype details: Optional[list[dict[str, Any]]] + :keyword headers: Extra HTTP headers to include on the response. + :paramtype headers: Optional[dict[str, str]] + :return: A ready-to-send JSON error response. + :rtype: JSONResponse + """ + body: dict[str, Any] = {"code": code, "message": message} + if details is not None: + body["details"] = details + return JSONResponse( + {"error": body}, status_code=status_code, headers=headers + ) diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_logger.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_logger.py new file mode 100644 index 000000000000..4b4e960e971d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_logger.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import logging + + +def get_logger() -> logging.Logger: + """Return the library-scoped logger. + + The log level is configured by the ``log_level`` constructor parameter + of :class:`AgentServer` (or the ``AGENT_LOG_LEVEL`` env var as fallback). + This function simply returns the named logger without forcing a level so + that the level already set by the constructor is preserved. + + :return: Logger instance for azure.ai.agentserver. + :rtype: logging.Logger + """ + return logging.getLogger("azure.ai.agentserver") diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_tracing.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_tracing.py new file mode 100644 index 000000000000..2c612424ecaa --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_tracing.py @@ -0,0 +1,307 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Optional OpenTelemetry tracing for AgentServer. + +Tracing is **disabled by default**. Enable it in one of two ways: + +1. Set the environment variable ``AGENT_ENABLE_TRACING=true``. +2. Pass ``enable_tracing=True`` to the :class:`AgentServer` constructor + (constructor argument takes precedence over the env var). + +When enabled, the module requires ``opentelemetry-api`` to be installed:: + + pip install azure-ai-agentserver[tracing] + +If the package is not installed, tracing silently becomes a no-op. + +When an Application Insights connection string is available (via constructor +or ``APPLICATIONINSIGHTS_CONNECTION_STRING`` env var), traces **and** logs are +automatically exported to Azure Monitor. This requires the additional +``opentelemetry-sdk`` and ``azure-monitor-opentelemetry-exporter`` packages +(both included in the ``[tracing]`` extras group). +""" +from __future__ import annotations + +import logging +from contextlib import contextmanager +from collections.abc import AsyncIterable, AsyncIterator, Mapping # pylint: disable=import-error +from typing import TYPE_CHECKING, Any, Iterator, Optional, Union + +#: Starlette's ``Content`` type — the element type for streaming bodies. +_Content = Union[str, bytes, memoryview] + +#: W3C Trace Context header names used for distributed trace propagation. +_W3C_HEADERS = frozenset(("traceparent", "tracestate")) + +logger = logging.getLogger("azure.ai.agentserver") + +_HAS_OTEL = False +try: + from opentelemetry import trace + from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + + _HAS_OTEL = True +except ImportError: + if TYPE_CHECKING: + from opentelemetry import trace + from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + + +class TracingHelper: + """Lightweight wrapper around OpenTelemetry. + + Only instantiate when tracing is enabled. If ``opentelemetry-api`` is + not installed, a warning is logged and all methods become no-ops. + + When *connection_string* is provided, a :class:`TracerProvider` with an + Azure Monitor exporter is configured globally and log records from the + ``azure.ai.agentserver`` logger are forwarded to Application Insights. + This requires ``opentelemetry-sdk`` and + ``azure-monitor-opentelemetry-exporter``. + """ + + def __init__(self, connection_string: Optional[str] = None) -> None: + self._enabled = _HAS_OTEL + self._tracer: Any = None + + if not self._enabled: + logger.warning( + "Tracing was enabled but opentelemetry-api is not installed. " + "Install it with: pip install azure-ai-agentserver[tracing]" + ) + return + + if connection_string: + self._setup_azure_monitor(connection_string) + + self._tracer = trace.get_tracer("azure.ai.agentserver") + + # ------------------------------------------------------------------ + # Azure Monitor auto-configuration + # ------------------------------------------------------------------ + + @staticmethod + def _setup_azure_monitor(connection_string: str) -> None: + """Configure global TracerProvider and LoggerProvider for App Insights. + + Sets up ``AzureMonitorTraceExporter`` so spans are exported to + Application Insights, and ``AzureMonitorLogExporter`` so log records + from the ``azure.ai.agentserver`` namespace are forwarded. + + If the required packages are not installed, a warning is logged and + export is silently skipped — span creation still works via the + default no-op or user-configured provider. + + :param connection_string: Application Insights connection string. + :type connection_string: str + """ + try: + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + + from azure.monitor.opentelemetry.exporter import ( # type: ignore[import-untyped] + AzureMonitorTraceExporter, + ) + except ImportError: + logger.warning( + "Application Insights connection string was provided but " + "required packages are not installed. Install them with: " + "pip install azure-ai-agentserver[tracing]" + ) + return + + resource = Resource.create({"service.name": "azure.ai.agentserver"}) + + # --- Trace export --- + provider = SdkTracerProvider(resource=resource) + exporter = AzureMonitorTraceExporter( + connection_string=connection_string, + ) + provider.add_span_processor(BatchSpanProcessor(exporter)) + trace.set_tracer_provider(provider) + logger.info("Application Insights trace exporter configured.") + + # --- Log export --- + try: + from opentelemetry._logs import set_logger_provider + from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler + from opentelemetry.sdk._logs.export import BatchLogRecordProcessor + + from azure.monitor.opentelemetry.exporter import ( # type: ignore[import-untyped] + AzureMonitorLogExporter, + ) + + log_provider = LoggerProvider(resource=resource) + set_logger_provider(log_provider) + log_exporter = AzureMonitorLogExporter( + connection_string=connection_string, + ) + log_provider.add_log_record_processor( + BatchLogRecordProcessor(log_exporter) + ) + handler = LoggingHandler(logger_provider=log_provider) + logging.getLogger("azure.ai.agentserver").addHandler(handler) + logger.info("Application Insights log exporter configured.") + except ImportError: + logger.warning( + "Log export to Application Insights requires " + "opentelemetry-sdk. Logs will not be forwarded." + ) + + @contextmanager + def span( + self, + name: str, + attributes: Optional[dict[str, str]] = None, + carrier: Optional[dict[str, str]] = None, + ) -> Iterator[Any]: + """Create a traced span if tracing is enabled, otherwise no-op. + + Yields the OpenTelemetry span object when tracing is active, or + ``None`` when tracing is disabled. Callers may use the yielded span + together with :meth:`record_error` to attach error information. + + :param name: Span name, e.g. ``"AgentServer.invoke"``. + :type name: str + :param attributes: Key-value span attributes. + :type attributes: Optional[dict[str, str]] + :param carrier: Incoming HTTP headers for W3C trace-context propagation. + :type carrier: Optional[dict[str, str]] + :return: Context manager that yields the OTel span or *None*. + :rtype: Iterator[Any] + """ + if not self._enabled or self._tracer is None: + yield None + return + + # Extract parent context from W3C traceparent header if present + ctx = None + if carrier: + ctx = TraceContextTextMapPropagator().extract(carrier=carrier) + + with self._tracer.start_as_current_span( + name=name, + attributes=attributes or {}, + kind=trace.SpanKind.SERVER, + context=ctx, + ) as otel_span: + yield otel_span + + def start_span( + self, + name: str, + attributes: Optional[dict[str, str]] = None, + carrier: Optional[dict[str, str]] = None, + ) -> Any: + """Start a span without a context manager. + + Use this for streaming responses where the span must outlive the + initial ``invoke()`` call. The caller **must** call :meth:`end_span` + when the work is finished. + + :param name: Span name, e.g. ``"AgentServer.invoke"``. + :type name: str + :param attributes: Key-value span attributes. + :type attributes: Optional[dict[str, str]] + :param carrier: Incoming HTTP headers for W3C trace-context propagation. + :type carrier: Optional[dict[str, str]] + :return: The OTel span, or *None* when tracing is disabled. + :rtype: Any + """ + if not self._enabled or self._tracer is None: + return None + + ctx = None + if carrier: + ctx = TraceContextTextMapPropagator().extract(carrier=carrier) + + return self._tracer.start_span( + name=name, + attributes=attributes or {}, + kind=trace.SpanKind.SERVER, + context=ctx, + ) + + def end_span(self, span: Any, exc: Optional[Exception] = None) -> None: + """End a span started with :meth:`start_span`. + + Optionally records an error before ending. No-op when *span* is + ``None`` (tracing disabled). + + :param span: The OTel span, or *None*. + :type span: Any + :param exc: Optional exception to record before ending. + :type exc: Optional[Exception] + """ + if span is None: + return + if exc is not None: + self.record_error(span, exc) + span.end() + + @staticmethod + def record_error(span: Any, exc: Exception) -> None: + """Record an exception and ERROR status on a span. + + No-op when *span* is ``None`` (tracing disabled) or when + ``opentelemetry-api`` is not installed. + + :param span: The OTel span returned by :meth:`span`, or *None*. + :type span: Any + :param exc: The exception to record. + :type exc: Exception + """ + if span is not None and _HAS_OTEL: + span.set_status(trace.StatusCode.ERROR, str(exc)) + span.record_exception(exc) + + async def trace_stream( + self, iterator: AsyncIterable[_Content], span: Any + ) -> AsyncIterator[_Content]: + """Wrap a streaming body iterator so the tracing span covers the full + duration of data transmission. + + Yields chunks from *iterator* unchanged. When the iterator is + exhausted or raises an exception the span is ended (with error status + if applicable). Safe to call when tracing is disabled (*span* is + ``None``). + + :param iterator: The original async body iterator from + :class:`~starlette.responses.StreamingResponse`. + :type iterator: AsyncIterable[Union[str, bytes, memoryview]] + :param span: The OTel span (or *None* when tracing is disabled). + :type span: Any + :return: An async iterator that yields chunks unchanged. + :rtype: AsyncIterator[Union[str, bytes, memoryview]] + """ + error: Optional[Exception] = None + try: + async for chunk in iterator: + yield chunk + except Exception as exc: + error = exc + raise + finally: + self.end_span(span, exc=error) + + +def extract_w3c_carrier(headers: Mapping[str, str]) -> dict[str, str]: + """Extract W3C trace-context headers from a mapping. + + Filters the input to only ``traceparent`` and ``tracestate`` — the two + headers defined by the `W3C Trace Context`_ standard. This avoids + passing unrelated headers (e.g. ``authorization``, ``cookie``) into the + OpenTelemetry propagator. + + .. _W3C Trace Context: https://www.w3.org/TR/trace-context/ + + :param headers: A mapping of header name to value (e.g. + ``request.headers``). + :type headers: Mapping[str, str] + :return: A dict containing only the W3C propagation headers present + in *headers*. + :rtype: dict[str, str] + """ + return {k: v for k, v in headers.items() if k in _W3C_HEADERS} diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_version.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_version.py new file mode 100644 index 000000000000..67d209a8cafd --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_version.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +VERSION = "1.0.0b1" diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/py.typed b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/py.typed new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/__init__.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/__init__.py new file mode 100644 index 000000000000..d540fd20468c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/__init__.py @@ -0,0 +1,3 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_base.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_base.py new file mode 100644 index 000000000000..5833eb9dd906 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_base.py @@ -0,0 +1,546 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import asyncio # pylint: disable=do-not-import-asyncio +import contextlib +import logging +import uuid +from collections.abc import AsyncGenerator, Awaitable, Callable # pylint: disable=import-error +from typing import Any, Optional + +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse +from starlette.routing import Route + +from .._constants import Constants +from .._errors import error_response +from .._tracing import TracingHelper, extract_w3c_carrier +from ..validation._openapi_validator import OpenApiValidator +from . import _config + +logger = logging.getLogger("azure.ai.agentserver") + + +class AgentServer: + """Agent server with pluggable protocol heads. + + Instantiate and register handlers with decorators:: + + server = AgentServer() + + @server.invoke_handler + async def handle(request): + return JSONResponse({"ok": True}) + + Optionally register handlers with :meth:`get_invocation_handler` and + :meth:`cancel_invocation_handler` for additional protocol support. + + Developer receives raw Starlette ``Request`` objects and returns raw + Starlette ``Response`` objects, giving full control over content types, + streaming, headers, and status codes. + + :param openapi_spec: Optional OpenAPI spec dict. When provided, the spec + is served at ``GET /invocations/docs/openapi.json`` for documentation. + Runtime request validation is **not** enabled by default — set + *enable_request_validation* to opt in. + :type openapi_spec: Optional[dict[str, Any]] + :param enable_request_validation: When *True*, incoming ``POST /invocations`` + request bodies are validated against the *openapi_spec* before reaching + :meth:`invoke`. When *None* (default) the + ``AGENT_ENABLE_REQUEST_VALIDATION`` env var is consulted (``"true"`` to + enable). Requires *openapi_spec* to be set. + :type enable_request_validation: Optional[bool] + :param enable_tracing: Enable OpenTelemetry tracing. When *None* (default) + the ``AGENT_ENABLE_TRACING`` env var is consulted (``"true"`` to enable). + Requires ``opentelemetry-api`` — install with + ``pip install azure-ai-agentserver[tracing]``. + When an Application Insights connection string is also available, + traces and logs are automatically exported to Azure Monitor. + :type enable_tracing: Optional[bool] + :param application_insights_connection_string: Application Insights + connection string for exporting traces and logs to Azure Monitor. + When *None* (default) the ``APPLICATIONINSIGHTS_CONNECTION_STRING`` + env var is consulted. Only takes effect when *enable_tracing* is + ``True``. Requires ``opentelemetry-sdk`` and + ``azure-monitor-opentelemetry-exporter`` (included in the + ``[tracing]`` extras group). + :type application_insights_connection_string: Optional[str] + :param graceful_shutdown_timeout: Seconds to wait for in-flight requests to + complete after receiving SIGTERM / shutdown signal. When *None* (default) + the ``AGENT_GRACEFUL_SHUTDOWN_TIMEOUT`` env var is consulted; if that is + also unset the default is 30 seconds. Set to ``0`` to disable the + drain period. + :type graceful_shutdown_timeout: Optional[int] + :param request_timeout: Maximum seconds an ``invoke()`` call may run before + being cancelled. When *None* (default) the ``AGENT_REQUEST_TIMEOUT`` + env var is consulted; if that is also unset the default is 300 seconds + (5 minutes). Set to ``0`` to disable the timeout. + :type request_timeout: Optional[int] + :param log_level: Library log level (e.g. ``"DEBUG"``, ``"INFO"``). When + *None* (default) the ``AGENT_LOG_LEVEL`` env var is consulted; if that + is also unset the default is ``"INFO"``. + :type log_level: Optional[str] + :param debug_errors: When *True*, error responses include the original + exception message instead of a generic ``"Internal server error"``. + When *None* (default) the ``AGENT_DEBUG_ERRORS`` env var is consulted + (any truthy value enables it). Defaults to *False*. + :type debug_errors: Optional[bool] + """ + + def __init__( + self, + *, + openapi_spec: Optional[dict[str, Any]] = None, + enable_request_validation: Optional[bool] = None, + enable_tracing: Optional[bool] = None, + application_insights_connection_string: Optional[str] = None, + graceful_shutdown_timeout: Optional[int] = None, + request_timeout: Optional[int] = None, + log_level: Optional[str] = None, + debug_errors: Optional[bool] = None, + ) -> None: + # Decorator handler slots ------------------------------------------ + self._invoke_fn: Optional[Callable] = None + self._get_invocation_fn: Optional[Callable] = None + self._cancel_invocation_fn: Optional[Callable] = None + self._shutdown_fn: Optional[Callable] = None + + # Logging & debug ------------------------------------------------- + resolved_level = _config.resolve_log_level(log_level) + logger.setLevel(resolved_level) + if not logger.handlers: + _console = logging.StreamHandler() + _console.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")) + logger.addHandler(_console) + self._debug_errors = _config.resolve_bool_feature( + debug_errors, Constants.AGENT_DEBUG_ERRORS + ) + + self._openapi_spec = openapi_spec + _validation_on = _config.resolve_bool_feature( + enable_request_validation, Constants.AGENT_ENABLE_REQUEST_VALIDATION + ) + self._validator: Optional[OpenApiValidator] = ( + OpenApiValidator(openapi_spec) + if openapi_spec and _validation_on + else None + ) + _tracing_on = _config.resolve_bool_feature(enable_tracing, Constants.AGENT_ENABLE_TRACING) + _conn_str = _config.resolve_appinsights_connection_string( + application_insights_connection_string + ) if _tracing_on else None + self._tracing: Optional[TracingHelper] = ( + TracingHelper(connection_string=_conn_str) if _tracing_on else None + ) + self._graceful_shutdown_timeout = _config.resolve_graceful_shutdown_timeout( + graceful_shutdown_timeout + ) + self._request_timeout = _config.resolve_request_timeout(request_timeout) + self.app: Starlette + self._build_app() + + # ------------------------------------------------------------------ + # Handler decorators + # ------------------------------------------------------------------ + + def invoke_handler( + self, fn: Callable[[Request], Awaitable[Response]] + ) -> Callable[[Request], Awaitable[Response]]: + """Register a function as the invoke handler. + + Usage:: + + @server.invoke_handler + async def handle(request: Request) -> Response: + ... + + :param fn: Async function accepting a Starlette Request and returning a Response. + :type fn: Callable[[Request], Awaitable[Response]] + :return: The original function (unmodified). + :rtype: Callable[[Request], Awaitable[Response]] + """ + self._invoke_fn = fn + return fn + + def get_invocation_handler( + self, fn: Callable[[Request], Awaitable[Response]] + ) -> Callable[[Request], Awaitable[Response]]: + """Register a function as the get-invocation handler. + + :param fn: Async function accepting a Starlette Request and returning a Response. + :type fn: Callable[[Request], Awaitable[Response]] + :return: The original function (unmodified). + :rtype: Callable[[Request], Awaitable[Response]] + """ + self._get_invocation_fn = fn + return fn + + def cancel_invocation_handler( + self, fn: Callable[[Request], Awaitable[Response]] + ) -> Callable[[Request], Awaitable[Response]]: + """Register a function as the cancel-invocation handler. + + :param fn: Async function accepting a Starlette Request and returning a Response. + :type fn: Callable[[Request], Awaitable[Response]] + :return: The original function (unmodified). + :rtype: Callable[[Request], Awaitable[Response]] + """ + self._cancel_invocation_fn = fn + return fn + + def shutdown_handler(self, fn: Callable[[], Awaitable[None]]) -> Callable[[], Awaitable[None]]: + """Register a function as the shutdown handler. + + :param fn: Async function called during graceful shutdown. + :type fn: Callable[[], Awaitable[None]] + :return: The original function (unmodified). + :rtype: Callable[[], Awaitable[None]] + """ + self._shutdown_fn = fn + return fn + + # ------------------------------------------------------------------ + # Dispatch methods (internal) + # ------------------------------------------------------------------ + + async def _dispatch_invoke(self, request: Request) -> Response: + """Dispatch to the registered invoke handler.""" + if self._invoke_fn is not None: + return await self._invoke_fn(request) + raise RuntimeError( + "No invoke handler registered. Use the @server.invoke_handler decorator." + ) + + async def _dispatch_get_invocation(self, request: Request) -> Response: + """Dispatch to the registered get-invocation handler, or return 404.""" + if self._get_invocation_fn is not None: + return await self._get_invocation_fn(request) + return error_response("not_supported", "get_invocation not supported", status_code=404) + + async def _dispatch_cancel_invocation(self, request: Request) -> Response: + """Dispatch to the registered cancel-invocation handler, or return 404.""" + if self._cancel_invocation_fn is not None: + return await self._cancel_invocation_fn(request) + return error_response("not_supported", "cancel_invocation not supported", status_code=404) + + async def _dispatch_shutdown(self) -> None: + """Dispatch to the registered shutdown handler, or no-op.""" + if self._shutdown_fn is not None: + await self._shutdown_fn() + + def get_openapi_spec(self) -> Optional[dict[str, Any]]: + """Return the OpenAPI spec dict for this agent, or None. + + :return: The registered OpenAPI spec or None. + :rtype: Optional[dict[str, Any]] + """ + return self._openapi_spec + + # ------------------------------------------------------------------ + # Run helpers + # ------------------------------------------------------------------ + + def _build_hypercorn_config(self, host: str, port: int) -> object: + """Create a Hypercorn config with resolved host, port and timeouts. + + :param host: Network interface to bind. + :type host: str + :param port: Port to bind. + :type port: int + :return: Configured Hypercorn config. + :rtype: hypercorn.config.Config + """ + from hypercorn.config import Config as HypercornConfig + + config = HypercornConfig() + config.bind = [f"{host}:{port}"] + config.graceful_timeout = float(self._graceful_shutdown_timeout) + return config + + def run(self, host: str = "127.0.0.1", port: Optional[int] = None) -> None: + """Start the server synchronously. + + Uses Hypercorn as the ASGI server, which supports HTTP/1.1 and HTTP/2. + + :param host: Network interface to bind. Defaults to ``127.0.0.1``. + Use ``"0.0.0.0"`` to listen on all interfaces. + :type host: str + :param port: Port to bind. Defaults to ``AGENT_SERVER_PORT`` env var or 8088. + :type port: Optional[int] + """ + from hypercorn.asyncio import serve as _hypercorn_serve + + resolved_port = _config.resolve_port(port) + logger.info("AgentServer starting on %s:%s", host, resolved_port) + config = self._build_hypercorn_config(host, resolved_port) + asyncio.run(_hypercorn_serve(self.app, config)) # type: ignore[arg-type] # Starlette is ASGI-compatible + + async def run_async(self, host: str = "127.0.0.1", port: Optional[int] = None) -> None: + """Start the server asynchronously (awaitable). + + Uses Hypercorn as the ASGI server, which supports HTTP/1.1 and HTTP/2. + + :param host: Network interface to bind. Defaults to ``127.0.0.1``. + Use ``"0.0.0.0"`` to listen on all interfaces. + :type host: str + :param port: Port to bind. Defaults to ``AGENT_SERVER_PORT`` env var or 8088. + :type port: Optional[int] + """ + from hypercorn.asyncio import serve as _hypercorn_serve + + resolved_port = _config.resolve_port(port) + logger.info("AgentServer starting on %s:%s (async)", host, resolved_port) + config = self._build_hypercorn_config(host, resolved_port) + await _hypercorn_serve(self.app, config) # type: ignore[arg-type] # Starlette is ASGI-compatible + + # ------------------------------------------------------------------ + # Private: app construction + # ------------------------------------------------------------------ + + def _build_app(self) -> None: + """Construct the Starlette ASGI application with all routes.""" + + @contextlib.asynccontextmanager + async def _lifespan(_app: Starlette) -> AsyncGenerator[None, None]: # noqa: RUF029 + logger.info("AgentServer started") + yield + + # --- SHUTDOWN: runs once when the server is stopping --- + logger.info( + "AgentServer shutting down (graceful timeout=%ss)", + self._graceful_shutdown_timeout, + ) + try: + await asyncio.wait_for( + self._dispatch_shutdown(), + timeout=self._graceful_shutdown_timeout or None, + ) + except asyncio.TimeoutError: + logger.warning( + "on_shutdown did not complete within %ss timeout", + self._graceful_shutdown_timeout, + ) + except Exception: # pylint: disable=broad-exception-caught + logger.exception("Error in on_shutdown") + + routes = [ + Route( + "/invocations/docs/openapi.json", + self._get_openapi_spec_endpoint, + methods=["GET"], + name="get_openapi_spec", + ), + Route( + "/invocations", + self._create_invocation_endpoint, + methods=["POST"], + name="create_invocation", + ), + Route( + "/invocations/{invocation_id}", + self._get_invocation_endpoint, + methods=["GET"], + name="get_invocation", + ), + Route( + "/invocations/{invocation_id}/cancel", + self._cancel_invocation_endpoint, + methods=["POST"], + name="cancel_invocation", + ), + Route("/liveness", self._liveness_endpoint, methods=["GET"], name="liveness"), + Route("/readiness", self._readiness_endpoint, methods=["GET"], name="readiness"), + ] + + self.app = Starlette(routes=routes, lifespan=_lifespan) + + # ------------------------------------------------------------------ + # Private: endpoint handlers + # ------------------------------------------------------------------ + + async def _get_openapi_spec_endpoint(self, request: Request) -> Response: # pylint: disable=unused-argument + """GET /invocations/docs/openapi.json — return registered spec or 404. + + :param request: The incoming Starlette request. + :type request: Request + :return: JSON response with the spec or 404. + :rtype: Response + """ + spec = self.get_openapi_spec() + if spec is None: + return error_response("not_found", "No OpenAPI spec registered", status_code=404) + return JSONResponse(spec) + + async def _create_invocation_endpoint(self, request: Request) -> Response: + """POST /invocations — create and process an invocation. + + :param request: The incoming Starlette request. + :type request: Request + :return: The invocation result or error response. + :rtype: Response + """ + invocation_id = str(uuid.uuid4()) + request.state.invocation_id = invocation_id + + # Validate request body against OpenAPI spec + if self._validator is not None: + content_type = request.headers.get("content-type", "application/json") + body = await request.body() + errors = self._validator.validate_request(body, content_type) + if errors: + return error_response( + "invalid_payload", + "Request validation failed", + status_code=400, + details=[ + {"code": "validation_error", "message": e} + for e in errors + ], + ) + + carrier = extract_w3c_carrier(request.headers) + + # Use manual span management so that streaming responses keep the + # span open until the last chunk is yielded (or an error occurs). + otel_span = ( + self._tracing.start_span( + "AgentServer.invoke", + attributes={"invocation.id": invocation_id}, + carrier=carrier, + ) + if self._tracing is not None + else None + ) + try: + invoke_awaitable = self._dispatch_invoke(request) + timeout = self._request_timeout or None # 0 → None (no limit) + response = await asyncio.wait_for(invoke_awaitable, timeout=timeout) + except asyncio.TimeoutError: + if self._tracing is not None: + self._tracing.end_span(otel_span) + logger.error( + "Invocation %s timed out after %ss", + invocation_id, + self._request_timeout, + ) + return error_response( + "request_timeout", + f"Invocation timed out after {self._request_timeout}s", + status_code=504, + headers={Constants.INVOCATION_ID_HEADER: invocation_id}, + ) + except Exception as exc: # pylint: disable=broad-exception-caught + if self._tracing is not None: + self._tracing.end_span(otel_span, exc=exc) + logger.error("Error processing invocation %s: %s", invocation_id, exc, exc_info=True) + message = str(exc) if self._debug_errors else "Internal server error" + return error_response( + "internal_error", + message, + status_code=500, + headers={Constants.INVOCATION_ID_HEADER: invocation_id}, + ) + + # For streaming responses, wrap the body iterator so the span stays + # open until all chunks are sent and captures any streaming errors. + if isinstance(response, StreamingResponse) and self._tracing is not None: + response.body_iterator = self._tracing.trace_stream(response.body_iterator, otel_span) + elif self._tracing is not None: + self._tracing.end_span(otel_span) + + # Auto-inject invocation_id header if developer didn't set it + if Constants.INVOCATION_ID_HEADER not in response.headers: + response.headers[Constants.INVOCATION_ID_HEADER] = invocation_id + + return response + + async def _get_invocation_endpoint(self, request: Request) -> Response: + """GET /invocations/{invocation_id} — retrieve an invocation result. + + :param request: The incoming Starlette request. + :type request: Request + :return: The stored result or 404. + :rtype: Response + """ + invocation_id = request.path_params["invocation_id"] + request.state.invocation_id = invocation_id + carrier = extract_w3c_carrier(request.headers) + span_cm = ( + self._tracing.span( + "AgentServer.get_invocation", + attributes={"invocation.id": invocation_id}, + carrier=carrier, + ) + if self._tracing is not None + else contextlib.nullcontext(None) + ) + with span_cm as _otel_span: + try: + return await self._dispatch_get_invocation(request) + except Exception as exc: # pylint: disable=broad-exception-caught + if self._tracing is not None: + self._tracing.record_error(_otel_span, exc) + logger.error("Error in get_invocation %s: %s", invocation_id, exc, exc_info=True) + message = str(exc) if self._debug_errors else "Internal server error" + return error_response( + "internal_error", + message, + status_code=500, + ) + + async def _cancel_invocation_endpoint(self, request: Request) -> Response: + """POST /invocations/{invocation_id}/cancel — cancel an invocation. + + :param request: The incoming Starlette request. + :type request: Request + :return: The cancellation result or 404. + :rtype: Response + """ + invocation_id = request.path_params["invocation_id"] + request.state.invocation_id = invocation_id + carrier = extract_w3c_carrier(request.headers) + span_cm = ( + self._tracing.span( + "AgentServer.cancel_invocation", + attributes={"invocation.id": invocation_id}, + carrier=carrier, + ) + if self._tracing is not None + else contextlib.nullcontext(None) + ) + with span_cm as _otel_span: + try: + return await self._dispatch_cancel_invocation(request) + except Exception as exc: # pylint: disable=broad-exception-caught + if self._tracing is not None: + self._tracing.record_error(_otel_span, exc) + logger.error("Error in cancel_invocation %s: %s", invocation_id, exc, exc_info=True) + message = str(exc) if self._debug_errors else "Internal server error" + return error_response( + "internal_error", + message, + status_code=500, + ) + + async def _liveness_endpoint(self, request: Request) -> Response: # pylint: disable=unused-argument + """GET /liveness — health check. + + :param request: The incoming Starlette request. + :type request: Request + :return: 200 OK response. + :rtype: Response + """ + return JSONResponse({"status": "alive"}) + + async def _readiness_endpoint(self, request: Request) -> Response: # pylint: disable=unused-argument + """GET /readiness — readiness check. + + :param request: The incoming Starlette request. + :type request: Request + :return: 200 OK response. + :rtype: Response + """ + return JSONResponse({"status": "ready"}) + + diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_config.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_config.py new file mode 100644 index 000000000000..1102d0395a73 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_config.py @@ -0,0 +1,180 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Configuration resolution helpers for AgentServer. + +Each ``resolve_*`` function follows the same hierarchy: +1. Explicit argument (if not *None*) +2. Environment variable +3. Built-in default + +A value of ``0`` conventionally disables the corresponding feature. + +Invalid environment variable values raise ``ValueError`` immediately so +misconfiguration is surfaced at startup rather than silently masked. +""" +import os +from typing import Optional + +from .._constants import Constants + + +def _parse_int_env(var_name: str) -> Optional[int]: + """Parse an integer environment variable, raising on invalid values. + + :param var_name: Name of the environment variable. + :type var_name: str + :return: The parsed integer or None if the variable is not set. + :rtype: Optional[int] + :raises ValueError: If the variable is set but cannot be parsed as an integer. + """ + raw = os.environ.get(var_name) + if not raw: + return None + try: + return int(raw) + except ValueError: + raise ValueError( + f"Invalid value for {var_name}: {raw!r} (expected an integer)" + ) from None + + +def _require_int(name: str, value: object) -> int: + """Validate that *value* is an integer. + + :param name: Human-readable parameter/env-var name for the error message. + :type name: str + :param value: The value to validate. + :type value: object + :return: The value cast to int. + :rtype: int + :raises ValueError: If *value* is not an integer. + """ + if not isinstance(value, int): + raise ValueError( + f"Invalid value for {name}: {value!r} (expected an integer)" + ) + return value + + +def resolve_port(port: Optional[int]) -> int: + """Resolve the server port from argument, env var, or default. + + :param port: Explicitly requested port or None. + :type port: Optional[int] + :return: The resolved port number. + :rtype: int + :raises ValueError: If the port value is not a valid integer or is outside 1-65535. + """ + if port is not None: + result = _require_int("port", port) + if not 1 <= result <= 65535: + raise ValueError( + f"Invalid value for port: {result} (expected 1-65535)" + ) + return result + env_port = _parse_int_env(Constants.AGENT_SERVER_PORT) + if env_port is not None: + if not 1 <= env_port <= 65535: + raise ValueError( + f"Invalid value for {Constants.AGENT_SERVER_PORT}: {env_port} (expected 1-65535)" + ) + return env_port + return Constants.DEFAULT_PORT + + +def resolve_graceful_shutdown_timeout(timeout: Optional[int]) -> int: + """Resolve the graceful shutdown timeout from argument, env var, or default. + + :param timeout: Explicitly requested timeout or None. + :type timeout: Optional[int] + :return: The resolved timeout in seconds. + :rtype: int + :raises ValueError: If the env var is not a valid integer. + """ + if timeout is not None: + return max(0, _require_int("graceful_shutdown_timeout", timeout)) + env_timeout = _parse_int_env(Constants.AGENT_GRACEFUL_SHUTDOWN_TIMEOUT) + if env_timeout is not None: + return max(0, env_timeout) + return Constants.DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT + + +def resolve_request_timeout(timeout: Optional[int]) -> int: + """Resolve the request timeout from argument, env var, or default. + + :param timeout: Explicitly requested timeout in seconds or None. + :type timeout: Optional[int] + :return: The resolved timeout in seconds. + :rtype: int + :raises ValueError: If the env var is not a valid integer. + """ + if timeout is not None: + return max(0, _require_int("request_timeout", timeout)) + env_timeout = _parse_int_env(Constants.AGENT_REQUEST_TIMEOUT) + if env_timeout is not None: + return max(0, env_timeout) + return Constants.DEFAULT_REQUEST_TIMEOUT + + +def resolve_bool_feature(value: Optional[bool], env_var: str) -> bool: + """Resolve an opt-in boolean feature from argument, env var, or default (False). + + :param value: Explicitly requested value or None. + :type value: Optional[bool] + :param env_var: Name of the environment variable to consult. + :type env_var: str + :return: Whether the feature is enabled. + :rtype: bool + """ + if value is not None: + return bool(value) + return os.environ.get(env_var, "").lower() == "true" + + +_VALID_LOG_LEVELS = ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL") + + +def resolve_appinsights_connection_string( + connection_string: Optional[str], +) -> Optional[str]: + """Resolve the Application Insights connection string. + + Resolution order: + + 1. Explicit *connection_string* argument (if not *None*). + 2. ``APPLICATIONINSIGHTS_CONNECTION_STRING`` env var (standard Azure + Monitor convention). + 3. *None* — no connection string available. + + :param connection_string: Explicitly provided connection string or None. + :type connection_string: Optional[str] + :return: The resolved connection string, or None. + :rtype: Optional[str] + """ + if connection_string is not None: + return connection_string + return os.environ.get( + Constants.APPLICATIONINSIGHTS_CONNECTION_STRING + ) + + +def resolve_log_level(level: Optional[str]) -> str: + """Resolve the library log level from argument, env var, or default (``INFO``). + + :param level: Explicitly requested level (e.g. ``"DEBUG"``) or None. + :type level: Optional[str] + :return: Validated, upper-cased log level string. + :rtype: str + :raises ValueError: If the value is not one of DEBUG/INFO/WARNING/ERROR/CRITICAL. + """ + if level is not None: + normalized = level.upper() + else: + normalized = os.environ.get(Constants.AGENT_LOG_LEVEL, "INFO").upper() + if normalized not in _VALID_LOG_LEVELS: + raise ValueError( + f"Invalid log level: {normalized!r} " + f"(expected one of {', '.join(_VALID_LOG_LEVELS)})" + ) + return normalized diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/__init__.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/__init__.py new file mode 100644 index 000000000000..d540fd20468c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/__init__.py @@ -0,0 +1,3 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/_openapi_validator.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/_openapi_validator.py new file mode 100644 index 000000000000..3cee55dd426d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/_openapi_validator.py @@ -0,0 +1,695 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# +# Why not use openapi-spec-validator / openapi-schema-validator / openapi-core? +# +# - openapi-spec-validator validates that an OpenAPI *document* is well-formed +# (meta-validation). It does not validate request/response data. +# +# - openapi-schema-validator (OAS30Validator) handles nullable, readOnly/ +# writeOnly, and OpenAPI keyword stripping natively, but does NOT provide: +# * $ref resolution from the full spec into extracted sub-schemas, +# * discriminator-aware oneOf/anyOf error collection (ported from the +# server-side C# JsonSchemaValidator), +# * JSON-path-prefixed error messages. +# Adopting it would save ~65 lines of preprocessing but adds three +# transitive dependencies (jsonschema-specifications, rfc3339-validator, +# jsonschema-path) while still requiring all custom error-collection code. +# +# - openapi-core is a full request/response middleware framework with its +# own routing and parsing. It conflicts with our Starlette middleware +# approach and is a much heavier dependency. +# +# Keeping only jsonschema as the single validation dependency gives us full +# control over error output and avoids unnecessary transitive packages. +# +import copy +import json +import logging +import re +from collections import Counter +from datetime import datetime, timezone +from typing import Any, Optional + +import jsonschema +from jsonschema import FormatChecker, ValidationError + +logger = logging.getLogger("azure.ai.agentserver") + +# --------------------------------------------------------------------------- +# Stdlib-only format checkers so we never depend on optional jsonschema extras +# (rfc3339-validator, fqdn, …). Registered on a module-level instance that +# is reused for every validation call. +# --------------------------------------------------------------------------- +_format_checker = FormatChecker(formats=()) # start with no built-in checks + +_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$") + + +@_format_checker.checks("date-time", raises=ValueError) +def _check_datetime(value: object) -> bool: + """Validate RFC 3339 / ISO 8601 date-time strings using stdlib only.""" + if not isinstance(value, str): + return True # non-string is not a format error + normalized = value.replace("Z", "+00:00") if value.endswith("Z") else value + datetime.fromisoformat(normalized) + return True + + +@_format_checker.checks("date", raises=ValueError) +def _check_date(value: object) -> bool: + """Validate ISO 8601 date strings (YYYY-MM-DD).""" + if not isinstance(value, str): + return True + datetime.strptime(value, "%Y-%m-%d") + return True + + +@_format_checker.checks("email", raises=ValueError) +def _check_email(value: object) -> bool: + """Basic RFC 5322 email format check (no DNS lookup).""" + if not isinstance(value, str): + return True + if not _EMAIL_RE.match(value): + raise ValueError(f"Invalid email: {value!r}") + return True + +# OpenAPI keywords that are not part of JSON Schema and must be stripped +# before handing a schema to a JSON Schema validator. +_OPENAPI_ONLY_KEYWORDS: frozenset[str] = frozenset( + {"discriminator", "xml", "externalDocs", "example"} +) + + +class OpenApiValidator: + """Validates request/response bodies against an OpenAPI spec. + + Extracts the request and response JSON schemas from the provided OpenAPI spec dict + and uses ``jsonschema`` to validate bodies at runtime. + + :param spec: An OpenAPI spec dictionary. + :type spec: dict[str, Any] + """ + + def __init__(self, spec: dict[str, Any], path: str = "/invocations") -> None: + self._spec = spec + self._path = path + self._request_body_required = self._is_request_body_required(spec, path) + raw_request = self._extract_request_schema(spec, path) + raw_response = self._extract_response_schema(spec, path) + self._request_schema = ( + self._preprocess_schema(raw_request, context="request") + if raw_request + else None + ) + self._response_schema = ( + self._preprocess_schema(raw_response, context="response") + if raw_response + else None + ) + + def validate_request(self, body: bytes, content_type: str) -> list[str]: + """Validate a request body against the spec's request schema. + + :param body: Raw request body bytes. + :type body: bytes + :param content_type: The Content-Type header value. + :type content_type: str + :return: List of validation error messages. Empty when valid. + :rtype: list[str] + """ + if self._request_schema is None: + return [] + # If requestBody.required is false, allow empty bodies + if not self._request_body_required and body.strip() == b"": + return [] + return self._validate_body(body, content_type, self._request_schema) + + def validate_response(self, body: bytes, content_type: str) -> list[str]: + """Validate a response body against the spec's response schema. + + :param body: Raw response body bytes. + :type body: bytes + :param content_type: The Content-Type header value. + :type content_type: str + :return: List of validation error messages. Empty when valid. + :rtype: list[str] + """ + if self._response_schema is None: + return [] + return self._validate_body(body, content_type, self._response_schema) + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + @staticmethod + def _validate_body(body: bytes, content_type: str, schema: dict[str, Any]) -> list[str]: + """Parse body as JSON and validate against *schema*. + + Uses discriminator-aware error collection for ``oneOf`` / ``anyOf`` + schemas: when a discriminator property is detected across branches the + validator selects the matching branch and reports only *its* errors, + avoiding the noisy dump of every branch. + + Error messages are prefixed with the JSON-path of the failing element + (e.g. ``$.items[0].type: ...``) so callers can pinpoint the problem. + + :param body: Raw bytes to validate. + :type body: bytes + :param content_type: The Content-Type header value. + :type content_type: str + :param schema: JSON Schema dict to validate against. + :type schema: dict[str, Any] + :return: List of validation error strings. Empty when valid. + :rtype: list[str] + """ + if "json" not in content_type.lower(): + return [] # skip validation for non-JSON payloads + + try: + data = json.loads(body) + except (json.JSONDecodeError, UnicodeDecodeError) as exc: + return [f"Invalid JSON body: {exc}"] + + errors: list[str] = [] + validator = jsonschema.Draft7Validator( + schema, format_checker=_format_checker + ) + for error in validator.iter_errors(data): + errors.extend(_collect_errors(error)) + return errors + + @staticmethod + def _extract_request_schema(spec: dict[str, Any], path: str) -> Optional[dict[str, Any]]: + """Extract the request body JSON schema from the POST operation at *path*. + + :param spec: OpenAPI spec dictionary. + :type spec: dict[str, Any] + :param path: The API path (e.g. ``/invocations``). + :type path: str + :return: JSON Schema dict or None. + :rtype: Optional[dict[str, Any]] + """ + return OpenApiValidator._find_schema_in_paths( + spec, path, "post", "requestBody" + ) + + @staticmethod + def _extract_response_schema(spec: dict[str, Any], path: str) -> Optional[dict[str, Any]]: + """Extract the response body JSON schema from the POST operation at *path*. + + :param spec: OpenAPI spec dictionary. + :type spec: dict[str, Any] + :param path: The API path (e.g. ``/invocations``). + :type path: str + :return: JSON Schema dict or None. + :rtype: Optional[dict[str, Any]] + """ + return OpenApiValidator._find_schema_in_paths( + spec, path, "post", "responses" + ) + + @staticmethod + def _find_schema_in_paths( + spec: dict[str, Any], + path: str, + method: str, + section: str, + ) -> Optional[dict[str, Any]]: + """Walk the spec to find a JSON schema for the given path/method/section. + + :param spec: OpenAPI spec dictionary. + :type spec: dict[str, Any] + :param path: The API path (e.g. ``/invocations``). + :type path: str + :param method: HTTP method (e.g. ``post``). + :type method: str + :param section: Either ``requestBody`` or ``responses``. + :type section: str + :return: Resolved JSON Schema dict or None. + :rtype: Optional[dict[str, Any]] + """ + paths = spec.get("paths", {}) + operation = paths.get(path, {}).get(method, {}) + + if section == "requestBody": + request_body = operation.get("requestBody", {}) + content = request_body.get("content", {}) + json_content = content.get("application/json", {}) + schema = json_content.get("schema") + return _resolve_refs_deep(spec, schema) if schema else None + + if section == "responses": + responses = operation.get("responses", {}) + # Try 200, then 201, then first available + for code in ("200", "201"): + resp = responses.get(code, {}) + content = resp.get("content", {}) + json_content = content.get("application/json", {}) + schema = json_content.get("schema") + if schema: + return _resolve_refs_deep(spec, schema) + # Fallback: first response with JSON content + for resp in responses.values(): + if isinstance(resp, dict): + content = resp.get("content", {}) + json_content = content.get("application/json", {}) + schema = json_content.get("schema") + if schema: + return _resolve_refs_deep(spec, schema) + return None + + @staticmethod + def _is_request_body_required(spec: dict[str, Any], path: str) -> bool: + """Check whether ``requestBody.required`` is true for the POST operation. + + :param spec: OpenAPI spec dictionary. + :type spec: dict[str, Any] + :param path: The API path (e.g. ``/invocations``). + :type path: str + :return: True if the request body is explicitly required (default True). + :rtype: bool + """ + paths = spec.get("paths", {}) + operation = paths.get(path, {}).get("post", {}) + request_body = operation.get("requestBody", {}) + return request_body.get("required", True) + + @staticmethod + def _preprocess_schema( + schema: dict[str, Any], context: str = "request" + ) -> dict[str, Any]: + """Convert OpenAPI-specific keywords into pure JSON Schema. + + Performs a deep copy then applies the following transformations: + 1. ``nullable: true`` → ``type: [originalType, "null"]`` + 2. Strip ``readOnly`` properties in request context, + ``writeOnly`` in response context. + 3. Remove OpenAPI-only keywords (``discriminator``, ``xml``, etc.). + + :param schema: Resolved JSON Schema dict (may contain OpenAPI extensions). + :type schema: dict[str, Any] + :param context: Either ``"request"`` or ``"response"``. + :type context: str + :return: A pure JSON Schema dict safe for ``jsonschema`` validation. + :rtype: dict[str, Any] + """ + schema = copy.deepcopy(schema) + OpenApiValidator._apply_nullable(schema) + OpenApiValidator._strip_readonly_writeonly(schema, context) + OpenApiValidator._strip_openapi_keywords(schema) + return schema + + @staticmethod + def _apply_nullable(schema: dict[str, Any]) -> None: + """Convert ``nullable: true`` to JSON Schema union type in-place. + + Walks the schema tree and transforms + ``{"type": "string", "nullable": true}`` into + ``{"type": ["string", "null"]}``. + + :param schema: Schema dict to mutate in-place. + :type schema: dict[str, Any] + """ + if not isinstance(schema, dict): + return + if schema.pop("nullable", False): + original = schema.get("type") + if isinstance(original, str): + schema["type"] = [original, "null"] + elif isinstance(original, list) and "null" not in original: + schema["type"] = original + ["null"] + # Recurse into nested structures + for key in ("items", "additionalProperties"): + child = schema.get(key) + if isinstance(child, dict): + OpenApiValidator._apply_nullable(child) + for prop in schema.get("properties", {}).values(): + if isinstance(prop, dict): + OpenApiValidator._apply_nullable(prop) + for keyword in ("allOf", "oneOf", "anyOf"): + for sub in schema.get(keyword, []): + if isinstance(sub, dict): + OpenApiValidator._apply_nullable(sub) + + @staticmethod + def _strip_readonly_writeonly( + schema: dict[str, Any], context: str + ) -> None: + """Remove readOnly/writeOnly properties based on context. + + In **request** context, properties marked ``readOnly: true`` are removed + from ``properties`` and from the ``required`` list (the server generates + them; clients should not send them). + + In **response** context, ``writeOnly: true`` properties are removed + (e.g. passwords that the client sends but the server never echoes back). + + :param schema: Schema dict to mutate in-place. + :type schema: dict[str, Any] + :param context: ``"request"`` or ``"response"``. + :type context: str + """ + if not isinstance(schema, dict): + return + props = schema.get("properties", {}) + required = schema.get("required", []) + to_remove: list[str] = [] + for name, prop_schema in props.items(): + if not isinstance(prop_schema, dict): + continue + if context == "request" and prop_schema.get("readOnly"): + to_remove.append(name) + elif context == "response" and prop_schema.get("writeOnly"): + to_remove.append(name) + for name in to_remove: + props.pop(name, None) + if name in required: + required.remove(name) + # Recurse into nested objects + for prop in props.values(): + if isinstance(prop, dict): + OpenApiValidator._strip_readonly_writeonly(prop, context) + child = schema.get("items") + if isinstance(child, dict): + OpenApiValidator._strip_readonly_writeonly(child, context) + for keyword in ("allOf", "oneOf", "anyOf"): + for sub in schema.get(keyword, []): + if isinstance(sub, dict): + OpenApiValidator._strip_readonly_writeonly(sub, context) + + @staticmethod + def _strip_openapi_keywords(schema: dict[str, Any]) -> None: + """Remove OpenAPI-only keywords that confuse JSON Schema validators. + + Strips ``discriminator``, ``xml``, ``externalDocs``, and ``example`` + from the schema tree in-place. + + :param schema: Schema dict to mutate in-place. + :type schema: dict[str, Any] + """ + if not isinstance(schema, dict): + return + for kw in _OPENAPI_ONLY_KEYWORDS: + schema.pop(kw, None) + for key in ("items", "additionalProperties"): + child = schema.get(key) + if isinstance(child, dict): + OpenApiValidator._strip_openapi_keywords(child) + for prop in schema.get("properties", {}).values(): + if isinstance(prop, dict): + OpenApiValidator._strip_openapi_keywords(prop) + for keyword in ("allOf", "oneOf", "anyOf"): + for sub in schema.get(keyword, []): + if isinstance(sub, dict): + OpenApiValidator._strip_openapi_keywords(sub) + + +# ------------------------------------------------------------------ +# Discriminator-aware error collection helpers +# ------------------------------------------------------------------ + + +def _format_error(error: ValidationError) -> str: + """Format a single validation error with its JSON path prefix. + + :param error: A ``jsonschema`` validation error. + :type error: ValidationError + :return: Human-readable error string, optionally path-prefixed. + :rtype: str + """ + path = error.json_path + if path and path != "$": + return f"{path}: {error.message}" + return error.message + + +def _collect_errors(error: ValidationError) -> list[str]: + """Collect formatted error messages from a ``ValidationError``. + + For ``oneOf`` / ``anyOf`` errors the helper attempts discriminator-aware + branch selection (mirroring the server-side C# ``JsonSchemaValidator``). + When a discriminator property is detected, only the *matching* branch's + errors are reported, avoiding a noisy dump of every branch. + + :param error: A ``jsonschema`` validation error. + :type error: ValidationError + :return: List of formatted error strings. + :rtype: list[str] + """ + if error.validator in ("oneOf", "anyOf") and error.context: + return _collect_composition_errors(error) + return [_format_error(error)] + + +def _collect_composition_errors(error: ValidationError) -> list[str]: + """Handle ``oneOf`` / ``anyOf`` errors with discriminator-based branch selection. + + Algorithm (ported from the C# ``JsonSchemaValidator``): + + 1. Group errors by branch index (``schema_path[0]``). + 2. Detect a *discriminator path*: a ``const`` / ``type`` / ``enum`` + error that appears at the same ``absolute_path`` in the majority of + branches (threshold: ``max(2, ceil(n/2))``). + 3. The *matching branch* is the one **without** a discriminator error + at that path. + 4. Report only the matching branch's errors. If no branch matches, + report a concise ``"Invalid value. Expected one of: ..."`` message. + + Falls back to ``best_match`` when no discriminator can be identified. + + :param error: A ``oneOf`` / ``anyOf`` validation error. + :type error: ValidationError + :return: List of formatted error strings. + :rtype: list[str] + """ + # Group sub-errors by branch index (first element of schema_path) + branch_groups: dict[int, list[ValidationError]] = {} + for sub in error.context: + if sub.schema_path: + idx = sub.schema_path[0] + if isinstance(idx, int): + branch_groups.setdefault(idx, []).append(sub) + + if len(branch_groups) < 2: + # Cannot do branch analysis — fallback + best = jsonschema.exceptions.best_match([error]) + if best is not None and best is not error: + return _collect_errors(best) + return [_format_error(error)] + + disc_path = _find_discriminator_path(branch_groups) + + if disc_path is None: + best = jsonschema.exceptions.best_match([error]) + if best is not None and best is not error: + return _collect_errors(best) + return [_format_error(error)] + + # Find the branch that matches the discriminator value + matching_idx = _find_matching_branch(branch_groups, disc_path) + + if matching_idx is not None: + result: list[str] = [] + for sub in branch_groups[matching_idx]: + result.extend(_collect_errors(sub)) + return result if result else [_format_error(error)] + + # No matching branch — report the discriminator mismatch + return _report_discriminator_error(branch_groups, disc_path, error) + + +_DISCRIMINATOR_VALIDATORS: frozenset[str] = frozenset({"const", "type", "enum"}) + + +def _find_discriminator_path( + branch_groups: dict[int, list[ValidationError]], +) -> Optional[tuple[str | int, ...]]: + """Detect a discriminator property across ``oneOf`` / ``anyOf`` branches. + + A discriminator is a property whose ``const``, ``enum``, or ``type`` + constraint fails at the same ``absolute_path`` in the majority of + branches. + + :param branch_groups: Errors grouped by branch index. + :type branch_groups: dict[int, list[ValidationError]] + :return: The ``absolute_path`` of the discriminator as a tuple, or *None*. + :rtype: Optional[tuple[str | int, ...]] + """ + n_branches = len(branch_groups) + if n_branches < 2: + return None + + # Collect discriminator-error paths per branch + per_branch_paths: list[set[tuple[str | int, ...]]] = [] + for errors in branch_groups.values(): + paths: set[tuple[str | int, ...]] = set() + for err in errors: + if err.validator in _DISCRIMINATOR_VALIDATORS: + paths.add(tuple(err.absolute_path)) + per_branch_paths.append(paths) + + path_counts: Counter[tuple[str | int, ...]] = Counter() + for paths in per_branch_paths: + for p in paths: + path_counts[p] += 1 + + min_threshold = (n_branches + 1) // 2 # ceil(n/2) — at least half the branches + for path, count in path_counts.most_common(): + if count >= min_threshold: + return path + return None + + +def _find_matching_branch( + branch_groups: dict[int, list[ValidationError]], + disc_path: tuple[str | int, ...], +) -> Optional[int]: + """Return the branch index that has **no** discriminator error at *disc_path*. + + :param branch_groups: Errors grouped by branch index. + :type branch_groups: dict[int, list[ValidationError]] + :param disc_path: The discriminator property path. + :type disc_path: tuple[str | int, ...] + :return: The matching branch index or *None*. + :rtype: Optional[int] + """ + for idx, errors in branch_groups.items(): + has_disc_error = any( + err.validator in _DISCRIMINATOR_VALIDATORS + and tuple(err.absolute_path) == disc_path + for err in errors + ) + if not has_disc_error: + return idx + return None + + +def _report_discriminator_error( + branch_groups: dict[int, list[ValidationError]], + disc_path: tuple[str | int, ...], + parent: ValidationError, +) -> list[str]: + """Produce a concise discriminator-mismatch message. + + Collects all expected ``const`` / ``enum`` values from the branches and + reports them as ``"Invalid value. Expected one of: X, Y, Z"``. + + :param branch_groups: Errors grouped by branch index. + :type branch_groups: dict[int, list[ValidationError]] + :param disc_path: The discriminator property path. + :type disc_path: tuple[str | int, ...] + :param parent: The parent ``oneOf`` / ``anyOf`` error. + :type parent: ValidationError + :return: List containing a single formatted error string. + :rtype: list[str] + """ + # Build the JSON-path string for the discriminator property + path_str = "$" + for segment in disc_path: + if isinstance(segment, int): + path_str += f"[{segment}]" + else: + path_str += f".{segment}" + + # Check for type errors first (more fundamental than const/enum) + for errors in branch_groups.values(): + for err in errors: + if err.validator == "type" and tuple(err.absolute_path) == disc_path: + return [f"{path_str}: {err.message}"] + + # Collect expected values from const and enum errors + expected: list[str] = [] + for errors in branch_groups.values(): + for err in errors: + if tuple(err.absolute_path) != disc_path: + continue + if err.validator == "const": + val = err.schema.get("const") if isinstance(err.schema, dict) else None + if val is not None: + formatted = json.dumps(val) if not isinstance(val, str) else f'"{val}"' + if formatted not in expected: + expected.append(formatted) + elif err.validator == "enum": + enum_vals = err.schema.get("enum", []) if isinstance(err.schema, dict) else [] + for val in enum_vals: + formatted = json.dumps(val) if not isinstance(val, str) else f'"{val}"' + if formatted not in expected: + expected.append(formatted) + + if expected: + if len(expected) == 1: + return [f"{path_str}: Invalid value. Expected: {expected[0]}"] + return [f"{path_str}: Invalid value. Expected one of: {', '.join(expected)}"] + + # Fallback + return [_format_error(parent)] + + +def _resolve_ref(spec: dict[str, Any], schema: dict[str, Any]) -> dict[str, Any]: + """Resolve a single ``$ref`` pointer within the spec. + + :param spec: The full OpenAPI spec dictionary. + :type spec: dict[str, Any] + :param schema: A schema dict that may contain a ``$ref`` key. + :type schema: dict[str, Any] + :return: The resolved schema. + :rtype: dict[str, Any] + """ + if "$ref" not in schema: + return schema + ref_path = schema["$ref"] # e.g. "#/components/schemas/MyModel" + parts = ref_path.lstrip("#/").split("/") + current: Any = spec + for part in parts: + if isinstance(current, dict): + current = current.get(part) + else: + return schema # can't resolve, return as-is + return current if isinstance(current, dict) else schema + + +def _resolve_refs_deep(spec: dict[str, Any], node: Any, _seen: Optional[set[str]] = None) -> Any: + """Recursively resolve all ``$ref`` pointers in a schema tree. + + Walks the schema, replacing every ``{"$ref": "..."}`` with the referenced + definition inlined from *spec*. A *_seen* set guards against infinite + recursion from circular references. + + :param spec: The full OpenAPI spec dictionary. + :type spec: dict[str, Any] + :param node: The current schema node (dict, list, or scalar). + :type node: Any + :param _seen: Set of already-visited ``$ref`` paths (cycle guard). + :type _seen: Optional[set[str]] + :return: The schema tree with all ``$ref`` pointers inlined. + :rtype: Any + """ + if _seen is None: + _seen = set() + + if isinstance(node, dict): + if "$ref" in node: + ref_path = node["$ref"] + if ref_path in _seen: + return node # circular – leave the $ref as-is + _seen = _seen | {ref_path} + resolved = _resolve_ref(spec, node) + if resolved is node: + return node # couldn't resolve + # Preserve sibling keywords (e.g. nullable: true alongside $ref) + siblings = {k: v for k, v in node.items() if k != "$ref"} + resolved = _resolve_refs_deep(spec, resolved, _seen) + if siblings and isinstance(resolved, dict): + merged = dict(resolved) + merged.update(siblings) + return merged + return resolved + return {k: _resolve_refs_deep(spec, v, _seen) for k, v in node.items()} + + if isinstance(node, list): + return [_resolve_refs_deep(spec, item, _seen) for item in node] + + return node diff --git a/sdk/agentserver/azure-ai-agentserver/cspell.json b/sdk/agentserver/azure-ai-agentserver/cspell.json new file mode 100644 index 000000000000..5af59c8e52e0 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/cspell.json @@ -0,0 +1,36 @@ +{ + "ignoreWords": [ + "agentframework", + "agentserver", + "appinsights", + "ASGI", + "azureai", + "ainvoke", + "behaviour", + "caplog", + "delenv", + "hypercorn", + "invocations", + "langgraph", + "msgpack", + "openapi", + "paramtype", + "requestschema", + "rtype", + "serialisation", + "Specialised", + "Standardised", + "starlette", + "traceparent", + "tracestate", + "tracecontext", + "varint", + "writeonly" + ], + "ignorePaths": [ + "*.csv", + "*.json", + "*.rst", + "samples/**" + ] +} \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver/dev_requirements.txt b/sdk/agentserver/azure-ai-agentserver/dev_requirements.txt new file mode 100644 index 000000000000..a4d2cb770dbc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/dev_requirements.txt @@ -0,0 +1,9 @@ +-e ../../../eng/tools/azure-sdk-tools +pytest +httpx +httpx[http2] +pytest-asyncio +opentelemetry-api>=1.20.0 +opentelemetry-sdk>=1.20.0 +azure-monitor-opentelemetry-exporter>=1.0.0b21 +cryptography diff --git a/sdk/agentserver/azure-ai-agentserver/pyproject.toml b/sdk/agentserver/azure-ai-agentserver/pyproject.toml new file mode 100644 index 000000000000..19063c140e09 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/pyproject.toml @@ -0,0 +1,76 @@ +[project] +name = "azure-ai-agentserver" +dynamic = ["version", "readme"] +description = "Generic agent server for Azure AI with pluggable protocol heads" +requires-python = ">=3.10" +authors = [ + { name = "Microsoft Corporation", email = "azpysdkhelp@microsoft.com" }, +] +license = "MIT" +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +keywords = ["azure", "azure sdk", "agent", "agentserver"] + +dependencies = [ + "azure-core>=1.35.0", + "starlette>=0.45.0", + "hypercorn>=0.17.0", + "jsonschema>=4.0.0", +] + +[project.optional-dependencies] +tracing = [ + "opentelemetry-api>=1.20.0", + "opentelemetry-sdk>=1.20.0", + "azure-monitor-opentelemetry-exporter>=1.0.0b21", +] + +[build-system] +requires = ["setuptools>=69", "wheel"] +build-backend = "setuptools.build_meta" + +[project.urls] +repository = "https://github.com/Azure/azure-sdk-for-python" + +[tool.setuptools.packages.find] +exclude = [ + "tests*", + "samples*", + "doc*", + "azure", + "azure.ai", +] + +[tool.setuptools.dynamic] +version = { attr = "azure.ai.agentserver._version.VERSION" } +readme = { file = ["README.md"], content-type = "text/markdown" } + +[tool.setuptools.package-data] +pytyped = ["py.typed"] + +[tool.ruff] +line-length = 120 +target-version = "py310" +lint.select = ["E", "F", "B", "I"] +lint.ignore = [] +fix = false + +[tool.ruff.lint.isort] +known-first-party = ["azure.ai.agentserver"] +combine-as-imports = true + +[tool.azure-sdk-build] +breaking = false +mypy = true +pyright = true +verifytypes = true +pylint = true +type_check_samples = false # samples require env-specific deps (langchain, agent-framework, etc.) diff --git a/sdk/agentserver/azure-ai-agentserver/pyrightconfig.json b/sdk/agentserver/azure-ai-agentserver/pyrightconfig.json new file mode 100644 index 000000000000..5f81af3c9da7 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/pyrightconfig.json @@ -0,0 +1,11 @@ +{ + "reportOptionalMemberAccess": "warning", + "reportArgumentType": "warning", + "reportAttributeAccessIssue": "warning", + "reportMissingImports": "warning", + "reportGeneralTypeIssues": "warning", + "reportReturnType": "warning", + "exclude": [ + "**/samples/**" + ] +} \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/.env.sample b/sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/.env.sample new file mode 100644 index 000000000000..b2381c6f1b1b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/.env.sample @@ -0,0 +1,4 @@ +# Azure AI credentials (used by DefaultAzureCredential + AzureOpenAIChatClient) +AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com/ +AZURE_OPENAI_API_VERSION=2024-12-01-preview +AZURE_OPENAI_CHAT_DEPLOYMENT_NAME=gpt-4.1 \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/agentframework_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/agentframework_invoke_agent.py new file mode 100644 index 000000000000..ffcb7eccf734 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/agentframework_invoke_agent.py @@ -0,0 +1,81 @@ +"""Agent Framework agent served via /invoke. + +Customer owns the AgentFramework <-> invoke conversion logic. +This replaces the need for azure-ai-agentserver-agentframework. + +Usage:: + + # Start the agent + python agentframework_invoke_agent.py + + # Send a request + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"input": "What is the weather in Seattle?"}' + # -> {"result": "The weather in Seattle is sunny with a high of 25°C."} +""" +import asyncio +import os +from random import randint +from typing import Annotated + +from dotenv import load_dotenv + +load_dotenv() + +from agent_framework import Agent +from agent_framework.azure import AzureOpenAIChatClient +from azure.identity import DefaultAzureCredential + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver import AgentServer + + +# -- Customer defines their tools -- + +def get_weather( + location: Annotated[str, "The location to get the weather for."], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + + +# -- Customer builds their Agent Framework agent -- + +def build_agent() -> Agent: + """Create an Agent Framework Agent with tools.""" + client = AzureOpenAIChatClient(credential=DefaultAzureCredential()) + return client.as_agent( + instructions="You are a helpful weather assistant.", + tools=get_weather, + ) + + +# -- Customer-managed adapter: Agent Framework <-> /invoke -- + +agent = build_agent() +server = AgentServer() + + +@server.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Process an invocation via Agent Framework. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON response with the agent result. + :rtype: starlette.responses.JSONResponse + """ + data = await request.json() + user_input = data.get("input", "") + + # Run the agent + response = await agent.run(user_input) + result = response.content if hasattr(response, "content") else str(response) + + return JSONResponse({"result": result}) + + +if __name__ == "__main__": + server.run() diff --git a/sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/requirements.txt new file mode 100644 index 000000000000..bd3b80baf653 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/requirements.txt @@ -0,0 +1,4 @@ +azure-ai-agentserver +agent-framework>=1.0.0rc2 +azure-identity>=1.25.0 +python-dotenv>=1.0.0 diff --git a/sdk/agentserver/azure-ai-agentserver/samples/async_invoke_agent/async_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver/samples/async_invoke_agent/async_invoke_agent.py new file mode 100644 index 000000000000..ee6f77c2b16a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/samples/async_invoke_agent/async_invoke_agent.py @@ -0,0 +1,168 @@ +"""Async invoke agent example. + +Demonstrates get_invocation and cancel_invocation for long-running work. +Invocations run in background tasks; callers poll or cancel by ID. + +.. warning:: + + **In-memory demo only.** This sample stores all invocation state + (``self._tasks``, ``self._results``) in process memory. Both in-flight + ``asyncio.Task`` objects and completed results are lost on process restart + — which *will* happen during platform rolling updates, health-check + failures, and scaling events. + + For production long-running invocations: + + * Persist results to durable storage (Redis, Cosmos DB, etc.) inside + ``_do_work`` **before** the method returns. + * On startup, rehydrate any incomplete work or mark it as failed. + * Consider an external task queue (Celery, Azure Queue, etc.) instead + of ``asyncio.create_task`` for work that must survive restarts. + +Usage:: + + # Start the agent + python async_invoke_agent.py + + # Start a long-running invocation + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"query": "analyze dataset"}' + # -> x-agent-invocation-id: abc-123 + # -> {"invocation_id": "abc-123", "status": "running"} + + # Poll for result + curl http://localhost:8088/invocations/abc-123 + # -> {"invocation_id": "abc-123", "status": "running"} (still working) + # -> {"invocation_id": "abc-123", "status": "completed"} (done) + + # Or cancel + curl -X POST http://localhost:8088/invocations/abc-123/cancel + # -> {"invocation_id": "abc-123", "status": "cancelled"} +""" +import asyncio +import json + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver import AgentServer + + +# In-memory state for demo purposes (see module docstring for production caveats) +_tasks: dict[str, asyncio.Task] = {} +_results: dict[str, bytes] = {} + +server = AgentServer() + + +async def _do_work(invocation_id: str, data: dict) -> bytes: + """Simulate long-running work. + + :param invocation_id: The invocation ID for this task. + :type invocation_id: str + :param data: The parsed request data. + :type data: dict + :return: JSON result bytes. + :rtype: bytes + """ + await asyncio.sleep(10) + result = json.dumps({ + "invocation_id": invocation_id, + "status": "completed", + "output": f"Processed: {data}", + }).encode() + _results[invocation_id] = result + return result + + +@server.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Start a long-running invocation in a background task. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON status indicating the task is running. + :rtype: starlette.responses.JSONResponse + """ + data = await request.json() + invocation_id = request.state.invocation_id + + task = asyncio.create_task(_do_work(invocation_id, data)) + _tasks[invocation_id] = task + + return JSONResponse({ + "invocation_id": invocation_id, + "status": "running", + }) + + +@server.get_invocation_handler +async def handle_get_invocation(request: Request) -> Response: + """Retrieve a previous invocation result. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON status or result. + :rtype: starlette.responses.JSONResponse + """ + invocation_id = request.state.invocation_id + + if invocation_id in _results: + return Response(content=_results[invocation_id], media_type="application/json") + + if invocation_id in _tasks: + task = _tasks[invocation_id] + if not task.done(): + return JSONResponse({ + "invocation_id": invocation_id, + "status": "running", + }) + result = task.result() + _results[invocation_id] = result + del _tasks[invocation_id] + return Response(content=result, media_type="application/json") + + return JSONResponse({"error": "not found"}, status_code=404) + + +@server.cancel_invocation_handler +async def handle_cancel_invocation(request: Request) -> Response: + """Cancel a running invocation. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON cancellation status. + :rtype: starlette.responses.JSONResponse + """ + invocation_id = request.state.invocation_id + + # Already completed — cannot cancel + if invocation_id in _results: + return JSONResponse({ + "invocation_id": invocation_id, + "status": "completed", + "error": "invocation already completed", + }) + + if invocation_id in _tasks: + task = _tasks[invocation_id] + if task.done(): + # Task finished between check — treat as completed + _results[invocation_id] = task.result() + del _tasks[invocation_id] + return JSONResponse({ + "invocation_id": invocation_id, + "status": "completed", + "error": "invocation already completed", + }) + task.cancel() + del _tasks[invocation_id] + return JSONResponse({ + "invocation_id": invocation_id, + "status": "cancelled", + }) + + return JSONResponse({"error": "not found"}, status_code=404) + + +if __name__ == "__main__": + server.run() diff --git a/sdk/agentserver/azure-ai-agentserver/samples/async_invoke_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver/samples/async_invoke_agent/requirements.txt new file mode 100644 index 000000000000..10ccd9f42648 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/samples/async_invoke_agent/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver diff --git a/sdk/agentserver/azure-ai-agentserver/samples/human_in_the_loop_agent/human_in_the_loop_agent.py b/sdk/agentserver/azure-ai-agentserver/samples/human_in_the_loop_agent/human_in_the_loop_agent.py new file mode 100644 index 000000000000..618a2532f59d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/samples/human_in_the_loop_agent/human_in_the_loop_agent.py @@ -0,0 +1,74 @@ +"""Human-in-the-loop invoke agent example. + +Demonstrates a synchronous human-in-the-loop pattern using only +POST /invocations. The agent asks a clarifying question, and the client +replies in a second request. + +Flow: + 1. Client sends a message -> agent returns a question + invocation_id + 2. Client sends a reply -> agent returns the final result + +Usage:: + + # Start the agent + python human_in_the_loop_agent.py + + # Step 1: Send a request — agent asks a clarifying question + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"message": "Book me a flight"}' + # -> {"invocation_id": "", "status": "needs_input", "question": "Where would you like to fly to?"} + + # Step 2: Reply with the answer — agent completes + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"reply_to": "", "message": "Seattle"}' + # -> {"invocation_id": "", "status": "completed", "response": "Flight to Seattle booked."} +""" +from typing import Any + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver import AgentServer + + +# Holds questions waiting for a human reply, keyed by invocation_id +_waiting: dict[str, dict[str, Any]] = {} + +server = AgentServer() + + +@server.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Handle messages and replies. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON response indicating status. + :rtype: starlette.responses.JSONResponse + """ + data = await request.json() + invocation_id = request.state.invocation_id + + # --- Reply to a previous question --- + reply_to = data.get("reply_to") + if reply_to: + if reply_to not in _waiting: + return JSONResponse({"error": f"No pending question for {reply_to}"}) + + return JSONResponse({ + "invocation_id": reply_to, + "status": "completed", + "response": f"Flight to {data.get('message', '?')} booked.", + }) + + # --- New request: ask a clarifying question --- + _waiting[invocation_id] = { + "message": data.get("message", ""), + } + return JSONResponse({ + "invocation_id": invocation_id, + "status": "needs_input", + "question": "Where would you like to fly to?", + }) + + +if __name__ == "__main__": + server.run() diff --git a/sdk/agentserver/azure-ai-agentserver/samples/human_in_the_loop_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver/samples/human_in_the_loop_agent/requirements.txt new file mode 100644 index 000000000000..10ccd9f42648 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/samples/human_in_the_loop_agent/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver diff --git a/sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/.env.sample b/sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/.env.sample new file mode 100644 index 000000000000..a75e7aec8869 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/.env.sample @@ -0,0 +1,5 @@ +# Azure OpenAI credentials +AZURE_OPENAI_API_KEY=your-api-key-here +AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com/ +AZURE_OPENAI_API_VERSION=2024-12-01-preview +AZURE_OPENAI_MODEL=gpt-4o diff --git a/sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/langgraph_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/langgraph_invoke_agent.py new file mode 100644 index 000000000000..89611059b643 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/langgraph_invoke_agent.py @@ -0,0 +1,100 @@ +"""LangGraph agent served via /invoke. + +Customer owns the LangGraph <-> invoke conversion logic. +This replaces the need for azure-ai-agentserver-langgraph. + +Usage:: + + # Start the agent + python langgraph_invoke_agent.py + + # Non-streaming request + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"message": "What is the capital of France?"}' + # -> {"reply": "The capital of France is Paris."} + + # Streaming request + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"message": "Tell me a joke", "stream": true}' + # -> {"delta": "Why did..."} + # {"delta": " the chicken..."} +""" +import json +import os +from typing import AsyncGenerator + +from dotenv import load_dotenv + +load_dotenv() + +from langgraph.graph import END, START, MessagesState, StateGraph +from langchain_openai import AzureChatOpenAI + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver import AgentServer + + +def build_graph() -> StateGraph: + """Customer builds their LangGraph agent as usual.""" + llm = AzureChatOpenAI( + model=os.environ["AZURE_OPENAI_MODEL"], + api_version=os.environ.get("AZURE_OPENAI_API_VERSION", "2024-12-01-preview"), + ) + + def chatbot(state: MessagesState): + return {"messages": [llm.invoke(state["messages"])]} + + graph = StateGraph(MessagesState) + graph.add_node("chatbot", chatbot) + graph.add_edge(START, "chatbot") + graph.add_edge("chatbot", END) + return graph.compile() + + +graph = build_graph() +server = AgentServer() + + +async def _stream_response(user_message: str) -> AsyncGenerator[bytes, None]: + """Async generator that yields response chunks. + + :param user_message: The user message to process. + :type user_message: str + :return: An async generator yielding JSON-encoded byte chunks. + :rtype: AsyncGenerator[bytes, None] + """ + async for event in graph.astream_events( + {"messages": [{"role": "user", "content": user_message}]}, + version="v2", + ): + if event["event"] == "on_chat_model_stream": + chunk = event["data"]["chunk"].content + if chunk: + yield json.dumps({"delta": chunk}).encode() + b"\n" + + +@server.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Process the invocation via LangGraph. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON response or streaming response. + :rtype: starlette.responses.Response + """ + data = await request.json() + user_message = data["message"] + stream = data.get("stream", False) + + if stream: + return StreamingResponse(_stream_response(user_message)) + + result = await graph.ainvoke( + {"messages": [{"role": "user", "content": user_message}]} + ) + last_message = result["messages"][-1] + return JSONResponse({"reply": last_message.content}) + + +if __name__ == "__main__": + server.run() diff --git a/sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/requirements.txt new file mode 100644 index 000000000000..980438cbf628 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/requirements.txt @@ -0,0 +1,4 @@ +azure-ai-agentserver +langgraph>=1.0.0 +langchain-openai>=1.0.0 +python-dotenv>=1.0.0 diff --git a/sdk/agentserver/azure-ai-agentserver/samples/openapi_validated_agent/openapi_validated_agent.py b/sdk/agentserver/azure-ai-agentserver/samples/openapi_validated_agent/openapi_validated_agent.py new file mode 100644 index 000000000000..1f8749b73d4f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/samples/openapi_validated_agent/openapi_validated_agent.py @@ -0,0 +1,117 @@ +"""OpenAPI-validated agent example. + +Demonstrates how to supply an OpenAPI spec to AgentServer so that +incoming requests are validated automatically. Invalid requests receive +a 400 response before ``invoke`` is called. + +The spec is also served at ``GET /invocations/docs/openapi.json`` so +that callers can discover the agent's contract at runtime. + +Usage:: + + # Start the agent + python openapi_validated_agent.py + + # Fetch the OpenAPI spec + curl http://localhost:8088/invocations/docs/openapi.json + + # Valid request (200) + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"name": "Alice", "language": "fr"}' + # -> {"greeting": "Bonjour, Alice!"} + + # Invalid request — missing required "name" field (400) + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"language": "en"}' + # -> {"error": ["'name' is a required property"]} +""" +from typing import Any + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver import AgentServer + +# Define a simple OpenAPI 3.0 spec inline. In production this could be +# loaded from a YAML/JSON file. +OPENAPI_SPEC: dict[str, Any] = { + "openapi": "3.0.0", + "info": {"title": "Greeting Agent", "version": "1.0.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["name"], + "properties": { + "name": { + "type": "string", + "description": "Name of the person to greet.", + }, + "language": { + "type": "string", + "enum": ["en", "es", "fr"], + "description": "Language for the greeting.", + }, + }, + "additionalProperties": False, + } + } + }, + }, + "responses": { + "200": { + "description": "Greeting response", + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["greeting"], + "properties": { + "greeting": {"type": "string"}, + }, + } + } + }, + } + }, + } + } + }, +} + +GREETINGS = { + "en": "Hello", + "es": "Hola", + "fr": "Bonjour", +} + + +server = AgentServer(openapi_spec=OPENAPI_SPEC, enable_request_validation=True) + + +@server.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Return a localised greeting. + + The OpenAPI spec enforces that "name" is required, "language" must be + one of ``en``, ``es``, or ``fr``, and no extra fields are allowed. + Requests that violate the schema are rejected with 400 before reaching + the invoke handler. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON greeting response. + :rtype: starlette.responses.JSONResponse + """ + data = await request.json() + language = data.get("language", "en") + prefix = GREETINGS.get(language, "Hello") + greeting = f"{prefix}, {data['name']}!" + return JSONResponse({"greeting": greeting}) + + +if __name__ == "__main__": + server.run() diff --git a/sdk/agentserver/azure-ai-agentserver/samples/openapi_validated_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver/samples/openapi_validated_agent/requirements.txt new file mode 100644 index 000000000000..10ccd9f42648 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/samples/openapi_validated_agent/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver diff --git a/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/Dockerfile b/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/Dockerfile new file mode 100644 index 000000000000..97ed3fb21fd1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/Dockerfile @@ -0,0 +1,19 @@ +FROM python:3.12-slim + +WORKDIR /app + +# Install the agentserver package from the local source tree +COPY sdk/agentserver/azure-ai-agentserver /src/azure-ai-agentserver +RUN pip install --no-cache-dir /src/azure-ai-agentserver + +# Copy the sample agent +COPY sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/simple_invoke_agent.py . + +EXPOSE 8088 + +# Bind to 0.0.0.0 so the port is accessible from outside the container. +# The default is 127.0.0.1 which is only reachable inside the container. +CMD ["python", "-c", "\ +from simple_invoke_agent import server; \ +server.run(host='0.0.0.0') \ +"] diff --git a/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/requirements.txt new file mode 100644 index 000000000000..10ccd9f42648 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver diff --git a/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/simple_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/simple_invoke_agent.py new file mode 100644 index 000000000000..60d7f3520622 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/simple_invoke_agent.py @@ -0,0 +1,38 @@ +"""Simple invoke agent example. + +Accepts JSON requests, echoes back with a greeting. + +Usage:: + + # Start the agent + python simple_invoke_agent.py + + # Send a greeting request + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"name": "Alice"}' + # -> {"greeting": "Hello, Alice!"} +""" +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver import AgentServer + + +server = AgentServer() + + +@server.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Process the invocation by echoing a greeting. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON greeting response. + :rtype: starlette.responses.JSONResponse + """ + data = await request.json() + greeting = f"Hello, {data['name']}!" + return JSONResponse({"greeting": greeting}) + + +if __name__ == "__main__": + server.run() diff --git a/sdk/agentserver/azure-ai-agentserver/tests/conftest.py b/sdk/agentserver/azure-ai-agentserver/tests/conftest.py new file mode 100644 index 000000000000..c3647378fc75 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/tests/conftest.py @@ -0,0 +1,205 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for azure-ai-agentserver tests.""" +import json + +import pytest +import pytest_asyncio +import httpx + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver import AgentServer + + +# --------------------------------------------------------------------------- +# Agent factory functions (decorator pattern) +# --------------------------------------------------------------------------- + + +def _make_echo_agent(**kwargs) -> AgentServer: + """Create an echo agent that returns the request body as-is.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + body = await request.body() + return Response(content=body, media_type="application/octet-stream") + + return server + + +def _make_streaming_agent(**kwargs) -> AgentServer: + """Create an agent that returns a multi-chunk streaming response.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> StreamingResponse: + async def generate(): + for i in range(3): + yield json.dumps({"chunk": i}).encode() + b"\n" + + return StreamingResponse(generate()) + + return server + + +def _make_async_storage_agent(**kwargs) -> AgentServer: + """Create an agent with get/cancel support via in-memory storage.""" + store: dict[str, bytes] = {} + server = AgentServer(**kwargs) + server._store = store # expose for test access + + @server.invoke_handler + async def invoke(request: Request) -> Response: + body = await request.body() + invocation_id = request.state.invocation_id + result = json.dumps({"echo": body.decode()}).encode() + store[invocation_id] = result + return Response(content=result, media_type="application/json") + + @server.get_invocation_handler + async def get_invocation(request: Request) -> Response: + invocation_id = request.state.invocation_id + if invocation_id in store: + return Response(content=store[invocation_id], media_type="application/json") + return JSONResponse({"error": "not found"}, status_code=404) + + @server.cancel_invocation_handler + async def cancel_invocation(request: Request) -> Response: + invocation_id = request.state.invocation_id + if invocation_id in store: + del store[invocation_id] + return JSONResponse({"status": "cancelled"}) + return JSONResponse({"error": "not found"}, status_code=404) + + return server + + +SAMPLE_OPENAPI_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "Test", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + }, + "required": ["name"], + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "greeting": {"type": "string"}, + }, + "required": ["greeting"], + } + } + } + } + }, + } + } + }, +} + + +def _make_validated_agent() -> AgentServer: + """Create an agent with OpenAPI validation that returns a greeting.""" + server = AgentServer(openapi_spec=SAMPLE_OPENAPI_SPEC, enable_request_validation=True) + + @server.invoke_handler + async def handle(request: Request) -> Response: + data = await request.json() + return JSONResponse({"greeting": f"Hello, {data['name']}!"}) + + return server + + +def _make_failing_agent(**kwargs) -> AgentServer: + """Create an agent whose invoke handler always raises.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + raise ValueError("something went wrong") + + return server + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def echo_client(): + """httpx.AsyncClient wired to an echo agent's ASGI app.""" + server = _make_echo_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +@pytest_asyncio.fixture +async def streaming_client(): + """httpx.AsyncClient wired to a streaming agent's ASGI app.""" + server = _make_streaming_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +@pytest.fixture +def async_storage_server(): + """An async storage agent instance.""" + return _make_async_storage_agent() + + +@pytest_asyncio.fixture +async def async_storage_client(async_storage_server): + """httpx.AsyncClient wired to an async storage agent's ASGI app.""" + transport = httpx.ASGITransport(app=async_storage_server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +@pytest_asyncio.fixture +async def validated_client(): + """httpx.AsyncClient wired to a validated agent's ASGI app.""" + server = _make_validated_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +@pytest_asyncio.fixture +async def no_spec_client(): + """httpx.AsyncClient wired to an echo agent (no OpenAPI spec).""" + server = _make_echo_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +@pytest_asyncio.fixture +async def failing_client(): + """httpx.AsyncClient wired to a failing agent's ASGI app.""" + server = _make_failing_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_decorator_pattern.py b/sdk/agentserver/azure-ai-agentserver/tests/test_decorator_pattern.py new file mode 100644 index 000000000000..018e9478a7e3 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_decorator_pattern.py @@ -0,0 +1,337 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for the decorator-based handler registration pattern.""" +from __future__ import annotations + +import httpx +import pytest +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver import AgentServer + + +# --------------------------------------------------------------------------- +# Decorator registration +# --------------------------------------------------------------------------- + + +class TestDecoratorRegistration: + """Verify that decorators store the function and return it unchanged.""" + + def test_invoke_handler_stores_function(self): + server = AgentServer() + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + assert server._invoke_fn is handle + + def test_invoke_handler_returns_original_function(self): + server = AgentServer() + original = None + + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + original = handle + result = server.invoke_handler(handle) + assert result is original + + def test_get_invocation_handler_stores_function(self): + server = AgentServer() + + @server.get_invocation_handler + async def handle(request: Request) -> Response: + return JSONResponse({"found": True}) + + assert server._get_invocation_fn is handle + + def test_cancel_invocation_handler_stores_function(self): + server = AgentServer() + + @server.cancel_invocation_handler + async def handle(request: Request) -> Response: + return JSONResponse({"cancelled": True}) + + assert server._cancel_invocation_fn is handle + + def test_shutdown_handler_stores_function(self): + server = AgentServer() + + @server.shutdown_handler + async def handle(): + pass + + assert server._shutdown_fn is handle + + +# --------------------------------------------------------------------------- +# Invoke handler — full request flow +# --------------------------------------------------------------------------- + + +class TestInvokeHandlerFlow: + """Verify that POST /invocations delegates to @invoke_handler.""" + + @pytest.mark.asyncio + async def test_invoke_returns_handler_response(self): + server = AgentServer() + + @server.invoke_handler + async def handle(request: Request) -> Response: + data = await request.json() + return JSONResponse({"echo": data["msg"]}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", json={"msg": "hello"}) + assert resp.status_code == 200 + assert resp.json()["echo"] == "hello" + + @pytest.mark.asyncio + async def test_invoke_includes_invocation_id_header(self): + server = AgentServer() + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", json={}) + assert "x-agent-invocation-id" in resp.headers + + @pytest.mark.asyncio + async def test_invoke_request_has_invocation_id_in_state(self): + """The handler receives request.state.invocation_id.""" + server = AgentServer() + captured_id = None + + @server.invoke_handler + async def handle(request: Request) -> Response: + nonlocal captured_id + captured_id = request.state.invocation_id + return JSONResponse({"id": captured_id}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", json={}) + assert resp.status_code == 200 + assert captured_id is not None + assert resp.json()["id"] == captured_id + + +# --------------------------------------------------------------------------- +# Missing invoke handler +# --------------------------------------------------------------------------- + + +class TestMissingInvokeHandler: + """When no invoke handler is registered and invoke() is not overridden, 500.""" + + @pytest.mark.asyncio + async def test_no_handler_returns_500(self): + server = AgentServer() + # No @invoke_handler registered + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", json={}) + assert resp.status_code == 500 + + +# --------------------------------------------------------------------------- +# Optional handler defaults +# --------------------------------------------------------------------------- + + +class TestOptionalHandlerDefaults: + """get_invocation and cancel_invocation return 404 by default.""" + + @pytest.mark.asyncio + async def test_get_invocation_returns_404_by_default(self): + server = AgentServer() + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/invocations/some-id") + assert resp.status_code == 404 + + @pytest.mark.asyncio + async def test_cancel_invocation_returns_404_by_default(self): + server = AgentServer() + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations/some-id/cancel") + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# Optional handler overrides via decorator +# --------------------------------------------------------------------------- + + +class TestOptionalHandlerOverrides: + """Registered optional handlers are called instead of the defaults.""" + + @pytest.mark.asyncio + async def test_get_invocation_handler_called(self): + server = AgentServer() + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return JSONResponse({"ok": True}) + + @server.get_invocation_handler + async def get_inv(request: Request) -> Response: + inv_id = request.state.invocation_id + return JSONResponse({"id": inv_id, "status": "completed"}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/invocations/test-123") + assert resp.status_code == 200 + assert resp.json()["id"] == "test-123" + assert resp.json()["status"] == "completed" + + @pytest.mark.asyncio + async def test_cancel_invocation_handler_called(self): + server = AgentServer() + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return JSONResponse({"ok": True}) + + @server.cancel_invocation_handler + async def cancel(request: Request) -> Response: + inv_id = request.state.invocation_id + return JSONResponse({"id": inv_id, "cancelled": True}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations/test-456/cancel") + assert resp.status_code == 200 + assert resp.json()["cancelled"] is True + + +# --------------------------------------------------------------------------- +# Shutdown handler via decorator +# --------------------------------------------------------------------------- + + +class TestShutdownHandler: + """@shutdown_handler is called during lifespan teardown.""" + + @pytest.mark.asyncio + async def test_shutdown_handler_called(self): + server = AgentServer() + shutdown_called = False + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return JSONResponse({"ok": True}) + + @server.shutdown_handler + async def cleanup(): + nonlocal shutdown_called + shutdown_called = True + + # Exercise the lifespan directly + lifespan = server.app.router.lifespan_context + async with lifespan(server.app): + pass # startup; on exit → shutdown + + assert shutdown_called is True + + @pytest.mark.asyncio + async def test_no_shutdown_handler_is_noop(self): + """Without @shutdown_handler, on_shutdown is a silent no-op.""" + server = AgentServer() + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return JSONResponse({"ok": True}) + + # Should not raise + lifespan = server.app.router.lifespan_context + async with lifespan(server.app): + pass + + +# --------------------------------------------------------------------------- +# Config passthrough +# --------------------------------------------------------------------------- + + +class TestConfigPassthrough: + """All AgentServer kwargs still work in decorator mode.""" + + def test_request_timeout_resolved(self): + server = AgentServer(request_timeout=42) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + assert server._request_timeout == 42 + + def test_graceful_shutdown_timeout_resolved(self): + server = AgentServer(graceful_shutdown_timeout=10) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + assert server._graceful_shutdown_timeout == 10 + + def test_debug_errors_resolved(self): + server = AgentServer(debug_errors=True) + assert server._debug_errors is True + + +# --------------------------------------------------------------------------- +# Health endpoints work in decorator mode +# --------------------------------------------------------------------------- + + +class TestHealthEndpointsDecoratorMode: + """Health endpoints respond even when using the decorator pattern.""" + + @pytest.mark.asyncio + async def test_liveness(self): + server = AgentServer() + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/liveness") + assert resp.status_code == 200 + assert resp.json()["status"] == "alive" + + @pytest.mark.asyncio + async def test_readiness(self): + server = AgentServer() + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/readiness") + assert resp.status_code == 200 + assert resp.json()["status"] == "ready" diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_edge_cases.py b/sdk/agentserver/azure-ai-agentserver/tests/test_edge_cases.py new file mode 100644 index 000000000000..6907187b80de --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_edge_cases.py @@ -0,0 +1,491 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""End-to-end edge-case tests for AgentServer.""" +import json +import uuid + +import httpx +import pytest +import pytest_asyncio + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver import AgentServer + + +# --------------------------------------------------------------------------- +# Agent factory functions for edge cases +# --------------------------------------------------------------------------- + + +def _make_custom_header_agent(**kwargs) -> AgentServer: + """Create an agent that sets its own x-agent-invocation-id header.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse( + {"ok": True}, + headers={"x-agent-invocation-id": "custom-id-from-agent"}, + ) + + return server + + +def _make_empty_streaming_agent(**kwargs) -> AgentServer: + """Create an agent that returns an empty streaming response (0 chunks).""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> StreamingResponse: + async def generate(): + return + yield # noqa: RUF028 — makes this an async generator + + return StreamingResponse(generate(), media_type="text/event-stream") + + return server + + +def _make_slow_failing_get_agent(**kwargs) -> AgentServer: + """Create an agent whose get_invocation raises so we can test debug errors on GET.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return JSONResponse({"ok": True}) + + @server.get_invocation_handler + async def get_invocation(request: Request) -> Response: + raise ValueError("get-debug-detail") + + return server + + +def _make_slow_failing_cancel_agent(**kwargs) -> AgentServer: + """Create an agent whose cancel_invocation raises so we can test debug errors on cancel.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return JSONResponse({"ok": True}) + + @server.cancel_invocation_handler + async def cancel_invocation(request: Request) -> Response: + raise ValueError("cancel-debug-detail") + + return server + + +def _make_large_payload_agent(**kwargs) -> AgentServer: + """Create an agent that echoes the request body length as JSON.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + body = await request.body() + return JSONResponse({"length": len(body)}) + + return server + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def custom_header_client(): + server = _make_custom_header_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +@pytest_asyncio.fixture +async def empty_streaming_client(): + server = _make_empty_streaming_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +@pytest_asyncio.fixture +async def large_payload_client(): + server = _make_large_payload_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +# --------------------------------------------------------------------------- +# Tests: wrong HTTP methods +# --------------------------------------------------------------------------- + + +class TestMethodNotAllowed: + """Verify that wrong HTTP methods return 405.""" + + @pytest.mark.asyncio + async def test_get_on_invocations_returns_405(self, echo_client): + """GET /invocations is not allowed (POST only).""" + resp = await echo_client.get("/invocations") + assert resp.status_code == 405 + + @pytest.mark.asyncio + async def test_post_on_liveness_returns_405(self, echo_client): + """POST /liveness is not allowed (GET only).""" + resp = await echo_client.post("/liveness") + assert resp.status_code == 405 + + @pytest.mark.asyncio + async def test_post_on_readiness_returns_405(self, echo_client): + """POST /readiness is not allowed (GET only).""" + resp = await echo_client.post("/readiness") + assert resp.status_code == 405 + + @pytest.mark.asyncio + async def test_put_on_invocations_returns_405(self, echo_client): + """PUT /invocations is not allowed.""" + resp = await echo_client.put("/invocations", content=b"{}") + assert resp.status_code == 405 + + @pytest.mark.asyncio + async def test_delete_on_invocations_id_returns_405(self, echo_client): + """DELETE /invocations/{id} is not allowed.""" + resp = await echo_client.delete(f"/invocations/{uuid.uuid4()}") + assert resp.status_code == 405 + + @pytest.mark.asyncio + async def test_post_on_openapi_spec_returns_405(self, echo_client): + """POST /invocations/docs/openapi.json is not allowed (GET only).""" + resp = await echo_client.post("/invocations/docs/openapi.json", content=b"{}") + assert resp.status_code == 405 + + +# --------------------------------------------------------------------------- +# Tests: OpenAPI validation edge cases +# --------------------------------------------------------------------------- + + +class TestOpenAPIValidation: + """End-to-end OpenAPI validation via HTTP.""" + + @pytest.mark.asyncio + async def test_valid_request_returns_200(self, validated_client): + """POST with valid body passes validation and returns greeting.""" + resp = await validated_client.post( + "/invocations", + content=json.dumps({"name": "Alice"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + assert resp.json()["greeting"] == "Hello, Alice!" + + @pytest.mark.asyncio + async def test_missing_required_field_returns_400(self, validated_client): + """POST missing required 'name' field returns 400 with validation details.""" + resp = await validated_client.post( + "/invocations", + content=json.dumps({"wrong_field": "foo"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + body = resp.json() + assert body["error"]["code"] == "invalid_payload" + assert "details" in body["error"] + + @pytest.mark.asyncio + async def test_malformed_json_returns_400(self, validated_client): + """POST with non-JSON body when spec expects JSON returns 400.""" + resp = await validated_client.post( + "/invocations", + content=b"this is not json", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + assert resp.json()["error"]["code"] == "invalid_payload" + + @pytest.mark.asyncio + async def test_extra_field_accepted(self, validated_client): + """Extra fields are accepted when additionalProperties is not false.""" + resp = await validated_client.post( + "/invocations", + content=json.dumps({"name": "Bob", "bonus": 42}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + assert resp.json()["greeting"] == "Hello, Bob!" + + @pytest.mark.asyncio + async def test_no_spec_skips_validation(self, no_spec_client): + """Agent with no OpenAPI spec accepts any payload without validation.""" + resp = await no_spec_client.post( + "/invocations", + content=b"arbitrary bytes not json at all!", + ) + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# Tests: response header behaviour +# --------------------------------------------------------------------------- + + +class TestResponseHeaders: + """Verify invocation-id header auto-injection and passthrough.""" + + @pytest.mark.asyncio + async def test_agent_custom_invocation_id_not_overwritten(self, custom_header_client): + """If agent sets x-agent-invocation-id, the server doesn't overwrite it.""" + resp = await custom_header_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert resp.headers["x-agent-invocation-id"] == "custom-id-from-agent" + + @pytest.mark.asyncio + async def test_invocation_id_injected_when_missing(self, echo_client): + """EchoAgent doesn't set the header; server auto-injects it.""" + resp = await echo_client.post("/invocations", content=b"hello") + assert resp.status_code == 200 + invocation_id = resp.headers.get("x-agent-invocation-id") + assert invocation_id is not None + uuid.UUID(invocation_id) # valid UUID + + +# --------------------------------------------------------------------------- +# Tests: payload edge cases +# --------------------------------------------------------------------------- + + +class TestPayloadEdgeCases: + """Body content edge cases.""" + + @pytest.mark.asyncio + async def test_large_payload(self, large_payload_client): + """Server handles a 1 MB payload correctly.""" + big = b"A" * (1024 * 1024) + resp = await large_payload_client.post("/invocations", content=big) + assert resp.status_code == 200 + assert resp.json()["length"] == 1024 * 1024 + + @pytest.mark.asyncio + async def test_unicode_body(self, echo_client): + """Unicode characters round-trip correctly.""" + data = "\u3053\u3093\u306b\u3061\u306f\u4e16\u754c \U0001f30d \u0645\u0631\u062d\u0628\u0627" + resp = await echo_client.post("/invocations", content=data.encode("utf-8")) + assert resp.status_code == 200 + assert resp.content.decode("utf-8") == data + + @pytest.mark.asyncio + async def test_binary_body(self, echo_client): + """Binary (non-UTF-8) bytes round-trip correctly.""" + data = bytes(range(256)) + resp = await echo_client.post("/invocations", content=data) + assert resp.status_code == 200 + assert resp.content == data + + +# --------------------------------------------------------------------------- +# Tests: streaming edge cases +# --------------------------------------------------------------------------- + + +class TestStreamingEdgeCases: + """Streaming response edge cases.""" + + @pytest.mark.asyncio + async def test_empty_streaming_response(self, empty_streaming_client): + """Empty streaming response returns 200 with no body.""" + resp = await empty_streaming_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert resp.content == b"" + + @pytest.mark.asyncio + async def test_streaming_response_has_invocation_id(self, streaming_client): + """Streaming response still gets the auto-injected invocation-id header.""" + resp = await streaming_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + invocation_id = resp.headers.get("x-agent-invocation-id") + assert invocation_id is not None + uuid.UUID(invocation_id) # valid UUID + + +# --------------------------------------------------------------------------- +# Tests: invocation lifecycle (storage agent) +# --------------------------------------------------------------------------- + + +class TestInvocationLifecycle: + """Multi-step invocation workflows.""" + + @pytest.mark.asyncio + async def test_multiple_gets_on_same_invocation(self, async_storage_client): + """GET /invocations/{id} returns the same data on repeated calls.""" + post_resp = await async_storage_client.post("/invocations", content=b'{"data":"hello"}') + invocation_id = post_resp.headers["x-agent-invocation-id"] + + get1 = await async_storage_client.get(f"/invocations/{invocation_id}") + get2 = await async_storage_client.get(f"/invocations/{invocation_id}") + assert get1.status_code == 200 + assert get2.status_code == 200 + assert get1.content == get2.content + + @pytest.mark.asyncio + async def test_double_cancel_returns_404_second_time(self, async_storage_client): + """Cancelling the same invocation twice: first succeeds, second 404.""" + post_resp = await async_storage_client.post("/invocations", content=b'{"data":"x"}') + invocation_id = post_resp.headers["x-agent-invocation-id"] + + cancel1 = await async_storage_client.post(f"/invocations/{invocation_id}/cancel") + assert cancel1.status_code == 200 + + cancel2 = await async_storage_client.post(f"/invocations/{invocation_id}/cancel") + assert cancel2.status_code == 404 + + @pytest.mark.asyncio + async def test_invoke_then_cancel_then_get_returns_404(self, async_storage_client): + """Full lifecycle: invoke → cancel → get returns 404.""" + post_resp = await async_storage_client.post("/invocations", content=b'{"val":1}') + invocation_id = post_resp.headers["x-agent-invocation-id"] + + await async_storage_client.post(f"/invocations/{invocation_id}/cancel") + + get_resp = await async_storage_client.get(f"/invocations/{invocation_id}") + assert get_resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# Tests: debug errors on get/cancel endpoints +# --------------------------------------------------------------------------- + + +class TestDebugErrorsOnGetCancel: + """AGENT_DEBUG_ERRORS exposes details on GET and CANCEL error responses too.""" + + @pytest.mark.asyncio + async def test_get_hides_details_by_default(self): + """GET error hides exception detail without AGENT_DEBUG_ERRORS.""" + agent = _make_slow_failing_get_agent() + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/invocations/some-id") + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "Internal server error" + + @pytest.mark.asyncio + async def test_get_exposes_details_with_debug(self, monkeypatch): + """GET error exposes exception detail with AGENT_DEBUG_ERRORS set.""" + monkeypatch.setenv("AGENT_DEBUG_ERRORS", "true") + agent = _make_slow_failing_get_agent() + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/invocations/some-id") + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "get-debug-detail" + + @pytest.mark.asyncio + async def test_cancel_hides_details_by_default(self): + """CANCEL error hides exception detail without AGENT_DEBUG_ERRORS.""" + agent = _make_slow_failing_cancel_agent() + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations/some-id/cancel") + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "Internal server error" + + @pytest.mark.asyncio + async def test_cancel_exposes_details_with_debug(self, monkeypatch): + """CANCEL error exposes exception detail with AGENT_DEBUG_ERRORS set.""" + monkeypatch.setenv("AGENT_DEBUG_ERRORS", "true") + agent = _make_slow_failing_cancel_agent() + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations/some-id/cancel") + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "cancel-debug-detail" + + @pytest.mark.asyncio + async def test_get_exposes_details_with_constructor_param(self): + """debug_errors=True in constructor exposes GET exception detail.""" + agent = _make_slow_failing_get_agent(debug_errors=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/invocations/some-id") + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "get-debug-detail" + + @pytest.mark.asyncio + async def test_cancel_exposes_details_with_constructor_param(self): + """debug_errors=True in constructor exposes CANCEL exception detail.""" + agent = _make_slow_failing_cancel_agent(debug_errors=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations/some-id/cancel") + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "cancel-debug-detail" + + @pytest.mark.asyncio + async def test_constructor_overrides_env_var(self, monkeypatch): + """debug_errors=False in constructor overrides AGENT_DEBUG_ERRORS=true.""" + monkeypatch.setenv("AGENT_DEBUG_ERRORS", "true") + agent = _make_slow_failing_get_agent(debug_errors=False) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/invocations/some-id") + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "Internal server error" + + +# --------------------------------------------------------------------------- +# Tests: log_level constructor parameter +# --------------------------------------------------------------------------- + + +class TestLogLevel: + """log_level parameter controls the library logger level.""" + + def test_log_level_via_constructor(self): + """log_level='DEBUG' sets library logger to DEBUG.""" + import logging + agent = _make_custom_header_agent(log_level="DEBUG") + lib_logger = logging.getLogger("azure.ai.agentserver") + assert lib_logger.level == logging.DEBUG + + def test_log_level_via_env_var(self, monkeypatch): + """AGENT_LOG_LEVEL=info sets library logger to INFO.""" + import logging + monkeypatch.setenv("AGENT_LOG_LEVEL", "info") + agent = _make_custom_header_agent() + lib_logger = logging.getLogger("azure.ai.agentserver") + assert lib_logger.level == logging.INFO + + def test_log_level_constructor_overrides_env_var(self, monkeypatch): + """Constructor log_level overrides AGENT_LOG_LEVEL env var.""" + import logging + monkeypatch.setenv("AGENT_LOG_LEVEL", "DEBUG") + agent = _make_custom_header_agent(log_level="ERROR") + lib_logger = logging.getLogger("azure.ai.agentserver") + assert lib_logger.level == logging.ERROR + + def test_invalid_log_level_raises(self): + """Invalid log_level raises ValueError.""" + with pytest.raises(ValueError, match="Invalid log level"): + _make_custom_header_agent(log_level="BOGUS") + + +class TestConcurrency: + """Verify multiple concurrent requests produce unique invocation IDs.""" + + @pytest.mark.asyncio + async def test_concurrent_invocations_unique_ids(self, echo_client): + """10 concurrent POSTs each get a unique invocation ID.""" + import asyncio + + async def do_post(): + return await echo_client.post("/invocations", content=b"concurrent") + + responses = await asyncio.gather(*[do_post() for _ in range(10)]) + ids = {r.headers["x-agent-invocation-id"] for r in responses} + assert len(ids) == 10 # all unique diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_get_cancel.py b/sdk/agentserver/azure-ai-agentserver/tests/test_get_cancel.py new file mode 100644 index 000000000000..5a6d1b963abd --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_get_cancel.py @@ -0,0 +1,111 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for get_invocation and cancel_invocation.""" +import json +import uuid + +import pytest + + +@pytest.mark.asyncio +async def test_get_invocation_after_invoke(async_storage_client): + """Invoke, then GET /invocations/{id} returns stored result.""" + resp = await async_storage_client.post("/invocations", content=b'{"key":"value"}') + invocation_id = resp.headers["x-agent-invocation-id"] + + get_resp = await async_storage_client.get(f"/invocations/{invocation_id}") + assert get_resp.status_code == 200 + data = json.loads(get_resp.content) + assert "echo" in data + + +@pytest.mark.asyncio +async def test_get_invocation_unknown_id_returns_404(async_storage_client): + """GET /invocations/{unknown} returns 404.""" + resp = await async_storage_client.get(f"/invocations/{uuid.uuid4()}") + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_cancel_invocation_after_invoke(async_storage_client): + """Invoke, then POST /invocations/{id}/cancel returns cancelled status.""" + resp = await async_storage_client.post("/invocations", content=b'{"key":"value"}') + invocation_id = resp.headers["x-agent-invocation-id"] + + cancel_resp = await async_storage_client.post(f"/invocations/{invocation_id}/cancel") + assert cancel_resp.status_code == 200 + data = json.loads(cancel_resp.content) + assert data["status"] == "cancelled" + + +@pytest.mark.asyncio +async def test_cancel_invocation_unknown_id_returns_404(async_storage_client): + """POST /invocations/{unknown}/cancel returns 404.""" + resp = await async_storage_client.post(f"/invocations/{uuid.uuid4()}/cancel") + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_get_after_cancel_returns_404(async_storage_client): + """Cancel, then get same ID returns 404.""" + resp = await async_storage_client.post("/invocations", content=b'{"key":"value"}') + invocation_id = resp.headers["x-agent-invocation-id"] + + await async_storage_client.post(f"/invocations/{invocation_id}/cancel") + get_resp = await async_storage_client.get(f"/invocations/{invocation_id}") + assert get_resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_get_invocation_error_returns_500(): + """GET /invocations/{id} returns 500 when customer code raises an error.""" + import httpx + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + + from azure.ai.agentserver import AgentServer + + server = AgentServer() + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return JSONResponse({}) + + @server.get_invocation_handler + async def get_inv(request: Request) -> Response: + raise RuntimeError("storage unavailable") + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/invocations/some-id") + assert resp.status_code == 500 + assert resp.json()["error"]["code"] == "internal_error" + assert resp.json()["error"]["message"] == "Internal server error" + + +@pytest.mark.asyncio +async def test_cancel_invocation_error_returns_500(): + """POST /invocations/{id}/cancel returns 500 when customer code raises an error.""" + import httpx + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + + from azure.ai.agentserver import AgentServer + + server = AgentServer() + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return JSONResponse({}) + + @server.cancel_invocation_handler + async def cancel(request: Request) -> Response: + raise RuntimeError("cancel failed") + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations/some-id/cancel") + assert resp.status_code == 500 + assert resp.json()["error"]["code"] == "internal_error" + assert resp.json()["error"]["message"] == "Internal server error" diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_graceful_shutdown.py b/sdk/agentserver/azure-ai-agentserver/tests/test_graceful_shutdown.py new file mode 100644 index 000000000000..73d0d8b63479 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_graceful_shutdown.py @@ -0,0 +1,438 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for graceful shutdown configuration and lifecycle behaviour.""" +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from starlette.requests import Request +from starlette.responses import Response + +from azure.ai.agentserver import AgentServer +from azure.ai.agentserver._constants import Constants + + +# --------------------------------------------------------------------------- +# Agent factory functions +# --------------------------------------------------------------------------- + + +def _make_stub_agent(**kwargs) -> AgentServer: + """Create a no-op agent used to inspect internal state.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return Response(content=b"ok") + + return server + + +# --------------------------------------------------------------------------- +# _resolve_graceful_shutdown_timeout +# --------------------------------------------------------------------------- + + +class TestResolveGracefulShutdownTimeout: + """Unit tests for the timeout resolution hierarchy: explicit > env > default.""" + + def test_explicit_value_takes_precedence(self): + agent = _make_stub_agent(graceful_shutdown_timeout=10) + assert agent._graceful_shutdown_timeout == 10 + + def test_explicit_zero_disables_drain(self): + agent = _make_stub_agent(graceful_shutdown_timeout=0) + assert agent._graceful_shutdown_timeout == 0 + + def test_env_var_used_when_no_explicit(self): + with patch.dict(os.environ, {Constants.AGENT_GRACEFUL_SHUTDOWN_TIMEOUT: "45"}): + agent = _make_stub_agent() + assert agent._graceful_shutdown_timeout == 45 + + def test_invalid_env_var_raises(self): + with patch.dict(os.environ, {Constants.AGENT_GRACEFUL_SHUTDOWN_TIMEOUT: "not-a-number"}): + with pytest.raises(ValueError, match="AGENT_GRACEFUL_SHUTDOWN_TIMEOUT"): + _make_stub_agent() + + def test_non_int_explicit_value_raises(self): + with pytest.raises(ValueError, match="expected an integer"): + _make_stub_agent(graceful_shutdown_timeout="ten") # type: ignore[arg-type] + + def test_default_when_nothing_set(self): + with patch.dict(os.environ, {}, clear=True): + agent = _make_stub_agent() + assert agent._graceful_shutdown_timeout == Constants.DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT + + def test_default_is_30_seconds(self): + assert Constants.DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT == 30 + + +# --------------------------------------------------------------------------- +# Constant exists +# --------------------------------------------------------------------------- + + +class TestConstants: + """Verify the new constant is wired correctly.""" + + def test_env_var_name(self): + assert Constants.AGENT_GRACEFUL_SHUTDOWN_TIMEOUT == "AGENT_GRACEFUL_SHUTDOWN_TIMEOUT" + + +# --------------------------------------------------------------------------- +# Hypercorn config receives graceful_timeout (sync run) +# --------------------------------------------------------------------------- + + +class TestRunPassesTimeout: + """Ensure run() forwards the timeout to Hypercorn config.""" + + @patch("hypercorn.asyncio.serve", new_callable=AsyncMock) + @patch("azure.ai.agentserver.server._base.asyncio") + def test_run_passes_timeout(self, mock_asyncio, _mock_serve): + agent = _make_stub_agent(graceful_shutdown_timeout=15) + agent.run() + mock_asyncio.run.assert_called_once() + # Verify the config built internally has the right graceful_timeout + config = agent._build_hypercorn_config("127.0.0.1", 8088) + assert config.graceful_timeout == 15.0 + + @patch("hypercorn.asyncio.serve", new_callable=AsyncMock) + @patch("azure.ai.agentserver.server._base.asyncio") + def test_run_passes_default_timeout(self, mock_asyncio, _mock_serve): + agent = _make_stub_agent() + agent.run() + config = agent._build_hypercorn_config("127.0.0.1", 8088) + assert config.graceful_timeout == 30.0 + + +# --------------------------------------------------------------------------- +# Hypercorn config receives graceful_timeout (async run) +# --------------------------------------------------------------------------- + + +class TestRunAsyncPassesTimeout: + """Ensure run_async() forwards the timeout to Hypercorn config.""" + + @pytest.mark.asyncio + @patch("hypercorn.asyncio.serve", new_callable=AsyncMock) + async def test_run_async_passes_timeout(self, mock_serve): + agent = _make_stub_agent(graceful_shutdown_timeout=20) + await agent.run_async() + mock_serve.assert_awaited_once() + # Check the config passed to serve + call_args = mock_serve.call_args + config = call_args[0][1] if len(call_args[0]) > 1 else call_args.kwargs.get("config") + assert config.graceful_timeout == 20.0 + + @pytest.mark.asyncio + @patch("hypercorn.asyncio.serve", new_callable=AsyncMock) + async def test_run_async_passes_default_timeout(self, mock_serve): + agent = _make_stub_agent() + await agent.run_async() + call_args = mock_serve.call_args + config = call_args[0][1] if len(call_args[0]) > 1 else call_args.kwargs.get("config") + assert config.graceful_timeout == 30.0 + + +# --------------------------------------------------------------------------- +# Lifespan shutdown log +# --------------------------------------------------------------------------- + + +class TestLifespanShutdown: + """Verify the lifespan emits a shutdown log message.""" + + @pytest.mark.asyncio + async def test_shutdown_log_emitted(self): + agent = _make_stub_agent(graceful_shutdown_timeout=42) + # Exercise the full lifespan by sending a request through the ASGI app + import httpx + + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/liveness") + assert resp.status_code == 200 + + # The shutdown log is emitted after the lifespan exits. + # We verify that the app is configured with the correct timeout so the + # log message references it. A direct lifespan exercise is below. + + @pytest.mark.asyncio + async def test_lifespan_shutdown_logs(self): + """Directly exercise the lifespan context manager and verify shutdown log.""" + import logging + + agent = _make_stub_agent(graceful_shutdown_timeout=99) + + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + # Grab the lifespan from the Starlette app + lifespan = agent.app.router.lifespan_context + async with lifespan(agent.app): + pass # startup yields here + + # After exiting the context manager, shutdown log should fire + shutdown_calls = [ + c + for c in mock_logger.info.call_args_list + if "shutting down" in str(c).lower() + ] + assert len(shutdown_calls) == 1 + assert "99" in str(shutdown_calls[0]) + + +# --------------------------------------------------------------------------- +# on_shutdown overridable method +# --------------------------------------------------------------------------- + + +def _make_shutdown_recording_agent(**kwargs) -> AgentServer: + """Create an agent that records on_shutdown calls.""" + server = AgentServer(**kwargs) + server.shutdown_log: list[str] = [] + + @server.invoke_handler + async def handle(request: Request) -> Response: + return Response(content=b"ok") + + @server.shutdown_handler + async def shutdown(): + server.shutdown_log.append("shutdown") + + return server + + +def _make_async_checkpoint_agent(**kwargs) -> AgentServer: + """Create an agent whose shutdown handler performs async work.""" + server = AgentServer(**kwargs) + server.flushed: list[str] = [] + + @server.invoke_handler + async def handle(request: Request) -> Response: + return Response(content=b"ok") + + @server.shutdown_handler + async def shutdown(): + import asyncio + await asyncio.sleep(0) + server.flushed.append("async-checkpoint") + + return server + + +def _make_failing_shutdown_agent(**kwargs) -> AgentServer: + """Create an agent whose shutdown handler raises an exception.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return Response(content=b"ok") + + @server.shutdown_handler + async def shutdown(): + raise RuntimeError("disk full") + + return server + + +class TestOnShutdownMethod: + """Verify the overridable on_shutdown() method is called during lifespan teardown.""" + + @pytest.mark.asyncio + async def test_on_shutdown_called(self): + agent = _make_shutdown_recording_agent() + lifespan = agent.app.router.lifespan_context + async with lifespan(agent.app): + pass + assert agent.shutdown_log == ["shutdown"] + + @pytest.mark.asyncio + async def test_async_work_in_on_shutdown(self): + agent = _make_async_checkpoint_agent() + lifespan = agent.app.router.lifespan_context + async with lifespan(agent.app): + pass + assert agent.flushed == ["async-checkpoint"] + + @pytest.mark.asyncio + async def test_default_on_shutdown_is_noop(self): + """Base class on_shutdown does nothing and doesn't raise.""" + agent = _make_stub_agent() + lifespan = agent.app.router.lifespan_context + async with lifespan(agent.app): + pass # should complete without error + + @pytest.mark.asyncio + async def test_on_shutdown_exception_is_logged_not_raised(self): + """A failing on_shutdown must not crash the shutdown sequence.""" + agent = _make_failing_shutdown_agent() + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + lifespan = agent.app.router.lifespan_context + async with lifespan(agent.app): + pass # should NOT raise + + exception_calls = [ + c + for c in mock_logger.exception.call_args_list + if "on_shutdown" in str(c).lower() + ] + assert len(exception_calls) == 1 + + @pytest.mark.asyncio + async def test_on_shutdown_runs_after_shutdown_log(self): + """on_shutdown fires after the shutdown log message.""" + order: list[str] = [] + + server = AgentServer() + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return Response(content=b"ok") + + @server.shutdown_handler + async def shutdown(): + order.append("callback") + + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + + def tracking_info(*args, **kwargs): + if args and "shutting down" in str(args[0]).lower(): + order.append("log") + + mock_logger.info.side_effect = tracking_info + + lifespan = server.app.router.lifespan_context + async with lifespan(server.app): + pass + + assert order == ["log", "callback"] + + @pytest.mark.asyncio + async def test_on_shutdown_has_access_to_state(self): + """Shutdown handler can access state via closure.""" + connections: list[str] = ["db", "cache"] + closed: list[str] = [] + + server = AgentServer() + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return Response(content=b"ok") + + @server.shutdown_handler + async def shutdown(): + for conn in connections: + closed.append(conn) + + lifespan = server.app.router.lifespan_context + async with lifespan(server.app): + pass + assert closed == ["db", "cache"] + + +# --------------------------------------------------------------------------- +# on_shutdown timeout enforcement +# --------------------------------------------------------------------------- + + +class TestOnShutdownTimeout: + """Verify on_shutdown is bounded by graceful_shutdown_timeout.""" + + @pytest.mark.asyncio + async def test_slow_on_shutdown_is_cancelled_and_warning_logged(self): + """If on_shutdown exceeds the timeout, it is cancelled and a warning is logged.""" + import asyncio + + server = AgentServer(graceful_shutdown_timeout=1) + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return Response(content=b"ok") + + @server.shutdown_handler + async def shutdown(): + await asyncio.sleep(999) # way longer than the 1s timeout + + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + lifespan = server.app.router.lifespan_context + async with lifespan(server.app): + pass # should NOT hang + + warning_calls = [ + c + for c in mock_logger.warning.call_args_list + if "did not complete" in str(c).lower() + ] + assert len(warning_calls) == 1 + assert "1" in str(warning_calls[0]) + + @pytest.mark.asyncio + async def test_fast_on_shutdown_completes_normally(self): + """on_shutdown that finishes within the timeout succeeds without warnings.""" + import asyncio + + done = False + server = AgentServer(graceful_shutdown_timeout=5) + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return Response(content=b"ok") + + @server.shutdown_handler + async def shutdown(): + nonlocal done + await asyncio.sleep(0) + done = True + + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + lifespan = server.app.router.lifespan_context + async with lifespan(server.app): + pass + + warning_calls = [ + c + for c in mock_logger.warning.call_args_list + if "did not complete" in str(c).lower() + ] + assert len(warning_calls) == 0 + assert done is True + + @pytest.mark.asyncio + async def test_zero_timeout_disables_wait(self): + """graceful_shutdown_timeout=0 passes None to wait_for (no timeout).""" + import asyncio + + ran = False + server = AgentServer(graceful_shutdown_timeout=0) + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return Response(content=b"ok") + + @server.shutdown_handler + async def shutdown(): + nonlocal ran + ran = True + + lifespan = server.app.router.lifespan_context + async with lifespan(server.app): + pass + assert ran is True + + @pytest.mark.asyncio + async def test_timeout_does_not_suppress_exceptions(self): + """Exceptions from on_shutdown are still logged even if within timeout.""" + agent = _make_failing_shutdown_agent(graceful_shutdown_timeout=5) + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + lifespan = agent.app.router.lifespan_context + async with lifespan(agent.app): + pass + + exception_calls = [ + c + for c in mock_logger.exception.call_args_list + if "on_shutdown" in str(c).lower() + ] + assert len(exception_calls) == 1 diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_health.py b/sdk/agentserver/azure-ai-agentserver/tests/test_health.py new file mode 100644 index 000000000000..3c0786f8eeac --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_health.py @@ -0,0 +1,23 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for health check endpoints.""" +import pytest + + +@pytest.mark.asyncio +async def test_liveness_returns_200(echo_client): + """GET /liveness returns 200.""" + resp = await echo_client.get("/liveness") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "alive" + + +@pytest.mark.asyncio +async def test_readiness_returns_200(echo_client): + """GET /readiness returns 200.""" + resp = await echo_client.get("/readiness") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ready" diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_http2.py b/sdk/agentserver/azure-ai-agentserver/tests/test_http2.py new file mode 100644 index 000000000000..15f013804bc9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_http2.py @@ -0,0 +1,473 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Integration tests verifying HTTP/2 support via Hypercorn. + +**TLS tests** start a real Hypercorn server with a self-signed cert and +connect via HTTP/2 using ``httpx`` with ALPN negotiation. # cspell:ignore ALPN + +**h2c tests** (HTTP/2 cleartext) start a plain-TCP Hypercorn server and +connect via HTTP/2 *prior knowledge* (``http1=False, http2=True``), +proving that HTTP/2 works without any TLS certificates. + +Requirements (auto-skipped when missing): +- ``cryptography`` — self-signed certificate generation (TLS tests only) +- ``h2`` — HTTP/2 client support in httpx (installed with hypercorn) +""" +from __future__ import annotations + +import asyncio +import datetime +import ipaddress +import socket +import tempfile +from pathlib import Path +from typing import AsyncGenerator + +import pytest +import pytest_asyncio + +cryptography = pytest.importorskip("cryptography", reason="cryptography needed for TLS cert generation") +pytest.importorskip("h2", reason="h2 needed for HTTP/2 client support") + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID + +import httpx +from hypercorn.asyncio import serve as _hypercorn_serve +from hypercorn.config import Config as HypercornConfig +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver import AgentServer + + +# --------------------------------------------------------------------------- +# Test agents +# --------------------------------------------------------------------------- + + +def _make_echo_agent(**kwargs): + """Agent that echoes the request body in the response.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + body = await request.body() + return JSONResponse({"echo": body.decode()}) + + return server + + +def _make_stream_agent(**kwargs): + """Agent that returns a streaming response.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + async def generate() -> AsyncGenerator[bytes, None]: + for chunk in [b"chunk1", b"chunk2", b"chunk3"]: + yield chunk + + return StreamingResponse(generate(), media_type="application/octet-stream") + + return server + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_free_port() -> int: + """Find an available TCP port.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def _generate_self_signed_cert(cert_path: Path, key_path: Path) -> None: + """Generate a self-signed TLS certificate for localhost / 127.0.0.1.""" + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), + ]) + + now = datetime.datetime.now(datetime.timezone.utc) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + datetime.timedelta(hours=1)) + .add_extension( + x509.SubjectAlternativeName([ + x509.DNSName("localhost"), + x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), + ]), + critical=False, + ) + .sign(key, hashes.SHA256()) + ) + + key_path.write_bytes( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM)) + + +async def _wait_for_server(host: str, port: int, timeout: float = 5.0) -> None: + """Poll until the server is accepting connections.""" + deadline = asyncio.get_event_loop().time() + timeout + while asyncio.get_event_loop().time() < deadline: + try: + _, writer = await asyncio.open_connection(host, port) + writer.close() + await writer.wait_closed() + return + except OSError: + await asyncio.sleep(0.05) + raise TimeoutError(f"Server on {host}:{port} did not start within {timeout}s") + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def tls_cert_pair(): + """Generate a self-signed cert + key in a temp directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + cert_path = Path(tmpdir) / "cert.pem" + key_path = Path(tmpdir) / "key.pem" + _generate_self_signed_cert(cert_path, key_path) + yield str(cert_path), str(key_path) + + +@pytest_asyncio.fixture() +async def echo_h2_server(tls_cert_pair): + """Start _EchoAgent on a random port with TLS, yield the port.""" + cert_path, key_path = tls_cert_pair + port = _get_free_port() + + agent = _make_echo_agent() + config = HypercornConfig() + config.bind = [f"127.0.0.1:{port}"] + config.certfile = cert_path + config.keyfile = key_path + config.graceful_timeout = 1.0 + + shutdown_event = asyncio.Event() + task = asyncio.create_task( + _hypercorn_serve(agent.app, config, shutdown_trigger=shutdown_event.wait) + ) + await _wait_for_server("127.0.0.1", port) + yield port + + shutdown_event.set() + await task + + +@pytest_asyncio.fixture() +async def stream_h2_server(tls_cert_pair): + """Start _StreamAgent on a random port with TLS, yield the port.""" + cert_path, key_path = tls_cert_pair + port = _get_free_port() + + agent = _make_stream_agent() + config = HypercornConfig() + config.bind = [f"127.0.0.1:{port}"] + config.certfile = cert_path + config.keyfile = key_path + config.graceful_timeout = 1.0 + + shutdown_event = asyncio.Event() + task = asyncio.create_task( + _hypercorn_serve(agent.app, config, shutdown_trigger=shutdown_event.wait) + ) + await _wait_for_server("127.0.0.1", port) + yield port + + shutdown_event.set() + await task + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestHttp2Health: + """Verify health endpoints work over HTTP/2.""" + + @pytest.mark.asyncio + async def test_h2_liveness(self, echo_h2_server): + port = echo_h2_server + async with httpx.AsyncClient(http2=True, verify=False) as client: + resp = await client.get(f"https://127.0.0.1:{port}/liveness") + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + assert resp.json() == {"status": "alive"} + + @pytest.mark.asyncio + async def test_h2_readiness(self, echo_h2_server): + port = echo_h2_server + async with httpx.AsyncClient(http2=True, verify=False) as client: + resp = await client.get(f"https://127.0.0.1:{port}/readiness") + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + assert resp.json() == {"status": "ready"} + + +class TestHttp2Invoke: + """Verify invocation works over HTTP/2.""" + + @pytest.mark.asyncio + async def test_h2_invoke_returns_200(self, echo_h2_server): + port = echo_h2_server + async with httpx.AsyncClient(http2=True, verify=False) as client: + resp = await client.post( + f"https://127.0.0.1:{port}/invocations", + json={"message": "hello"}, + ) + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + data = resp.json() + assert "echo" in data + + @pytest.mark.asyncio + async def test_h2_invocation_id_header(self, echo_h2_server): + port = echo_h2_server + async with httpx.AsyncClient(http2=True, verify=False) as client: + resp = await client.post( + f"https://127.0.0.1:{port}/invocations", + json={"message": "test"}, + ) + assert resp.http_version == "HTTP/2" + assert "x-agent-invocation-id" in resp.headers + + @pytest.mark.asyncio + async def test_h2_multiple_requests_on_same_connection(self, echo_h2_server): + """HTTP/2 multiplexing: multiple requests on a single connection.""" + port = echo_h2_server + async with httpx.AsyncClient(http2=True, verify=False) as client: + responses = await asyncio.gather( + client.post(f"https://127.0.0.1:{port}/invocations", json={"n": 1}), + client.post(f"https://127.0.0.1:{port}/invocations", json={"n": 2}), + client.post(f"https://127.0.0.1:{port}/invocations", json={"n": 3}), + ) + for resp in responses: + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + # Each invocation gets a unique ID + ids = {resp.headers["x-agent-invocation-id"] for resp in responses} + assert len(ids) == 3 + + +class TestHttp2Streaming: + """Verify streaming responses work over HTTP/2.""" + + @pytest.mark.asyncio + async def test_h2_streaming_response(self, stream_h2_server): + port = stream_h2_server + async with httpx.AsyncClient(http2=True, verify=False) as client: + resp = await client.post( + f"https://127.0.0.1:{port}/invocations", + json={}, + ) + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + assert resp.content == b"chunk1chunk2chunk3" + + @pytest.mark.asyncio + async def test_h2_streaming_has_invocation_id(self, stream_h2_server): + port = stream_h2_server + async with httpx.AsyncClient(http2=True, verify=False) as client: + resp = await client.post( + f"https://127.0.0.1:{port}/invocations", + json={}, + ) + assert resp.http_version == "HTTP/2" + assert "x-agent-invocation-id" in resp.headers + + +class TestHttp1Fallback: + """Verify HTTP/1.1 still works (client without h2 negotiation).""" + + @pytest.mark.asyncio + async def test_http1_invoke(self, echo_h2_server): + """A client that does not request HTTP/2 gets HTTP/1.1.""" + port = echo_h2_server + async with httpx.AsyncClient(http2=False, verify=False) as client: + resp = await client.post( + f"https://127.0.0.1:{port}/invocations", + json={"message": "http1"}, + ) + assert resp.status_code == 200 + assert resp.http_version == "HTTP/1.1" + + +# =========================================================================== +# h2c (HTTP/2 cleartext) — no TLS certificates required +# =========================================================================== + + +@pytest_asyncio.fixture() +async def echo_h2c_server(): + """Start _EchoAgent on a random port *without* TLS, yield the port.""" + port = _get_free_port() + + agent = _make_echo_agent() + config = HypercornConfig() + config.bind = [f"127.0.0.1:{port}"] + config.graceful_timeout = 1.0 + + shutdown_event = asyncio.Event() + task = asyncio.create_task( + _hypercorn_serve(agent.app, config, shutdown_trigger=shutdown_event.wait) + ) + await _wait_for_server("127.0.0.1", port) + yield port + + shutdown_event.set() + await task + + +@pytest_asyncio.fixture() +async def stream_h2c_server(): + """Start _StreamAgent on a random port *without* TLS, yield the port.""" + port = _get_free_port() + + agent = _make_stream_agent() + config = HypercornConfig() + config.bind = [f"127.0.0.1:{port}"] + config.graceful_timeout = 1.0 + + shutdown_event = asyncio.Event() + task = asyncio.create_task( + _hypercorn_serve(agent.app, config, shutdown_trigger=shutdown_event.wait) + ) + await _wait_for_server("127.0.0.1", port) + yield port + + shutdown_event.set() + await task + + +class TestH2cHealth: + """Verify health endpoints work over h2c (HTTP/2 cleartext).""" + + @pytest.mark.asyncio + async def test_h2c_liveness(self, echo_h2c_server): + port = echo_h2c_server + async with httpx.AsyncClient(http1=False, http2=True) as client: + resp = await client.get(f"http://127.0.0.1:{port}/liveness") + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + assert resp.json() == {"status": "alive"} + + @pytest.mark.asyncio + async def test_h2c_readiness(self, echo_h2c_server): + port = echo_h2c_server + async with httpx.AsyncClient(http1=False, http2=True) as client: + resp = await client.get(f"http://127.0.0.1:{port}/readiness") + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + assert resp.json() == {"status": "ready"} + + +class TestH2cInvoke: + """Verify invocation works over h2c (HTTP/2 cleartext).""" + + @pytest.mark.asyncio + async def test_h2c_invoke_returns_200(self, echo_h2c_server): + port = echo_h2c_server + async with httpx.AsyncClient(http1=False, http2=True) as client: + resp = await client.post( + f"http://127.0.0.1:{port}/invocations", + json={"message": "hello"}, + ) + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + data = resp.json() + assert "echo" in data + + @pytest.mark.asyncio + async def test_h2c_invocation_id_header(self, echo_h2c_server): + port = echo_h2c_server + async with httpx.AsyncClient(http1=False, http2=True) as client: + resp = await client.post( + f"http://127.0.0.1:{port}/invocations", + json={"message": "test"}, + ) + assert resp.http_version == "HTTP/2" + assert "x-agent-invocation-id" in resp.headers + + @pytest.mark.asyncio + async def test_h2c_multiplexing(self, echo_h2c_server): + """HTTP/2 multiplexing over cleartext: concurrent requests, unique IDs.""" + port = echo_h2c_server + async with httpx.AsyncClient(http1=False, http2=True) as client: + responses = await asyncio.gather( + client.post(f"http://127.0.0.1:{port}/invocations", json={"n": 1}), + client.post(f"http://127.0.0.1:{port}/invocations", json={"n": 2}), + client.post(f"http://127.0.0.1:{port}/invocations", json={"n": 3}), + ) + for resp in responses: + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + ids = {resp.headers["x-agent-invocation-id"] for resp in responses} + assert len(ids) == 3 + + +class TestH2cStreaming: + """Verify streaming responses work over h2c (HTTP/2 cleartext).""" + + @pytest.mark.asyncio + async def test_h2c_streaming_response(self, stream_h2c_server): + port = stream_h2c_server + async with httpx.AsyncClient(http1=False, http2=True) as client: + resp = await client.post( + f"http://127.0.0.1:{port}/invocations", + json={}, + ) + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + assert resp.content == b"chunk1chunk2chunk3" + + @pytest.mark.asyncio + async def test_h2c_streaming_has_invocation_id(self, stream_h2c_server): + port = stream_h2c_server + async with httpx.AsyncClient(http1=False, http2=True) as client: + resp = await client.post( + f"http://127.0.0.1:{port}/invocations", + json={}, + ) + assert resp.http_version == "HTTP/2" + assert "x-agent-invocation-id" in resp.headers + + +class TestH2cFallback: + """Verify HTTP/1.1 still works on the plain-TCP server.""" + + @pytest.mark.asyncio + async def test_h2c_server_accepts_http1(self, echo_h2c_server): + """A client that only speaks HTTP/1.1 is served normally.""" + port = echo_h2c_server + async with httpx.AsyncClient(http1=True, http2=False) as client: + resp = await client.post( + f"http://127.0.0.1:{port}/invocations", + json={"message": "http1"}, + ) + assert resp.status_code == 200 + assert resp.http_version == "HTTP/1.1" diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_invoke.py b/sdk/agentserver/azure-ai-agentserver/tests/test_invoke.py new file mode 100644 index 000000000000..2192e96c477b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_invoke.py @@ -0,0 +1,119 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for invoke dispatch (streaming and non-streaming).""" +import json +import uuid + +import httpx +import pytest + +from conftest import _make_failing_agent + + +@pytest.mark.asyncio +async def test_invoke_echoes_body(echo_client): + """POST /invocations body is passed to invoke() and echoed back.""" + payload = b'{"message":"ping"}' + resp = await echo_client.post("/invocations", content=payload) + assert resp.status_code == 200 + assert resp.content == payload + + +@pytest.mark.asyncio +async def test_invoke_receives_headers(echo_client): + """request.headers contains sent HTTP headers.""" + # EchoAgent echoes body; we just confirm the request succeeds with custom headers. + resp = await echo_client.post( + "/invocations", + content=b"{}", + headers={"X-Custom-Header": "test-value", "Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_invoke_receives_invocation_id(echo_client): + """Response x-agent-invocation-id is a non-empty UUID string.""" + resp = await echo_client.post("/invocations", content=b"{}") + invocation_id = resp.headers["x-agent-invocation-id"] + assert invocation_id + uuid.UUID(invocation_id) # raises if not valid UUID + + +@pytest.mark.asyncio +async def test_invoke_invocation_id_unique(echo_client): + """Two consecutive POST /invocations return different x-agent-invocation-id values.""" + resp1 = await echo_client.post("/invocations", content=b"{}") + resp2 = await echo_client.post("/invocations", content=b"{}") + id1 = resp1.headers["x-agent-invocation-id"] + id2 = resp2.headers["x-agent-invocation-id"] + assert id1 != id2 + + +@pytest.mark.asyncio +async def test_invoke_streaming_returns_chunked(streaming_client): + """Streaming agent returns a StreamingResponse.""" + resp = await streaming_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + # The response should contain all chunks + assert len(resp.content) > 0 + + +@pytest.mark.asyncio +async def test_invoke_streaming_yields_all_chunks(streaming_client): + """All chunks from the async generator are received by the client.""" + resp = await streaming_client.post("/invocations", content=b"{}") + lines = resp.content.decode().strip().split("\n") + chunks = [json.loads(line) for line in lines] + assert len(chunks) == 3 + assert chunks[0]["chunk"] == 0 + assert chunks[1]["chunk"] == 1 + assert chunks[2]["chunk"] == 2 + + +@pytest.mark.asyncio +async def test_invoke_streaming_has_invocation_id_header(streaming_client): + """Streaming response also includes x-agent-invocation-id header.""" + resp = await streaming_client.post("/invocations", content=b"{}") + invocation_id = resp.headers.get("x-agent-invocation-id") + assert invocation_id is not None + uuid.UUID(invocation_id) + + +@pytest.mark.asyncio +async def test_invoke_empty_body(echo_client): + """POST /invocations with empty body doesn't crash.""" + resp = await echo_client.post("/invocations", content=b"") + assert resp.status_code == 200 + assert resp.content == b"" + + +@pytest.mark.asyncio +async def test_invoke_error_returns_500(failing_client): + """When invoke() raises, server returns 500 with generic error message and invocation id.""" + resp = await failing_client.post("/invocations", content=b'{"key":"value"}') + assert resp.status_code == 500 + data = resp.json() + assert data["error"]["code"] == "internal_error" + assert data["error"]["message"] == "Internal server error" + assert resp.headers.get("x-agent-invocation-id") is not None + + +@pytest.mark.asyncio +async def test_invoke_error_hides_details_by_default(failing_client): + """Without AGENT_DEBUG_ERRORS, the actual exception message is hidden.""" + resp = await failing_client.post("/invocations", content=b'{}') + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "Internal server error" + + +@pytest.mark.asyncio +async def test_invoke_error_exposes_details_with_debug(): + """With debug_errors=True, the actual exception message is returned.""" + agent = _make_failing_agent(debug_errors=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b'{}') + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "something went wrong" diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_logger.py b/sdk/agentserver/azure-ai-agentserver/tests/test_logger.py new file mode 100644 index 000000000000..314211cec644 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_logger.py @@ -0,0 +1,22 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for the library-scoped logger.""" +import logging + + +def test_library_logger_exists(): + """The azure.ai.agentserver logger is a standard named logger.""" + lib_logger = logging.getLogger("azure.ai.agentserver") + assert lib_logger.name == "azure.ai.agentserver" + + +def test_log_level_preserved_across_imports(): + """Importing the server module does not reset a level already set.""" + lib_logger = logging.getLogger("azure.ai.agentserver") + lib_logger.setLevel(logging.ERROR) + + # Re-importing the base module should not override the level. + from azure.ai.agentserver.server import _base # noqa: F401 + + assert lib_logger.level == logging.ERROR diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_multimodal_protocol.py b/sdk/agentserver/azure-ai-agentserver/tests/test_multimodal_protocol.py new file mode 100644 index 000000000000..a8d718c7b268 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_multimodal_protocol.py @@ -0,0 +1,807 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for multi-modality payloads, content types, and protocol edge cases. + +AgentServer is an HTTP/ASGI server. It does NOT define WebSocket or gRPC +routes, so those protocol attempts must be handled gracefully. The +server is content-type agnostic: agents can receive and return any +media type (images, audio, protobuf, SSE, etc.). +""" +import base64 +import io +import json +import uuid + +import httpx +import pytest +import pytest_asyncio + +from starlette.requests import Request +from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse, Response, StreamingResponse + +from azure.ai.agentserver import AgentServer + + +# --------------------------------------------------------------------------- +# Minimal binary blobs for realistic content types +# --------------------------------------------------------------------------- + +# 1x1 red pixel PNG (67 bytes — smallest valid PNG) +_MINIMAL_PNG = bytes.fromhex( + "89504e470d0a1a0a" # PNG signature + "0000000d49484452" "00000001000000010802" "000000907753de" # header + "0000000c49444154" "789c63f80f00000101000518d84e" # image data + "0000000049454e44ae426082" # end +) + +# Minimal WAV header (44 bytes) — 1 sample, 16-bit mono 8 kHz +_MINIMAL_WAV = bytes.fromhex( + "52494646" # RIFF + "26000000" # file size - 8 (38) + "57415645" # WAVE + "666d7420" # fmt + "10000000" # chunk size (16) + "0100" # PCM + "0100" # mono + "401f0000" # sample rate 8000 + "803e0000" # byte rate 16000 + "0200" # block align + "1000" # bits per sample (16) + "64617461" # data + "02000000" # data bytes (2) + "0000" # 1 silent sample +) + +# Tiny protobuf-like payload (field 1, varint value 150) +_PROTO_PAYLOAD = b"\x08\x96\x01" + + +# --------------------------------------------------------------------------- +# Specialised agent implementations: multi-modal +# --------------------------------------------------------------------------- + + +def _make_ctype_echo_agent(**kwargs): + """Returns the request Content-Type and body length as JSON.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + ctype = request.headers.get("content-type", "") + body = await request.body() + return JSONResponse({"content_type": ctype, "length": len(body)}) + + return server + + +def _make_image_agent(**kwargs): + """Returns a pre-canned PNG image.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return Response(content=_MINIMAL_PNG, media_type="image/png") + + return server + + +def _make_audio_agent(**kwargs): + """Returns a pre-canned WAV clip.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return Response(content=_MINIMAL_WAV, media_type="audio/wav") + + return server + + +def _make_html_agent(**kwargs): + """Returns an HTML response.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return HTMLResponse("

Hello Agent

") + + return server + + +def _make_plaintext_agent(**kwargs): + """Returns plain text.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return PlainTextResponse("Hello, plain text world!") + + return server + + +def _make_xml_agent(**kwargs): + """Returns XML content.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + xml = 'ok' + return Response(content=xml.encode(), media_type="application/xml") + + return server + + +def _make_sse_agent(**kwargs): + """Returns a Server-Sent Events stream.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> StreamingResponse: + async def event_stream(): + for i in range(3): + yield f"event: message\ndata: {json.dumps({'n': i})}\n\n".encode() + + return StreamingResponse(event_stream(), media_type="text/event-stream") + + return server + + +def _make_custom_status_agent(**kwargs): + """Returns whichever HTTP status code the client requests in body.status_code.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + data = await request.json() + code = int(data.get("status_code", 200)) + return JSONResponse({"status_code": code}, status_code=code) + + return server + + +def _make_multipart_raw_agent(**kwargs): + """Echoes the raw multipart body's content-type and length.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + ctype = request.headers.get("content-type", "") + body = await request.body() + return JSONResponse({"content_type": ctype, "length": len(body), "body_prefix": body[:80].decode("latin-1")}) + + return server + + +def _make_query_string_agent(**kwargs): + """Echoes query parameters as JSON.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"query": dict(request.query_params)}) + + return server + + +def _make_b64image_agent(**kwargs): + """Accepts JSON with base64-encoded image data and decodes it.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + data = await request.json() + image_b64 = data.get("image", "") + decoded = base64.b64decode(image_b64) + return JSONResponse({"decoded_size": len(decoded), "starts_with_png": decoded[:4] == b"\x89PNG"}) + + return server + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def ctype_echo_client(): + server = _make_ctype_echo_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def image_client(): + server = _make_image_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def audio_client(): + server = _make_audio_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def html_client(): + server = _make_html_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def plaintext_client(): + server = _make_plaintext_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def xml_client(): + server = _make_xml_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def sse_client(): + server = _make_sse_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def custom_status_client(): + server = _make_custom_status_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def multipart_raw_client(): + server = _make_multipart_raw_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def query_client(): + server = _make_query_string_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def b64image_client(): + server = _make_b64image_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +# ======================================================================== +# MULTI-MODALITY: requests with different content types +# ======================================================================== + + +class TestMultiModalRequests: + """Verify the server accepts and forwards any content type to the agent.""" + + @pytest.mark.asyncio + async def test_image_png_payload(self, ctype_echo_client): + """POST with image/png content type and binary PNG data.""" + resp = await ctype_echo_client.post( + "/invocations", + content=_MINIMAL_PNG, + headers={"Content-Type": "image/png"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["content_type"] == "image/png" + assert body["length"] == len(_MINIMAL_PNG) + + @pytest.mark.asyncio + async def test_image_jpeg_payload(self, ctype_echo_client): + """POST with image/jpeg content type.""" + fake_jpeg = b"\xff\xd8\xff\xe0" + b"\x00" * 100 # JPEG SOI marker + padding + resp = await ctype_echo_client.post( + "/invocations", + content=fake_jpeg, + headers={"Content-Type": "image/jpeg"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "image/jpeg" + assert resp.json()["length"] == len(fake_jpeg) + + @pytest.mark.asyncio + async def test_audio_wav_payload(self, ctype_echo_client): + """POST with audio/wav content type and WAV header.""" + resp = await ctype_echo_client.post( + "/invocations", + content=_MINIMAL_WAV, + headers={"Content-Type": "audio/wav"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "audio/wav" + assert resp.json()["length"] == len(_MINIMAL_WAV) + + @pytest.mark.asyncio + async def test_video_mp4_payload(self, ctype_echo_client): + """POST with video/mp4 content type and opaque bytes.""" + fake_mp4 = bytes.fromhex("00000018667479706d703432") + b"\x00" * 200 + resp = await ctype_echo_client.post( + "/invocations", + content=fake_mp4, + headers={"Content-Type": "video/mp4"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "video/mp4" + + @pytest.mark.asyncio + async def test_protobuf_payload(self, ctype_echo_client): + """POST with application/x-protobuf content type.""" + resp = await ctype_echo_client.post( + "/invocations", + content=_PROTO_PAYLOAD, + headers={"Content-Type": "application/x-protobuf"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "application/x-protobuf" + assert resp.json()["length"] == len(_PROTO_PAYLOAD) + + @pytest.mark.asyncio + async def test_octet_stream_payload(self, ctype_echo_client): + """POST with application/octet-stream content type.""" + resp = await ctype_echo_client.post( + "/invocations", + content=bytes(range(256)), + headers={"Content-Type": "application/octet-stream"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "application/octet-stream" + assert resp.json()["length"] == 256 + + @pytest.mark.asyncio + async def test_msgpack_payload(self, ctype_echo_client): + """POST with application/msgpack content type (binary serialisation).""" + resp = await ctype_echo_client.post( + "/invocations", + content=b"\x82\xa3key\xa5value\xa3num\x2a", # {"key":"value","num":42} + headers={"Content-Type": "application/msgpack"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "application/msgpack" + + @pytest.mark.asyncio + async def test_base64_encoded_image_in_json(self, b64image_client): + """JSON body carrying a base64-encoded PNG (multi-modal pattern).""" + encoded = base64.b64encode(_MINIMAL_PNG).decode("ascii") + resp = await b64image_client.post( + "/invocations", + content=json.dumps({"image": encoded}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["decoded_size"] == len(_MINIMAL_PNG) + assert body["starts_with_png"] is True + + +# ======================================================================== +# MULTI-MODALITY: responses with different content types +# ======================================================================== + + +class TestMultiModalResponses: + """Agents can return any media type — verify end-to-end.""" + + @pytest.mark.asyncio + async def test_agent_returns_png(self, image_client): + """Agent returning image/png is received correctly by the client.""" + resp = await image_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "image/png" + assert resp.content[:4] == b"\x89PNG" + assert resp.headers.get("x-agent-invocation-id") is not None + + @pytest.mark.asyncio + async def test_agent_returns_wav(self, audio_client): + """Agent returning audio/wav is received correctly.""" + resp = await audio_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "audio/wav" + assert resp.content[:4] == b"RIFF" + + @pytest.mark.asyncio + async def test_agent_returns_html(self, html_client): + """Agent returning HTML is received with correct content type.""" + resp = await html_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert "text/html" in resp.headers["content-type"] + assert "

Hello Agent

" in resp.text + + @pytest.mark.asyncio + async def test_agent_returns_plain_text(self, plaintext_client): + """Agent returning plain text is received correctly.""" + resp = await plaintext_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert "text/plain" in resp.headers["content-type"] + assert resp.text == "Hello, plain text world!" + + @pytest.mark.asyncio + async def test_agent_returns_xml(self, xml_client): + """Agent returning application/xml is received correctly.""" + resp = await xml_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "application/xml" + assert "ok" in resp.text + + +# ======================================================================== +# STREAMING: SSE (Server-Sent Events) protocol +# ======================================================================== + + +class TestServerSentEvents: + """Verify SSE streaming works end-to-end.""" + + @pytest.mark.asyncio + async def test_sse_stream_yields_events(self, sse_client): + """Agent producing SSE events delivers them to the client.""" + resp = await sse_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert "text/event-stream" in resp.headers.get("content-type", "") + + # Parse the SSE stream + raw = resp.text + events = [line for line in raw.split("\n") if line.startswith("data:")] + assert len(events) == 3 + for i, ev in enumerate(events): + payload = json.loads(ev.removeprefix("data:").strip()) + assert payload["n"] == i + + @pytest.mark.asyncio + async def test_sse_has_invocation_id(self, sse_client): + """SSE response still carries x-agent-invocation-id.""" + resp = await sse_client.post("/invocations", content=b"{}") + assert resp.headers.get("x-agent-invocation-id") is not None + + +# ======================================================================== +# MULTIPART form data (file uploads) +# ======================================================================== + + +class TestMultipartFormData: + """Verify the server passes multipart/form-data payloads to the agent.""" + + @pytest.mark.asyncio + async def test_single_file_upload(self, multipart_raw_client): + """Upload a single file via multipart form data.""" + resp = await multipart_raw_client.post( + "/invocations", + files={"image": ("photo.png", _MINIMAL_PNG, "image/png")}, + ) + assert resp.status_code == 200 + body = resp.json() + assert "multipart/form-data" in body["content_type"] + assert body["length"] > len(_MINIMAL_PNG) # includes boundary overhead + + @pytest.mark.asyncio + async def test_multiple_files_upload(self, multipart_raw_client): + """Upload multiple files in one request.""" + resp = await multipart_raw_client.post( + "/invocations", + files=[ + ("file1", ("a.png", _MINIMAL_PNG, "image/png")), + ("file2", ("b.wav", _MINIMAL_WAV, "audio/wav")), + ], + ) + assert resp.status_code == 200 + body = resp.json() + assert "multipart/form-data" in body["content_type"] + # Total body must be larger than both payloads combined + assert body["length"] > len(_MINIMAL_PNG) + len(_MINIMAL_WAV) + + @pytest.mark.asyncio + async def test_mixed_form_fields_and_files(self, multipart_raw_client): + """Multipart with both plain text fields and file uploads.""" + resp = await multipart_raw_client.post( + "/invocations", + data={"prompt": "describe this image"}, + files={"image": ("pic.png", _MINIMAL_PNG, "image/png")}, + ) + assert resp.status_code == 200 + body = resp.json() + assert "multipart/form-data" in body["content_type"] + assert body["length"] > 0 + + +# ======================================================================== +# CUSTOM HTTP STATUS CODES +# ======================================================================== + + +class TestCustomStatusCodes: + """AgentServer preserves whatever HTTP status the agent returns.""" + + @pytest.mark.asyncio + async def test_201_created(self, custom_status_client): + """Agent returning 201 Created.""" + resp = await custom_status_client.post( + "/invocations", + content=json.dumps({"status_code": 201}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 201 + assert resp.headers.get("x-agent-invocation-id") is not None + + @pytest.mark.asyncio + async def test_202_accepted(self, custom_status_client): + """Agent returning 202 Accepted (async processing).""" + resp = await custom_status_client.post( + "/invocations", + content=json.dumps({"status_code": 202}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 202 + + @pytest.mark.asyncio + async def test_204_no_content(self, custom_status_client): + """Agent returning 204 No Content.""" + resp = await custom_status_client.post( + "/invocations", + content=json.dumps({"status_code": 204}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 204 + + @pytest.mark.asyncio + async def test_400_bad_request(self, custom_status_client): + """Agent can signal a 400 back to caller.""" + resp = await custom_status_client.post( + "/invocations", + content=json.dumps({"status_code": 400}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + @pytest.mark.asyncio + async def test_422_unprocessable(self, custom_status_client): + """Agent returning 422 Unprocessable Entity.""" + resp = await custom_status_client.post( + "/invocations", + content=json.dumps({"status_code": 422}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 422 + + @pytest.mark.asyncio + async def test_503_service_unavailable(self, custom_status_client): + """Agent signalling downstream unavailability.""" + resp = await custom_status_client.post( + "/invocations", + content=json.dumps({"status_code": 503}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 503 + + +# ======================================================================== +# HTTP PROTOCOL: query strings, HEAD, path params +# ======================================================================== + + +class TestHttpProtocol: + """HTTP protocol-level edge cases.""" + + @pytest.mark.asyncio + async def test_query_string_forwarded_to_agent(self, query_client): + """Query parameters on /invocations are accessible to the agent.""" + resp = await query_client.post( + "/invocations?model=gpt-4&temperature=0.7", + content=b"{}", + ) + assert resp.status_code == 200 + q = resp.json()["query"] + assert q["model"] == "gpt-4" + assert q["temperature"] == "0.7" + + @pytest.mark.asyncio + async def test_head_liveness(self, echo_client): + """HEAD /liveness returns 200 with no body (standard HTTP HEAD).""" + resp = await echo_client.head("/liveness") + # Starlette may return 200 or 405 — both are valid server behavior + assert resp.status_code in (200, 405) + + @pytest.mark.asyncio + async def test_path_param_special_characters(self, echo_client): + """GET /invocations/{id} with URL-encoded special chars in the ID.""" + weird_id = "abc%20def%2F123" + resp = await echo_client.get(f"/invocations/{weird_id}") + # default agent returns 404 — ensure it doesn't crash + assert resp.status_code == 404 + + @pytest.mark.asyncio + async def test_accept_header_passthrough(self, ctype_echo_client): + """Accept header doesn't interfere with agent processing.""" + resp = await ctype_echo_client.post( + "/invocations", + content=b"hello", + headers={ + "Accept": "application/xml", + "Content-Type": "text/plain", + }, + ) + assert resp.status_code == 200 + # Agent returns JSON regardless — server doesn't enforce Accept + assert resp.json()["content_type"] == "text/plain" + + @pytest.mark.asyncio + async def test_multiple_custom_headers_forwarded(self, ctype_echo_client): + """Custom X- headers reach the agent without interference.""" + resp = await ctype_echo_client.post( + "/invocations", + content=b"{}", + headers={ + "Content-Type": "application/json", + "X-Request-Id": "req-42", + "X-Session-Token": "tok-abc", + }, + ) + assert resp.status_code == 200 + + +# ======================================================================== +# WebSocket: upgrade attempt to an HTTP-only server +# ======================================================================== + + +class TestWebSocketUpgradeRejected: + """The server has no WebSocket routes. WS upgrade attempts must fail + gracefully rather than crashing.""" + + @pytest.mark.asyncio + async def test_websocket_upgrade_on_invocations(self): + """WebSocket handshake on /invocations fails cleanly.""" + agent = _make_ctype_echo_agent() + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + # Simulate a WebSocket upgrade via HTTP headers + resp = await client.get( + "/invocations", + headers={ + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version": "13", + }, + ) + # Server should reject: no WS route → 4xx response + assert resp.status_code in (400, 403, 404, 405, 426) + + @pytest.mark.asyncio + async def test_websocket_upgrade_on_unknown_path(self): + """WebSocket handshake on an undefined path fails cleanly.""" + agent = _make_ctype_echo_agent() + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get( + "/ws/invocations", + headers={ + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version": "13", + }, + ) + assert resp.status_code in (400, 403, 404, 405, 426) + + +# ======================================================================== +# gRPC-like: binary framed payload over HTTP/2 style headers +# ======================================================================== + + +class TestGrpcLikePayloads: + """The server is HTTP/1.1, but clients may send gRPC-style payloads + (application/grpc, length-prefixed protobuf). Server should accept + bytes and let the agent deal with them — no crash.""" + + @pytest.mark.asyncio + async def test_grpc_content_type_accepted(self, ctype_echo_client): + """POST with application/grpc content type is forwarded to the agent.""" + # gRPC frames: 1-byte compressed flag + 4-byte length + payload + grpc_frame = b"\x00" + len(_PROTO_PAYLOAD).to_bytes(4, "big") + _PROTO_PAYLOAD + resp = await ctype_echo_client.post( + "/invocations", + content=grpc_frame, + headers={"Content-Type": "application/grpc"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "application/grpc" + assert resp.json()["length"] == len(grpc_frame) + + @pytest.mark.asyncio + async def test_grpc_web_content_type_accepted(self, ctype_echo_client): + """POST with application/grpc-web content type (for gRPC-Web proxy).""" + resp = await ctype_echo_client.post( + "/invocations", + content=_PROTO_PAYLOAD, + headers={"Content-Type": "application/grpc-web"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "application/grpc-web" + + @pytest.mark.asyncio + async def test_grpc_plus_proto_content_type(self, ctype_echo_client): + """POST with application/grpc+proto content type.""" + resp = await ctype_echo_client.post( + "/invocations", + content=_PROTO_PAYLOAD, + headers={"Content-Type": "application/grpc+proto"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "application/grpc+proto" + + +# ======================================================================== +# Multiple modalities in one request / response +# ======================================================================== + + +class TestMixedModality: + """Scenarios combining text + binary in a single invocation.""" + + @pytest.mark.asyncio + async def test_json_with_base64_image_roundtrip(self, b64image_client): + """JSON containing base64-encoded image data round-trips correctly.""" + encoded = base64.b64encode(_MINIMAL_PNG).decode("ascii") + payload = {"image": encoded, "prompt": "What is in this image?"} + resp = await b64image_client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["decoded_size"] == len(_MINIMAL_PNG) + assert body["starts_with_png"] is True + + @pytest.mark.asyncio + async def test_multipart_image_plus_audio(self, multipart_raw_client): + """Multipart upload combining image and audio files.""" + resp = await multipart_raw_client.post( + "/invocations", + files=[ + ("image", ("pic.png", _MINIMAL_PNG, "image/png")), + ("audio", ("clip.wav", _MINIMAL_WAV, "audio/wav")), + ], + data={"instruction": "transcribe the audio and describe the image"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert "multipart/form-data" in body["content_type"] + assert body["length"] > len(_MINIMAL_PNG) + len(_MINIMAL_WAV) + + @pytest.mark.asyncio + async def test_multipart_large_file(self, multipart_raw_client): + """Multipart upload with a large binary file (512 KB).""" + big_file = b"\x00" * (512 * 1024) + resp = await multipart_raw_client.post( + "/invocations", + files={"bigfile": ("large.bin", big_file, "application/octet-stream")}, + ) + assert resp.status_code == 200 + assert resp.json()["length"] > 512 * 1024 diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_openapi_validation.py b/sdk/agentserver/azure-ai-agentserver/tests/test_openapi_validation.py new file mode 100644 index 000000000000..6806da272fe5 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_openapi_validation.py @@ -0,0 +1,2167 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for OpenAPI spec validation.""" +import json + +import httpx +import pytest + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver import AgentServer +from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + +# --------------------------------------------------------------------------- +# Helpers – inline agent + client builder +# --------------------------------------------------------------------------- + + +def _make_echo_agent(spec: dict) -> AgentServer: + """Create an agent that echoes the parsed JSON body with validation.""" + server = AgentServer(openapi_spec=spec, enable_request_validation=True) + + @server.invoke_handler + async def handle(request: Request) -> Response: + data = await request.json() + return JSONResponse(data) + + return server + + +async def _client_for(agent: AgentServer): + transport = httpx.ASGITransport(app=agent.app) + return httpx.AsyncClient(transport=transport, base_url="http://testserver") + + +# --------------------------------------------------------------------------- +# Complex OpenAPI spec with nested objects, arrays, enums, $ref, constraints +# --------------------------------------------------------------------------- + +COMPLEX_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "Complex Agent", "version": "2.0"}, + "components": { + "schemas": { + "Address": { + "type": "object", + "properties": { + "street": {"type": "string", "minLength": 1}, + "city": {"type": "string"}, + "zip": {"type": "string", "pattern": "^[0-9]{5}$"}, + "country": {"type": "string", "enum": ["US", "CA", "MX"]}, + }, + "required": ["street", "city", "zip", "country"], + }, + "OrderItem": { + "type": "object", + "properties": { + "sku": {"type": "string", "minLength": 3, "maxLength": 12}, + "quantity": {"type": "integer", "minimum": 1, "maximum": 999}, + "unit_price": {"type": "number", "exclusiveMinimum": 0}, + }, + "required": ["sku", "quantity", "unit_price"], + }, + } + }, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "customer_name": { + "type": "string", + "minLength": 1, + "maxLength": 100, + }, + "email": { + "type": "string", + "format": "email", + }, + "age": { + "type": "integer", + "minimum": 18, + "maximum": 150, + }, + "tier": { + "type": "string", + "enum": ["bronze", "silver", "gold", "platinum"], + }, + "shipping_address": { + "$ref": "#/components/schemas/Address" + }, + "items": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OrderItem" + }, + "minItems": 1, + "maxItems": 50, + }, + "notes": { + "type": "string", + }, + "tags": { + "type": "array", + "items": {"type": "string"}, + "uniqueItems": True, + }, + }, + "required": [ + "customer_name", + "age", + "tier", + "shipping_address", + "items", + ], + "additionalProperties": False, + } + } + } + }, + "responses": { + "200": { + "description": "Order accepted", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "order_id": {"type": "string"}, + }, + } + } + }, + } + }, + } + } + }, +} + + +def _valid_order(**overrides) -> dict: + """Return a fully valid order payload, with optional field overrides.""" + base = { + "customer_name": "Alice Smith", + "age": 30, + "tier": "gold", + "shipping_address": { + "street": "123 Main St", + "city": "Redmond", + "zip": "98052", + "country": "US", + }, + "items": [{"sku": "ABC123", "quantity": 2, "unit_price": 19.99}], + } + base.update(overrides) + return base + + +# =================================================================== +# Original tests (kept as-is) +# =================================================================== + + +@pytest.mark.asyncio +async def test_valid_request_passes(validated_client): + """Request matching schema returns 200.""" + resp = await validated_client.post( + "/invocations", + content=json.dumps({"name": "World"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + data = json.loads(resp.content) + assert data["greeting"] == "Hello, World!" + + +@pytest.mark.asyncio +async def test_invalid_request_returns_400(validated_client): + """Request missing required field returns 400 with error details.""" + resp = await validated_client.post( + "/invocations", + content=json.dumps({"wrong_field": "oops"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + data = resp.json() + assert data["error"]["code"] == "invalid_payload" + assert "details" in data["error"] + + +@pytest.mark.asyncio +async def test_invalid_request_wrong_type_returns_400(validated_client): + """Request with wrong field type returns 400.""" + resp = await validated_client.post( + "/invocations", + content=json.dumps({"name": 12345}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + data = resp.json() + assert "details" in data["error"] + + +@pytest.mark.asyncio +async def test_no_spec_skips_validation(no_spec_client): + """Agent with no spec accepts any request body.""" + resp = await no_spec_client.post( + "/invocations", + content=b"this is not json at all", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_spec_endpoint_returns_spec(validated_client): + """GET /invocations/docs/openapi.json returns the registered spec.""" + resp = await validated_client.get("/invocations/docs/openapi.json") + assert resp.status_code == 200 + data = resp.json() + assert "paths" in data + + +@pytest.mark.asyncio +async def test_non_json_body_skips_validation(no_spec_client): + """Non-JSON content type bypasses JSON schema validation.""" + server = AgentServer(openapi_spec={ + "openapi": "3.0.0", + "info": {"title": "Test", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": {"type": "object", "required": ["name"]} + } + } + }, + "responses": {"200": {"description": "OK"}} + } + } + } + }, enable_request_validation=True) + + @server.invoke_handler + async def handle(request: Request) -> Response: + body = await request.body() + return Response(content=body, media_type="text/plain") + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post( + "/invocations", + content=b"plain text body", + headers={"Content-Type": "text/plain"}, + ) + assert resp.status_code == 200 + + +# =================================================================== +# Complex spec – valid payloads +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_valid_full_payload(): + """A fully-populated valid order is accepted.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + payload = _valid_order( + email="alice@example.com", + notes="Leave at the door", + tags=["priority", "fragile"], + ) + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_complex_valid_minimal_payload(): + """Only required fields are present — should pass.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order()).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_complex_valid_multiple_items(): + """Order with several line items passes.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + items = [ + {"sku": "AAA", "quantity": 1, "unit_price": 5.0}, + {"sku": "BBB", "quantity": 10, "unit_price": 12.50}, + {"sku": "CCC", "quantity": 999, "unit_price": 0.01}, + ] + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=items)).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +# =================================================================== +# Complex spec – missing required fields +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_missing_top_level_required(): + """Missing top-level required field 'tier' triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + payload = _valid_order() + del payload["tier"] + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + details = resp.json()["error"]["details"] + assert any("tier" in d["message"] for d in details) + + +@pytest.mark.asyncio +async def test_complex_missing_nested_required(): + """Missing required 'city' in shipping_address triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + payload = _valid_order() + del payload["shipping_address"]["city"] + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + details = resp.json()["error"]["details"] + assert any("city" in d["message"] for d in details) + + +@pytest.mark.asyncio +async def test_complex_missing_array_item_field(): + """Order item missing required 'sku' field triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + bad_item = {"quantity": 1, "unit_price": 5.0} # no sku + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[bad_item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + details = resp.json()["error"]["details"] + assert any("sku" in d["message"] for d in details) + + +# =================================================================== +# Complex spec – type mismatches +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_wrong_type_age_string(): + """'age' must be integer, not string.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(age="thirty")).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_wrong_type_items_not_array(): + """'items' must be an array, not an object.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps( + _valid_order(items={"sku": "X", "quantity": 1, "unit_price": 1.0}) + ).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_wrong_type_quantity_float(): + """'quantity' must be integer, not float.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + bad_item = {"sku": "ABC", "quantity": 2.5, "unit_price": 10.0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[bad_item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +# =================================================================== +# Complex spec – enum validation +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_invalid_tier_enum(): + """'tier' value not in enum list triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(tier="diamond")).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + details = resp.json()["error"]["details"] + assert any("diamond" in d["message"] for d in details) + + +@pytest.mark.asyncio +async def test_complex_invalid_country_enum(): + """'country' outside allowed enum triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + payload = _valid_order() + payload["shipping_address"]["country"] = "FR" + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + details = resp.json()["error"]["details"] + assert any("FR" in d["message"] for d in details) + + +# =================================================================== +# Complex spec – numeric constraints (minimum / maximum) +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_age_below_minimum(): + """'age' below minimum (18) triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(age=10)).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_age_above_maximum(): + """'age' above maximum (150) triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(age=200)).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_quantity_below_minimum(): + """Order item quantity below 1 triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + bad_item = {"sku": "ABC", "quantity": 0, "unit_price": 10.0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[bad_item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_quantity_above_maximum(): + """Order item quantity above 999 triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + bad_item = {"sku": "ABC", "quantity": 1000, "unit_price": 10.0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[bad_item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_unit_price_zero_exclusive(): + """unit_price with exclusiveMinimum: 0 rejects exactly 0.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + bad_item = {"sku": "ABC", "quantity": 1, "unit_price": 0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[bad_item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +# =================================================================== +# Complex spec – string constraints (minLength / maxLength / pattern) +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_customer_name_empty(): + """Empty customer_name violates minLength: 1.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(customer_name="")).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_customer_name_too_long(): + """customer_name exceeding maxLength: 100 triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(customer_name="A" * 101)).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_sku_too_short(): + """SKU shorter than minLength: 3 triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + bad_item = {"sku": "AB", "quantity": 1, "unit_price": 5.0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[bad_item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_sku_too_long(): + """SKU longer than maxLength: 12 triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + bad_item = {"sku": "A" * 13, "quantity": 1, "unit_price": 5.0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[bad_item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_zip_pattern_invalid(): + """Zip code not matching '^[0-9]{5}$' triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + payload = _valid_order() + payload["shipping_address"]["zip"] = "ABCDE" + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_zip_pattern_too_short(): + """Zip code with fewer than 5 digits fails pattern.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + payload = _valid_order() + payload["shipping_address"]["zip"] = "1234" + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +# =================================================================== +# Complex spec – array constraints (minItems / maxItems / uniqueItems) +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_empty_items_array(): + """Empty items array violates minItems: 1.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_duplicate_tags(): + """Duplicate entries in 'tags' violates uniqueItems.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps( + _valid_order(tags=["urgent", "urgent"]) + ).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +# =================================================================== +# Complex spec – additionalProperties: false +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_additional_properties_rejected(): + """Extra top-level field rejected by additionalProperties: false.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + payload = _valid_order() + payload["surprise_field"] = "not allowed" + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + details = resp.json()["error"]["details"] + assert any("surprise_field" in d["message"] for d in details) + + +# =================================================================== +# Complex spec – $ref resolution +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_ref_address_validated(): + """$ref to Address schema is resolved and validated.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + payload = _valid_order() + # Replace structured address with a scalar — violates $ref schema + payload["shipping_address"] = "123 Main St" + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_ref_order_item_validated(): + """$ref to OrderItem schema is resolved; invalid item caught.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + # unit_price negative — violates exclusiveMinimum + bad_item = {"sku": "ABC", "quantity": 1, "unit_price": -5.0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[bad_item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +# =================================================================== +# Complex spec – malformed JSON body +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_malformed_json(): + """Malformed JSON body returns 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=b"{not json at all", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + data = resp.json() + assert data["error"]["code"] == "invalid_payload" + + +# =================================================================== +# Complex spec – boundary / edge-case valid values +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_boundary_age_exactly_minimum(): + """Age exactly at minimum (18) should pass.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(age=18)).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_complex_boundary_age_exactly_maximum(): + """Age exactly at maximum (150) should pass.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(age=150)).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_complex_boundary_sku_exactly_min_length(): + """SKU at exactly minLength: 3 should pass.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + item = {"sku": "XYZ", "quantity": 1, "unit_price": 1.0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_complex_boundary_sku_exactly_max_length(): + """SKU at exactly maxLength: 12 should pass.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + item = {"sku": "A" * 12, "quantity": 1, "unit_price": 1.0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_complex_multiple_errors_reported(): + """Multiple validation errors are all reported in 'details'.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + # Wrong type for age, invalid enum for tier, empty items array + payload = { + "customer_name": "Bob", + "age": "old", + "tier": "diamond", + "shipping_address": { + "street": "1 Elm", + "city": "X", + "zip": "00000", + "country": "US", + }, + "items": [], + } + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + details = resp.json()["error"]["details"] + assert len(details) >= 3 # at least age, tier, items errors + + +# --------------------------------------------------------------------------- +# Tests: validate_response +# --------------------------------------------------------------------------- + +RESPONSE_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "Resp", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": {"type": "object"}, + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "result": {"type": "string"}, + }, + "required": ["result"], + } + } + } + } + }, + } + } + }, +} + + +@pytest.mark.asyncio +async def test_validate_response_valid(): + """Valid response body passes validation.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + v = OpenApiValidator(RESPONSE_SPEC) + errors = v.validate_response(b'{"result": "ok"}', "application/json") + assert errors == [] + + +@pytest.mark.asyncio +async def test_validate_response_invalid(): + """Invalid response body returns errors.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + v = OpenApiValidator(RESPONSE_SPEC) + errors = v.validate_response(b'{"wrong": 42}', "application/json") + assert len(errors) > 0 + + +@pytest.mark.asyncio +async def test_validate_response_no_schema(): + """When no response schema exists, validation passes (no-op).""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + spec_no_resp = { + "openapi": "3.0.0", + "info": {"title": "NoResp", "version": "1.0"}, + "paths": {"/invocations": {"post": {}}}, + } + v = OpenApiValidator(spec_no_resp) + errors = v.validate_response(b'{"anything": true}', "application/json") + assert errors == [] + + +# --------------------------------------------------------------------------- +# Tests: no request body schema +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_validate_request_no_schema(): + """When no request schema exists, validation passes (no-op).""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + spec_no_req = { + "openapi": "3.0.0", + "info": {"title": "NoReq", "version": "1.0"}, + "paths": {"/invocations": {"post": {}}}, + } + v = OpenApiValidator(spec_no_req) + errors = v.validate_request(b'{"anything": true}', "application/json") + assert errors == [] + + +# --------------------------------------------------------------------------- +# Tests: response schema from non-200/201 status +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_response_schema_fallback_to_first_available(): + """Response schema extraction falls back to first response with JSON.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + spec = { + "openapi": "3.0.0", + "info": {"title": "Fallback", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "responses": { + "202": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"status": {"type": "string"}}, + "required": ["status"], + } + } + } + } + } + } + } + }, + } + v = OpenApiValidator(spec) + # Valid against the 202 schema + assert v.validate_response(b'{"status": "accepted"}', "application/json") == [] + # Invalid — missing "status" + errors = v.validate_response(b'{"other": 1}', "application/json") + assert len(errors) > 0 + + +# --------------------------------------------------------------------------- +# Tests: $ref edge cases +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_unresolvable_ref(): + """An unresolvable $ref leaves the node as-is (no crash).""" + from azure.ai.agentserver.validation._openapi_validator import _resolve_ref + + spec: dict = {"components": {"schemas": {}}} + node = {"$ref": "#/components/schemas/DoesNotExist"} + result = _resolve_ref(spec, node) + # Can't resolve — returns the original node + assert result is node + + +@pytest.mark.asyncio +async def test_ref_path_hits_non_dict(): + """A $ref path that traverses a non-dict returns the original node.""" + from azure.ai.agentserver.validation._openapi_validator import _resolve_ref + + spec: dict = {"components": {"schemas": "not-a-dict"}} + node = {"$ref": "#/components/schemas/Foo"} + result = _resolve_ref(spec, node) + assert result is node + + +@pytest.mark.asyncio +async def test_circular_ref_stops_recursion(): + """Circular $ref does not cause infinite recursion.""" + from azure.ai.agentserver.validation._openapi_validator import _resolve_refs_deep + + spec: dict = { + "components": { + "schemas": { + "Node": { + "type": "object", + "properties": { + "child": {"$ref": "#/components/schemas/Node"}, + }, + } + } + } + } + node = {"$ref": "#/components/schemas/Node"} + result = _resolve_refs_deep(spec, node) + # Should resolve at least the first level, and leave the circular ref as-is + assert result["type"] == "object" + child = result["properties"]["child"] + # The circular reference should be left unresolved + assert "$ref" in child + + +@pytest.mark.asyncio +async def test_ref_resolves_to_non_dict(): + """A $ref that resolves to a non-dict value returns the original node.""" + from azure.ai.agentserver.validation._openapi_validator import _resolve_ref + + spec: dict = {"components": {"schemas": {"Bad": 42}}} + node = {"$ref": "#/components/schemas/Bad"} + result = _resolve_ref(spec, node) + assert result is node + + +# =================================================================== +# OpenAPI nullable support +# =================================================================== + +NULLABLE_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "Nullable", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "nickname": { + "type": "string", + "nullable": True, + }, + "age": { + "type": "integer", + "nullable": True, + }, + }, + "required": ["name", "nickname"], + } + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + + +@pytest.mark.asyncio +async def test_nullable_accepts_null(): + """A nullable field accepts a JSON null value.""" + agent = _make_echo_agent(NULLABLE_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"name": "Alice", "nickname": None}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_nullable_accepts_value(): + """A nullable field also accepts a normal string value.""" + agent = _make_echo_agent(NULLABLE_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"name": "Alice", "nickname": "Ali"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_nullable_rejects_wrong_type(): + """A nullable string still rejects integers.""" + agent = _make_echo_agent(NULLABLE_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"name": "Alice", "nickname": 42}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_non_nullable_rejects_null(): + """A non-nullable field (name) rejects null.""" + agent = _make_echo_agent(NULLABLE_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"name": None, "nickname": "Ali"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_nullable_integer_accepts_null(): + """A nullable integer accepts null.""" + agent = _make_echo_agent(NULLABLE_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"name": "A", "nickname": "B", "age": None}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +# =================================================================== +# readOnly / writeOnly support +# =================================================================== + +READONLY_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "ReadOnly", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "id": { + "type": "string", + "readOnly": True, + }, + "password": { + "type": "string", + "writeOnly": True, + }, + }, + "required": ["name", "id"], + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "id": { + "type": "string", + "readOnly": True, + }, + "password": { + "type": "string", + "writeOnly": True, + }, + }, + "required": ["name", "id"], + } + } + } + } + }, + } + } + }, +} + + +@pytest.mark.asyncio +async def test_readonly_not_required_in_request(): + """readOnly field 'id' is stripped from required in request context.""" + agent = _make_echo_agent(READONLY_SPEC) + async with await _client_for(agent) as client: + # Don't send 'id' — it's readOnly, should not be required + resp = await client.post( + "/invocations", + content=json.dumps({"name": "Alice"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_writeonly_allowed_in_request(): + """writeOnly field 'password' is allowed in request.""" + agent = _make_echo_agent(READONLY_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"name": "Alice", "password": "secret"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_readonly_schema_introspection_request(): + """readOnly properties are removed from the preprocessed request schema.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + v = OpenApiValidator(READONLY_SPEC) + # 'id' should be gone from requestschema properties + assert "id" not in v._request_schema.get("properties", {}) + + +@pytest.mark.asyncio +async def test_readonly_schema_introspection_response(): + """readOnly properties remain in the preprocessed response schema.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + v = OpenApiValidator(READONLY_SPEC) + # 'id' should still be in response schema + assert "id" in v._response_schema.get("properties", {}) + + +@pytest.mark.asyncio +async def test_writeonly_stripped_in_response(): + """writeOnly properties are removed from the preprocessed response schema.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + v = OpenApiValidator(READONLY_SPEC) + assert "password" not in v._response_schema.get("properties", {}) + + +@pytest.mark.asyncio +async def test_writeonly_present_in_request(): + """writeOnly properties remain in the preprocessed request schema.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + v = OpenApiValidator(READONLY_SPEC) + assert "password" in v._request_schema.get("properties", {}) + + +# =================================================================== +# requestBody.required: false +# =================================================================== + +OPTIONAL_BODY_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "OptionalBody", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "required": False, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "query": {"type": "string"}, + }, + "required": ["query"], + } + } + }, + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + +REQUIRED_BODY_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "RequiredBody", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "query": {"type": "string"}, + }, + "required": ["query"], + } + } + }, + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + + +@pytest.mark.asyncio +async def test_optional_body_empty_accepted(): + """Empty body is accepted when requestBody.required is false.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + v = OpenApiValidator(OPTIONAL_BODY_SPEC) + errors = v.validate_request(b"", "application/json") + assert errors == [] + + +@pytest.mark.asyncio +async def test_optional_body_whitespace_accepted(): + """Whitespace-only body is accepted when requestBody.required is false.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + v = OpenApiValidator(OPTIONAL_BODY_SPEC) + errors = v.validate_request(b" ", "application/json") + assert errors == [] + + +@pytest.mark.asyncio +async def test_optional_body_present_still_validated(): + """When body IS present with optional requestBody, it still must be valid.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + v = OpenApiValidator(OPTIONAL_BODY_SPEC) + errors = v.validate_request(b'{"wrong": 1}', "application/json") + assert len(errors) > 0 + + +@pytest.mark.asyncio +async def test_required_body_empty_rejected(): + """Empty body is rejected when requestBody.required is true.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + v = OpenApiValidator(REQUIRED_BODY_SPEC) + errors = v.validate_request(b"", "application/json") + assert len(errors) > 0 # "Invalid JSON body" + + +@pytest.mark.asyncio +async def test_default_body_required_behavior(): + """When requestBody.required is omitted, body is required by default.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + spec = { + "openapi": "3.0.0", + "info": {"title": "Default", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": {"type": "object", "required": ["q"]}, + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, + } + v = OpenApiValidator(spec) + errors = v.validate_request(b"", "application/json") + assert len(errors) > 0 + + +# =================================================================== +# OpenAPI keyword stripping +# =================================================================== + + +@pytest.mark.asyncio +async def test_openapi_keywords_stripped(): + """discriminator, xml, externalDocs, example are stripped from schemas.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + schema = { + "type": "object", + "discriminator": {"propertyName": "type"}, + "xml": {"name": "Foo"}, + "externalDocs": {"url": "https://example.com"}, + "example": {"type": "bar"}, + "properties": {"name": {"type": "string"}}, + } + result = OpenApiValidator._preprocess_schema(schema) + assert "discriminator" not in result + assert "xml" not in result + assert "externalDocs" not in result + assert "example" not in result + assert result["properties"]["name"]["type"] == "string" + + +@pytest.mark.asyncio +async def test_openapi_keywords_stripped_nested(): + """OpenAPI keywords are stripped from nested properties too.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + schema = { + "type": "object", + "properties": { + "child": { + "type": "object", + "example": {"key": "value"}, + "xml": {"wrapped": True}, + "properties": {"inner": {"type": "string"}}, + } + }, + } + result = OpenApiValidator._preprocess_schema(schema) + child = result["properties"]["child"] + assert "example" not in child + assert "xml" not in child + + +# =================================================================== +# Format validation (enabled via FormatChecker) +# =================================================================== + +FORMAT_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "Format", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "email": { + "type": "string", + "format": "email", + }, + "created_at": { + "type": "string", + "format": "date-time", + }, + }, + "required": ["email"], + } + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + + +@pytest.mark.asyncio +async def test_format_email_valid(): + """Valid email passes format validation.""" + agent = _make_echo_agent(FORMAT_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"email": "alice@example.com"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_format_email_invalid(): + """Invalid email fails format validation.""" + agent = _make_echo_agent(FORMAT_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"email": "not-an-email"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_format_datetime_valid(): + """Valid ISO 8601 date-time passes format validation.""" + agent = _make_echo_agent(FORMAT_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps( + {"email": "a@b.com", "created_at": "2025-01-01T00:00:00Z"} + ).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_format_datetime_invalid(): + """Invalid date-time fails format validation.""" + agent = _make_echo_agent(FORMAT_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps( + {"email": "a@b.com", "created_at": "not-a-date"} + ).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +# =================================================================== +# allOf / oneOf / anyOf composition +# =================================================================== + +COMPOSITION_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "Composition", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "allOf": [ + { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + { + "type": "object", + "properties": {"age": {"type": "integer"}}, + "required": ["age"], + }, + ] + } + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + +ONEOF_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "OneOf", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "oneOf": [ + { + "type": "object", + "properties": { + "kind": { + "type": "string", + "enum": ["cat"], + }, + "purrs": {"type": "boolean"}, + }, + "required": ["kind", "purrs"], + }, + { + "type": "object", + "properties": { + "kind": { + "type": "string", + "enum": ["dog"], + }, + "barks": {"type": "boolean"}, + }, + "required": ["kind", "barks"], + }, + ] + } + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + +ANYOF_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "AnyOf", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + ] + } + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + + +@pytest.mark.asyncio +async def test_allof_valid(): + """allOf: both sub-schemas satisfied → 200.""" + agent = _make_echo_agent(COMPOSITION_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"name": "Alice", "age": 30}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_allof_missing_field(): + """allOf: missing field from second schema → 400.""" + agent = _make_echo_agent(COMPOSITION_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"name": "Alice"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_oneof_first_branch(): + """oneOf: matching first branch (cat) → 200.""" + agent = _make_echo_agent(ONEOF_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"kind": "cat", "purrs": True}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_oneof_second_branch(): + """oneOf: matching second branch (dog) → 200.""" + agent = _make_echo_agent(ONEOF_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"kind": "dog", "barks": True}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_oneof_no_match(): + """oneOf: matching neither branch → 400.""" + agent = _make_echo_agent(ONEOF_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"kind": "fish"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_anyof_string(): + """anyOf: string value matches first branch → 200.""" + agent = _make_echo_agent(ANYOF_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps("hello").encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_anyof_integer(): + """anyOf: integer value matches second branch → 200.""" + agent = _make_echo_agent(ANYOF_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(42).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_anyof_wrong_type(): + """anyOf: boolean matches neither string nor integer → 400.""" + agent = _make_echo_agent(ANYOF_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(True).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +# =================================================================== +# nullable + $ref combination +# =================================================================== + + +@pytest.mark.asyncio +async def test_nullable_ref_accepts_null(): + """A nullable $ref field accepts null after preprocessing.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + spec: dict = { + "openapi": "3.0.0", + "info": {"title": "NullRef", "version": "1.0"}, + "components": { + "schemas": { + "Address": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } + } + }, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "addr": { + "$ref": "#/components/schemas/Address", + "nullable": True, + } + }, + } + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, + } + v = OpenApiValidator(spec) + errors = v.validate_request(b'{"addr": null}', "application/json") + assert errors == [] + + +# =================================================================== +# Preprocessing unit tests +# =================================================================== + + +@pytest.mark.asyncio +async def test_apply_nullable_no_duplicate_null(): + """_apply_nullable does not add 'null' twice if type is already a list with null.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + schema: dict = {"type": ["string", "null"], "nullable": True} + OpenApiValidator._apply_nullable(schema) + assert schema["type"] == ["string", "null"] + + +@pytest.mark.asyncio +async def test_apply_nullable_false(): + """nullable: false is a no-op (just removes the key).""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + schema: dict = {"type": "string", "nullable": False} + OpenApiValidator._apply_nullable(schema) + assert schema["type"] == "string" + assert "nullable" not in schema + + +@pytest.mark.asyncio +async def test_strip_openapi_keywords_nested_deeply(): + """OpenAPI keywords are stripped from deeply nested schemas.""" + from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + + schema: dict = { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "example": {"deep": True}, + "properties": { + "val": {"type": "string", "xml": {"attr": True}}, + }, + }, + } + }, + } + OpenApiValidator._strip_openapi_keywords(schema) + items_schema = schema["properties"]["items"]["items"] + assert "example" not in items_schema + assert "xml" not in items_schema["properties"]["val"] + + +# =================================================================== +# Discriminator-aware error collection (ported from C# JsonSchemaValidator) +# =================================================================== + +# --- Helper: direct validator call for unit-level assertions --- + +def _validate_request(spec: dict, body: dict) -> list[str]: + """Shortcut: build an OpenApiValidator and return request errors.""" + v = OpenApiValidator(spec) + return v.validate_request( + json.dumps(body).encode(), + "application/json", + ) + + +# Spec with const-based discriminator (Azure-style polymorphic schema) +CONST_DISCRIMINATOR_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "Disc", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": {"type": "string", "const": "workflow"}, + "steps": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["type", "steps"], + }, + { + "type": "object", + "properties": { + "type": {"type": "string", "const": "prompt"}, + "text": {"type": "string"}, + }, + "required": ["type", "text"], + }, + { + "type": "object", + "properties": { + "type": {"type": "string", "const": "tool_call"}, + "tool_name": {"type": "string"}, + "args": {"type": "object"}, + }, + "required": ["type", "tool_name"], + }, + ] + } + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + + +@pytest.mark.asyncio +async def test_discriminator_valid_workflow(): + """const discriminator: valid workflow branch → passes.""" + errors = _validate_request( + CONST_DISCRIMINATOR_SPEC, + {"type": "workflow", "steps": ["a", "b"]}, + ) + assert errors == [] + + +@pytest.mark.asyncio +async def test_discriminator_valid_prompt(): + """const discriminator: valid prompt branch → passes.""" + errors = _validate_request( + CONST_DISCRIMINATOR_SPEC, + {"type": "prompt", "text": "hello"}, + ) + assert errors == [] + + +@pytest.mark.asyncio +async def test_discriminator_valid_tool_call(): + """const discriminator: valid tool_call branch → passes.""" + errors = _validate_request( + CONST_DISCRIMINATOR_SPEC, + {"type": "tool_call", "tool_name": "search"}, + ) + assert errors == [] + + +@pytest.mark.asyncio +async def test_discriminator_reports_matching_branch_errors(): + """const discriminator: correct type but missing required field → reports only that branch.""" + errors = _validate_request( + CONST_DISCRIMINATOR_SPEC, + {"type": "workflow"}, # missing "steps" + ) + assert len(errors) >= 1 + # Should report the missing "steps" field, not errors from prompt/tool_call branches + assert any("steps" in e for e in errors) + # Should NOT mention fields from other branches + assert not any("text" in e for e in errors) + assert not any("tool_name" in e for e in errors) + + +@pytest.mark.asyncio +async def test_discriminator_unknown_value_reports_expected(): + """const discriminator: unknown type value → concise 'Expected one of' message.""" + errors = _validate_request( + CONST_DISCRIMINATOR_SPEC, + {"type": "unknown_thing", "steps": ["a"]}, + ) + assert len(errors) >= 1 + # Should mention the expected values, not dump all branch errors + combined = " ".join(errors) + assert "Expected" in combined or "type" in combined + + +@pytest.mark.asyncio +async def test_discriminator_wrong_type_value(): + """const discriminator: type field is integer instead of string → type error reported.""" + errors = _validate_request( + CONST_DISCRIMINATOR_SPEC, + {"type": 123, "steps": ["a"]}, + ) + assert len(errors) >= 1 + combined = " ".join(errors) + assert "type" in combined + + +@pytest.mark.asyncio +async def test_discriminator_missing_type_field(): + """const discriminator: missing discriminator field entirely → error.""" + errors = _validate_request( + CONST_DISCRIMINATOR_SPEC, + {"steps": ["a"]}, + ) + assert len(errors) >= 1 + + +# Spec with enum-based discriminator (the existing ONEOF_SPEC pattern) +@pytest.mark.asyncio +async def test_enum_discriminator_matching_branch_errors(): + """enum discriminator: correct kind but missing required field → branch-specific errors.""" + errors = _validate_request( + ONEOF_SPEC, + {"kind": "cat"}, # missing "purrs" + ) + assert len(errors) >= 1 + assert any("purrs" in e for e in errors) + # Should not mention "barks" from the dog branch + assert not any("barks" in e for e in errors) + + +# --- JSON path in error messages --- + +@pytest.mark.asyncio +async def test_error_includes_json_path_for_nested_property(): + """Validation errors for nested properties include the JSON path prefix.""" + errors = _validate_request( + COMPLEX_SPEC, + { + "customer_name": "Bob", + "tier": "gold", + "shipping_address": { + "street": "1 Elm", + # missing "city" + "zip": "00000", + "country": "US", + }, + "items": [{"sku": "ABCD", "quantity": 1, "unit_price": 9.99}], + }, + ) + assert len(errors) >= 1 + # At least one error should contain a JSON-path-like prefix + assert any("$." in e or "city" in e for e in errors) + + +@pytest.mark.asyncio +async def test_error_includes_json_path_for_array_item(): + """Validation errors inside array items include the element index in the path.""" + errors = _validate_request( + COMPLEX_SPEC, + { + "customer_name": "Bob", + "tier": "gold", + "shipping_address": { + "street": "1 Elm", + "city": "X", + "zip": "00000", + "country": "US", + }, + "items": [{"quantity": 1, "unit_price": 9.99}], # missing "sku" + }, + ) + assert len(errors) >= 1 + assert any("items" in e for e in errors) + + +# --- Spec with nested oneOf inside properties --- + +NESTED_ONEOF_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "NestedOneOf", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "payload": { + "oneOf": [ + { + "type": "object", + "properties": { + "kind": {"type": "string", "const": "text"}, + "content": {"type": "string"}, + }, + "required": ["kind", "content"], + }, + { + "type": "object", + "properties": { + "kind": {"type": "string", "const": "image"}, + "url": {"type": "string", "format": "uri"}, + }, + "required": ["kind", "url"], + }, + ], + }, + }, + "required": ["name", "payload"], + } + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + + +@pytest.mark.asyncio +async def test_nested_oneof_valid_text(): + """Nested oneOf: valid text payload → passes.""" + errors = _validate_request( + NESTED_ONEOF_SPEC, + {"name": "test", "payload": {"kind": "text", "content": "hello"}}, + ) + assert errors == [] + + +@pytest.mark.asyncio +async def test_nested_oneof_valid_image(): + """Nested oneOf: valid image payload → passes.""" + errors = _validate_request( + NESTED_ONEOF_SPEC, + {"name": "test", "payload": {"kind": "image", "url": "https://example.com/img.png"}}, + ) + assert errors == [] + + +@pytest.mark.asyncio +async def test_nested_oneof_wrong_discriminator(): + """Nested oneOf: unknown discriminator value → error mentioning payload path.""" + errors = _validate_request( + NESTED_ONEOF_SPEC, + {"name": "test", "payload": {"kind": "video", "url": "https://example.com/v.mp4"}}, + ) + assert len(errors) >= 1 + combined = " ".join(errors) + # Should mention "payload" in the path + assert "payload" in combined + + +@pytest.mark.asyncio +async def test_nested_oneof_matching_branch_missing_field(): + """Nested oneOf: correct kind=text but missing content → branch-specific error.""" + errors = _validate_request( + NESTED_ONEOF_SPEC, + {"name": "test", "payload": {"kind": "text"}}, + ) + assert len(errors) >= 1 + assert any("content" in e for e in errors) + # Should NOT mention "url" from the image branch + assert not any("url" in e for e in errors) + + +# --- Unit-level tests for helper functions --- + +@pytest.mark.asyncio +async def test_format_error_includes_path(): + """_format_error prefixes with JSON path when not root.""" + from azure.ai.agentserver.validation._openapi_validator import _format_error + import jsonschema + + schema = {"type": "object", "properties": {"age": {"type": "integer"}}, "required": ["age"]} + validator = jsonschema.Draft7Validator(schema) + errors = list(validator.iter_errors({"age": "not_int"})) + assert len(errors) == 1 + formatted = _format_error(errors[0]) + assert "$.age" in formatted + + +@pytest.mark.asyncio +async def test_format_error_root_path_no_prefix(): + """_format_error does not prefix when error is at root ($).""" + from azure.ai.agentserver.validation._openapi_validator import _format_error + import jsonschema + + schema = {"type": "object", "required": ["name"]} + validator = jsonschema.Draft7Validator(schema) + errors = list(validator.iter_errors({})) + assert len(errors) == 1 + formatted = _format_error(errors[0]) + # Root-level error — should be the message without "$.xxx:" prefix + assert formatted == errors[0].message + + +@pytest.mark.asyncio +async def test_collect_errors_flat(): + """_collect_errors for a simple (non-composition) error returns path-prefixed message.""" + from azure.ai.agentserver.validation._openapi_validator import _collect_errors + import jsonschema + + schema = {"type": "object", "properties": {"x": {"type": "integer"}}} + validator = jsonschema.Draft7Validator(schema) + errors = list(validator.iter_errors({"x": "abc"})) + collected = [] + for e in errors: + collected.extend(_collect_errors(e)) + assert len(collected) == 1 + assert "$.x" in collected[0] diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_request_limits.py b/sdk/agentserver/azure-ai-agentserver/tests/test_request_limits.py new file mode 100644 index 000000000000..e8161122a3c9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_request_limits.py @@ -0,0 +1,147 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for invoke timeout guards.""" +import asyncio +import os +from unittest.mock import patch + +import httpx +import pytest + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver import AgentServer +from azure.ai.agentserver._constants import Constants + + +# --------------------------------------------------------------------------- +# Agent factory functions +# --------------------------------------------------------------------------- + + +def _make_slow_agent(**kwargs) -> AgentServer: + """Create an agent that sleeps forever — used to test invoke timeout.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + await asyncio.sleep(999) + return Response(content=b"done") + + return server + + +def _make_fast_agent(**kwargs) -> AgentServer: + """Create an agent that returns immediately — used to verify no false timeout.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + return server + + +# =========================================================================== +# Request timeout +# =========================================================================== + + +class TestRequestTimeoutConstants: + """Verify constants are wired correctly.""" + + def test_default_is_300_seconds(self): + assert Constants.DEFAULT_REQUEST_TIMEOUT == 300 + + def test_env_var_name(self): + assert Constants.AGENT_REQUEST_TIMEOUT == "AGENT_REQUEST_TIMEOUT" + + +class TestRequestTimeoutResolution: + """Test the resolution hierarchy: explicit > env > default.""" + + def test_explicit_value(self): + agent = _make_fast_agent(request_timeout=60) + assert agent._request_timeout == 60 + + def test_explicit_zero_disables(self): + agent = _make_fast_agent(request_timeout=0) + assert agent._request_timeout == 0 + + def test_env_var_used_when_no_explicit(self): + with patch.dict(os.environ, {Constants.AGENT_REQUEST_TIMEOUT: "120"}): + agent = _make_fast_agent() + assert agent._request_timeout == 120 + + def test_invalid_env_var_raises(self): + with patch.dict(os.environ, {Constants.AGENT_REQUEST_TIMEOUT: "abc"}): + with pytest.raises(ValueError, match="AGENT_REQUEST_TIMEOUT"): + _make_fast_agent() + + def test_default_when_nothing_set(self): + with patch.dict(os.environ, {}, clear=True): + agent = _make_fast_agent() + assert agent._request_timeout == Constants.DEFAULT_REQUEST_TIMEOUT + + +class TestInvokeTimeoutEnforcement: + """Test that long-running invoke() calls are cancelled.""" + + @pytest.mark.asyncio + async def test_slow_invoke_returns_504(self): + agent = _make_slow_agent(request_timeout=1) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b"{}") + assert resp.status_code == 504 + data = resp.json() + assert data["error"]["code"] == "request_timeout" + assert "timed out" in data["error"]["message"].lower() + assert "1s" in data["error"]["message"] + + @pytest.mark.asyncio + async def test_slow_invoke_includes_invocation_id_header(self): + agent = _make_slow_agent(request_timeout=1) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b"{}") + assert resp.status_code == 504 + assert Constants.INVOCATION_ID_HEADER in resp.headers + + @pytest.mark.asyncio + async def test_fast_invoke_not_affected(self): + agent = _make_fast_agent(request_timeout=5) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert resp.json() == {"ok": True} + + @pytest.mark.asyncio + async def test_timeout_disabled_allows_no_limit(self): + """request_timeout=0 means no timeout (passes None to wait_for).""" + agent = _make_fast_agent(request_timeout=0) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + + @pytest.mark.asyncio + async def test_timeout_error_is_logged(self): + agent = _make_slow_agent(request_timeout=1) + transport = httpx.ASGITransport(app=agent.app) + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + async with httpx.AsyncClient( + transport=transport, base_url="http://testserver" + ) as client: + resp = await client.post("/invocations", content=b"{}") + assert resp.status_code == 504 + + timeout_calls = [ + c + for c in mock_logger.error.call_args_list + if "timed out" in str(c).lower() + ] + assert len(timeout_calls) == 1 diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_server_routes.py b/sdk/agentserver/azure-ai-agentserver/tests/test_server_routes.py new file mode 100644 index 000000000000..f4a2882fc0cc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_server_routes.py @@ -0,0 +1,116 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for route registration and basic endpoint behavior.""" +import uuid + +import pytest + + +@pytest.mark.asyncio +async def test_post_invocations_returns_200(echo_client): + """POST /invocations with valid body returns 200.""" + resp = await echo_client.post("/invocations", content=b'{"hello":"world"}') + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_post_invocations_returns_invocation_id_header(echo_client): + """Response includes x-agent-invocation-id header in UUID format.""" + resp = await echo_client.post("/invocations", content=b'{"hello":"world"}') + invocation_id = resp.headers.get("x-agent-invocation-id") + assert invocation_id is not None + # Validate UUID format + uuid.UUID(invocation_id) + + +@pytest.mark.asyncio +async def test_get_openapi_spec_returns_404_when_not_set(echo_client): + """GET /invocations/docs/openapi.json returns 404 if no spec registered.""" + resp = await echo_client.get("/invocations/docs/openapi.json") + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_get_openapi_spec_returns_spec(validated_client): + """GET /invocations/docs/openapi.json returns registered spec as JSON.""" + resp = await validated_client.get("/invocations/docs/openapi.json") + assert resp.status_code == 200 + data = resp.json() + assert data["openapi"] == "3.0.0" + assert "/invocations" in data["paths"] + + +@pytest.mark.asyncio +async def test_get_invocation_returns_404_default(echo_client): + """GET /invocations/{id} returns 404 when not overridden.""" + resp = await echo_client.get(f"/invocations/{uuid.uuid4()}") + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_cancel_invocation_returns_404_default(echo_client): + """POST /invocations/{id}/cancel returns 404 when not overridden.""" + resp = await echo_client.post(f"/invocations/{uuid.uuid4()}/cancel") + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_unknown_route_returns_404(echo_client): + """GET /unknown returns 404.""" + resp = await echo_client.get("/unknown") + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# _resolve_port tests +# --------------------------------------------------------------------------- + + +class TestResolvePort: + """Unit tests for resolve_port.""" + + def test_explicit_port_wins(self): + """Explicit port argument takes precedence over everything.""" + from azure.ai.agentserver.server._config import resolve_port + assert resolve_port(9090) == 9090 + + def test_env_var_used_when_no_explicit(self, monkeypatch): + """AGENT_SERVER_PORT env var is used when no explicit port given.""" + from azure.ai.agentserver.server._config import resolve_port + monkeypatch.setenv("AGENT_SERVER_PORT", "7777") + assert resolve_port(None) == 7777 + + def test_default_port_when_nothing_set(self, monkeypatch): + """Falls back to 8088 when no explicit port and no env var.""" + from azure.ai.agentserver.server._config import resolve_port + monkeypatch.delenv("AGENT_SERVER_PORT", raising=False) + assert resolve_port(None) == 8088 + + def test_invalid_env_var_raises(self, monkeypatch): + """Non-numeric AGENT_SERVER_PORT env var raises ValueError.""" + from azure.ai.agentserver.server._config import resolve_port + monkeypatch.setenv("AGENT_SERVER_PORT", "not_a_number") + with pytest.raises(ValueError, match="AGENT_SERVER_PORT"): + resolve_port(None) + + def test_non_int_explicit_port_raises(self): + """Passing a non-integer port raises ValueError.""" + from azure.ai.agentserver.server._config import resolve_port + with pytest.raises(ValueError, match="expected an integer"): + resolve_port("sss") # type: ignore[arg-type] + + def test_port_out_of_range_raises(self): + """Port outside 1-65535 raises ValueError.""" + from azure.ai.agentserver.server._config import resolve_port + with pytest.raises(ValueError, match="1-65535"): + resolve_port(0) + with pytest.raises(ValueError, match="1-65535"): + resolve_port(70000) + + def test_env_var_port_out_of_range_raises(self, monkeypatch): + """AGENT_SERVER_PORT outside 1-65535 raises ValueError.""" + from azure.ai.agentserver.server._config import resolve_port + monkeypatch.setenv("AGENT_SERVER_PORT", "0") + with pytest.raises(ValueError, match="1-65535"): + resolve_port(None) diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_tracing.py b/sdk/agentserver/azure-ai-agentserver/tests/test_tracing.py new file mode 100644 index 000000000000..f5a9c945e5ea --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_tracing.py @@ -0,0 +1,523 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for optional OpenTelemetry tracing.""" +import asyncio +import json +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver import AgentServer + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _InMemoryExporter(SpanExporter): + """Minimal in-memory span exporter for tests.""" + + def __init__(self): + self._spans: list = [] + + def export(self, spans): # type: ignore[override] + self._spans.extend(spans) + return SpanExportResult.SUCCESS + + def get_finished_spans(self): + return list(self._spans) + + def shutdown(self): + self._spans.clear() + + +def _make_echo_traced_agent(**kwargs) -> AgentServer: + """Create a simple agent for tracing tests.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + body = await request.body() + return Response(content=body, media_type="application/octet-stream") + + return server + + +def _make_failing_traced_agent(**kwargs) -> AgentServer: + """Create an agent whose invoke raises — so the span records the error.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + raise RuntimeError("trace-this-error") + + return server + + +@pytest.fixture() +def span_exporter(): + """Return the module-level exporter with a clean slate for each test.""" + _MODULE_EXPORTER._spans.clear() + yield _MODULE_EXPORTER + + +# Module-level OTel setup — set once to avoid +# "Overriding of current TracerProvider is not allowed" warnings. +_MODULE_EXPORTER = _InMemoryExporter() +_MODULE_PROVIDER = TracerProvider() +_MODULE_PROVIDER.add_span_processor(SimpleSpanProcessor(_MODULE_EXPORTER)) +trace.set_tracer_provider(_MODULE_PROVIDER) + + +# --------------------------------------------------------------------------- +# Tests: tracing disabled (default) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_tracing_disabled_by_default(): + """Agent created without enable_tracing has tracing off.""" + agent = _make_echo_traced_agent() + assert agent._tracing is None # noqa: SLF001 + + +@pytest.mark.asyncio +async def test_tracing_disabled_no_spans(span_exporter): + """When tracing is disabled, no spans are produced.""" + agent = _make_echo_traced_agent() # default: tracing off + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b'{"hi": "there"}') + + spans = span_exporter.get_finished_spans() + assert len(spans) == 0 + + +# --------------------------------------------------------------------------- +# Tests: tracing enabled via constructor +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_tracing_enabled_creates_invoke_span(span_exporter): + """POST /invocations with tracing enabled creates a span.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b'{"data": 1}') + assert resp.status_code == 200 + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if s.name == "AgentServer.invoke"] + assert len(invoke_spans) == 1 + + span = invoke_spans[0] + assert "invocation.id" in dict(span.attributes) + assert span.status.status_code == trace.StatusCode.UNSET # success + + +@pytest.mark.asyncio +async def test_tracing_invoke_error_records_exception(span_exporter): + """When invoke() raises, the span records the error status and exception.""" + agent = _make_failing_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b'{}') + assert resp.status_code == 500 + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if s.name == "AgentServer.invoke"] + assert len(invoke_spans) == 1 + + span = invoke_spans[0] + assert span.status.status_code == trace.StatusCode.ERROR + # Exception should be recorded in events + events = span.events + assert any("trace-this-error" in str(e.attributes) for e in events) + + +@pytest.mark.asyncio +async def test_tracing_get_invocation_creates_span(span_exporter): + """GET /invocations/{id} with tracing creates a span.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/invocations/test-id-123") + # Default returns 404 — but span should still exist + assert resp.status_code == 404 + + spans = span_exporter.get_finished_spans() + get_spans = [s for s in spans if s.name == "AgentServer.get_invocation"] + assert len(get_spans) == 1 + assert dict(get_spans[0].attributes)["invocation.id"] == "test-id-123" + + +@pytest.mark.asyncio +async def test_tracing_cancel_invocation_creates_span(span_exporter): + """POST /invocations/{id}/cancel with tracing creates a span.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations/test-cancel-456/cancel") + assert resp.status_code == 404 + + spans = span_exporter.get_finished_spans() + cancel_spans = [s for s in spans if s.name == "AgentServer.cancel_invocation"] + assert len(cancel_spans) == 1 + assert dict(cancel_spans[0].attributes)["invocation.id"] == "test-cancel-456" + + +# --------------------------------------------------------------------------- +# Tests: tracing enabled via env var +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_tracing_enabled_via_env_var(monkeypatch, span_exporter): + """AGENT_ENABLE_TRACING=true activates tracing.""" + monkeypatch.setenv("AGENT_ENABLE_TRACING", "true") + agent = _make_echo_traced_agent() + assert agent._tracing is not None # noqa: SLF001 + + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b'{}') + + spans = span_exporter.get_finished_spans() + assert any(s.name == "AgentServer.invoke" for s in spans) + + +@pytest.mark.asyncio +async def test_tracing_constructor_overrides_env_var(monkeypatch, span_exporter): + """Constructor enable_tracing=False overrides AGENT_ENABLE_TRACING=true.""" + monkeypatch.setenv("AGENT_ENABLE_TRACING", "true") + agent = _make_echo_traced_agent(enable_tracing=False) + assert agent._tracing is None # noqa: SLF001 + + +@pytest.mark.asyncio +async def test_tracing_propagates_traceparent(span_exporter): + """Incoming traceparent header is extracted as parent context.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + # Valid W3C traceparent — version-trace-id-parent-id-flags + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post( + "/invocations", + content=b'{"hello": "world"}', + headers={"traceparent": traceparent}, + ) + assert resp.status_code == 200 + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if s.name == "AgentServer.invoke"] + assert len(invoke_spans) == 1 + span = invoke_spans[0] + # The span's trace ID should match the traceparent's trace ID + expected_trace_id = int("0af7651916cd43dd8448eb211c80319c", 16) + assert span.context.trace_id == expected_trace_id + + +# --------------------------------------------------------------------------- +# Tests: streaming response tracing +# --------------------------------------------------------------------------- + + +def _make_streaming_traced_agent(**kwargs) -> AgentServer: + """Create an agent that returns a StreamingResponse.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + async def _generate(): + for i in range(5): + yield f"chunk-{i}\n".encode() + + return StreamingResponse(_generate(), media_type="application/octet-stream") + + return server + + +def _make_slow_streaming_agent(**kwargs) -> AgentServer: + """Create an agent that streams with deliberate delays per chunk.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + async def _generate(): + for i in range(3): + await asyncio.sleep(0.05) + yield f"slow-{i}\n".encode() + + return StreamingResponse(_generate(), media_type="text/plain") + + return server + + +def _make_failing_stream_agent(**kwargs) -> AgentServer: + """Create an agent whose streaming body raises mid-stream.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + async def _generate(): + yield b"ok-chunk\n" + raise RuntimeError("stream-exploded") + + return StreamingResponse(_generate(), media_type="text/plain") + + return server + + +@pytest.mark.asyncio +async def test_streaming_response_creates_span(span_exporter): + """Streaming response still produces a span.""" + agent = _make_streaming_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b'{}') + assert resp.status_code == 200 + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if s.name == "AgentServer.invoke"] + assert len(invoke_spans) == 1 + assert invoke_spans[0].status.status_code == trace.StatusCode.UNSET + + +@pytest.mark.asyncio +async def test_streaming_span_covers_full_body(span_exporter): + """Span for streaming response covers the full streaming duration, + not just the invoke() call that creates the StreamingResponse.""" + agent = _make_slow_streaming_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b'{}') + assert resp.status_code == 200 + # All chunks received + assert b"slow-0" in resp.content + assert b"slow-2" in resp.content + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if s.name == "AgentServer.invoke"] + assert len(invoke_spans) == 1 + span = invoke_spans[0] + + # Span duration should cover the streaming time (~150ms for 3×50ms), + # not just the instant invoke() returns the StreamingResponse object. + duration_ns = span.end_time - span.start_time + # At minimum 100ms (conservative) — would be <1ms without the fix. + assert duration_ns > 50_000_000, f"Span duration {duration_ns}ns is too short for streaming" + + +@pytest.mark.asyncio +async def test_streaming_body_fully_received(span_exporter): + """All chunks from a streaming response are delivered to the client.""" + agent = _make_streaming_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b'{}') + assert resp.status_code == 200 + body = resp.content.decode() + for i in range(5): + assert f"chunk-{i}" in body + + +@pytest.mark.asyncio +async def test_streaming_error_recorded_in_span(span_exporter): + """Errors during streaming are recorded on the span.""" + agent = _make_failing_stream_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + # The server will encounter a mid-stream error; httpx may raise or + # return a partial response depending on ASGI transport behaviour. + try: + await client.post("/invocations", content=b'{}') + except Exception: + pass # connection reset / partial read is acceptable + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if s.name == "AgentServer.invoke"] + assert len(invoke_spans) == 1 + + span = invoke_spans[0] + assert span.status.status_code == trace.StatusCode.ERROR + events = span.events + assert any("stream-exploded" in str(e.attributes) for e in events) + + +@pytest.mark.asyncio +async def test_streaming_tracing_disabled_no_span(span_exporter): + """When tracing is disabled, streaming responses produce no spans.""" + agent = _make_streaming_traced_agent() # tracing off (default) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b'{}') + assert resp.status_code == 200 + body = resp.content.decode() + assert "chunk-0" in body # body still delivered + + spans = span_exporter.get_finished_spans() + assert len(spans) == 0 + + +@pytest.mark.asyncio +async def test_streaming_propagates_traceparent(span_exporter): + """Incoming traceparent header is propagated for streaming responses.""" + agent = _make_streaming_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post( + "/invocations", + content=b'{}', + headers={"traceparent": traceparent}, + ) + assert resp.status_code == 200 + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if s.name == "AgentServer.invoke"] + assert len(invoke_spans) == 1 + expected_trace_id = int("0af7651916cd43dd8448eb211c80319c", 16) + assert invoke_spans[0].context.trace_id == expected_trace_id + + +# --------------------------------------------------------------------------- +# Tests: Application Insights connection string resolution +# --------------------------------------------------------------------------- + + +class TestAppInsightsConnectionStringResolution: + """Tests for resolve_appinsights_connection_string and constructor wiring.""" + + def test_explicit_param_takes_priority(self, monkeypatch): + """Constructor param beats env var.""" + from azure.ai.agentserver.server._config import resolve_appinsights_connection_string + + monkeypatch.setenv("APPLICATIONINSIGHTS_CONNECTION_STRING", "env-standard") + result = resolve_appinsights_connection_string("explicit-value") + assert result == "explicit-value" + + def test_standard_env_var_fallback(self, monkeypatch): + """Falls back to APPLICATIONINSIGHTS_CONNECTION_STRING.""" + from azure.ai.agentserver.server._config import resolve_appinsights_connection_string + + monkeypatch.setenv("APPLICATIONINSIGHTS_CONNECTION_STRING", "env-standard") + result = resolve_appinsights_connection_string(None) + assert result == "env-standard" + + def test_no_connection_string_returns_none(self, monkeypatch): + """Returns None when no source provides a connection string.""" + from azure.ai.agentserver.server._config import resolve_appinsights_connection_string + + monkeypatch.delenv("APPLICATIONINSIGHTS_CONNECTION_STRING", raising=False) + result = resolve_appinsights_connection_string(None) + assert result is None + + +# --------------------------------------------------------------------------- +# Tests: _setup_azure_monitor +# --------------------------------------------------------------------------- + + +class TestSetupAzureMonitor: + """Tests for TracingHelper._setup_azure_monitor (mocked exporter imports).""" + + def test_setup_configures_tracer_provider(self): + """_setup_azure_monitor sets a global TracerProvider with exporter.""" + from azure.ai.agentserver._tracing import TracingHelper + + mock_exporter = MagicMock() + mock_exporter_cls = MagicMock(return_value=mock_exporter) + + with patch.dict( + "sys.modules", + { + "azure.monitor.opentelemetry.exporter": MagicMock( + AzureMonitorTraceExporter=mock_exporter_cls, + AzureMonitorLogExporter=MagicMock(return_value=MagicMock()), + ), + }, + ), patch("opentelemetry.trace.set_tracer_provider") as mock_set_provider: + TracingHelper._setup_azure_monitor("InstrumentationKey=test") + + mock_exporter_cls.assert_called_once_with( + connection_string="InstrumentationKey=test" + ) + mock_set_provider.assert_called_once() + + def test_setup_logs_warning_when_packages_missing(self, caplog): + """Warns gracefully when azure-monitor exporter is not installed.""" + import builtins + from azure.ai.agentserver._tracing import TracingHelper + + real_import = builtins.__import__ + + def _block_monitor(name, *args, **kwargs): + if "azure.monitor" in name: + raise ImportError("no monitor") + return real_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=_block_monitor): + import logging + + with caplog.at_level(logging.WARNING, logger="azure.ai.agentserver"): + TracingHelper._setup_azure_monitor("InstrumentationKey=test") + + assert "required packages are not installed" in caplog.text + + def test_constructor_passes_connection_string(self, monkeypatch): + """AgentServer passes resolved connection string to TracingHelper.""" + monkeypatch.delenv("APPLICATIONINSIGHTS_CONNECTION_STRING", raising=False) + + with patch( + "azure.ai.agentserver._tracing.TracingHelper._setup_azure_monitor" + ) as mock_setup: + _make_echo_traced_agent( + enable_tracing=True, + application_insights_connection_string="InstrumentationKey=from-param", + ) + mock_setup.assert_called_once_with("InstrumentationKey=from-param") + + def test_constructor_no_connection_string_skips_setup(self, monkeypatch): + """When no connection string is available, _setup_azure_monitor is not called.""" + monkeypatch.delenv("APPLICATIONINSIGHTS_CONNECTION_STRING", raising=False) + + with patch( + "azure.ai.agentserver._tracing.TracingHelper._setup_azure_monitor" + ) as mock_setup: + _make_echo_traced_agent(enable_tracing=True) + mock_setup.assert_not_called() + + def test_constructor_env_var_connection_string(self, monkeypatch): + """Connection string from env var is passed to _setup_azure_monitor.""" + monkeypatch.setenv( + "APPLICATIONINSIGHTS_CONNECTION_STRING", + "InstrumentationKey=from-env", + ) + + with patch( + "azure.ai.agentserver._tracing.TracingHelper._setup_azure_monitor" + ) as mock_setup: + _make_echo_traced_agent(enable_tracing=True) + mock_setup.assert_called_once_with("InstrumentationKey=from-env") + + def test_tracing_disabled_skips_connection_string_resolution(self, monkeypatch): + """When tracing is disabled, connection string is not resolved.""" + monkeypatch.setenv( + "APPLICATIONINSIGHTS_CONNECTION_STRING", + "InstrumentationKey=should-not-use", + ) + + agent = _make_echo_traced_agent(enable_tracing=False) + assert agent._tracing is None # noqa: SLF001 diff --git a/sdk/agentserver/ci.yml b/sdk/agentserver/ci.yml index bb2d6f479b00..7e718801a805 100644 --- a/sdk/agentserver/ci.yml +++ b/sdk/agentserver/ci.yml @@ -40,6 +40,8 @@ extends: Selection: sparse GenerateVMJobs: true Artifacts: + - name: azure-ai-agentserver + safeName: azureaiagentserver - name: azure-ai-agentserver-core safeName: azureaiagentservercore - name: azure-ai-agentserver-agentframework diff --git a/shared_requirements.txt b/shared_requirements.txt index b5e1b85f184a..518ceb2272f0 100644 --- a/shared_requirements.txt +++ b/shared_requirements.txt @@ -87,6 +87,7 @@ azure-confidentialledger-certificate azure-ai-projects starlette uvicorn +hypercorn opentelemetry-exporter-otlp-proto-http agent-framework-azure-ai langgraph From bdb9d35ee045887bc7b1807dd37d9a823d9a2021 Mon Sep 17 00:00:00 2001 From: Zhiyong Yang Date: Tue, 10 Mar 2026 10:02:16 -0700 Subject: [PATCH 02/10] Fix pylint/pyright CI failures and accept invocation ID from request header - Fix pylint: trailing newline, too-many-instance-attributes, missing docstrings, unused import, too-many-return-statements - Fix pyright: guard error.context for reportOptionalIterable - Accept x-agent-invocation-id from request header (generate UUID if absent) - Server always controls the response header (handler cannot override) Co-Authored-By: Claude Opus 4.6 --- .../azure/ai/agentserver/server/_base.py | 38 ++++++++++++++----- .../validation/_openapi_validator.py | 32 +++++++++++++--- .../tests/test_edge_cases.py | 32 +++++++++++++++- 3 files changed, 84 insertions(+), 18 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_base.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_base.py index 5833eb9dd906..06cbdb97bded 100644 --- a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_base.py +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_base.py @@ -22,7 +22,7 @@ logger = logging.getLogger("azure.ai.agentserver") -class AgentServer: +class AgentServer: # pylint: disable=too-many-instance-attributes """Agent server with pluggable protocol heads. Instantiate and register handlers with decorators:: @@ -205,7 +205,13 @@ def shutdown_handler(self, fn: Callable[[], Awaitable[None]]) -> Callable[[], Aw # ------------------------------------------------------------------ async def _dispatch_invoke(self, request: Request) -> Response: - """Dispatch to the registered invoke handler.""" + """Dispatch to the registered invoke handler. + + :param request: The incoming Starlette request. + :type request: Request + :return: The response from the invoke handler. + :rtype: Response + """ if self._invoke_fn is not None: return await self._invoke_fn(request) raise RuntimeError( @@ -213,13 +219,25 @@ async def _dispatch_invoke(self, request: Request) -> Response: ) async def _dispatch_get_invocation(self, request: Request) -> Response: - """Dispatch to the registered get-invocation handler, or return 404.""" + """Dispatch to the registered get-invocation handler, or return 404. + + :param request: The incoming Starlette request. + :type request: Request + :return: The response from the get-invocation handler. + :rtype: Response + """ if self._get_invocation_fn is not None: return await self._get_invocation_fn(request) return error_response("not_supported", "get_invocation not supported", status_code=404) async def _dispatch_cancel_invocation(self, request: Request) -> Response: - """Dispatch to the registered cancel-invocation handler, or return 404.""" + """Dispatch to the registered cancel-invocation handler, or return 404. + + :param request: The incoming Starlette request. + :type request: Request + :return: The response from the cancel-invocation handler. + :rtype: Response + """ if self._cancel_invocation_fn is not None: return await self._cancel_invocation_fn(request) return error_response("not_supported", "cancel_invocation not supported", status_code=404) @@ -380,7 +398,10 @@ async def _create_invocation_endpoint(self, request: Request) -> Response: :return: The invocation result or error response. :rtype: Response """ - invocation_id = str(uuid.uuid4()) + invocation_id = ( + request.headers.get(Constants.INVOCATION_ID_HEADER) + or str(uuid.uuid4()) + ) request.state.invocation_id = invocation_id # Validate request body against OpenAPI spec @@ -449,9 +470,8 @@ async def _create_invocation_endpoint(self, request: Request) -> Response: elif self._tracing is not None: self._tracing.end_span(otel_span) - # Auto-inject invocation_id header if developer didn't set it - if Constants.INVOCATION_ID_HEADER not in response.headers: - response.headers[Constants.INVOCATION_ID_HEADER] = invocation_id + # Always set invocation_id header (overrides any handler-set value) + response.headers[Constants.INVOCATION_ID_HEADER] = invocation_id return response @@ -542,5 +562,3 @@ async def _readiness_endpoint(self, request: Request) -> Response: # pylint: di :rtype: Response """ return JSONResponse({"status": "ready"}) - - diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/_openapi_validator.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/_openapi_validator.py index 3cee55dd426d..cfe352ff839a 100644 --- a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/_openapi_validator.py +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/_openapi_validator.py @@ -29,7 +29,7 @@ import logging import re from collections import Counter -from datetime import datetime, timezone +from datetime import datetime from typing import Any, Optional import jsonschema @@ -49,7 +49,13 @@ @_format_checker.checks("date-time", raises=ValueError) def _check_datetime(value: object) -> bool: - """Validate RFC 3339 / ISO 8601 date-time strings using stdlib only.""" + """Validate RFC 3339 / ISO 8601 date-time strings using stdlib only. + + :param value: The value to validate. + :type value: object + :return: True if valid. + :rtype: bool + """ if not isinstance(value, str): return True # non-string is not a format error normalized = value.replace("Z", "+00:00") if value.endswith("Z") else value @@ -59,7 +65,13 @@ def _check_datetime(value: object) -> bool: @_format_checker.checks("date", raises=ValueError) def _check_date(value: object) -> bool: - """Validate ISO 8601 date strings (YYYY-MM-DD).""" + """Validate ISO 8601 date strings (YYYY-MM-DD). + + :param value: The value to validate. + :type value: object + :return: True if valid. + :rtype: bool + """ if not isinstance(value, str): return True datetime.strptime(value, "%Y-%m-%d") @@ -68,7 +80,13 @@ def _check_date(value: object) -> bool: @_format_checker.checks("email", raises=ValueError) def _check_email(value: object) -> bool: - """Basic RFC 5322 email format check (no DNS lookup).""" + """Basic RFC 5322 email format check (no DNS lookup). + + :param value: The value to validate. + :type value: object + :return: True if valid. + :rtype: bool + """ if not isinstance(value, str): return True if not _EMAIL_RE.match(value): @@ -467,7 +485,7 @@ def _collect_composition_errors(error: ValidationError) -> list[str]: """ # Group sub-errors by branch index (first element of schema_path) branch_groups: dict[int, list[ValidationError]] = {} - for sub in error.context: + for sub in error.context or []: if sub.schema_path: idx = sub.schema_path[0] if isinstance(idx, int): @@ -651,7 +669,9 @@ def _resolve_ref(spec: dict[str, Any], schema: dict[str, Any]) -> dict[str, Any] return current if isinstance(current, dict) else schema -def _resolve_refs_deep(spec: dict[str, Any], node: Any, _seen: Optional[set[str]] = None) -> Any: +def _resolve_refs_deep( # pylint: disable=too-many-return-statements + spec: dict[str, Any], node: Any, _seen: Optional[set[str]] = None +) -> Any: """Recursively resolve all ``$ref`` pointers in a schema tree. Walks the schema, replacing every ``{"$ref": "..."}`` with the referenced diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_edge_cases.py b/sdk/agentserver/azure-ai-agentserver/tests/test_edge_cases.py index 6907187b80de..2ea55294ee51 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_edge_cases.py +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_edge_cases.py @@ -239,10 +239,13 @@ class TestResponseHeaders: @pytest.mark.asyncio async def test_agent_custom_invocation_id_not_overwritten(self, custom_header_client): - """If agent sets x-agent-invocation-id, the server doesn't overwrite it.""" + """If agent sets x-agent-invocation-id, the server overwrites it with the canonical value.""" resp = await custom_header_client.post("/invocations", content=b"{}") assert resp.status_code == 200 - assert resp.headers["x-agent-invocation-id"] == "custom-id-from-agent" + # Server always controls the invocation-id header; handler cannot override it. + invocation_id = resp.headers["x-agent-invocation-id"] + assert invocation_id != "custom-id-from-agent" + uuid.UUID(invocation_id) # must be a valid server-generated UUID @pytest.mark.asyncio async def test_invocation_id_injected_when_missing(self, echo_client): @@ -253,6 +256,31 @@ async def test_invocation_id_injected_when_missing(self, echo_client): assert invocation_id is not None uuid.UUID(invocation_id) # valid UUID + @pytest.mark.asyncio + async def test_invocation_id_accepted_from_request_header(self, echo_client): + """If caller sends x-agent-invocation-id in the request, server uses it.""" + custom_id = "caller-provided-id-12345" + resp = await echo_client.post( + "/invocations", + content=b"hello", + headers={"x-agent-invocation-id": custom_id}, + ) + assert resp.status_code == 200 + assert resp.headers["x-agent-invocation-id"] == custom_id + + @pytest.mark.asyncio + async def test_invocation_id_generated_when_request_header_empty(self, echo_client): + """If caller sends empty x-agent-invocation-id, server generates a new one.""" + resp = await echo_client.post( + "/invocations", + content=b"hello", + headers={"x-agent-invocation-id": ""}, + ) + assert resp.status_code == 200 + invocation_id = resp.headers["x-agent-invocation-id"] + assert invocation_id != "" + uuid.UUID(invocation_id) # valid UUID + # --------------------------------------------------------------------------- # Tests: payload edge cases From 36a8826db5691c33547eb06f51c3ca98c6db57f3 Mon Sep 17 00:00:00 2001 From: Zhiyong Yang Date: Wed, 11 Mar 2026 09:11:04 -0700 Subject: [PATCH 03/10] a few improvements --- .../azure/ai/agentserver/_tracing.py | 36 ++++++--- .../azure/ai/agentserver/server/_base.py | 75 ++++++++++--------- .../azure/ai/agentserver/server/_config.py | 31 +++++--- .../validation/_openapi_validator.py | 72 +++++++++--------- 4 files changed, 122 insertions(+), 92 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_tracing.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_tracing.py index 2c612424ecaa..4928662fffa0 100644 --- a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_tracing.py @@ -28,13 +28,15 @@ from collections.abc import AsyncIterable, AsyncIterator, Mapping # pylint: disable=import-error from typing import TYPE_CHECKING, Any, Iterator, Optional, Union +from ._logger import get_logger + #: Starlette's ``Content`` type — the element type for streaming bodies. _Content = Union[str, bytes, memoryview] #: W3C Trace Context header names used for distributed trace propagation. -_W3C_HEADERS = frozenset(("traceparent", "tracestate")) +_W3C_HEADERS = ("traceparent", "tracestate") -logger = logging.getLogger("azure.ai.agentserver") +logger = get_logger() _HAS_OTEL = False try: @@ -64,6 +66,7 @@ class TracingHelper: def __init__(self, connection_string: Optional[str] = None) -> None: self._enabled = _HAS_OTEL self._tracer: Any = None + self._propagator: Any = None if not self._enabled: logger.warning( @@ -76,11 +79,24 @@ def __init__(self, connection_string: Optional[str] = None) -> None: self._setup_azure_monitor(connection_string) self._tracer = trace.get_tracer("azure.ai.agentserver") + self._propagator = TraceContextTextMapPropagator() # ------------------------------------------------------------------ # Azure Monitor auto-configuration # ------------------------------------------------------------------ + def _extract_context(self, carrier: Optional[dict[str, str]]) -> Any: + """Extract parent trace context from a W3C carrier dict. + + :param carrier: W3C trace-context headers or None. + :type carrier: Optional[dict[str, str]] + :return: The extracted OTel context, or None. + :rtype: Any + """ + if carrier and self._propagator is not None: + return self._propagator.extract(carrier=carrier) + return None + @staticmethod def _setup_azure_monitor(connection_string: str) -> None: """Configure global TracerProvider and LoggerProvider for App Insights. @@ -176,10 +192,7 @@ def span( yield None return - # Extract parent context from W3C traceparent header if present - ctx = None - if carrier: - ctx = TraceContextTextMapPropagator().extract(carrier=carrier) + ctx = self._extract_context(carrier) with self._tracer.start_as_current_span( name=name, @@ -213,9 +226,7 @@ def start_span( if not self._enabled or self._tracer is None: return None - ctx = None - if carrier: - ctx = TraceContextTextMapPropagator().extract(carrier=carrier) + ctx = self._extract_context(carrier) return self._tracer.start_span( name=name, @@ -304,4 +315,9 @@ def extract_w3c_carrier(headers: Mapping[str, str]) -> dict[str, str]: in *headers*. :rtype: dict[str, str] """ - return {k: v for k, v in headers.items() if k in _W3C_HEADERS} + result: dict[str, str] = {} + for key in _W3C_HEADERS: + val = headers.get(key) + if val is not None: + result[key] = val + return result diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_base.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_base.py index 06cbdb97bded..9a879f26ea91 100644 --- a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_base.py +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_base.py @@ -15,11 +15,16 @@ from .._constants import Constants from .._errors import error_response +from .._logger import get_logger from .._tracing import TracingHelper, extract_w3c_carrier from ..validation._openapi_validator import OpenApiValidator from . import _config -logger = logging.getLogger("azure.ai.agentserver") +logger = get_logger() + +# Pre-built health-check responses to avoid per-request allocation. +_LIVENESS_BODY = b'{"status":"alive"}' +_READINESS_BODY = b'{"status":"ready"}' class AgentServer: # pylint: disable=too-many-instance-attributes @@ -475,20 +480,32 @@ async def _create_invocation_endpoint(self, request: Request) -> Response: return response - async def _get_invocation_endpoint(self, request: Request) -> Response: - """GET /invocations/{invocation_id} — retrieve an invocation result. + async def _traced_invocation_endpoint( + self, + request: Request, + span_name: str, + dispatch: Callable[[Request], Awaitable[Response]], + ) -> Response: + """Shared implementation for get/cancel invocation endpoints. + + Extracts the invocation ID from path params, optionally creates a + tracing span, dispatches to the handler, and handles errors. :param request: The incoming Starlette request. :type request: Request - :return: The stored result or 404. + :param span_name: OTel span name (e.g. ``"AgentServer.get_invocation"``). + :type span_name: str + :param dispatch: The dispatch method to invoke. + :type dispatch: Callable[[Request], Awaitable[Response]] + :return: The handler response or an error response. :rtype: Response """ invocation_id = request.path_params["invocation_id"] request.state.invocation_id = invocation_id - carrier = extract_w3c_carrier(request.headers) + carrier = extract_w3c_carrier(request.headers) if self._tracing is not None else {} span_cm = ( self._tracing.span( - "AgentServer.get_invocation", + span_name, attributes={"invocation.id": invocation_id}, carrier=carrier, ) @@ -497,11 +514,11 @@ async def _get_invocation_endpoint(self, request: Request) -> Response: ) with span_cm as _otel_span: try: - return await self._dispatch_get_invocation(request) + return await dispatch(request) except Exception as exc: # pylint: disable=broad-exception-caught if self._tracing is not None: self._tracing.record_error(_otel_span, exc) - logger.error("Error in get_invocation %s: %s", invocation_id, exc, exc_info=True) + logger.error("Error in %s %s: %s", span_name, invocation_id, exc, exc_info=True) message = str(exc) if self._debug_errors else "Internal server error" return error_response( "internal_error", @@ -509,6 +526,18 @@ async def _get_invocation_endpoint(self, request: Request) -> Response: status_code=500, ) + async def _get_invocation_endpoint(self, request: Request) -> Response: + """GET /invocations/{invocation_id} — retrieve an invocation result. + + :param request: The incoming Starlette request. + :type request: Request + :return: The stored result or 404. + :rtype: Response + """ + return await self._traced_invocation_endpoint( + request, "AgentServer.get_invocation", self._dispatch_get_invocation + ) + async def _cancel_invocation_endpoint(self, request: Request) -> Response: """POST /invocations/{invocation_id}/cancel — cancel an invocation. @@ -517,31 +546,9 @@ async def _cancel_invocation_endpoint(self, request: Request) -> Response: :return: The cancellation result or 404. :rtype: Response """ - invocation_id = request.path_params["invocation_id"] - request.state.invocation_id = invocation_id - carrier = extract_w3c_carrier(request.headers) - span_cm = ( - self._tracing.span( - "AgentServer.cancel_invocation", - attributes={"invocation.id": invocation_id}, - carrier=carrier, - ) - if self._tracing is not None - else contextlib.nullcontext(None) + return await self._traced_invocation_endpoint( + request, "AgentServer.cancel_invocation", self._dispatch_cancel_invocation ) - with span_cm as _otel_span: - try: - return await self._dispatch_cancel_invocation(request) - except Exception as exc: # pylint: disable=broad-exception-caught - if self._tracing is not None: - self._tracing.record_error(_otel_span, exc) - logger.error("Error in cancel_invocation %s: %s", invocation_id, exc, exc_info=True) - message = str(exc) if self._debug_errors else "Internal server error" - return error_response( - "internal_error", - message, - status_code=500, - ) async def _liveness_endpoint(self, request: Request) -> Response: # pylint: disable=unused-argument """GET /liveness — health check. @@ -551,7 +558,7 @@ async def _liveness_endpoint(self, request: Request) -> Response: # pylint: dis :return: 200 OK response. :rtype: Response """ - return JSONResponse({"status": "alive"}) + return Response(_LIVENESS_BODY, media_type="application/json") async def _readiness_endpoint(self, request: Request) -> Response: # pylint: disable=unused-argument """GET /readiness — readiness check. @@ -561,4 +568,4 @@ async def _readiness_endpoint(self, request: Request) -> Response: # pylint: di :return: 200 OK response. :rtype: Response """ - return JSONResponse({"status": "ready"}) + return Response(_READINESS_BODY, media_type="application/json") diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_config.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_config.py index 1102d0395a73..80b603b280fd 100644 --- a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_config.py +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_config.py @@ -57,6 +57,24 @@ def _require_int(name: str, value: object) -> int: return value +def _validate_port(value: int, source: str) -> int: + """Validate that a port number is within the valid range. + + :param value: The port number to validate. + :type value: int + :param source: Human-readable source name for the error message. + :type source: str + :return: The validated port number. + :rtype: int + :raises ValueError: If the port is outside 1-65535. + """ + if not 1 <= value <= 65535: + raise ValueError( + f"Invalid value for {source}: {value} (expected 1-65535)" + ) + return value + + def resolve_port(port: Optional[int]) -> int: """Resolve the server port from argument, env var, or default. @@ -67,19 +85,10 @@ def resolve_port(port: Optional[int]) -> int: :raises ValueError: If the port value is not a valid integer or is outside 1-65535. """ if port is not None: - result = _require_int("port", port) - if not 1 <= result <= 65535: - raise ValueError( - f"Invalid value for port: {result} (expected 1-65535)" - ) - return result + return _validate_port(_require_int("port", port), "port") env_port = _parse_int_env(Constants.AGENT_SERVER_PORT) if env_port is not None: - if not 1 <= env_port <= 65535: - raise ValueError( - f"Invalid value for {Constants.AGENT_SERVER_PORT}: {env_port} (expected 1-65535)" - ) - return env_port + return _validate_port(env_port, Constants.AGENT_SERVER_PORT) return Constants.DEFAULT_PORT diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/_openapi_validator.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/_openapi_validator.py index cfe352ff839a..41cf88f24c50 100644 --- a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/_openapi_validator.py +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/_openapi_validator.py @@ -26,16 +26,18 @@ # import copy import json -import logging import re from collections import Counter from datetime import datetime +from collections.abc import Callable # pylint: disable=import-error from typing import Any, Optional import jsonschema from jsonschema import FormatChecker, ValidationError -logger = logging.getLogger("azure.ai.agentserver") +from .._logger import get_logger + +logger = get_logger() # --------------------------------------------------------------------------- # Stdlib-only format checkers so we never depend on optional jsonschema extras @@ -111,8 +113,6 @@ class OpenApiValidator: """ def __init__(self, spec: dict[str, Any], path: str = "/invocations") -> None: - self._spec = spec - self._path = path self._request_body_required = self._is_request_body_required(spec, path) raw_request = self._extract_request_schema(spec, path) raw_response = self._extract_response_schema(spec, path) @@ -295,6 +295,31 @@ def _is_request_body_required(spec: dict[str, Any], path: str) -> bool: request_body = operation.get("requestBody", {}) return request_body.get("required", True) + @staticmethod + def _walk_schema(schema: dict[str, Any], visitor: Callable[[dict[str, Any]], None]) -> None: + """Recursively apply *visitor* to every sub-schema in the tree. + + *visitor* is called on each dict-typed sub-schema found in + ``items``, ``additionalProperties``, ``properties``, and + ``allOf``/``oneOf``/``anyOf``. + + :param schema: Root schema dict. + :type schema: dict[str, Any] + :param visitor: Callable applied to each sub-schema dict. + :type visitor: Callable[[dict[str, Any]], None] + """ + for key in ("items", "additionalProperties"): + child = schema.get(key) + if isinstance(child, dict): + visitor(child) + for prop in schema.get("properties", {}).values(): + if isinstance(prop, dict): + visitor(prop) + for keyword in ("allOf", "oneOf", "anyOf"): + for sub in schema.get(keyword, []): + if isinstance(sub, dict): + visitor(sub) + @staticmethod def _preprocess_schema( schema: dict[str, Any], context: str = "request" @@ -339,18 +364,7 @@ def _apply_nullable(schema: dict[str, Any]) -> None: schema["type"] = [original, "null"] elif isinstance(original, list) and "null" not in original: schema["type"] = original + ["null"] - # Recurse into nested structures - for key in ("items", "additionalProperties"): - child = schema.get(key) - if isinstance(child, dict): - OpenApiValidator._apply_nullable(child) - for prop in schema.get("properties", {}).values(): - if isinstance(prop, dict): - OpenApiValidator._apply_nullable(prop) - for keyword in ("allOf", "oneOf", "anyOf"): - for sub in schema.get(keyword, []): - if isinstance(sub, dict): - OpenApiValidator._apply_nullable(sub) + OpenApiValidator._walk_schema(schema, OpenApiValidator._apply_nullable) @staticmethod def _strip_readonly_writeonly( @@ -386,17 +400,11 @@ def _strip_readonly_writeonly( props.pop(name, None) if name in required: required.remove(name) - # Recurse into nested objects - for prop in props.values(): - if isinstance(prop, dict): - OpenApiValidator._strip_readonly_writeonly(prop, context) - child = schema.get("items") - if isinstance(child, dict): + + def _recurse(child: dict[str, Any]) -> None: OpenApiValidator._strip_readonly_writeonly(child, context) - for keyword in ("allOf", "oneOf", "anyOf"): - for sub in schema.get(keyword, []): - if isinstance(sub, dict): - OpenApiValidator._strip_readonly_writeonly(sub, context) + + OpenApiValidator._walk_schema(schema, _recurse) @staticmethod def _strip_openapi_keywords(schema: dict[str, Any]) -> None: @@ -412,17 +420,7 @@ def _strip_openapi_keywords(schema: dict[str, Any]) -> None: return for kw in _OPENAPI_ONLY_KEYWORDS: schema.pop(kw, None) - for key in ("items", "additionalProperties"): - child = schema.get(key) - if isinstance(child, dict): - OpenApiValidator._strip_openapi_keywords(child) - for prop in schema.get("properties", {}).values(): - if isinstance(prop, dict): - OpenApiValidator._strip_openapi_keywords(prop) - for keyword in ("allOf", "oneOf", "anyOf"): - for sub in schema.get(keyword, []): - if isinstance(sub, dict): - OpenApiValidator._strip_openapi_keywords(sub) + OpenApiValidator._walk_schema(schema, OpenApiValidator._strip_openapi_keywords) # ------------------------------------------------------------------ From 8051db7e55134830a0586c32d35c5e4839dcb782 Mon Sep 17 00:00:00 2001 From: Zhiyong Yang Date: Wed, 11 Mar 2026 14:47:12 -0700 Subject: [PATCH 04/10] Flatten package structure and fix several bugs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove server/ and validation/ subdirectories, moving _base.py, _config.py, and _openapi_validator.py directly into azure/ai/agentserver/ - Fix _parse_int_env treating "0" as unset (if not raw → if raw is None) - Return 501 Not Implemented when no invoke handler is registered instead of an unclear 500 Internal Server Error - Add x-agent-invocation-id header to GET/cancel invocation responses (success and error paths) - Fix pyright warnings by importing best_match from jsonschema.exceptions - Update all internal and test imports for new flat layout Co-Authored-By: Claude Opus 4.6 --- .../azure/ai/agentserver/__init__.py | 2 +- .../ai/agentserver/{server => }/_base.py | 28 +++++++--- .../ai/agentserver/{server => }/_config.py | 4 +- .../{validation => }/_openapi_validator.py | 7 ++- .../azure/ai/agentserver/server/__init__.py | 3 - .../ai/agentserver/validation/__init__.py | 3 - .../tests/test_decorator_pattern.py | 12 +++- .../tests/test_get_cancel.py | 4 ++ .../tests/test_graceful_shutdown.py | 21 ++++--- .../azure-ai-agentserver/tests/test_logger.py | 2 +- .../tests/test_openapi_validation.py | 56 +++++++++---------- .../tests/test_request_limits.py | 7 ++- .../tests/test_server_routes.py | 22 +++++--- .../tests/test_tracing.py | 6 +- 14 files changed, 105 insertions(+), 72 deletions(-) rename sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/{server => }/_base.py (96%) rename sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/{server => }/_config.py (99%) rename sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/{validation => }/_openapi_validator.py (99%) delete mode 100644 sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/__init__.py delete mode 100644 sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/__init__.py diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/__init__.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/__init__.py index aa692afa00e8..4be26052cbe4 100644 --- a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/__init__.py @@ -3,7 +3,7 @@ # --------------------------------------------------------- from ._version import VERSION -from .server._base import AgentServer +from ._base import AgentServer __all__ = ["AgentServer"] __version__ = VERSION diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_base.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_base.py similarity index 96% rename from sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_base.py rename to sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_base.py index 9a879f26ea91..41b03dad9de4 100644 --- a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_base.py +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_base.py @@ -13,11 +13,11 @@ from starlette.responses import JSONResponse, Response, StreamingResponse from starlette.routing import Route -from .._constants import Constants -from .._errors import error_response -from .._logger import get_logger -from .._tracing import TracingHelper, extract_w3c_carrier -from ..validation._openapi_validator import OpenApiValidator +from ._constants import Constants +from ._errors import error_response +from ._logger import get_logger +from ._tracing import TracingHelper, extract_w3c_carrier +from ._openapi_validator import OpenApiValidator from . import _config logger = get_logger() @@ -216,10 +216,11 @@ async def _dispatch_invoke(self, request: Request) -> Response: :type request: Request :return: The response from the invoke handler. :rtype: Response + :raises NotImplementedError: If no invoke handler has been registered. """ if self._invoke_fn is not None: return await self._invoke_fn(request) - raise RuntimeError( + raise NotImplementedError( "No invoke handler registered. Use the @server.invoke_handler decorator." ) @@ -442,6 +443,16 @@ async def _create_invocation_endpoint(self, request: Request) -> Response: invoke_awaitable = self._dispatch_invoke(request) timeout = self._request_timeout or None # 0 → None (no limit) response = await asyncio.wait_for(invoke_awaitable, timeout=timeout) + except NotImplementedError as exc: + if self._tracing is not None: + self._tracing.end_span(otel_span, exc=exc) + logger.error("Invocation %s failed: %s", invocation_id, exc) + return error_response( + "not_implemented", + str(exc), + status_code=501, + headers={Constants.INVOCATION_ID_HEADER: invocation_id}, + ) except asyncio.TimeoutError: if self._tracing is not None: self._tracing.end_span(otel_span) @@ -514,7 +525,9 @@ async def _traced_invocation_endpoint( ) with span_cm as _otel_span: try: - return await dispatch(request) + response = await dispatch(request) + response.headers[Constants.INVOCATION_ID_HEADER] = invocation_id + return response except Exception as exc: # pylint: disable=broad-exception-caught if self._tracing is not None: self._tracing.record_error(_otel_span, exc) @@ -524,6 +537,7 @@ async def _traced_invocation_endpoint( "internal_error", message, status_code=500, + headers={Constants.INVOCATION_ID_HEADER: invocation_id}, ) async def _get_invocation_endpoint(self, request: Request) -> Response: diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_config.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_config.py similarity index 99% rename from sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_config.py rename to sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_config.py index 80b603b280fd..910479580f18 100644 --- a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/_config.py +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_config.py @@ -16,7 +16,7 @@ import os from typing import Optional -from .._constants import Constants +from ._constants import Constants def _parse_int_env(var_name: str) -> Optional[int]: @@ -29,7 +29,7 @@ def _parse_int_env(var_name: str) -> Optional[int]: :raises ValueError: If the variable is set but cannot be parsed as an integer. """ raw = os.environ.get(var_name) - if not raw: + if raw is None: return None try: return int(raw) diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/_openapi_validator.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_openapi_validator.py similarity index 99% rename from sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/_openapi_validator.py rename to sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_openapi_validator.py index 41cf88f24c50..88688738c31f 100644 --- a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/_openapi_validator.py +++ b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_openapi_validator.py @@ -33,9 +33,10 @@ from typing import Any, Optional import jsonschema +from jsonschema.exceptions import best_match from jsonschema import FormatChecker, ValidationError -from .._logger import get_logger +from ._logger import get_logger logger = get_logger() @@ -491,7 +492,7 @@ def _collect_composition_errors(error: ValidationError) -> list[str]: if len(branch_groups) < 2: # Cannot do branch analysis — fallback - best = jsonschema.exceptions.best_match([error]) + best = best_match([error]) if best is not None and best is not error: return _collect_errors(best) return [_format_error(error)] @@ -499,7 +500,7 @@ def _collect_composition_errors(error: ValidationError) -> list[str]: disc_path = _find_discriminator_path(branch_groups) if disc_path is None: - best = jsonschema.exceptions.best_match([error]) + best = best_match([error]) if best is not None and best is not error: return _collect_errors(best) return [_format_error(error)] diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/__init__.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/__init__.py deleted file mode 100644 index d540fd20468c..000000000000 --- a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/server/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# --------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# --------------------------------------------------------- diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/__init__.py b/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/__init__.py deleted file mode 100644 index d540fd20468c..000000000000 --- a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/validation/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# --------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# --------------------------------------------------------- diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_decorator_pattern.py b/sdk/agentserver/azure-ai-agentserver/tests/test_decorator_pattern.py index 018e9478a7e3..eab02458a9a2 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_decorator_pattern.py +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_decorator_pattern.py @@ -130,16 +130,20 @@ async def handle(request: Request) -> Response: class TestMissingInvokeHandler: - """When no invoke handler is registered and invoke() is not overridden, 500.""" + """When no invoke handler is registered and invoke() is not overridden, 501.""" @pytest.mark.asyncio - async def test_no_handler_returns_500(self): + async def test_no_handler_returns_501(self): server = AgentServer() # No @invoke_handler registered transport = httpx.ASGITransport(app=server.app) async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: resp = await client.post("/invocations", json={}) - assert resp.status_code == 500 + assert resp.status_code == 501 + data = resp.json() + assert data["error"]["code"] == "not_implemented" + assert "invoke handler" in data["error"]["message"].lower() + assert "x-agent-invocation-id" in resp.headers # --------------------------------------------------------------------------- @@ -162,6 +166,7 @@ async def handle(request: Request) -> Response: async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: resp = await client.get("/invocations/some-id") assert resp.status_code == 404 + assert resp.headers.get("x-agent-invocation-id") == "some-id" @pytest.mark.asyncio async def test_cancel_invocation_returns_404_by_default(self): @@ -175,6 +180,7 @@ async def handle(request: Request) -> Response: async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: resp = await client.post("/invocations/some-id/cancel") assert resp.status_code == 404 + assert resp.headers.get("x-agent-invocation-id") == "some-id" # --------------------------------------------------------------------------- diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_get_cancel.py b/sdk/agentserver/azure-ai-agentserver/tests/test_get_cancel.py index 5a6d1b963abd..6ffdf5082a40 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_get_cancel.py +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_get_cancel.py @@ -16,6 +16,7 @@ async def test_get_invocation_after_invoke(async_storage_client): get_resp = await async_storage_client.get(f"/invocations/{invocation_id}") assert get_resp.status_code == 200 + assert get_resp.headers.get("x-agent-invocation-id") == invocation_id data = json.loads(get_resp.content) assert "echo" in data @@ -35,6 +36,7 @@ async def test_cancel_invocation_after_invoke(async_storage_client): cancel_resp = await async_storage_client.post(f"/invocations/{invocation_id}/cancel") assert cancel_resp.status_code == 200 + assert cancel_resp.headers.get("x-agent-invocation-id") == invocation_id data = json.loads(cancel_resp.content) assert data["status"] == "cancelled" @@ -82,6 +84,7 @@ async def get_inv(request: Request) -> Response: assert resp.status_code == 500 assert resp.json()["error"]["code"] == "internal_error" assert resp.json()["error"]["message"] == "Internal server error" + assert resp.headers.get("x-agent-invocation-id") == "some-id" @pytest.mark.asyncio @@ -109,3 +112,4 @@ async def cancel(request: Request) -> Response: assert resp.status_code == 500 assert resp.json()["error"]["code"] == "internal_error" assert resp.json()["error"]["message"] == "Internal server error" + assert resp.headers.get("x-agent-invocation-id") == "some-id" diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_graceful_shutdown.py b/sdk/agentserver/azure-ai-agentserver/tests/test_graceful_shutdown.py index 73d0d8b63479..1fce19d303f5 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_graceful_shutdown.py +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_graceful_shutdown.py @@ -46,6 +46,11 @@ def test_explicit_zero_disables_drain(self): agent = _make_stub_agent(graceful_shutdown_timeout=0) assert agent._graceful_shutdown_timeout == 0 + def test_env_var_zero_disables_drain(self): + with patch.dict(os.environ, {Constants.AGENT_GRACEFUL_SHUTDOWN_TIMEOUT: "0"}): + agent = _make_stub_agent() + assert agent._graceful_shutdown_timeout == 0 + def test_env_var_used_when_no_explicit(self): with patch.dict(os.environ, {Constants.AGENT_GRACEFUL_SHUTDOWN_TIMEOUT: "45"}): agent = _make_stub_agent() @@ -90,7 +95,7 @@ class TestRunPassesTimeout: """Ensure run() forwards the timeout to Hypercorn config.""" @patch("hypercorn.asyncio.serve", new_callable=AsyncMock) - @patch("azure.ai.agentserver.server._base.asyncio") + @patch("azure.ai.agentserver._base.asyncio") def test_run_passes_timeout(self, mock_asyncio, _mock_serve): agent = _make_stub_agent(graceful_shutdown_timeout=15) agent.run() @@ -100,7 +105,7 @@ def test_run_passes_timeout(self, mock_asyncio, _mock_serve): assert config.graceful_timeout == 15.0 @patch("hypercorn.asyncio.serve", new_callable=AsyncMock) - @patch("azure.ai.agentserver.server._base.asyncio") + @patch("azure.ai.agentserver._base.asyncio") def test_run_passes_default_timeout(self, mock_asyncio, _mock_serve): agent = _make_stub_agent() agent.run() @@ -167,7 +172,7 @@ async def test_lifespan_shutdown_logs(self): agent = _make_stub_agent(graceful_shutdown_timeout=99) - with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + with patch("azure.ai.agentserver._base.logger") as mock_logger: # Grab the lifespan from the Starlette app lifespan = agent.app.router.lifespan_context async with lifespan(agent.app): @@ -268,7 +273,7 @@ async def test_default_on_shutdown_is_noop(self): async def test_on_shutdown_exception_is_logged_not_raised(self): """A failing on_shutdown must not crash the shutdown sequence.""" agent = _make_failing_shutdown_agent() - with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + with patch("azure.ai.agentserver._base.logger") as mock_logger: lifespan = agent.app.router.lifespan_context async with lifespan(agent.app): pass # should NOT raise @@ -295,7 +300,7 @@ async def invoke(request: Request) -> Response: async def shutdown(): order.append("callback") - with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + with patch("azure.ai.agentserver._base.logger") as mock_logger: def tracking_info(*args, **kwargs): if args and "shutting down" in str(args[0]).lower(): @@ -355,7 +360,7 @@ async def invoke(request: Request) -> Response: async def shutdown(): await asyncio.sleep(999) # way longer than the 1s timeout - with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + with patch("azure.ai.agentserver._base.logger") as mock_logger: lifespan = server.app.router.lifespan_context async with lifespan(server.app): pass # should NOT hang @@ -386,7 +391,7 @@ async def shutdown(): await asyncio.sleep(0) done = True - with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + with patch("azure.ai.agentserver._base.logger") as mock_logger: lifespan = server.app.router.lifespan_context async with lifespan(server.app): pass @@ -425,7 +430,7 @@ async def shutdown(): async def test_timeout_does_not_suppress_exceptions(self): """Exceptions from on_shutdown are still logged even if within timeout.""" agent = _make_failing_shutdown_agent(graceful_shutdown_timeout=5) - with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + with patch("azure.ai.agentserver._base.logger") as mock_logger: lifespan = agent.app.router.lifespan_context async with lifespan(agent.app): pass diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_logger.py b/sdk/agentserver/azure-ai-agentserver/tests/test_logger.py index 314211cec644..06c3ef31414f 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_logger.py +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_logger.py @@ -17,6 +17,6 @@ def test_log_level_preserved_across_imports(): lib_logger.setLevel(logging.ERROR) # Re-importing the base module should not override the level. - from azure.ai.agentserver.server import _base # noqa: F401 + from azure.ai.agentserver import _base # noqa: F401 assert lib_logger.level == logging.ERROR diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_openapi_validation.py b/sdk/agentserver/azure-ai-agentserver/tests/test_openapi_validation.py index 6806da272fe5..4082567a49f9 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_openapi_validation.py +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_openapi_validation.py @@ -11,7 +11,7 @@ from starlette.responses import JSONResponse, Response from azure.ai.agentserver import AgentServer -from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator +from azure.ai.agentserver._openapi_validator import OpenApiValidator # --------------------------------------------------------------------------- @@ -860,7 +860,7 @@ async def test_complex_multiple_errors_reported(): @pytest.mark.asyncio async def test_validate_response_valid(): """Valid response body passes validation.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator v = OpenApiValidator(RESPONSE_SPEC) errors = v.validate_response(b'{"result": "ok"}', "application/json") @@ -870,7 +870,7 @@ async def test_validate_response_valid(): @pytest.mark.asyncio async def test_validate_response_invalid(): """Invalid response body returns errors.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator v = OpenApiValidator(RESPONSE_SPEC) errors = v.validate_response(b'{"wrong": 42}', "application/json") @@ -880,7 +880,7 @@ async def test_validate_response_invalid(): @pytest.mark.asyncio async def test_validate_response_no_schema(): """When no response schema exists, validation passes (no-op).""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator spec_no_resp = { "openapi": "3.0.0", @@ -900,7 +900,7 @@ async def test_validate_response_no_schema(): @pytest.mark.asyncio async def test_validate_request_no_schema(): """When no request schema exists, validation passes (no-op).""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator spec_no_req = { "openapi": "3.0.0", @@ -920,7 +920,7 @@ async def test_validate_request_no_schema(): @pytest.mark.asyncio async def test_response_schema_fallback_to_first_available(): """Response schema extraction falls back to first response with JSON.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator spec = { "openapi": "3.0.0", @@ -961,7 +961,7 @@ async def test_response_schema_fallback_to_first_available(): @pytest.mark.asyncio async def test_unresolvable_ref(): """An unresolvable $ref leaves the node as-is (no crash).""" - from azure.ai.agentserver.validation._openapi_validator import _resolve_ref + from azure.ai.agentserver._openapi_validator import _resolve_ref spec: dict = {"components": {"schemas": {}}} node = {"$ref": "#/components/schemas/DoesNotExist"} @@ -973,7 +973,7 @@ async def test_unresolvable_ref(): @pytest.mark.asyncio async def test_ref_path_hits_non_dict(): """A $ref path that traverses a non-dict returns the original node.""" - from azure.ai.agentserver.validation._openapi_validator import _resolve_ref + from azure.ai.agentserver._openapi_validator import _resolve_ref spec: dict = {"components": {"schemas": "not-a-dict"}} node = {"$ref": "#/components/schemas/Foo"} @@ -984,7 +984,7 @@ async def test_ref_path_hits_non_dict(): @pytest.mark.asyncio async def test_circular_ref_stops_recursion(): """Circular $ref does not cause infinite recursion.""" - from azure.ai.agentserver.validation._openapi_validator import _resolve_refs_deep + from azure.ai.agentserver._openapi_validator import _resolve_refs_deep spec: dict = { "components": { @@ -1010,7 +1010,7 @@ async def test_circular_ref_stops_recursion(): @pytest.mark.asyncio async def test_ref_resolves_to_non_dict(): """A $ref that resolves to a non-dict value returns the original node.""" - from azure.ai.agentserver.validation._openapi_validator import _resolve_ref + from azure.ai.agentserver._openapi_validator import _resolve_ref spec: dict = {"components": {"schemas": {"Bad": 42}}} node = {"$ref": "#/components/schemas/Bad"} @@ -1211,7 +1211,7 @@ async def test_writeonly_allowed_in_request(): @pytest.mark.asyncio async def test_readonly_schema_introspection_request(): """readOnly properties are removed from the preprocessed request schema.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator v = OpenApiValidator(READONLY_SPEC) # 'id' should be gone from requestschema properties @@ -1221,7 +1221,7 @@ async def test_readonly_schema_introspection_request(): @pytest.mark.asyncio async def test_readonly_schema_introspection_response(): """readOnly properties remain in the preprocessed response schema.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator v = OpenApiValidator(READONLY_SPEC) # 'id' should still be in response schema @@ -1231,7 +1231,7 @@ async def test_readonly_schema_introspection_response(): @pytest.mark.asyncio async def test_writeonly_stripped_in_response(): """writeOnly properties are removed from the preprocessed response schema.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator v = OpenApiValidator(READONLY_SPEC) assert "password" not in v._response_schema.get("properties", {}) @@ -1240,7 +1240,7 @@ async def test_writeonly_stripped_in_response(): @pytest.mark.asyncio async def test_writeonly_present_in_request(): """writeOnly properties remain in the preprocessed request schema.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator v = OpenApiValidator(READONLY_SPEC) assert "password" in v._request_schema.get("properties", {}) @@ -1306,7 +1306,7 @@ async def test_writeonly_present_in_request(): @pytest.mark.asyncio async def test_optional_body_empty_accepted(): """Empty body is accepted when requestBody.required is false.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator v = OpenApiValidator(OPTIONAL_BODY_SPEC) errors = v.validate_request(b"", "application/json") @@ -1316,7 +1316,7 @@ async def test_optional_body_empty_accepted(): @pytest.mark.asyncio async def test_optional_body_whitespace_accepted(): """Whitespace-only body is accepted when requestBody.required is false.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator v = OpenApiValidator(OPTIONAL_BODY_SPEC) errors = v.validate_request(b" ", "application/json") @@ -1326,7 +1326,7 @@ async def test_optional_body_whitespace_accepted(): @pytest.mark.asyncio async def test_optional_body_present_still_validated(): """When body IS present with optional requestBody, it still must be valid.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator v = OpenApiValidator(OPTIONAL_BODY_SPEC) errors = v.validate_request(b'{"wrong": 1}', "application/json") @@ -1336,7 +1336,7 @@ async def test_optional_body_present_still_validated(): @pytest.mark.asyncio async def test_required_body_empty_rejected(): """Empty body is rejected when requestBody.required is true.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator v = OpenApiValidator(REQUIRED_BODY_SPEC) errors = v.validate_request(b"", "application/json") @@ -1346,7 +1346,7 @@ async def test_required_body_empty_rejected(): @pytest.mark.asyncio async def test_default_body_required_behavior(): """When requestBody.required is omitted, body is required by default.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator spec = { "openapi": "3.0.0", @@ -1379,7 +1379,7 @@ async def test_default_body_required_behavior(): @pytest.mark.asyncio async def test_openapi_keywords_stripped(): """discriminator, xml, externalDocs, example are stripped from schemas.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator schema = { "type": "object", @@ -1400,7 +1400,7 @@ async def test_openapi_keywords_stripped(): @pytest.mark.asyncio async def test_openapi_keywords_stripped_nested(): """OpenAPI keywords are stripped from nested properties too.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator schema = { "type": "object", @@ -1729,7 +1729,7 @@ async def test_anyof_wrong_type(): @pytest.mark.asyncio async def test_nullable_ref_accepts_null(): """A nullable $ref field accepts null after preprocessing.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator spec: dict = { "openapi": "3.0.0", @@ -1779,7 +1779,7 @@ async def test_nullable_ref_accepts_null(): @pytest.mark.asyncio async def test_apply_nullable_no_duplicate_null(): """_apply_nullable does not add 'null' twice if type is already a list with null.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator schema: dict = {"type": ["string", "null"], "nullable": True} OpenApiValidator._apply_nullable(schema) @@ -1789,7 +1789,7 @@ async def test_apply_nullable_no_duplicate_null(): @pytest.mark.asyncio async def test_apply_nullable_false(): """nullable: false is a no-op (just removes the key).""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator schema: dict = {"type": "string", "nullable": False} OpenApiValidator._apply_nullable(schema) @@ -1800,7 +1800,7 @@ async def test_apply_nullable_false(): @pytest.mark.asyncio async def test_strip_openapi_keywords_nested_deeply(): """OpenAPI keywords are stripped from deeply nested schemas.""" - from azure.ai.agentserver.validation._openapi_validator import OpenApiValidator + from azure.ai.agentserver._openapi_validator import OpenApiValidator schema: dict = { "type": "object", @@ -2125,7 +2125,7 @@ async def test_nested_oneof_matching_branch_missing_field(): @pytest.mark.asyncio async def test_format_error_includes_path(): """_format_error prefixes with JSON path when not root.""" - from azure.ai.agentserver.validation._openapi_validator import _format_error + from azure.ai.agentserver._openapi_validator import _format_error import jsonschema schema = {"type": "object", "properties": {"age": {"type": "integer"}}, "required": ["age"]} @@ -2139,7 +2139,7 @@ async def test_format_error_includes_path(): @pytest.mark.asyncio async def test_format_error_root_path_no_prefix(): """_format_error does not prefix when error is at root ($).""" - from azure.ai.agentserver.validation._openapi_validator import _format_error + from azure.ai.agentserver._openapi_validator import _format_error import jsonschema schema = {"type": "object", "required": ["name"]} @@ -2154,7 +2154,7 @@ async def test_format_error_root_path_no_prefix(): @pytest.mark.asyncio async def test_collect_errors_flat(): """_collect_errors for a simple (non-composition) error returns path-prefixed message.""" - from azure.ai.agentserver.validation._openapi_validator import _collect_errors + from azure.ai.agentserver._openapi_validator import _collect_errors import jsonschema schema = {"type": "object", "properties": {"x": {"type": "integer"}}} diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_request_limits.py b/sdk/agentserver/azure-ai-agentserver/tests/test_request_limits.py index e8161122a3c9..e8644e388f37 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_request_limits.py +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_request_limits.py @@ -70,6 +70,11 @@ def test_explicit_zero_disables(self): agent = _make_fast_agent(request_timeout=0) assert agent._request_timeout == 0 + def test_env_var_zero_disables(self): + with patch.dict(os.environ, {Constants.AGENT_REQUEST_TIMEOUT: "0"}): + agent = _make_fast_agent() + assert agent._request_timeout == 0 + def test_env_var_used_when_no_explicit(self): with patch.dict(os.environ, {Constants.AGENT_REQUEST_TIMEOUT: "120"}): agent = _make_fast_agent() @@ -132,7 +137,7 @@ async def test_timeout_disabled_allows_no_limit(self): async def test_timeout_error_is_logged(self): agent = _make_slow_agent(request_timeout=1) transport = httpx.ASGITransport(app=agent.app) - with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + with patch("azure.ai.agentserver._base.logger") as mock_logger: async with httpx.AsyncClient( transport=transport, base_url="http://testserver" ) as client: diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_server_routes.py b/sdk/agentserver/azure-ai-agentserver/tests/test_server_routes.py index f4a2882fc0cc..351edee0f2c5 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_server_routes.py +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_server_routes.py @@ -44,15 +44,19 @@ async def test_get_openapi_spec_returns_spec(validated_client): @pytest.mark.asyncio async def test_get_invocation_returns_404_default(echo_client): """GET /invocations/{id} returns 404 when not overridden.""" - resp = await echo_client.get(f"/invocations/{uuid.uuid4()}") + inv_id = str(uuid.uuid4()) + resp = await echo_client.get(f"/invocations/{inv_id}") assert resp.status_code == 404 + assert resp.headers.get("x-agent-invocation-id") == inv_id @pytest.mark.asyncio async def test_cancel_invocation_returns_404_default(echo_client): """POST /invocations/{id}/cancel returns 404 when not overridden.""" - resp = await echo_client.post(f"/invocations/{uuid.uuid4()}/cancel") + inv_id = str(uuid.uuid4()) + resp = await echo_client.post(f"/invocations/{inv_id}/cancel") assert resp.status_code == 404 + assert resp.headers.get("x-agent-invocation-id") == inv_id @pytest.mark.asyncio @@ -72,37 +76,37 @@ class TestResolvePort: def test_explicit_port_wins(self): """Explicit port argument takes precedence over everything.""" - from azure.ai.agentserver.server._config import resolve_port + from azure.ai.agentserver._config import resolve_port assert resolve_port(9090) == 9090 def test_env_var_used_when_no_explicit(self, monkeypatch): """AGENT_SERVER_PORT env var is used when no explicit port given.""" - from azure.ai.agentserver.server._config import resolve_port + from azure.ai.agentserver._config import resolve_port monkeypatch.setenv("AGENT_SERVER_PORT", "7777") assert resolve_port(None) == 7777 def test_default_port_when_nothing_set(self, monkeypatch): """Falls back to 8088 when no explicit port and no env var.""" - from azure.ai.agentserver.server._config import resolve_port + from azure.ai.agentserver._config import resolve_port monkeypatch.delenv("AGENT_SERVER_PORT", raising=False) assert resolve_port(None) == 8088 def test_invalid_env_var_raises(self, monkeypatch): """Non-numeric AGENT_SERVER_PORT env var raises ValueError.""" - from azure.ai.agentserver.server._config import resolve_port + from azure.ai.agentserver._config import resolve_port monkeypatch.setenv("AGENT_SERVER_PORT", "not_a_number") with pytest.raises(ValueError, match="AGENT_SERVER_PORT"): resolve_port(None) def test_non_int_explicit_port_raises(self): """Passing a non-integer port raises ValueError.""" - from azure.ai.agentserver.server._config import resolve_port + from azure.ai.agentserver._config import resolve_port with pytest.raises(ValueError, match="expected an integer"): resolve_port("sss") # type: ignore[arg-type] def test_port_out_of_range_raises(self): """Port outside 1-65535 raises ValueError.""" - from azure.ai.agentserver.server._config import resolve_port + from azure.ai.agentserver._config import resolve_port with pytest.raises(ValueError, match="1-65535"): resolve_port(0) with pytest.raises(ValueError, match="1-65535"): @@ -110,7 +114,7 @@ def test_port_out_of_range_raises(self): def test_env_var_port_out_of_range_raises(self, monkeypatch): """AGENT_SERVER_PORT outside 1-65535 raises ValueError.""" - from azure.ai.agentserver.server._config import resolve_port + from azure.ai.agentserver._config import resolve_port monkeypatch.setenv("AGENT_SERVER_PORT", "0") with pytest.raises(ValueError, match="1-65535"): resolve_port(None) diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_tracing.py b/sdk/agentserver/azure-ai-agentserver/tests/test_tracing.py index f5a9c945e5ea..4651f2f0b4d6 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver/tests/test_tracing.py @@ -402,7 +402,7 @@ class TestAppInsightsConnectionStringResolution: def test_explicit_param_takes_priority(self, monkeypatch): """Constructor param beats env var.""" - from azure.ai.agentserver.server._config import resolve_appinsights_connection_string + from azure.ai.agentserver._config import resolve_appinsights_connection_string monkeypatch.setenv("APPLICATIONINSIGHTS_CONNECTION_STRING", "env-standard") result = resolve_appinsights_connection_string("explicit-value") @@ -410,7 +410,7 @@ def test_explicit_param_takes_priority(self, monkeypatch): def test_standard_env_var_fallback(self, monkeypatch): """Falls back to APPLICATIONINSIGHTS_CONNECTION_STRING.""" - from azure.ai.agentserver.server._config import resolve_appinsights_connection_string + from azure.ai.agentserver._config import resolve_appinsights_connection_string monkeypatch.setenv("APPLICATIONINSIGHTS_CONNECTION_STRING", "env-standard") result = resolve_appinsights_connection_string(None) @@ -418,7 +418,7 @@ def test_standard_env_var_fallback(self, monkeypatch): def test_no_connection_string_returns_none(self, monkeypatch): """Returns None when no source provides a connection string.""" - from azure.ai.agentserver.server._config import resolve_appinsights_connection_string + from azure.ai.agentserver._config import resolve_appinsights_connection_string monkeypatch.delenv("APPLICATIONINSIGHTS_CONNECTION_STRING", raising=False) result = resolve_appinsights_connection_string(None) From 1b222e924b2b872e42045105e1b83d024a33e6a1 Mon Sep 17 00:00:00 2001 From: Zhiyong Yang Date: Wed, 11 Mar 2026 21:30:12 -0700 Subject: [PATCH 05/10] Rename package to azure-ai-agentserver-server and fix CI issues MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename azure-ai-agentserver → azure-ai-agentserver-server, move import path to azure.ai.agentserver.server to match sibling package conventions - Fix pylint C0412: group collections imports in _openapi_validator.py - Fix TimeoutError span not recorded as ERROR in _base.py - Return 501 (not 404) for unregistered get/cancel handlers - Update ci.yml artifact name, all tests, samples, and docs Co-Authored-By: Claude Opus 4.6 --- .../CHANGELOG.md | 2 +- .../LICENSE | 0 .../MANIFEST.in | 3 +- .../README.md | 10 ++-- .../azure/__init__.py | 0 .../azure/ai/__init__.py | 0 .../azure/ai/agentserver/__init__.py | 1 + .../azure/ai/agentserver/server}/__init__.py | 1 + .../azure/ai/agentserver/server}/_base.py | 10 ++-- .../azure/ai/agentserver/server}/_config.py | 0 .../ai/agentserver/server}/_constants.py | 0 .../azure/ai/agentserver/server}/_errors.py | 0 .../azure/ai/agentserver/server}/_logger.py | 0 .../agentserver/server}/_openapi_validator.py | 2 +- .../azure/ai/agentserver/server}/_tracing.py | 6 +- .../azure/ai/agentserver/server}/_version.py | 0 .../azure/ai/agentserver/server}/py.typed | 0 .../cspell.json | 0 .../dev_requirements.txt | 0 .../pyproject.toml | 6 +- .../pyrightconfig.json | 0 .../agentframework_invoke_agent/.env.sample | 0 .../agentframework_invoke_agent.py | 2 +- .../requirements.txt | 2 +- .../async_invoke_agent/async_invoke_agent.py | 2 +- .../async_invoke_agent/requirements.txt | 1 + .../human_in_the_loop_agent.py | 2 +- .../human_in_the_loop_agent/requirements.txt | 1 + .../langgraph_invoke_agent/.env.sample | 0 .../langgraph_invoke_agent.py | 2 +- .../langgraph_invoke_agent/requirements.txt | 2 +- .../openapi_validated_agent.py | 2 +- .../openapi_validated_agent/requirements.txt | 1 + .../simple_invoke_agent/requirements.txt | 1 + .../simple_invoke_agent.py | 2 +- .../tests/conftest.py | 4 +- .../tests/test_decorator_pattern.py | 10 ++-- .../tests/test_edge_cases.py | 2 +- .../tests/test_get_cancel.py | 4 +- .../tests/test_graceful_shutdown.py | 20 +++---- .../tests/test_health.py | 0 .../tests/test_http2.py | 2 +- .../tests/test_invoke.py | 0 .../tests/test_logger.py | 2 +- .../tests/test_multimodal_protocol.py | 2 +- .../tests/test_openapi_validation.py | 58 +++++++++---------- .../tests/test_request_limits.py | 6 +- .../tests/test_server_routes.py | 26 ++++----- .../tests/test_tracing.py | 24 ++++---- .../async_invoke_agent/requirements.txt | 1 - .../human_in_the_loop_agent/requirements.txt | 1 - .../openapi_validated_agent/requirements.txt | 1 - .../samples/simple_invoke_agent/Dockerfile | 19 ------ .../simple_invoke_agent/requirements.txt | 1 - sdk/agentserver/ci.yml | 4 +- 55 files changed, 116 insertions(+), 132 deletions(-) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/CHANGELOG.md (88%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/LICENSE (100%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/MANIFEST.in (63%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/README.md (97%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/azure/__init__.py (100%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/azure/ai/__init__.py (100%) create mode 100644 sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/__init__.py rename sdk/agentserver/{azure-ai-agentserver/azure/ai/agentserver => azure-ai-agentserver-server/azure/ai/agentserver/server}/__init__.py (81%) rename sdk/agentserver/{azure-ai-agentserver/azure/ai/agentserver => azure-ai-agentserver-server/azure/ai/agentserver/server}/_base.py (99%) rename sdk/agentserver/{azure-ai-agentserver/azure/ai/agentserver => azure-ai-agentserver-server/azure/ai/agentserver/server}/_config.py (100%) rename sdk/agentserver/{azure-ai-agentserver/azure/ai/agentserver => azure-ai-agentserver-server/azure/ai/agentserver/server}/_constants.py (100%) rename sdk/agentserver/{azure-ai-agentserver/azure/ai/agentserver => azure-ai-agentserver-server/azure/ai/agentserver/server}/_errors.py (100%) rename sdk/agentserver/{azure-ai-agentserver/azure/ai/agentserver => azure-ai-agentserver-server/azure/ai/agentserver/server}/_logger.py (100%) rename sdk/agentserver/{azure-ai-agentserver/azure/ai/agentserver => azure-ai-agentserver-server/azure/ai/agentserver/server}/_openapi_validator.py (100%) rename sdk/agentserver/{azure-ai-agentserver/azure/ai/agentserver => azure-ai-agentserver-server/azure/ai/agentserver/server}/_tracing.py (98%) rename sdk/agentserver/{azure-ai-agentserver/azure/ai/agentserver => azure-ai-agentserver-server/azure/ai/agentserver/server}/_version.py (100%) rename sdk/agentserver/{azure-ai-agentserver/azure/ai/agentserver => azure-ai-agentserver-server/azure/ai/agentserver/server}/py.typed (100%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/cspell.json (100%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/dev_requirements.txt (100%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/pyproject.toml (91%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/pyrightconfig.json (100%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/samples/agentframework_invoke_agent/.env.sample (100%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/samples/agentframework_invoke_agent/agentframework_invoke_agent.py (97%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/samples/agentframework_invoke_agent/requirements.txt (71%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/samples/async_invoke_agent/async_invoke_agent.py (99%) create mode 100644 sdk/agentserver/azure-ai-agentserver-server/samples/async_invoke_agent/requirements.txt rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/samples/human_in_the_loop_agent/human_in_the_loop_agent.py (97%) create mode 100644 sdk/agentserver/azure-ai-agentserver-server/samples/human_in_the_loop_agent/requirements.txt rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/samples/langgraph_invoke_agent/.env.sample (100%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/samples/langgraph_invoke_agent/langgraph_invoke_agent.py (98%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/samples/langgraph_invoke_agent/requirements.txt (68%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/samples/openapi_validated_agent/openapi_validated_agent.py (98%) create mode 100644 sdk/agentserver/azure-ai-agentserver-server/samples/openapi_validated_agent/requirements.txt create mode 100644 sdk/agentserver/azure-ai-agentserver-server/samples/simple_invoke_agent/requirements.txt rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/samples/simple_invoke_agent/simple_invoke_agent.py (94%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/tests/conftest.py (98%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/tests/test_decorator_pattern.py (97%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/tests/test_edge_cases.py (99%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/tests/test_get_cancel.py (97%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/tests/test_graceful_shutdown.py (95%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/tests/test_health.py (100%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/tests/test_http2.py (99%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/tests/test_invoke.py (100%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/tests/test_logger.py (92%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/tests/test_multimodal_protocol.py (99%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/tests/test_openapi_validation.py (97%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/tests/test_request_limits.py (96%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/tests/test_server_routes.py (83%) rename sdk/agentserver/{azure-ai-agentserver => azure-ai-agentserver-server}/tests/test_tracing.py (95%) delete mode 100644 sdk/agentserver/azure-ai-agentserver/samples/async_invoke_agent/requirements.txt delete mode 100644 sdk/agentserver/azure-ai-agentserver/samples/human_in_the_loop_agent/requirements.txt delete mode 100644 sdk/agentserver/azure-ai-agentserver/samples/openapi_validated_agent/requirements.txt delete mode 100644 sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/Dockerfile delete mode 100644 sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/requirements.txt diff --git a/sdk/agentserver/azure-ai-agentserver/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-server/CHANGELOG.md similarity index 88% rename from sdk/agentserver/azure-ai-agentserver/CHANGELOG.md rename to sdk/agentserver/azure-ai-agentserver-server/CHANGELOG.md index 1ba538d24413..3b472f26b730 100644 --- a/sdk/agentserver/azure-ai-agentserver/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-server/CHANGELOG.md @@ -4,7 +4,7 @@ ### Features Added -- Initial release of `azure-ai-agentserver`. +- Initial release of `azure-ai-agentserver-server`. - Generic `AgentServer` base class with pluggable protocol heads. - `/invoke` protocol head with all 4 operations: create, get, cancel, and OpenAPI spec. - OpenAPI spec-based request/response validation via `jsonschema`. diff --git a/sdk/agentserver/azure-ai-agentserver/LICENSE b/sdk/agentserver/azure-ai-agentserver-server/LICENSE similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/LICENSE rename to sdk/agentserver/azure-ai-agentserver-server/LICENSE diff --git a/sdk/agentserver/azure-ai-agentserver/MANIFEST.in b/sdk/agentserver/azure-ai-agentserver-server/MANIFEST.in similarity index 63% rename from sdk/agentserver/azure-ai-agentserver/MANIFEST.in rename to sdk/agentserver/azure-ai-agentserver-server/MANIFEST.in index 468601f6166b..49a1b88738e9 100644 --- a/sdk/agentserver/azure-ai-agentserver/MANIFEST.in +++ b/sdk/agentserver/azure-ai-agentserver-server/MANIFEST.in @@ -4,4 +4,5 @@ recursive-include tests *.py recursive-include samples *.py *.md include azure/__init__.py include azure/ai/__init__.py -include azure/ai/agentserver/py.typed +include azure/ai/agentserver/__init__.py +include azure/ai/agentserver/server/py.typed diff --git a/sdk/agentserver/azure-ai-agentserver/README.md b/sdk/agentserver/azure-ai-agentserver-server/README.md similarity index 97% rename from sdk/agentserver/azure-ai-agentserver/README.md rename to sdk/agentserver/azure-ai-agentserver-server/README.md index bd6a8c031225..581da4b56339 100644 --- a/sdk/agentserver/azure-ai-agentserver/README.md +++ b/sdk/agentserver/azure-ai-agentserver-server/README.md @@ -10,7 +10,7 @@ endpoints — with **zero framework coupling**. ### Install the package ```bash -pip install azure-ai-agentserver +pip install azure-ai-agentserver-server ``` **Requires Python >= 3.10.** @@ -21,7 +21,7 @@ pip install azure-ai-agentserver from starlette.requests import Request from starlette.responses import JSONResponse, Response -from azure.ai.agentserver import AgentServer +from azure.ai.agentserver.server import AgentServer server = AgentServer() @@ -90,7 +90,7 @@ your agent is compatible with the hosting platform — no manual route setup req │ ▼ ┌───────────────────────────────┐ - │ azure-ai-agentserver │ + │ azure-ai-agentserver-server │ │ AgentServer │ │ │ │ Protocol heads: │ @@ -236,7 +236,7 @@ export AGENT_ENABLE_TRACING=true Install the tracing extras (includes OpenTelemetry and the Azure Monitor exporter): ```bash -pip install azure-ai-agentserver[tracing] +pip install azure-ai-agentserver-server[tracing] ``` When enabled, spans are created for `invoke`, `get_invocation`, and `cancel_invocation` @@ -289,7 +289,7 @@ GitHub issue [here](https://github.com/Azure/azure-sdk-for-python/issues). ## Next steps -Please visit [Samples](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/agentserver/azure-ai-agentserver/samples) +Please visit [Samples](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/agentserver/azure-ai-agentserver-server/samples) for more usage examples. ## Contributing diff --git a/sdk/agentserver/azure-ai-agentserver/azure/__init__.py b/sdk/agentserver/azure-ai-agentserver-server/azure/__init__.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/azure/__init__.py rename to sdk/agentserver/azure-ai-agentserver-server/azure/__init__.py diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/__init__.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/__init__.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/azure/ai/__init__.py rename to sdk/agentserver/azure-ai-agentserver-server/azure/ai/__init__.py diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/__init__.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/__init__.py new file mode 100644 index 000000000000..d55ccad1f573 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/__init__.py @@ -0,0 +1 @@ +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/__init__.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/__init__.py similarity index 81% rename from sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/__init__.py rename to sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/__init__.py index 4be26052cbe4..a82226f72189 100644 --- a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/__init__.py @@ -1,6 +1,7 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +__path__ = __import__("pkgutil").extend_path(__path__, __name__) from ._version import VERSION from ._base import AgentServer diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_base.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_base.py similarity index 99% rename from sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_base.py rename to sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_base.py index 41b03dad9de4..e3c8c26a1f19 100644 --- a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_base.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_base.py @@ -59,7 +59,7 @@ async def handle(request): :param enable_tracing: Enable OpenTelemetry tracing. When *None* (default) the ``AGENT_ENABLE_TRACING`` env var is consulted (``"true"`` to enable). Requires ``opentelemetry-api`` — install with - ``pip install azure-ai-agentserver[tracing]``. + ``pip install azure-ai-agentserver-server[tracing]``. When an Application Insights connection string is also available, traces and logs are automatically exported to Azure Monitor. :type enable_tracing: Optional[bool] @@ -234,7 +234,7 @@ async def _dispatch_get_invocation(self, request: Request) -> Response: """ if self._get_invocation_fn is not None: return await self._get_invocation_fn(request) - return error_response("not_supported", "get_invocation not supported", status_code=404) + return error_response("not_supported", "get_invocation not supported", status_code=501) async def _dispatch_cancel_invocation(self, request: Request) -> Response: """Dispatch to the registered cancel-invocation handler, or return 404. @@ -246,7 +246,7 @@ async def _dispatch_cancel_invocation(self, request: Request) -> Response: """ if self._cancel_invocation_fn is not None: return await self._cancel_invocation_fn(request) - return error_response("not_supported", "cancel_invocation not supported", status_code=404) + return error_response("not_supported", "cancel_invocation not supported", status_code=501) async def _dispatch_shutdown(self) -> None: """Dispatch to the registered shutdown handler, or no-op.""" @@ -453,9 +453,9 @@ async def _create_invocation_endpoint(self, request: Request) -> Response: status_code=501, headers={Constants.INVOCATION_ID_HEADER: invocation_id}, ) - except asyncio.TimeoutError: + except asyncio.TimeoutError as exc: if self._tracing is not None: - self._tracing.end_span(otel_span) + self._tracing.end_span(otel_span, exc=exc) logger.error( "Invocation %s timed out after %ss", invocation_id, diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_config.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_config.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_config.py rename to sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_config.py diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_constants.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_constants.py rename to sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_errors.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_errors.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_errors.py rename to sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_errors.py diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_logger.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_logger.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_logger.py rename to sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_logger.py diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_openapi_validator.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_openapi_validator.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_openapi_validator.py rename to sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_openapi_validator.py index 88688738c31f..3345b081d5df 100644 --- a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_openapi_validator.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_openapi_validator.py @@ -28,8 +28,8 @@ import json import re from collections import Counter -from datetime import datetime from collections.abc import Callable # pylint: disable=import-error +from datetime import datetime from typing import Any, Optional import jsonschema diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_tracing.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py similarity index 98% rename from sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_tracing.py rename to sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py index 4928662fffa0..8b85575a17cd 100644 --- a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py @@ -11,7 +11,7 @@ When enabled, the module requires ``opentelemetry-api`` to be installed:: - pip install azure-ai-agentserver[tracing] + pip install azure-ai-agentserver-server[tracing] If the package is not installed, tracing silently becomes a no-op. @@ -71,7 +71,7 @@ def __init__(self, connection_string: Optional[str] = None) -> None: if not self._enabled: logger.warning( "Tracing was enabled but opentelemetry-api is not installed. " - "Install it with: pip install azure-ai-agentserver[tracing]" + "Install it with: pip install azure-ai-agentserver-server[tracing]" ) return @@ -124,7 +124,7 @@ def _setup_azure_monitor(connection_string: str) -> None: logger.warning( "Application Insights connection string was provided but " "required packages are not installed. Install them with: " - "pip install azure-ai-agentserver[tracing]" + "pip install azure-ai-agentserver-server[tracing]" ) return diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_version.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_version.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/_version.py rename to sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_version.py diff --git a/sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/py.typed b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/py.typed similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/azure/ai/agentserver/py.typed rename to sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/py.typed diff --git a/sdk/agentserver/azure-ai-agentserver/cspell.json b/sdk/agentserver/azure-ai-agentserver-server/cspell.json similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/cspell.json rename to sdk/agentserver/azure-ai-agentserver-server/cspell.json diff --git a/sdk/agentserver/azure-ai-agentserver/dev_requirements.txt b/sdk/agentserver/azure-ai-agentserver-server/dev_requirements.txt similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/dev_requirements.txt rename to sdk/agentserver/azure-ai-agentserver-server/dev_requirements.txt diff --git a/sdk/agentserver/azure-ai-agentserver/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-server/pyproject.toml similarity index 91% rename from sdk/agentserver/azure-ai-agentserver/pyproject.toml rename to sdk/agentserver/azure-ai-agentserver-server/pyproject.toml index 19063c140e09..663bd57fdb10 100644 --- a/sdk/agentserver/azure-ai-agentserver/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-server/pyproject.toml @@ -1,5 +1,5 @@ [project] -name = "azure-ai-agentserver" +name = "azure-ai-agentserver-server" dynamic = ["version", "readme"] description = "Generic agent server for Azure AI with pluggable protocol heads" requires-python = ">=3.10" @@ -50,7 +50,7 @@ exclude = [ ] [tool.setuptools.dynamic] -version = { attr = "azure.ai.agentserver._version.VERSION" } +version = { attr = "azure.ai.agentserver.server._version.VERSION" } readme = { file = ["README.md"], content-type = "text/markdown" } [tool.setuptools.package-data] @@ -64,7 +64,7 @@ lint.ignore = [] fix = false [tool.ruff.lint.isort] -known-first-party = ["azure.ai.agentserver"] +known-first-party = ["azure.ai.agentserver.server"] combine-as-imports = true [tool.azure-sdk-build] diff --git a/sdk/agentserver/azure-ai-agentserver/pyrightconfig.json b/sdk/agentserver/azure-ai-agentserver-server/pyrightconfig.json similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/pyrightconfig.json rename to sdk/agentserver/azure-ai-agentserver-server/pyrightconfig.json diff --git a/sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/.env.sample b/sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/.env.sample similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/.env.sample rename to sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/.env.sample diff --git a/sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/agentframework_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/agentframework_invoke_agent.py similarity index 97% rename from sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/agentframework_invoke_agent.py rename to sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/agentframework_invoke_agent.py index ffcb7eccf734..b49f330f52c4 100644 --- a/sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/agentframework_invoke_agent.py +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/agentframework_invoke_agent.py @@ -28,7 +28,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response -from azure.ai.agentserver import AgentServer +from azure.ai.agentserver.server import AgentServer # -- Customer defines their tools -- diff --git a/sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/requirements.txt similarity index 71% rename from sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/requirements.txt rename to sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/requirements.txt index bd3b80baf653..a02b64e685a5 100644 --- a/sdk/agentserver/azure-ai-agentserver/samples/agentframework_invoke_agent/requirements.txt +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/requirements.txt @@ -1,4 +1,4 @@ -azure-ai-agentserver +azure-ai-agentserver-server agent-framework>=1.0.0rc2 azure-identity>=1.25.0 python-dotenv>=1.0.0 diff --git a/sdk/agentserver/azure-ai-agentserver/samples/async_invoke_agent/async_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-server/samples/async_invoke_agent/async_invoke_agent.py similarity index 99% rename from sdk/agentserver/azure-ai-agentserver/samples/async_invoke_agent/async_invoke_agent.py rename to sdk/agentserver/azure-ai-agentserver-server/samples/async_invoke_agent/async_invoke_agent.py index ee6f77c2b16a..c929fd898107 100644 --- a/sdk/agentserver/azure-ai-agentserver/samples/async_invoke_agent/async_invoke_agent.py +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/async_invoke_agent/async_invoke_agent.py @@ -44,7 +44,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response -from azure.ai.agentserver import AgentServer +from azure.ai.agentserver.server import AgentServer # In-memory state for demo purposes (see module docstring for production caveats) diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/async_invoke_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-server/samples/async_invoke_agent/requirements.txt new file mode 100644 index 000000000000..16f731287fdc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/async_invoke_agent/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-server diff --git a/sdk/agentserver/azure-ai-agentserver/samples/human_in_the_loop_agent/human_in_the_loop_agent.py b/sdk/agentserver/azure-ai-agentserver-server/samples/human_in_the_loop_agent/human_in_the_loop_agent.py similarity index 97% rename from sdk/agentserver/azure-ai-agentserver/samples/human_in_the_loop_agent/human_in_the_loop_agent.py rename to sdk/agentserver/azure-ai-agentserver-server/samples/human_in_the_loop_agent/human_in_the_loop_agent.py index 618a2532f59d..4526500e6bb1 100644 --- a/sdk/agentserver/azure-ai-agentserver/samples/human_in_the_loop_agent/human_in_the_loop_agent.py +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/human_in_the_loop_agent/human_in_the_loop_agent.py @@ -26,7 +26,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response -from azure.ai.agentserver import AgentServer +from azure.ai.agentserver.server import AgentServer # Holds questions waiting for a human reply, keyed by invocation_id diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/human_in_the_loop_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-server/samples/human_in_the_loop_agent/requirements.txt new file mode 100644 index 000000000000..16f731287fdc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/human_in_the_loop_agent/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-server diff --git a/sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/.env.sample b/sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/.env.sample similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/.env.sample rename to sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/.env.sample diff --git a/sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/langgraph_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/langgraph_invoke_agent.py similarity index 98% rename from sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/langgraph_invoke_agent.py rename to sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/langgraph_invoke_agent.py index 89611059b643..f717551643f8 100644 --- a/sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/langgraph_invoke_agent.py +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/langgraph_invoke_agent.py @@ -31,7 +31,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response, StreamingResponse -from azure.ai.agentserver import AgentServer +from azure.ai.agentserver.server import AgentServer def build_graph() -> StateGraph: diff --git a/sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/requirements.txt similarity index 68% rename from sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/requirements.txt rename to sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/requirements.txt index 980438cbf628..681ea34a9682 100644 --- a/sdk/agentserver/azure-ai-agentserver/samples/langgraph_invoke_agent/requirements.txt +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/requirements.txt @@ -1,4 +1,4 @@ -azure-ai-agentserver +azure-ai-agentserver-server langgraph>=1.0.0 langchain-openai>=1.0.0 python-dotenv>=1.0.0 diff --git a/sdk/agentserver/azure-ai-agentserver/samples/openapi_validated_agent/openapi_validated_agent.py b/sdk/agentserver/azure-ai-agentserver-server/samples/openapi_validated_agent/openapi_validated_agent.py similarity index 98% rename from sdk/agentserver/azure-ai-agentserver/samples/openapi_validated_agent/openapi_validated_agent.py rename to sdk/agentserver/azure-ai-agentserver-server/samples/openapi_validated_agent/openapi_validated_agent.py index 1f8749b73d4f..60ffb28516de 100644 --- a/sdk/agentserver/azure-ai-agentserver/samples/openapi_validated_agent/openapi_validated_agent.py +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/openapi_validated_agent/openapi_validated_agent.py @@ -28,7 +28,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response -from azure.ai.agentserver import AgentServer +from azure.ai.agentserver.server import AgentServer # Define a simple OpenAPI 3.0 spec inline. In production this could be # loaded from a YAML/JSON file. diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/openapi_validated_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-server/samples/openapi_validated_agent/requirements.txt new file mode 100644 index 000000000000..16f731287fdc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/openapi_validated_agent/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-server diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/simple_invoke_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-server/samples/simple_invoke_agent/requirements.txt new file mode 100644 index 000000000000..16f731287fdc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/simple_invoke_agent/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-server diff --git a/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/simple_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-server/samples/simple_invoke_agent/simple_invoke_agent.py similarity index 94% rename from sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/simple_invoke_agent.py rename to sdk/agentserver/azure-ai-agentserver-server/samples/simple_invoke_agent/simple_invoke_agent.py index 60d7f3520622..368511106848 100644 --- a/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/simple_invoke_agent.py +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/simple_invoke_agent/simple_invoke_agent.py @@ -14,7 +14,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response -from azure.ai.agentserver import AgentServer +from azure.ai.agentserver.server import AgentServer server = AgentServer() diff --git a/sdk/agentserver/azure-ai-agentserver/tests/conftest.py b/sdk/agentserver/azure-ai-agentserver-server/tests/conftest.py similarity index 98% rename from sdk/agentserver/azure-ai-agentserver/tests/conftest.py rename to sdk/agentserver/azure-ai-agentserver-server/tests/conftest.py index c3647378fc75..e9e6fcda8c3d 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/conftest.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/conftest.py @@ -1,7 +1,7 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -"""Shared fixtures for azure-ai-agentserver tests.""" +"""Shared fixtures for azure-ai-agentserver-server tests.""" import json import pytest @@ -11,7 +11,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response, StreamingResponse -from azure.ai.agentserver import AgentServer +from azure.ai.agentserver.server import AgentServer # --------------------------------------------------------------------------- diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_decorator_pattern.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_decorator_pattern.py similarity index 97% rename from sdk/agentserver/azure-ai-agentserver/tests/test_decorator_pattern.py rename to sdk/agentserver/azure-ai-agentserver-server/tests/test_decorator_pattern.py index eab02458a9a2..8e4b6f0f0902 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_decorator_pattern.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_decorator_pattern.py @@ -9,7 +9,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response -from azure.ai.agentserver import AgentServer +from azure.ai.agentserver.server import AgentServer # --------------------------------------------------------------------------- @@ -155,7 +155,7 @@ class TestOptionalHandlerDefaults: """get_invocation and cancel_invocation return 404 by default.""" @pytest.mark.asyncio - async def test_get_invocation_returns_404_by_default(self): + async def test_get_invocation_returns_501_by_default(self): server = AgentServer() @server.invoke_handler @@ -165,11 +165,11 @@ async def handle(request: Request) -> Response: transport = httpx.ASGITransport(app=server.app) async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: resp = await client.get("/invocations/some-id") - assert resp.status_code == 404 + assert resp.status_code == 501 assert resp.headers.get("x-agent-invocation-id") == "some-id" @pytest.mark.asyncio - async def test_cancel_invocation_returns_404_by_default(self): + async def test_cancel_invocation_returns_501_by_default(self): server = AgentServer() @server.invoke_handler @@ -179,7 +179,7 @@ async def handle(request: Request) -> Response: transport = httpx.ASGITransport(app=server.app) async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: resp = await client.post("/invocations/some-id/cancel") - assert resp.status_code == 404 + assert resp.status_code == 501 assert resp.headers.get("x-agent-invocation-id") == "some-id" diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_edge_cases.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_edge_cases.py similarity index 99% rename from sdk/agentserver/azure-ai-agentserver/tests/test_edge_cases.py rename to sdk/agentserver/azure-ai-agentserver-server/tests/test_edge_cases.py index 2ea55294ee51..c5db4d49ca66 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_edge_cases.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_edge_cases.py @@ -12,7 +12,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response, StreamingResponse -from azure.ai.agentserver import AgentServer +from azure.ai.agentserver.server import AgentServer # --------------------------------------------------------------------------- diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_get_cancel.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_get_cancel.py similarity index 97% rename from sdk/agentserver/azure-ai-agentserver/tests/test_get_cancel.py rename to sdk/agentserver/azure-ai-agentserver-server/tests/test_get_cancel.py index 6ffdf5082a40..edc5f1d47b48 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_get_cancel.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_get_cancel.py @@ -66,7 +66,7 @@ async def test_get_invocation_error_returns_500(): from starlette.requests import Request from starlette.responses import JSONResponse, Response - from azure.ai.agentserver import AgentServer + from azure.ai.agentserver.server import AgentServer server = AgentServer() @@ -94,7 +94,7 @@ async def test_cancel_invocation_error_returns_500(): from starlette.requests import Request from starlette.responses import JSONResponse, Response - from azure.ai.agentserver import AgentServer + from azure.ai.agentserver.server import AgentServer server = AgentServer() diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_graceful_shutdown.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_graceful_shutdown.py similarity index 95% rename from sdk/agentserver/azure-ai-agentserver/tests/test_graceful_shutdown.py rename to sdk/agentserver/azure-ai-agentserver-server/tests/test_graceful_shutdown.py index 1fce19d303f5..17f6b6c7e7b9 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_graceful_shutdown.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_graceful_shutdown.py @@ -10,8 +10,8 @@ from starlette.requests import Request from starlette.responses import Response -from azure.ai.agentserver import AgentServer -from azure.ai.agentserver._constants import Constants +from azure.ai.agentserver.server import AgentServer +from azure.ai.agentserver.server._constants import Constants # --------------------------------------------------------------------------- @@ -95,7 +95,7 @@ class TestRunPassesTimeout: """Ensure run() forwards the timeout to Hypercorn config.""" @patch("hypercorn.asyncio.serve", new_callable=AsyncMock) - @patch("azure.ai.agentserver._base.asyncio") + @patch("azure.ai.agentserver.server._base.asyncio") def test_run_passes_timeout(self, mock_asyncio, _mock_serve): agent = _make_stub_agent(graceful_shutdown_timeout=15) agent.run() @@ -105,7 +105,7 @@ def test_run_passes_timeout(self, mock_asyncio, _mock_serve): assert config.graceful_timeout == 15.0 @patch("hypercorn.asyncio.serve", new_callable=AsyncMock) - @patch("azure.ai.agentserver._base.asyncio") + @patch("azure.ai.agentserver.server._base.asyncio") def test_run_passes_default_timeout(self, mock_asyncio, _mock_serve): agent = _make_stub_agent() agent.run() @@ -172,7 +172,7 @@ async def test_lifespan_shutdown_logs(self): agent = _make_stub_agent(graceful_shutdown_timeout=99) - with patch("azure.ai.agentserver._base.logger") as mock_logger: + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: # Grab the lifespan from the Starlette app lifespan = agent.app.router.lifespan_context async with lifespan(agent.app): @@ -273,7 +273,7 @@ async def test_default_on_shutdown_is_noop(self): async def test_on_shutdown_exception_is_logged_not_raised(self): """A failing on_shutdown must not crash the shutdown sequence.""" agent = _make_failing_shutdown_agent() - with patch("azure.ai.agentserver._base.logger") as mock_logger: + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: lifespan = agent.app.router.lifespan_context async with lifespan(agent.app): pass # should NOT raise @@ -300,7 +300,7 @@ async def invoke(request: Request) -> Response: async def shutdown(): order.append("callback") - with patch("azure.ai.agentserver._base.logger") as mock_logger: + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: def tracking_info(*args, **kwargs): if args and "shutting down" in str(args[0]).lower(): @@ -360,7 +360,7 @@ async def invoke(request: Request) -> Response: async def shutdown(): await asyncio.sleep(999) # way longer than the 1s timeout - with patch("azure.ai.agentserver._base.logger") as mock_logger: + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: lifespan = server.app.router.lifespan_context async with lifespan(server.app): pass # should NOT hang @@ -391,7 +391,7 @@ async def shutdown(): await asyncio.sleep(0) done = True - with patch("azure.ai.agentserver._base.logger") as mock_logger: + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: lifespan = server.app.router.lifespan_context async with lifespan(server.app): pass @@ -430,7 +430,7 @@ async def shutdown(): async def test_timeout_does_not_suppress_exceptions(self): """Exceptions from on_shutdown are still logged even if within timeout.""" agent = _make_failing_shutdown_agent(graceful_shutdown_timeout=5) - with patch("azure.ai.agentserver._base.logger") as mock_logger: + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: lifespan = agent.app.router.lifespan_context async with lifespan(agent.app): pass diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_health.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_health.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/tests/test_health.py rename to sdk/agentserver/azure-ai-agentserver-server/tests/test_health.py diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_http2.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_http2.py similarity index 99% rename from sdk/agentserver/azure-ai-agentserver/tests/test_http2.py rename to sdk/agentserver/azure-ai-agentserver-server/tests/test_http2.py index 15f013804bc9..51da9b840096 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_http2.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_http2.py @@ -41,7 +41,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response, StreamingResponse -from azure.ai.agentserver import AgentServer +from azure.ai.agentserver.server import AgentServer # --------------------------------------------------------------------------- diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_invoke.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_invoke.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver/tests/test_invoke.py rename to sdk/agentserver/azure-ai-agentserver-server/tests/test_invoke.py diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_logger.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_logger.py similarity index 92% rename from sdk/agentserver/azure-ai-agentserver/tests/test_logger.py rename to sdk/agentserver/azure-ai-agentserver-server/tests/test_logger.py index 06c3ef31414f..314211cec644 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_logger.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_logger.py @@ -17,6 +17,6 @@ def test_log_level_preserved_across_imports(): lib_logger.setLevel(logging.ERROR) # Re-importing the base module should not override the level. - from azure.ai.agentserver import _base # noqa: F401 + from azure.ai.agentserver.server import _base # noqa: F401 assert lib_logger.level == logging.ERROR diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_multimodal_protocol.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_multimodal_protocol.py similarity index 99% rename from sdk/agentserver/azure-ai-agentserver/tests/test_multimodal_protocol.py rename to sdk/agentserver/azure-ai-agentserver-server/tests/test_multimodal_protocol.py index a8d718c7b268..9e481d1ea117 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_multimodal_protocol.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_multimodal_protocol.py @@ -20,7 +20,7 @@ from starlette.requests import Request from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse, Response, StreamingResponse -from azure.ai.agentserver import AgentServer +from azure.ai.agentserver.server import AgentServer # --------------------------------------------------------------------------- diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_openapi_validation.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_openapi_validation.py similarity index 97% rename from sdk/agentserver/azure-ai-agentserver/tests/test_openapi_validation.py rename to sdk/agentserver/azure-ai-agentserver-server/tests/test_openapi_validation.py index 4082567a49f9..0a2f10703403 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_openapi_validation.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_openapi_validation.py @@ -10,8 +10,8 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response -from azure.ai.agentserver import AgentServer -from azure.ai.agentserver._openapi_validator import OpenApiValidator +from azure.ai.agentserver.server import AgentServer +from azure.ai.agentserver.server._openapi_validator import OpenApiValidator # --------------------------------------------------------------------------- @@ -860,7 +860,7 @@ async def test_complex_multiple_errors_reported(): @pytest.mark.asyncio async def test_validate_response_valid(): """Valid response body passes validation.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator v = OpenApiValidator(RESPONSE_SPEC) errors = v.validate_response(b'{"result": "ok"}', "application/json") @@ -870,7 +870,7 @@ async def test_validate_response_valid(): @pytest.mark.asyncio async def test_validate_response_invalid(): """Invalid response body returns errors.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator v = OpenApiValidator(RESPONSE_SPEC) errors = v.validate_response(b'{"wrong": 42}', "application/json") @@ -880,7 +880,7 @@ async def test_validate_response_invalid(): @pytest.mark.asyncio async def test_validate_response_no_schema(): """When no response schema exists, validation passes (no-op).""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator spec_no_resp = { "openapi": "3.0.0", @@ -900,7 +900,7 @@ async def test_validate_response_no_schema(): @pytest.mark.asyncio async def test_validate_request_no_schema(): """When no request schema exists, validation passes (no-op).""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator spec_no_req = { "openapi": "3.0.0", @@ -920,7 +920,7 @@ async def test_validate_request_no_schema(): @pytest.mark.asyncio async def test_response_schema_fallback_to_first_available(): """Response schema extraction falls back to first response with JSON.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator spec = { "openapi": "3.0.0", @@ -961,7 +961,7 @@ async def test_response_schema_fallback_to_first_available(): @pytest.mark.asyncio async def test_unresolvable_ref(): """An unresolvable $ref leaves the node as-is (no crash).""" - from azure.ai.agentserver._openapi_validator import _resolve_ref + from azure.ai.agentserver.server._openapi_validator import _resolve_ref spec: dict = {"components": {"schemas": {}}} node = {"$ref": "#/components/schemas/DoesNotExist"} @@ -973,7 +973,7 @@ async def test_unresolvable_ref(): @pytest.mark.asyncio async def test_ref_path_hits_non_dict(): """A $ref path that traverses a non-dict returns the original node.""" - from azure.ai.agentserver._openapi_validator import _resolve_ref + from azure.ai.agentserver.server._openapi_validator import _resolve_ref spec: dict = {"components": {"schemas": "not-a-dict"}} node = {"$ref": "#/components/schemas/Foo"} @@ -984,7 +984,7 @@ async def test_ref_path_hits_non_dict(): @pytest.mark.asyncio async def test_circular_ref_stops_recursion(): """Circular $ref does not cause infinite recursion.""" - from azure.ai.agentserver._openapi_validator import _resolve_refs_deep + from azure.ai.agentserver.server._openapi_validator import _resolve_refs_deep spec: dict = { "components": { @@ -1010,7 +1010,7 @@ async def test_circular_ref_stops_recursion(): @pytest.mark.asyncio async def test_ref_resolves_to_non_dict(): """A $ref that resolves to a non-dict value returns the original node.""" - from azure.ai.agentserver._openapi_validator import _resolve_ref + from azure.ai.agentserver.server._openapi_validator import _resolve_ref spec: dict = {"components": {"schemas": {"Bad": 42}}} node = {"$ref": "#/components/schemas/Bad"} @@ -1211,7 +1211,7 @@ async def test_writeonly_allowed_in_request(): @pytest.mark.asyncio async def test_readonly_schema_introspection_request(): """readOnly properties are removed from the preprocessed request schema.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator v = OpenApiValidator(READONLY_SPEC) # 'id' should be gone from requestschema properties @@ -1221,7 +1221,7 @@ async def test_readonly_schema_introspection_request(): @pytest.mark.asyncio async def test_readonly_schema_introspection_response(): """readOnly properties remain in the preprocessed response schema.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator v = OpenApiValidator(READONLY_SPEC) # 'id' should still be in response schema @@ -1231,7 +1231,7 @@ async def test_readonly_schema_introspection_response(): @pytest.mark.asyncio async def test_writeonly_stripped_in_response(): """writeOnly properties are removed from the preprocessed response schema.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator v = OpenApiValidator(READONLY_SPEC) assert "password" not in v._response_schema.get("properties", {}) @@ -1240,7 +1240,7 @@ async def test_writeonly_stripped_in_response(): @pytest.mark.asyncio async def test_writeonly_present_in_request(): """writeOnly properties remain in the preprocessed request schema.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator v = OpenApiValidator(READONLY_SPEC) assert "password" in v._request_schema.get("properties", {}) @@ -1306,7 +1306,7 @@ async def test_writeonly_present_in_request(): @pytest.mark.asyncio async def test_optional_body_empty_accepted(): """Empty body is accepted when requestBody.required is false.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator v = OpenApiValidator(OPTIONAL_BODY_SPEC) errors = v.validate_request(b"", "application/json") @@ -1316,7 +1316,7 @@ async def test_optional_body_empty_accepted(): @pytest.mark.asyncio async def test_optional_body_whitespace_accepted(): """Whitespace-only body is accepted when requestBody.required is false.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator v = OpenApiValidator(OPTIONAL_BODY_SPEC) errors = v.validate_request(b" ", "application/json") @@ -1326,7 +1326,7 @@ async def test_optional_body_whitespace_accepted(): @pytest.mark.asyncio async def test_optional_body_present_still_validated(): """When body IS present with optional requestBody, it still must be valid.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator v = OpenApiValidator(OPTIONAL_BODY_SPEC) errors = v.validate_request(b'{"wrong": 1}', "application/json") @@ -1336,7 +1336,7 @@ async def test_optional_body_present_still_validated(): @pytest.mark.asyncio async def test_required_body_empty_rejected(): """Empty body is rejected when requestBody.required is true.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator v = OpenApiValidator(REQUIRED_BODY_SPEC) errors = v.validate_request(b"", "application/json") @@ -1346,7 +1346,7 @@ async def test_required_body_empty_rejected(): @pytest.mark.asyncio async def test_default_body_required_behavior(): """When requestBody.required is omitted, body is required by default.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator spec = { "openapi": "3.0.0", @@ -1379,7 +1379,7 @@ async def test_default_body_required_behavior(): @pytest.mark.asyncio async def test_openapi_keywords_stripped(): """discriminator, xml, externalDocs, example are stripped from schemas.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator schema = { "type": "object", @@ -1400,7 +1400,7 @@ async def test_openapi_keywords_stripped(): @pytest.mark.asyncio async def test_openapi_keywords_stripped_nested(): """OpenAPI keywords are stripped from nested properties too.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator schema = { "type": "object", @@ -1729,7 +1729,7 @@ async def test_anyof_wrong_type(): @pytest.mark.asyncio async def test_nullable_ref_accepts_null(): """A nullable $ref field accepts null after preprocessing.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator spec: dict = { "openapi": "3.0.0", @@ -1779,7 +1779,7 @@ async def test_nullable_ref_accepts_null(): @pytest.mark.asyncio async def test_apply_nullable_no_duplicate_null(): """_apply_nullable does not add 'null' twice if type is already a list with null.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator schema: dict = {"type": ["string", "null"], "nullable": True} OpenApiValidator._apply_nullable(schema) @@ -1789,7 +1789,7 @@ async def test_apply_nullable_no_duplicate_null(): @pytest.mark.asyncio async def test_apply_nullable_false(): """nullable: false is a no-op (just removes the key).""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator schema: dict = {"type": "string", "nullable": False} OpenApiValidator._apply_nullable(schema) @@ -1800,7 +1800,7 @@ async def test_apply_nullable_false(): @pytest.mark.asyncio async def test_strip_openapi_keywords_nested_deeply(): """OpenAPI keywords are stripped from deeply nested schemas.""" - from azure.ai.agentserver._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import OpenApiValidator schema: dict = { "type": "object", @@ -2125,7 +2125,7 @@ async def test_nested_oneof_matching_branch_missing_field(): @pytest.mark.asyncio async def test_format_error_includes_path(): """_format_error prefixes with JSON path when not root.""" - from azure.ai.agentserver._openapi_validator import _format_error + from azure.ai.agentserver.server._openapi_validator import _format_error import jsonschema schema = {"type": "object", "properties": {"age": {"type": "integer"}}, "required": ["age"]} @@ -2139,7 +2139,7 @@ async def test_format_error_includes_path(): @pytest.mark.asyncio async def test_format_error_root_path_no_prefix(): """_format_error does not prefix when error is at root ($).""" - from azure.ai.agentserver._openapi_validator import _format_error + from azure.ai.agentserver.server._openapi_validator import _format_error import jsonschema schema = {"type": "object", "required": ["name"]} @@ -2154,7 +2154,7 @@ async def test_format_error_root_path_no_prefix(): @pytest.mark.asyncio async def test_collect_errors_flat(): """_collect_errors for a simple (non-composition) error returns path-prefixed message.""" - from azure.ai.agentserver._openapi_validator import _collect_errors + from azure.ai.agentserver.server._openapi_validator import _collect_errors import jsonschema schema = {"type": "object", "properties": {"x": {"type": "integer"}}} diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_request_limits.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_request_limits.py similarity index 96% rename from sdk/agentserver/azure-ai-agentserver/tests/test_request_limits.py rename to sdk/agentserver/azure-ai-agentserver-server/tests/test_request_limits.py index e8644e388f37..aa4d387152d6 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_request_limits.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_request_limits.py @@ -12,8 +12,8 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response -from azure.ai.agentserver import AgentServer -from azure.ai.agentserver._constants import Constants +from azure.ai.agentserver.server import AgentServer +from azure.ai.agentserver.server._constants import Constants # --------------------------------------------------------------------------- @@ -137,7 +137,7 @@ async def test_timeout_disabled_allows_no_limit(self): async def test_timeout_error_is_logged(self): agent = _make_slow_agent(request_timeout=1) transport = httpx.ASGITransport(app=agent.app) - with patch("azure.ai.agentserver._base.logger") as mock_logger: + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: async with httpx.AsyncClient( transport=transport, base_url="http://testserver" ) as client: diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_server_routes.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_server_routes.py similarity index 83% rename from sdk/agentserver/azure-ai-agentserver/tests/test_server_routes.py rename to sdk/agentserver/azure-ai-agentserver-server/tests/test_server_routes.py index 351edee0f2c5..80ae3e5338fb 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_server_routes.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_server_routes.py @@ -42,20 +42,20 @@ async def test_get_openapi_spec_returns_spec(validated_client): @pytest.mark.asyncio -async def test_get_invocation_returns_404_default(echo_client): - """GET /invocations/{id} returns 404 when not overridden.""" +async def test_get_invocation_returns_501_default(echo_client): + """GET /invocations/{id} returns 501 when not overridden.""" inv_id = str(uuid.uuid4()) resp = await echo_client.get(f"/invocations/{inv_id}") - assert resp.status_code == 404 + assert resp.status_code == 501 assert resp.headers.get("x-agent-invocation-id") == inv_id @pytest.mark.asyncio -async def test_cancel_invocation_returns_404_default(echo_client): - """POST /invocations/{id}/cancel returns 404 when not overridden.""" +async def test_cancel_invocation_returns_501_default(echo_client): + """POST /invocations/{id}/cancel returns 501 when not overridden.""" inv_id = str(uuid.uuid4()) resp = await echo_client.post(f"/invocations/{inv_id}/cancel") - assert resp.status_code == 404 + assert resp.status_code == 501 assert resp.headers.get("x-agent-invocation-id") == inv_id @@ -76,37 +76,37 @@ class TestResolvePort: def test_explicit_port_wins(self): """Explicit port argument takes precedence over everything.""" - from azure.ai.agentserver._config import resolve_port + from azure.ai.agentserver.server._config import resolve_port assert resolve_port(9090) == 9090 def test_env_var_used_when_no_explicit(self, monkeypatch): """AGENT_SERVER_PORT env var is used when no explicit port given.""" - from azure.ai.agentserver._config import resolve_port + from azure.ai.agentserver.server._config import resolve_port monkeypatch.setenv("AGENT_SERVER_PORT", "7777") assert resolve_port(None) == 7777 def test_default_port_when_nothing_set(self, monkeypatch): """Falls back to 8088 when no explicit port and no env var.""" - from azure.ai.agentserver._config import resolve_port + from azure.ai.agentserver.server._config import resolve_port monkeypatch.delenv("AGENT_SERVER_PORT", raising=False) assert resolve_port(None) == 8088 def test_invalid_env_var_raises(self, monkeypatch): """Non-numeric AGENT_SERVER_PORT env var raises ValueError.""" - from azure.ai.agentserver._config import resolve_port + from azure.ai.agentserver.server._config import resolve_port monkeypatch.setenv("AGENT_SERVER_PORT", "not_a_number") with pytest.raises(ValueError, match="AGENT_SERVER_PORT"): resolve_port(None) def test_non_int_explicit_port_raises(self): """Passing a non-integer port raises ValueError.""" - from azure.ai.agentserver._config import resolve_port + from azure.ai.agentserver.server._config import resolve_port with pytest.raises(ValueError, match="expected an integer"): resolve_port("sss") # type: ignore[arg-type] def test_port_out_of_range_raises(self): """Port outside 1-65535 raises ValueError.""" - from azure.ai.agentserver._config import resolve_port + from azure.ai.agentserver.server._config import resolve_port with pytest.raises(ValueError, match="1-65535"): resolve_port(0) with pytest.raises(ValueError, match="1-65535"): @@ -114,7 +114,7 @@ def test_port_out_of_range_raises(self): def test_env_var_port_out_of_range_raises(self, monkeypatch): """AGENT_SERVER_PORT outside 1-65535 raises ValueError.""" - from azure.ai.agentserver._config import resolve_port + from azure.ai.agentserver.server._config import resolve_port monkeypatch.setenv("AGENT_SERVER_PORT", "0") with pytest.raises(ValueError, match="1-65535"): resolve_port(None) diff --git a/sdk/agentserver/azure-ai-agentserver/tests/test_tracing.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py similarity index 95% rename from sdk/agentserver/azure-ai-agentserver/tests/test_tracing.py rename to sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py index 4651f2f0b4d6..ad49224f8379 100644 --- a/sdk/agentserver/azure-ai-agentserver/tests/test_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py @@ -16,7 +16,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response, StreamingResponse -from azure.ai.agentserver import AgentServer +from azure.ai.agentserver.server import AgentServer # --------------------------------------------------------------------------- @@ -153,8 +153,8 @@ async def test_tracing_get_invocation_creates_span(span_exporter): transport = httpx.ASGITransport(app=agent.app) async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: resp = await client.get("/invocations/test-id-123") - # Default returns 404 — but span should still exist - assert resp.status_code == 404 + # Default returns 501 — but span should still exist + assert resp.status_code == 501 spans = span_exporter.get_finished_spans() get_spans = [s for s in spans if s.name == "AgentServer.get_invocation"] @@ -169,7 +169,7 @@ async def test_tracing_cancel_invocation_creates_span(span_exporter): transport = httpx.ASGITransport(app=agent.app) async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: resp = await client.post("/invocations/test-cancel-456/cancel") - assert resp.status_code == 404 + assert resp.status_code == 501 spans = span_exporter.get_finished_spans() cancel_spans = [s for s in spans if s.name == "AgentServer.cancel_invocation"] @@ -402,7 +402,7 @@ class TestAppInsightsConnectionStringResolution: def test_explicit_param_takes_priority(self, monkeypatch): """Constructor param beats env var.""" - from azure.ai.agentserver._config import resolve_appinsights_connection_string + from azure.ai.agentserver.server._config import resolve_appinsights_connection_string monkeypatch.setenv("APPLICATIONINSIGHTS_CONNECTION_STRING", "env-standard") result = resolve_appinsights_connection_string("explicit-value") @@ -410,7 +410,7 @@ def test_explicit_param_takes_priority(self, monkeypatch): def test_standard_env_var_fallback(self, monkeypatch): """Falls back to APPLICATIONINSIGHTS_CONNECTION_STRING.""" - from azure.ai.agentserver._config import resolve_appinsights_connection_string + from azure.ai.agentserver.server._config import resolve_appinsights_connection_string monkeypatch.setenv("APPLICATIONINSIGHTS_CONNECTION_STRING", "env-standard") result = resolve_appinsights_connection_string(None) @@ -418,7 +418,7 @@ def test_standard_env_var_fallback(self, monkeypatch): def test_no_connection_string_returns_none(self, monkeypatch): """Returns None when no source provides a connection string.""" - from azure.ai.agentserver._config import resolve_appinsights_connection_string + from azure.ai.agentserver.server._config import resolve_appinsights_connection_string monkeypatch.delenv("APPLICATIONINSIGHTS_CONNECTION_STRING", raising=False) result = resolve_appinsights_connection_string(None) @@ -435,7 +435,7 @@ class TestSetupAzureMonitor: def test_setup_configures_tracer_provider(self): """_setup_azure_monitor sets a global TracerProvider with exporter.""" - from azure.ai.agentserver._tracing import TracingHelper + from azure.ai.agentserver.server._tracing import TracingHelper mock_exporter = MagicMock() mock_exporter_cls = MagicMock(return_value=mock_exporter) @@ -459,7 +459,7 @@ def test_setup_configures_tracer_provider(self): def test_setup_logs_warning_when_packages_missing(self, caplog): """Warns gracefully when azure-monitor exporter is not installed.""" import builtins - from azure.ai.agentserver._tracing import TracingHelper + from azure.ai.agentserver.server._tracing import TracingHelper real_import = builtins.__import__ @@ -481,7 +481,7 @@ def test_constructor_passes_connection_string(self, monkeypatch): monkeypatch.delenv("APPLICATIONINSIGHTS_CONNECTION_STRING", raising=False) with patch( - "azure.ai.agentserver._tracing.TracingHelper._setup_azure_monitor" + "azure.ai.agentserver.server._tracing.TracingHelper._setup_azure_monitor" ) as mock_setup: _make_echo_traced_agent( enable_tracing=True, @@ -494,7 +494,7 @@ def test_constructor_no_connection_string_skips_setup(self, monkeypatch): monkeypatch.delenv("APPLICATIONINSIGHTS_CONNECTION_STRING", raising=False) with patch( - "azure.ai.agentserver._tracing.TracingHelper._setup_azure_monitor" + "azure.ai.agentserver.server._tracing.TracingHelper._setup_azure_monitor" ) as mock_setup: _make_echo_traced_agent(enable_tracing=True) mock_setup.assert_not_called() @@ -507,7 +507,7 @@ def test_constructor_env_var_connection_string(self, monkeypatch): ) with patch( - "azure.ai.agentserver._tracing.TracingHelper._setup_azure_monitor" + "azure.ai.agentserver.server._tracing.TracingHelper._setup_azure_monitor" ) as mock_setup: _make_echo_traced_agent(enable_tracing=True) mock_setup.assert_called_once_with("InstrumentationKey=from-env") diff --git a/sdk/agentserver/azure-ai-agentserver/samples/async_invoke_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver/samples/async_invoke_agent/requirements.txt deleted file mode 100644 index 10ccd9f42648..000000000000 --- a/sdk/agentserver/azure-ai-agentserver/samples/async_invoke_agent/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -azure-ai-agentserver diff --git a/sdk/agentserver/azure-ai-agentserver/samples/human_in_the_loop_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver/samples/human_in_the_loop_agent/requirements.txt deleted file mode 100644 index 10ccd9f42648..000000000000 --- a/sdk/agentserver/azure-ai-agentserver/samples/human_in_the_loop_agent/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -azure-ai-agentserver diff --git a/sdk/agentserver/azure-ai-agentserver/samples/openapi_validated_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver/samples/openapi_validated_agent/requirements.txt deleted file mode 100644 index 10ccd9f42648..000000000000 --- a/sdk/agentserver/azure-ai-agentserver/samples/openapi_validated_agent/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -azure-ai-agentserver diff --git a/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/Dockerfile b/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/Dockerfile deleted file mode 100644 index 97ed3fb21fd1..000000000000 --- a/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/Dockerfile +++ /dev/null @@ -1,19 +0,0 @@ -FROM python:3.12-slim - -WORKDIR /app - -# Install the agentserver package from the local source tree -COPY sdk/agentserver/azure-ai-agentserver /src/azure-ai-agentserver -RUN pip install --no-cache-dir /src/azure-ai-agentserver - -# Copy the sample agent -COPY sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/simple_invoke_agent.py . - -EXPOSE 8088 - -# Bind to 0.0.0.0 so the port is accessible from outside the container. -# The default is 127.0.0.1 which is only reachable inside the container. -CMD ["python", "-c", "\ -from simple_invoke_agent import server; \ -server.run(host='0.0.0.0') \ -"] diff --git a/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/requirements.txt deleted file mode 100644 index 10ccd9f42648..000000000000 --- a/sdk/agentserver/azure-ai-agentserver/samples/simple_invoke_agent/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -azure-ai-agentserver diff --git a/sdk/agentserver/ci.yml b/sdk/agentserver/ci.yml index 7e718801a805..9c0128b8089e 100644 --- a/sdk/agentserver/ci.yml +++ b/sdk/agentserver/ci.yml @@ -40,8 +40,8 @@ extends: Selection: sparse GenerateVMJobs: true Artifacts: - - name: azure-ai-agentserver - safeName: azureaiagentserver + - name: azure-ai-agentserver-server + safeName: azureaiagentserverserver - name: azure-ai-agentserver-core safeName: azureaiagentservercore - name: azure-ai-agentserver-agentframework From 58516ac3826633e30b2d0dfa7297cab8d79b1370 Mon Sep 17 00:00:00 2001 From: Zhiyong Yang Date: Thu, 12 Mar 2026 15:40:48 -0700 Subject: [PATCH 06/10] Refactor invocation protocol from mixin to composition and add OTel tracing Extract _InvocationProtocol into a standalone composed class that receives shared server state via a frozen _ServerContext dataclass, enabling clean addition of future protocol heads without inheritance conflicts. Add OpenTelemetry tracing with GenAI semantic convention attributes, W3C Trace Context propagation, baggage header parsing, and Azure Monitor Application Insights export support. Co-Authored-By: Claude Opus 4.6 --- .../azure/ai/agentserver/server/_base.py | 334 ++------------ .../azure/ai/agentserver/server/_config.py | 18 + .../azure/ai/agentserver/server/_constants.py | 3 + .../ai/agentserver/server/_invocation.py | 427 ++++++++++++++++++ .../ai/agentserver/server/_server_context.py | 25 + .../azure/ai/agentserver/server/_tracing.py | 144 +++++- .../tests/test_decorator_pattern.py | 6 +- .../tests/test_request_limits.py | 2 +- .../tests/test_tracing.py | 392 +++++++++++++++- 9 files changed, 1044 insertions(+), 307 deletions(-) create mode 100644 sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_invocation.py create mode 100644 sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_server_context.py diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_base.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_base.py index e3c8c26a1f19..01d8fad49060 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_base.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_base.py @@ -4,20 +4,20 @@ import asyncio # pylint: disable=do-not-import-asyncio import contextlib import logging -import uuid from collections.abc import AsyncGenerator, Awaitable, Callable # pylint: disable=import-error from typing import Any, Optional from starlette.applications import Starlette from starlette.requests import Request -from starlette.responses import JSONResponse, Response, StreamingResponse +from starlette.responses import Response from starlette.routing import Route from ._constants import Constants -from ._errors import error_response from ._logger import get_logger -from ._tracing import TracingHelper, extract_w3c_carrier +from ._tracing import TracingHelper from ._openapi_validator import OpenApiValidator +from ._invocation import _InvocationProtocol +from ._server_context import _ServerContext from . import _config logger = get_logger() @@ -105,10 +105,7 @@ def __init__( log_level: Optional[str] = None, debug_errors: Optional[bool] = None, ) -> None: - # Decorator handler slots ------------------------------------------ - self._invoke_fn: Optional[Callable] = None - self._get_invocation_fn: Optional[Callable] = None - self._cancel_invocation_fn: Optional[Callable] = None + # Shutdown handler slot (server-level lifecycle) ------------------- self._shutdown_fn: Optional[Callable] = None # Logging & debug ------------------------------------------------- @@ -122,15 +119,17 @@ def __init__( debug_errors, Constants.AGENT_DEBUG_ERRORS ) - self._openapi_spec = openapi_spec + # OpenAPI validation ----------------------------------------------- _validation_on = _config.resolve_bool_feature( enable_request_validation, Constants.AGENT_ENABLE_REQUEST_VALIDATION ) - self._validator: Optional[OpenApiValidator] = ( + validator: Optional[OpenApiValidator] = ( OpenApiValidator(openapi_spec) if openapi_spec and _validation_on else None ) + + # Tracing ---------------------------------------------------------- _tracing_on = _config.resolve_bool_feature(enable_tracing, Constants.AGENT_ENABLE_TRACING) _conn_str = _config.resolve_appinsights_connection_string( application_insights_connection_string @@ -138,15 +137,46 @@ def __init__( self._tracing: Optional[TracingHelper] = ( TracingHelper(connection_string=_conn_str) if _tracing_on else None ) + + # Timeouts --------------------------------------------------------- self._graceful_shutdown_timeout = _config.resolve_graceful_shutdown_timeout( graceful_shutdown_timeout ) self._request_timeout = _config.resolve_request_timeout(request_timeout) + + # Invocation protocol (composed) ------------------------------------- + ctx = _ServerContext( + tracing=self._tracing, + debug_errors=self._debug_errors, + request_timeout=self._request_timeout, + ) + self._invocation = _InvocationProtocol(ctx, openapi_spec, validator) + self.app: Starlette self._build_app() # ------------------------------------------------------------------ - # Handler decorators + # Shutdown handler (server-level lifecycle) + # ------------------------------------------------------------------ + + def shutdown_handler(self, fn: Callable[[], Awaitable[None]]) -> Callable[[], Awaitable[None]]: + """Register a function as the shutdown handler. + + :param fn: Async function called during graceful shutdown. + :type fn: Callable[[], Awaitable[None]] + :return: The original function (unmodified). + :rtype: Callable[[], Awaitable[None]] + """ + self._shutdown_fn = fn + return fn + + async def _dispatch_shutdown(self) -> None: + """Dispatch to the registered shutdown handler, or no-op.""" + if self._shutdown_fn is not None: + await self._shutdown_fn() + + # ------------------------------------------------------------------ + # Invocation protocol delegates # ------------------------------------------------------------------ def invoke_handler( @@ -165,8 +195,7 @@ async def handle(request: Request) -> Response: :return: The original function (unmodified). :rtype: Callable[[Request], Awaitable[Response]] """ - self._invoke_fn = fn - return fn + return self._invocation.invoke_handler(fn) def get_invocation_handler( self, fn: Callable[[Request], Awaitable[Response]] @@ -178,8 +207,7 @@ def get_invocation_handler( :return: The original function (unmodified). :rtype: Callable[[Request], Awaitable[Response]] """ - self._get_invocation_fn = fn - return fn + return self._invocation.get_invocation_handler(fn) def cancel_invocation_handler( self, fn: Callable[[Request], Awaitable[Response]] @@ -191,67 +219,7 @@ def cancel_invocation_handler( :return: The original function (unmodified). :rtype: Callable[[Request], Awaitable[Response]] """ - self._cancel_invocation_fn = fn - return fn - - def shutdown_handler(self, fn: Callable[[], Awaitable[None]]) -> Callable[[], Awaitable[None]]: - """Register a function as the shutdown handler. - - :param fn: Async function called during graceful shutdown. - :type fn: Callable[[], Awaitable[None]] - :return: The original function (unmodified). - :rtype: Callable[[], Awaitable[None]] - """ - self._shutdown_fn = fn - return fn - - # ------------------------------------------------------------------ - # Dispatch methods (internal) - # ------------------------------------------------------------------ - - async def _dispatch_invoke(self, request: Request) -> Response: - """Dispatch to the registered invoke handler. - - :param request: The incoming Starlette request. - :type request: Request - :return: The response from the invoke handler. - :rtype: Response - :raises NotImplementedError: If no invoke handler has been registered. - """ - if self._invoke_fn is not None: - return await self._invoke_fn(request) - raise NotImplementedError( - "No invoke handler registered. Use the @server.invoke_handler decorator." - ) - - async def _dispatch_get_invocation(self, request: Request) -> Response: - """Dispatch to the registered get-invocation handler, or return 404. - - :param request: The incoming Starlette request. - :type request: Request - :return: The response from the get-invocation handler. - :rtype: Response - """ - if self._get_invocation_fn is not None: - return await self._get_invocation_fn(request) - return error_response("not_supported", "get_invocation not supported", status_code=501) - - async def _dispatch_cancel_invocation(self, request: Request) -> Response: - """Dispatch to the registered cancel-invocation handler, or return 404. - - :param request: The incoming Starlette request. - :type request: Request - :return: The response from the cancel-invocation handler. - :rtype: Response - """ - if self._cancel_invocation_fn is not None: - return await self._cancel_invocation_fn(request) - return error_response("not_supported", "cancel_invocation not supported", status_code=501) - - async def _dispatch_shutdown(self) -> None: - """Dispatch to the registered shutdown handler, or no-op.""" - if self._shutdown_fn is not None: - await self._shutdown_fn() + return self._invocation.cancel_invocation_handler(fn) def get_openapi_spec(self) -> Optional[dict[str, Any]]: """Return the OpenAPI spec dict for this agent, or None. @@ -259,7 +227,7 @@ def get_openapi_spec(self) -> Optional[dict[str, Any]]: :return: The registered OpenAPI spec or None. :rtype: Optional[dict[str, Any]] """ - return self._openapi_spec + return self._invocation.get_openapi_spec() # ------------------------------------------------------------------ # Run helpers @@ -348,222 +316,18 @@ async def _lifespan(_app: Starlette) -> AsyncGenerator[None, None]: # noqa: RUF except Exception: # pylint: disable=broad-exception-caught logger.exception("Error in on_shutdown") - routes = [ - Route( - "/invocations/docs/openapi.json", - self._get_openapi_spec_endpoint, - methods=["GET"], - name="get_openapi_spec", - ), - Route( - "/invocations", - self._create_invocation_endpoint, - methods=["POST"], - name="create_invocation", - ), - Route( - "/invocations/{invocation_id}", - self._get_invocation_endpoint, - methods=["GET"], - name="get_invocation", - ), - Route( - "/invocations/{invocation_id}/cancel", - self._cancel_invocation_endpoint, - methods=["POST"], - name="cancel_invocation", - ), + routes = list(self._invocation.routes) + routes.extend([ Route("/liveness", self._liveness_endpoint, methods=["GET"], name="liveness"), Route("/readiness", self._readiness_endpoint, methods=["GET"], name="readiness"), - ] + ]) self.app = Starlette(routes=routes, lifespan=_lifespan) # ------------------------------------------------------------------ - # Private: endpoint handlers + # Health endpoints # ------------------------------------------------------------------ - async def _get_openapi_spec_endpoint(self, request: Request) -> Response: # pylint: disable=unused-argument - """GET /invocations/docs/openapi.json — return registered spec or 404. - - :param request: The incoming Starlette request. - :type request: Request - :return: JSON response with the spec or 404. - :rtype: Response - """ - spec = self.get_openapi_spec() - if spec is None: - return error_response("not_found", "No OpenAPI spec registered", status_code=404) - return JSONResponse(spec) - - async def _create_invocation_endpoint(self, request: Request) -> Response: - """POST /invocations — create and process an invocation. - - :param request: The incoming Starlette request. - :type request: Request - :return: The invocation result or error response. - :rtype: Response - """ - invocation_id = ( - request.headers.get(Constants.INVOCATION_ID_HEADER) - or str(uuid.uuid4()) - ) - request.state.invocation_id = invocation_id - - # Validate request body against OpenAPI spec - if self._validator is not None: - content_type = request.headers.get("content-type", "application/json") - body = await request.body() - errors = self._validator.validate_request(body, content_type) - if errors: - return error_response( - "invalid_payload", - "Request validation failed", - status_code=400, - details=[ - {"code": "validation_error", "message": e} - for e in errors - ], - ) - - carrier = extract_w3c_carrier(request.headers) - - # Use manual span management so that streaming responses keep the - # span open until the last chunk is yielded (or an error occurs). - otel_span = ( - self._tracing.start_span( - "AgentServer.invoke", - attributes={"invocation.id": invocation_id}, - carrier=carrier, - ) - if self._tracing is not None - else None - ) - try: - invoke_awaitable = self._dispatch_invoke(request) - timeout = self._request_timeout or None # 0 → None (no limit) - response = await asyncio.wait_for(invoke_awaitable, timeout=timeout) - except NotImplementedError as exc: - if self._tracing is not None: - self._tracing.end_span(otel_span, exc=exc) - logger.error("Invocation %s failed: %s", invocation_id, exc) - return error_response( - "not_implemented", - str(exc), - status_code=501, - headers={Constants.INVOCATION_ID_HEADER: invocation_id}, - ) - except asyncio.TimeoutError as exc: - if self._tracing is not None: - self._tracing.end_span(otel_span, exc=exc) - logger.error( - "Invocation %s timed out after %ss", - invocation_id, - self._request_timeout, - ) - return error_response( - "request_timeout", - f"Invocation timed out after {self._request_timeout}s", - status_code=504, - headers={Constants.INVOCATION_ID_HEADER: invocation_id}, - ) - except Exception as exc: # pylint: disable=broad-exception-caught - if self._tracing is not None: - self._tracing.end_span(otel_span, exc=exc) - logger.error("Error processing invocation %s: %s", invocation_id, exc, exc_info=True) - message = str(exc) if self._debug_errors else "Internal server error" - return error_response( - "internal_error", - message, - status_code=500, - headers={Constants.INVOCATION_ID_HEADER: invocation_id}, - ) - - # For streaming responses, wrap the body iterator so the span stays - # open until all chunks are sent and captures any streaming errors. - if isinstance(response, StreamingResponse) and self._tracing is not None: - response.body_iterator = self._tracing.trace_stream(response.body_iterator, otel_span) - elif self._tracing is not None: - self._tracing.end_span(otel_span) - - # Always set invocation_id header (overrides any handler-set value) - response.headers[Constants.INVOCATION_ID_HEADER] = invocation_id - - return response - - async def _traced_invocation_endpoint( - self, - request: Request, - span_name: str, - dispatch: Callable[[Request], Awaitable[Response]], - ) -> Response: - """Shared implementation for get/cancel invocation endpoints. - - Extracts the invocation ID from path params, optionally creates a - tracing span, dispatches to the handler, and handles errors. - - :param request: The incoming Starlette request. - :type request: Request - :param span_name: OTel span name (e.g. ``"AgentServer.get_invocation"``). - :type span_name: str - :param dispatch: The dispatch method to invoke. - :type dispatch: Callable[[Request], Awaitable[Response]] - :return: The handler response or an error response. - :rtype: Response - """ - invocation_id = request.path_params["invocation_id"] - request.state.invocation_id = invocation_id - carrier = extract_w3c_carrier(request.headers) if self._tracing is not None else {} - span_cm = ( - self._tracing.span( - span_name, - attributes={"invocation.id": invocation_id}, - carrier=carrier, - ) - if self._tracing is not None - else contextlib.nullcontext(None) - ) - with span_cm as _otel_span: - try: - response = await dispatch(request) - response.headers[Constants.INVOCATION_ID_HEADER] = invocation_id - return response - except Exception as exc: # pylint: disable=broad-exception-caught - if self._tracing is not None: - self._tracing.record_error(_otel_span, exc) - logger.error("Error in %s %s: %s", span_name, invocation_id, exc, exc_info=True) - message = str(exc) if self._debug_errors else "Internal server error" - return error_response( - "internal_error", - message, - status_code=500, - headers={Constants.INVOCATION_ID_HEADER: invocation_id}, - ) - - async def _get_invocation_endpoint(self, request: Request) -> Response: - """GET /invocations/{invocation_id} — retrieve an invocation result. - - :param request: The incoming Starlette request. - :type request: Request - :return: The stored result or 404. - :rtype: Response - """ - return await self._traced_invocation_endpoint( - request, "AgentServer.get_invocation", self._dispatch_get_invocation - ) - - async def _cancel_invocation_endpoint(self, request: Request) -> Response: - """POST /invocations/{invocation_id}/cancel — cancel an invocation. - - :param request: The incoming Starlette request. - :type request: Request - :return: The cancellation result or 404. - :rtype: Response - """ - return await self._traced_invocation_endpoint( - request, "AgentServer.cancel_invocation", self._dispatch_cancel_invocation - ) - async def _liveness_endpoint(self, request: Request) -> Response: # pylint: disable=unused-argument """GET /liveness — health check. diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_config.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_config.py index 910479580f18..7e2ffbb75a76 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_config.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_config.py @@ -187,3 +187,21 @@ def resolve_log_level(level: Optional[str]) -> str: f"(expected one of {', '.join(_VALID_LOG_LEVELS)})" ) return normalized + + +def resolve_agent_name() -> str: + """Resolve the agent name from the ``AGENT_NAME`` environment variable. + + :return: The agent name, or an empty string if not set. + :rtype: str + """ + return os.environ.get(Constants.AGENT_NAME, "") + + +def resolve_agent_version() -> str: + """Resolve the agent version from the ``AGENT_VERSION`` environment variable. + + :return: The agent version, or an empty string if not set. + :rtype: str + """ + return os.environ.get(Constants.AGENT_VERSION, "") diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py index d85e143e6dd4..ceb3f42ef302 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py @@ -14,7 +14,10 @@ class Constants: AGENT_REQUEST_TIMEOUT = "AGENT_REQUEST_TIMEOUT" AGENT_ENABLE_REQUEST_VALIDATION = "AGENT_ENABLE_REQUEST_VALIDATION" APPLICATIONINSIGHTS_CONNECTION_STRING = "APPLICATIONINSIGHTS_CONNECTION_STRING" + AGENT_NAME = "AGENT_NAME" + AGENT_VERSION = "AGENT_VERSION" DEFAULT_PORT = 8088 DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT = 30 DEFAULT_REQUEST_TIMEOUT = 300 # 5 minutes INVOCATION_ID_HEADER = "x-agent-invocation-id" + SESSION_ID_HEADER = "x-agent-session-id" diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_invocation.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_invocation.py new file mode 100644 index 000000000000..ea36c8e0644a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_invocation.py @@ -0,0 +1,427 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Invocation protocol for AgentServer. + +Encapsulates the ``/invocations`` REST endpoints, handler decorators, +dispatch methods, and OpenAPI spec serving. Designed as a standalone +composed object so that ``AgentServer`` can compose multiple protocol +heads (invocation, chat, etc.) without inheritance conflicts. + +Shared server state (tracing, error handling, timeouts) is received +via a :class:`~._server_context._ServerContext` instance. +""" +import asyncio # pylint: disable=do-not-import-asyncio +import contextlib +import uuid +from collections.abc import Awaitable, Callable # pylint: disable=import-error +from typing import Any, Optional + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse +from starlette.routing import Route + +from ._constants import Constants +from ._errors import error_response +from ._logger import get_logger +from ._tracing import extract_baggage_header, extract_w3c_carrier +from ._openapi_validator import OpenApiValidator +from ._server_context import _ServerContext +from . import _config + +logger = get_logger() + + +class _InvocationProtocol: + """Invocation protocol implementation. + + Receives shared server state via a :class:`_ServerContext` and manages + its own handler slots, agent identity, and endpoint handlers. + + **Not intended for direct instantiation.** Use via + :class:`~azure.ai.agentserver.server.AgentServer`, which creates and + delegates to this object. + """ + + def __init__( + self, + ctx: _ServerContext, + openapi_spec: Optional[dict[str, Any]], + validator: Optional[OpenApiValidator], + ) -> None: + """Initialise the invocation protocol. + + :param ctx: Shared server context (tracing, debug_errors, request_timeout). + :type ctx: _ServerContext + :param openapi_spec: Optional OpenAPI spec dict for documentation. + :type openapi_spec: Optional[dict[str, Any]] + :param validator: Optional request validator built from the spec. + :type validator: Optional[OpenApiValidator] + """ + self._ctx = ctx + self._invoke_fn: Optional[Callable] = None + self._get_invocation_fn: Optional[Callable] = None + self._cancel_invocation_fn: Optional[Callable] = None + self._openapi_spec = openapi_spec + self._validator = validator + + # Agent identity — used for span naming and GenAI attributes. + self._agent_name = _config.resolve_agent_name() + self._agent_version = _config.resolve_agent_version() + self._agent_label = ( + f"{self._agent_name}:{self._agent_version}" + if self._agent_name + else "" + ) + + # ------------------------------------------------------------------ + # Route registration + # ------------------------------------------------------------------ + + @property + def routes(self) -> list[Route]: + """Starlette routes for the invocation protocol. + + :return: A list of four Route objects for the invocation endpoints. + :rtype: list[Route] + """ + return [ + Route( + "/invocations/docs/openapi.json", + self._get_openapi_spec_endpoint, + methods=["GET"], + name="get_openapi_spec", + ), + Route( + "/invocations", + self._create_invocation_endpoint, + methods=["POST"], + name="create_invocation", + ), + Route( + "/invocations/{invocation_id}", + self._get_invocation_endpoint, + methods=["GET"], + name="get_invocation", + ), + Route( + "/invocations/{invocation_id}/cancel", + self._cancel_invocation_endpoint, + methods=["POST"], + name="cancel_invocation", + ), + ] + + # ------------------------------------------------------------------ + # Span name helper (protocol-specific) + # ------------------------------------------------------------------ + + def _span_name(self, operation: str) -> str: + """Build a span name using the operation and agent label. + + :param operation: The operation name (e.g. ``"execute_agent"``). + :type operation: str + :return: ``" :"`` or just ``""``. + :rtype: str + """ + if self._agent_label: + return f"{operation} {self._agent_label}" + return operation + + def _build_span_attrs( + self, + invocation_id: str, + session_id: str, + operation_name: Optional[str] = None, + ) -> dict[str, str]: + """Build GenAI semantic convention span attributes. + + :param invocation_id: The invocation ID for this request. + :type invocation_id: str + :param session_id: The session ID header value (empty string if absent). + :type session_id: str + :param operation_name: Optional ``gen_ai.operation.name`` value + (e.g. ``"invoke_agent"``). Omitted from the dict when *None*. + :type operation_name: Optional[str] + :return: Span attribute dict. + :rtype: dict[str, str] + """ + attrs: dict[str, str] = { + "invocation.id": invocation_id, + "gen_ai.agent.id": self._agent_label, + "gen_ai.response.id": invocation_id, + "gen_ai.provider.name": "microsoft.foundry", + } + if operation_name: + attrs["gen_ai.operation.name"] = operation_name + if session_id: + attrs["gen_ai.conversation.id"] = session_id + return attrs + + # ------------------------------------------------------------------ + # Handler decorators + # ------------------------------------------------------------------ + + def invoke_handler( + self, fn: Callable[[Request], Awaitable[Response]] + ) -> Callable[[Request], Awaitable[Response]]: + """Store *fn* as the invoke handler. See :meth:`AgentServer.invoke_handler`.""" + self._invoke_fn = fn + return fn + + def get_invocation_handler( + self, fn: Callable[[Request], Awaitable[Response]] + ) -> Callable[[Request], Awaitable[Response]]: + """Store *fn* as the get-invocation handler. See :meth:`AgentServer.get_invocation_handler`.""" + self._get_invocation_fn = fn + return fn + + def cancel_invocation_handler( + self, fn: Callable[[Request], Awaitable[Response]] + ) -> Callable[[Request], Awaitable[Response]]: + """Store *fn* as the cancel-invocation handler. See :meth:`AgentServer.cancel_invocation_handler`.""" + self._cancel_invocation_fn = fn + return fn + + # ------------------------------------------------------------------ + # Dispatch methods (internal) + # ------------------------------------------------------------------ + + async def _dispatch_invoke(self, request: Request) -> Response: + """Dispatch to the registered invoke handler. + + :param request: The incoming Starlette request. + :type request: Request + :return: The response from the invoke handler. + :rtype: Response + :raises NotImplementedError: If no invoke handler has been registered. + """ + if self._invoke_fn is not None: + return await self._invoke_fn(request) + raise NotImplementedError( + "No invoke handler registered. Use the @server.invoke_handler decorator." + ) + + async def _dispatch_get_invocation(self, request: Request) -> Response: + """Dispatch to the registered get-invocation handler, or return 501. + + :param request: The incoming Starlette request. + :type request: Request + :return: The response from the get-invocation handler. + :rtype: Response + """ + if self._get_invocation_fn is not None: + return await self._get_invocation_fn(request) + return error_response("not_supported", "get_invocation not supported", status_code=501) + + async def _dispatch_cancel_invocation(self, request: Request) -> Response: + """Dispatch to the registered cancel-invocation handler, or return 501. + + :param request: The incoming Starlette request. + :type request: Request + :return: The response from the cancel-invocation handler. + :rtype: Response + """ + if self._cancel_invocation_fn is not None: + return await self._cancel_invocation_fn(request) + return error_response("not_supported", "cancel_invocation not supported", status_code=501) + + def get_openapi_spec(self) -> Optional[dict[str, Any]]: + """Return the stored OpenAPI spec. See :meth:`AgentServer.get_openapi_spec`.""" + return self._openapi_spec + + # ------------------------------------------------------------------ + # Endpoint handlers + # ------------------------------------------------------------------ + + async def _get_openapi_spec_endpoint(self, request: Request) -> Response: # pylint: disable=unused-argument + """GET /invocations/docs/openapi.json — return registered spec or 404. + + :param request: The incoming Starlette request. + :type request: Request + :return: JSON response with the spec or 404. + :rtype: Response + """ + spec = self.get_openapi_spec() + if spec is None: + return error_response("not_found", "No OpenAPI spec registered", status_code=404) + return JSONResponse(spec) + + async def _create_invocation_endpoint(self, request: Request) -> Response: + """POST /invocations — create and process an invocation. + + :param request: The incoming Starlette request. + :type request: Request + :return: The invocation result or error response. + :rtype: Response + """ + invocation_id = ( + request.headers.get(Constants.INVOCATION_ID_HEADER) + or str(uuid.uuid4()) + ) + request.state.invocation_id = invocation_id + + # Validate request body against OpenAPI spec + if self._validator is not None: + content_type = request.headers.get("content-type", "application/json") + body = await request.body() + errors = self._validator.validate_request(body, content_type) + if errors: + return error_response( + "invalid_payload", + "Request validation failed", + status_code=400, + details=[ + {"code": "validation_error", "message": e} + for e in errors + ], + ) + + carrier = extract_w3c_carrier(request.headers) + baggage = extract_baggage_header(request.headers) + session_id = request.headers.get(Constants.SESSION_ID_HEADER, "") + span_attrs = self._build_span_attrs( + invocation_id, session_id, operation_name="invoke_agent" + ) + + # Use manual span management so that streaming responses keep the + # span open until the last chunk is yielded (or an error occurs). + otel_span = ( + self._ctx.tracing.start_span( + self._span_name("execute_agent"), + attributes=span_attrs, + carrier=carrier, + baggage_header=baggage, + ) + if self._ctx.tracing is not None + else None + ) + try: + invoke_awaitable = self._dispatch_invoke(request) + timeout = self._ctx.request_timeout or None # 0 → None (no limit) + response = await asyncio.wait_for(invoke_awaitable, timeout=timeout) + except NotImplementedError as exc: + if self._ctx.tracing is not None: + self._ctx.tracing.end_span(otel_span, exc=exc) + logger.error("Invocation %s failed: %s", invocation_id, exc) + return error_response( + "not_implemented", + str(exc), + status_code=501, + headers={Constants.INVOCATION_ID_HEADER: invocation_id}, + ) + except asyncio.TimeoutError as exc: + if self._ctx.tracing is not None: + self._ctx.tracing.end_span(otel_span, exc=exc) + logger.error( + "Invocation %s timed out after %ss", + invocation_id, + self._ctx.request_timeout, + ) + return error_response( + "request_timeout", + f"Invocation timed out after {self._ctx.request_timeout}s", + status_code=504, + headers={Constants.INVOCATION_ID_HEADER: invocation_id}, + ) + except Exception as exc: # pylint: disable=broad-exception-caught + if self._ctx.tracing is not None: + self._ctx.tracing.end_span(otel_span, exc=exc) + logger.error("Error processing invocation %s: %s", invocation_id, exc, exc_info=True) + message = str(exc) if self._ctx.debug_errors else "Internal server error" + return error_response( + "internal_error", + message, + status_code=500, + headers={Constants.INVOCATION_ID_HEADER: invocation_id}, + ) + + # For streaming responses, wrap the body iterator so the span stays + # open until all chunks are sent and captures any streaming errors. + if isinstance(response, StreamingResponse) and self._ctx.tracing is not None: + response.body_iterator = self._ctx.tracing.trace_stream(response.body_iterator, otel_span) + elif self._ctx.tracing is not None: + self._ctx.tracing.end_span(otel_span) + + # Always set invocation_id header (overrides any handler-set value) + response.headers[Constants.INVOCATION_ID_HEADER] = invocation_id + + return response + + async def _traced_invocation_endpoint( + self, + request: Request, + span_name: str, + dispatch: Callable[[Request], Awaitable[Response]], + ) -> Response: + """Shared implementation for get/cancel invocation endpoints. + + Extracts the invocation ID from path params, optionally creates a + tracing span, dispatches to the handler, and handles errors. + + :param request: The incoming Starlette request. + :type request: Request + :param span_name: OTel span name (e.g. ``"get_invocation"``). + :type span_name: str + :param dispatch: The dispatch method to invoke. + :type dispatch: Callable[[Request], Awaitable[Response]] + :return: The handler response or an error response. + :rtype: Response + """ + invocation_id = request.path_params["invocation_id"] + request.state.invocation_id = invocation_id + carrier = extract_w3c_carrier(request.headers) if self._ctx.tracing is not None else {} + baggage = extract_baggage_header(request.headers) if self._ctx.tracing is not None else None + session_id = request.headers.get(Constants.SESSION_ID_HEADER, "") + span_attrs = self._build_span_attrs(invocation_id, session_id) + + span_cm = ( + self._ctx.tracing.span( + span_name, + attributes=span_attrs, + carrier=carrier, + baggage_header=baggage, + ) + if self._ctx.tracing is not None + else contextlib.nullcontext(None) + ) + with span_cm as _otel_span: + try: + response = await dispatch(request) + response.headers[Constants.INVOCATION_ID_HEADER] = invocation_id + return response + except Exception as exc: # pylint: disable=broad-exception-caught + if self._ctx.tracing is not None: + self._ctx.tracing.record_error(_otel_span, exc) + logger.error("Error in %s %s: %s", span_name, invocation_id, exc, exc_info=True) + message = str(exc) if self._ctx.debug_errors else "Internal server error" + return error_response( + "internal_error", + message, + status_code=500, + headers={Constants.INVOCATION_ID_HEADER: invocation_id}, + ) + + async def _get_invocation_endpoint(self, request: Request) -> Response: + """GET /invocations/{invocation_id} — retrieve an invocation result. + + :param request: The incoming Starlette request. + :type request: Request + :return: The stored result or 501. + :rtype: Response + """ + return await self._traced_invocation_endpoint( + request, self._span_name("get_invocation"), self._dispatch_get_invocation + ) + + async def _cancel_invocation_endpoint(self, request: Request) -> Response: + """POST /invocations/{invocation_id}/cancel — cancel an invocation. + + :param request: The incoming Starlette request. + :type request: Request + :return: The cancellation result or 501. + :rtype: Response + """ + return await self._traced_invocation_endpoint( + request, self._span_name("cancel_invocation"), self._dispatch_cancel_invocation + ) diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_server_context.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_server_context.py new file mode 100644 index 000000000000..a9078aba6450 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_server_context.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Internal server context shared across protocol implementations.""" +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from ._tracing import TracingHelper + + +@dataclasses.dataclass(frozen=True) +class _ServerContext: + """Shared server state passed to protocol implementations. + + Internal — not part of the public API. Each protocol receives this + at construction time so it can access tracing, error handling, and + timeout configuration without coupling to the ``AgentServer`` class. + """ + + tracing: Optional[TracingHelper] + debug_errors: bool + request_timeout: int diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py index 8b85575a17cd..e80048645b72 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py @@ -36,6 +36,9 @@ #: W3C Trace Context header names used for distributed trace propagation. _W3C_HEADERS = ("traceparent", "tracestate") +#: Baggage key whose value overrides the parent span ID. +_LEAF_CUSTOMER_SPAN_ID = "leaf_customer_span_id" + logger = get_logger() _HAS_OTEL = False @@ -85,17 +88,38 @@ def __init__(self, connection_string: Optional[str] = None) -> None: # Azure Monitor auto-configuration # ------------------------------------------------------------------ - def _extract_context(self, carrier: Optional[dict[str, str]]) -> Any: + def _extract_context( + self, + carrier: Optional[dict[str, str]], + baggage_header: Optional[str] = None, + ) -> Any: """Extract parent trace context from a W3C carrier dict. + When a ``baggage`` header is provided and contains a + ``leaf_customer_span_id`` key, the parent span ID is overridden + so that the server's root span is parented under the leaf customer + span rather than the span referenced in the ``traceparent`` header. + :param carrier: W3C trace-context headers or None. :type carrier: Optional[dict[str, str]] + :param baggage_header: Raw ``baggage`` header value or None. + :type baggage_header: Optional[str] :return: The extracted OTel context, or None. :rtype: Any """ - if carrier and self._propagator is not None: - return self._propagator.extract(carrier=carrier) - return None + if not carrier or self._propagator is None: + return None + + ctx = self._propagator.extract(carrier=carrier) + + if not baggage_header: + return ctx + + leaf_span_id = _parse_baggage_key(baggage_header, _LEAF_CUSTOMER_SPAN_ID) + if not leaf_span_id: + return ctx + + return _override_parent_span_id(ctx, leaf_span_id) @staticmethod def _setup_azure_monitor(connection_string: str) -> None: @@ -172,6 +196,7 @@ def span( name: str, attributes: Optional[dict[str, str]] = None, carrier: Optional[dict[str, str]] = None, + baggage_header: Optional[str] = None, ) -> Iterator[Any]: """Create a traced span if tracing is enabled, otherwise no-op. @@ -179,12 +204,15 @@ def span( ``None`` when tracing is disabled. Callers may use the yielded span together with :meth:`record_error` to attach error information. - :param name: Span name, e.g. ``"AgentServer.invoke"``. + :param name: Span name, e.g. ``"execute_agent my_agent:1.0"``. :type name: str :param attributes: Key-value span attributes. :type attributes: Optional[dict[str, str]] :param carrier: Incoming HTTP headers for W3C trace-context propagation. :type carrier: Optional[dict[str, str]] + :param baggage_header: Raw ``baggage`` header value for + ``leaf_customer_span_id`` extraction. + :type baggage_header: Optional[str] :return: Context manager that yields the OTel span or *None*. :rtype: Iterator[Any] """ @@ -192,7 +220,7 @@ def span( yield None return - ctx = self._extract_context(carrier) + ctx = self._extract_context(carrier, baggage_header) with self._tracer.start_as_current_span( name=name, @@ -207,6 +235,7 @@ def start_span( name: str, attributes: Optional[dict[str, str]] = None, carrier: Optional[dict[str, str]] = None, + baggage_header: Optional[str] = None, ) -> Any: """Start a span without a context manager. @@ -214,19 +243,22 @@ def start_span( initial ``invoke()`` call. The caller **must** call :meth:`end_span` when the work is finished. - :param name: Span name, e.g. ``"AgentServer.invoke"``. + :param name: Span name, e.g. ``"execute_agent my_agent:1.0"``. :type name: str :param attributes: Key-value span attributes. :type attributes: Optional[dict[str, str]] :param carrier: Incoming HTTP headers for W3C trace-context propagation. :type carrier: Optional[dict[str, str]] + :param baggage_header: Raw ``baggage`` header value for + ``leaf_customer_span_id`` extraction. + :type baggage_header: Optional[str] :return: The OTel span, or *None* when tracing is disabled. :rtype: Any """ if not self._enabled or self._tracer is None: return None - ctx = self._extract_context(carrier) + ctx = self._extract_context(carrier, baggage_header) return self._tracer.start_span( name=name, @@ -321,3 +353,99 @@ def extract_w3c_carrier(headers: Mapping[str, str]) -> dict[str, str]: if val is not None: result[key] = val return result + + +def extract_baggage_header(headers: Mapping[str, str]) -> Optional[str]: + """Extract the raw ``baggage`` header value from a mapping. + + Returns *None* if the header is not present. + + :param headers: A mapping of header name to value. + :type headers: Mapping[str, str] + :return: The raw baggage header value or None. + :rtype: Optional[str] + """ + return headers.get("baggage") + + +def _parse_baggage_key(baggage: str, key: str) -> Optional[str]: + """Parse a single key from a W3C Baggage header value. + + The `W3C Baggage`_ format is a comma-separated list of + ``key=value`` pairs with optional properties after a ``;``. + + Example:: + + leaf_customer_span_id=abc123,other=val + + .. _W3C Baggage: https://www.w3.org/TR/baggage/ + + :param baggage: The raw header value. + :type baggage: str + :param key: The baggage key to look up. + :type key: str + :return: The value for *key*, or *None* if not found. + :rtype: Optional[str] + """ + for member in baggage.split(","): + member = member.strip() + if not member: + continue + # Split on first '=' only; value may contain '=' + kv_part = member.split(";", 1)[0] # strip optional properties + eq_idx = kv_part.find("=") + if eq_idx < 0: + continue + k = kv_part[:eq_idx].strip() + v = kv_part[eq_idx + 1:].strip() + if k == key: + return v + return None + + +def _override_parent_span_id(ctx: Any, hex_span_id: str) -> Any: + """Create a new context with the same trace ID but a different parent span ID. + + Constructs a :class:`~opentelemetry.trace.SpanContext` with the trace ID + taken from the existing context and the span ID replaced by + *hex_span_id*. The resulting context can be used as the ``context`` + argument to ``start_span`` / ``start_as_current_span``. + + Returns the original *ctx* unchanged if *hex_span_id* is invalid or + ``opentelemetry-api`` is not installed. + + :param ctx: An OTel context produced by ``TraceContextTextMapPropagator.extract()``. + :type ctx: Any + :param hex_span_id: 16-character lower-case hex string representing the + desired parent span ID. + :type hex_span_id: str + :return: A context with the overridden parent span ID, or the original. + :rtype: Any + """ + if not _HAS_OTEL: + return ctx + + try: + new_span_id = int(hex_span_id, 16) + except (ValueError, TypeError): + logger.warning("Invalid leaf_customer_span_id in baggage: %r", hex_span_id) + return ctx + + if new_span_id == 0: + return ctx + + # Grab the trace ID from the current parent span in ctx. + current_span = trace.get_current_span(ctx) + current_ctx = current_span.get_span_context() + if current_ctx is None or not current_ctx.is_valid: + return ctx + + custom_span_ctx = trace.SpanContext( + trace_id=current_ctx.trace_id, + span_id=new_span_id, + is_remote=True, + trace_flags=current_ctx.trace_flags, + trace_state=current_ctx.trace_state, + ) + custom_parent = trace.NonRecordingSpan(custom_span_ctx) + return trace.set_span_in_context(custom_parent, ctx) diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_decorator_pattern.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_decorator_pattern.py index 8e4b6f0f0902..e03d01efe567 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/tests/test_decorator_pattern.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_decorator_pattern.py @@ -27,7 +27,7 @@ def test_invoke_handler_stores_function(self): async def handle(request: Request) -> Response: return JSONResponse({"ok": True}) - assert server._invoke_fn is handle + assert server._invocation._invoke_fn is handle def test_invoke_handler_returns_original_function(self): server = AgentServer() @@ -47,7 +47,7 @@ def test_get_invocation_handler_stores_function(self): async def handle(request: Request) -> Response: return JSONResponse({"found": True}) - assert server._get_invocation_fn is handle + assert server._invocation._get_invocation_fn is handle def test_cancel_invocation_handler_stores_function(self): server = AgentServer() @@ -56,7 +56,7 @@ def test_cancel_invocation_handler_stores_function(self): async def handle(request: Request) -> Response: return JSONResponse({"cancelled": True}) - assert server._cancel_invocation_fn is handle + assert server._invocation._cancel_invocation_fn is handle def test_shutdown_handler_stores_function(self): server = AgentServer() diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_request_limits.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_request_limits.py index aa4d387152d6..c9f0efee996d 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/tests/test_request_limits.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_request_limits.py @@ -137,7 +137,7 @@ async def test_timeout_disabled_allows_no_limit(self): async def test_timeout_error_is_logged(self): agent = _make_slow_agent(request_timeout=1) transport = httpx.ASGITransport(app=agent.app) - with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + with patch("azure.ai.agentserver.server._invocation.logger") as mock_logger: async with httpx.AsyncClient( transport=transport, base_url="http://testserver" ) as client: diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py index ad49224f8379..15f78b7d5bf0 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py @@ -118,7 +118,7 @@ async def test_tracing_enabled_creates_invoke_span(span_exporter): assert resp.status_code == 200 spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.name == "AgentServer.invoke"] + invoke_spans = [s for s in spans if s.name == "execute_agent"] assert len(invoke_spans) == 1 span = invoke_spans[0] @@ -136,7 +136,7 @@ async def test_tracing_invoke_error_records_exception(span_exporter): assert resp.status_code == 500 spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.name == "AgentServer.invoke"] + invoke_spans = [s for s in spans if s.name == "execute_agent"] assert len(invoke_spans) == 1 span = invoke_spans[0] @@ -157,7 +157,7 @@ async def test_tracing_get_invocation_creates_span(span_exporter): assert resp.status_code == 501 spans = span_exporter.get_finished_spans() - get_spans = [s for s in spans if s.name == "AgentServer.get_invocation"] + get_spans = [s for s in spans if s.name == "get_invocation"] assert len(get_spans) == 1 assert dict(get_spans[0].attributes)["invocation.id"] == "test-id-123" @@ -172,7 +172,7 @@ async def test_tracing_cancel_invocation_creates_span(span_exporter): assert resp.status_code == 501 spans = span_exporter.get_finished_spans() - cancel_spans = [s for s in spans if s.name == "AgentServer.cancel_invocation"] + cancel_spans = [s for s in spans if s.name == "cancel_invocation"] assert len(cancel_spans) == 1 assert dict(cancel_spans[0].attributes)["invocation.id"] == "test-cancel-456" @@ -194,7 +194,7 @@ async def test_tracing_enabled_via_env_var(monkeypatch, span_exporter): await client.post("/invocations", content=b'{}') spans = span_exporter.get_finished_spans() - assert any(s.name == "AgentServer.invoke" for s in spans) + assert any(s.name == "execute_agent" for s in spans) @pytest.mark.asyncio @@ -221,7 +221,7 @@ async def test_tracing_propagates_traceparent(span_exporter): assert resp.status_code == 200 spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.name == "AgentServer.invoke"] + invoke_spans = [s for s in spans if s.name == "execute_agent"] assert len(invoke_spans) == 1 span = invoke_spans[0] # The span's trace ID should match the traceparent's trace ID @@ -290,7 +290,7 @@ async def test_streaming_response_creates_span(span_exporter): assert resp.status_code == 200 spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.name == "AgentServer.invoke"] + invoke_spans = [s for s in spans if s.name == "execute_agent"] assert len(invoke_spans) == 1 assert invoke_spans[0].status.status_code == trace.StatusCode.UNSET @@ -309,7 +309,7 @@ async def test_streaming_span_covers_full_body(span_exporter): assert b"slow-2" in resp.content spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.name == "AgentServer.invoke"] + invoke_spans = [s for s in spans if s.name == "execute_agent"] assert len(invoke_spans) == 1 span = invoke_spans[0] @@ -347,7 +347,7 @@ async def test_streaming_error_recorded_in_span(span_exporter): pass # connection reset / partial read is acceptable spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.name == "AgentServer.invoke"] + invoke_spans = [s for s in spans if s.name == "execute_agent"] assert len(invoke_spans) == 1 span = invoke_spans[0] @@ -386,7 +386,7 @@ async def test_streaming_propagates_traceparent(span_exporter): assert resp.status_code == 200 spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.name == "AgentServer.invoke"] + invoke_spans = [s for s in spans if s.name == "execute_agent"] assert len(invoke_spans) == 1 expected_trace_id = int("0af7651916cd43dd8448eb211c80319c", 16) assert invoke_spans[0].context.trace_id == expected_trace_id @@ -521,3 +521,375 @@ def test_tracing_disabled_skips_connection_string_resolution(self, monkeypatch): agent = _make_echo_traced_agent(enable_tracing=False) assert agent._tracing is None # noqa: SLF001 + + +# --------------------------------------------------------------------------- +# Tests: span naming with agent name and version +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_span_name_includes_agent_label(monkeypatch, span_exporter): + """Span name includes agent_name:agent_version when env vars are set.""" + monkeypatch.setenv("AGENT_NAME", "my-agent") + monkeypatch.setenv("AGENT_VERSION", "2.1") + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b'{}') + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if s.name == "execute_agent my-agent:2.1"] + assert len(invoke_spans) == 1 + + +@pytest.mark.asyncio +async def test_span_name_without_agent_label(span_exporter, monkeypatch): + """Span name is just the operation when AGENT_NAME is not set.""" + monkeypatch.delenv("AGENT_NAME", raising=False) + monkeypatch.delenv("AGENT_VERSION", raising=False) + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b'{}') + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if s.name == "execute_agent"] + assert len(invoke_spans) == 1 + + +@pytest.mark.asyncio +async def test_get_invocation_span_name_with_label(monkeypatch, span_exporter): + """GET /invocations/{id} span includes agent label.""" + monkeypatch.setenv("AGENT_NAME", "agent-x") + monkeypatch.setenv("AGENT_VERSION", "0.5") + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.get("/invocations/test-id") + + spans = span_exporter.get_finished_spans() + get_spans = [s for s in spans if s.name == "get_invocation agent-x:0.5"] + assert len(get_spans) == 1 + + +@pytest.mark.asyncio +async def test_cancel_invocation_span_name_with_label(monkeypatch, span_exporter): + """POST /invocations/{id}/cancel span includes agent label.""" + monkeypatch.setenv("AGENT_NAME", "agent-x") + monkeypatch.setenv("AGENT_VERSION", "0.5") + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations/test-id/cancel") + + spans = span_exporter.get_finished_spans() + cancel_spans = [s for s in spans if s.name == "cancel_invocation agent-x:0.5"] + assert len(cancel_spans) == 1 + + +# --------------------------------------------------------------------------- +# Tests: GenAI semantic convention attributes +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_genai_attributes_on_invoke_span(span_exporter, monkeypatch): + """Invoke span has GenAI semantic convention attributes.""" + monkeypatch.setenv("AGENT_NAME", "test-agent") + monkeypatch.setenv("AGENT_VERSION", "1.0") + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b'{}') + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + attrs = dict(invoke_spans[0].attributes) + assert attrs["gen_ai.operation.name"] == "invoke_agent" + assert attrs["gen_ai.agent.id"] == "test-agent:1.0" + assert attrs["gen_ai.provider.name"] == "microsoft.foundry" + assert "gen_ai.response.id" in attrs # UUID invocation ID + + +@pytest.mark.asyncio +async def test_genai_conversation_id_from_session_header(span_exporter, monkeypatch): + """gen_ai.conversation.id is set from x-agent-session-id header.""" + monkeypatch.delenv("AGENT_NAME", raising=False) + monkeypatch.delenv("AGENT_VERSION", raising=False) + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post( + "/invocations", + content=b'{}', + headers={"x-agent-session-id": "session-abc-123"}, + ) + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + attrs = dict(invoke_spans[0].attributes) + assert attrs["gen_ai.conversation.id"] == "session-abc-123" + + +@pytest.mark.asyncio +async def test_genai_conversation_id_absent_when_no_header(span_exporter, monkeypatch): + """gen_ai.conversation.id is NOT set when x-agent-session-id header is absent.""" + monkeypatch.delenv("AGENT_NAME", raising=False) + monkeypatch.delenv("AGENT_VERSION", raising=False) + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b'{}') + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + attrs = dict(invoke_spans[0].attributes) + assert "gen_ai.conversation.id" not in attrs + + +@pytest.mark.asyncio +async def test_genai_attributes_on_get_invocation_span(span_exporter, monkeypatch): + """GET /invocations/{id} span has GenAI attributes (minus operation.name).""" + monkeypatch.setenv("AGENT_NAME", "test-agent") + monkeypatch.setenv("AGENT_VERSION", "1.0") + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.get( + "/invocations/inv-42", + headers={"x-agent-session-id": "sess-99"}, + ) + + spans = span_exporter.get_finished_spans() + get_spans = [s for s in spans if "get_invocation" in s.name] + assert len(get_spans) == 1 + + attrs = dict(get_spans[0].attributes) + assert attrs["gen_ai.agent.id"] == "test-agent:1.0" + assert attrs["gen_ai.provider.name"] == "microsoft.foundry" + assert attrs["gen_ai.response.id"] == "inv-42" + assert attrs["gen_ai.conversation.id"] == "sess-99" + assert attrs["invocation.id"] == "inv-42" + + +# --------------------------------------------------------------------------- +# Tests: baggage extraction and leaf_customer_span_id +# --------------------------------------------------------------------------- + + +class TestBaggageParsing: + """Unit tests for _parse_baggage_key.""" + + def test_single_key(self): + from azure.ai.agentserver.server._tracing import _parse_baggage_key + assert _parse_baggage_key("leaf_customer_span_id=abc123", "leaf_customer_span_id") == "abc123" + + def test_multiple_keys(self): + from azure.ai.agentserver.server._tracing import _parse_baggage_key + baggage = "foo=bar,leaf_customer_span_id=deadbeef01234567,baz=qux" + assert _parse_baggage_key(baggage, "leaf_customer_span_id") == "deadbeef01234567" + + def test_key_not_present(self): + from azure.ai.agentserver.server._tracing import _parse_baggage_key + assert _parse_baggage_key("foo=bar,baz=qux", "leaf_customer_span_id") is None + + def test_empty_baggage(self): + from azure.ai.agentserver.server._tracing import _parse_baggage_key + assert _parse_baggage_key("", "leaf_customer_span_id") is None + + def test_key_with_properties(self): + from azure.ai.agentserver.server._tracing import _parse_baggage_key + baggage = "leaf_customer_span_id=abc123;property1=val1,other=2" + assert _parse_baggage_key(baggage, "leaf_customer_span_id") == "abc123" + + def test_whitespace_handling(self): + from azure.ai.agentserver.server._tracing import _parse_baggage_key + baggage = " leaf_customer_span_id = abc123 , other = val " + assert _parse_baggage_key(baggage, "leaf_customer_span_id") == "abc123" + + +class TestExtractBaggageHeader: + """Unit tests for extract_baggage_header.""" + + def test_present(self): + from azure.ai.agentserver.server._tracing import extract_baggage_header + headers = {"baggage": "key=val", "other": "x"} + assert extract_baggage_header(headers) == "key=val" + + def test_absent(self): + from azure.ai.agentserver.server._tracing import extract_baggage_header + assert extract_baggage_header({"other": "x"}) is None + + +@pytest.mark.asyncio +async def test_baggage_leaf_customer_span_id_overrides_parent(span_exporter): + """When baggage contains leaf_customer_span_id, the span's parent span ID + is overridden to match, while the trace ID stays the same as traceparent.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + + trace_id_hex = "0af7651916cd43dd8448eb211c80319c" + original_parent_hex = "b7ad6b7169203331" + leaf_span_hex = "00f067aa0ba902b7" + + traceparent = f"00-{trace_id_hex}-{original_parent_hex}-01" + baggage = f"leaf_customer_span_id={leaf_span_hex},other=val" + + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post( + "/invocations", + content=b'{"data": 1}', + headers={"traceparent": traceparent, "baggage": baggage}, + ) + assert resp.status_code == 200 + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + span = invoke_spans[0] + # Trace ID should match traceparent + expected_trace_id = int(trace_id_hex, 16) + assert span.context.trace_id == expected_trace_id + + # Parent span ID should be the leaf_customer_span_id, not the original + expected_parent_span_id = int(leaf_span_hex, 16) + assert span.parent.span_id == expected_parent_span_id + + +@pytest.mark.asyncio +async def test_baggage_without_leaf_uses_traceparent_parent(span_exporter): + """When baggage is present but does NOT contain leaf_customer_span_id, + the parent span ID comes from traceparent as usual.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + + trace_id_hex = "0af7651916cd43dd8448eb211c80319c" + parent_hex = "b7ad6b7169203331" + traceparent = f"00-{trace_id_hex}-{parent_hex}-01" + baggage = "some_other_key=value" + + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post( + "/invocations", + content=b'{}', + headers={"traceparent": traceparent, "baggage": baggage}, + ) + assert resp.status_code == 200 + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + span = invoke_spans[0] + expected_parent_span_id = int(parent_hex, 16) + assert span.parent.span_id == expected_parent_span_id + + +@pytest.mark.asyncio +async def test_baggage_no_traceparent_no_crash(span_exporter): + """When baggage is present but no traceparent, no crash occurs.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post( + "/invocations", + content=b'{}', + headers={"baggage": "leaf_customer_span_id=00f067aa0ba902b7"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_baggage_invalid_leaf_span_id_falls_back(span_exporter): + """Invalid hex in leaf_customer_span_id falls back to traceparent parent.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + + trace_id_hex = "0af7651916cd43dd8448eb211c80319c" + parent_hex = "b7ad6b7169203331" + traceparent = f"00-{trace_id_hex}-{parent_hex}-01" + baggage = "leaf_customer_span_id=not_valid_hex" + + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post( + "/invocations", + content=b'{}', + headers={"traceparent": traceparent, "baggage": baggage}, + ) + assert resp.status_code == 200 + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + span = invoke_spans[0] + # Should fall back to traceparent's parent span ID + expected_parent = int(parent_hex, 16) + assert span.parent.span_id == expected_parent + + +@pytest.mark.asyncio +async def test_baggage_leaf_on_get_invocation(span_exporter): + """Baggage leaf_customer_span_id also works on GET /invocations/{id}.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + + trace_id_hex = "0af7651916cd43dd8448eb211c80319c" + original_parent_hex = "b7ad6b7169203331" + leaf_span_hex = "00f067aa0ba902b7" + + traceparent = f"00-{trace_id_hex}-{original_parent_hex}-01" + baggage = f"leaf_customer_span_id={leaf_span_hex}" + + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.get( + "/invocations/test-id", + headers={"traceparent": traceparent, "baggage": baggage}, + ) + + spans = span_exporter.get_finished_spans() + get_spans = [s for s in spans if "get_invocation" in s.name] + assert len(get_spans) == 1 + + expected_trace_id = int(trace_id_hex, 16) + expected_parent = int(leaf_span_hex, 16) + assert get_spans[0].context.trace_id == expected_trace_id + assert get_spans[0].parent.span_id == expected_parent + + +# --------------------------------------------------------------------------- +# Tests: agent name / version resolution +# --------------------------------------------------------------------------- + + +class TestAgentNameVersionResolution: + """Tests for resolve_agent_name and resolve_agent_version.""" + + def test_agent_name_from_env(self, monkeypatch): + from azure.ai.agentserver.server._config import resolve_agent_name + monkeypatch.setenv("AGENT_NAME", "my-agent") + assert resolve_agent_name() == "my-agent" + + def test_agent_name_default_empty(self, monkeypatch): + from azure.ai.agentserver.server._config import resolve_agent_name + monkeypatch.delenv("AGENT_NAME", raising=False) + assert resolve_agent_name() == "" + + def test_agent_version_from_env(self, monkeypatch): + from azure.ai.agentserver.server._config import resolve_agent_version + monkeypatch.setenv("AGENT_VERSION", "3.0.1") + assert resolve_agent_version() == "3.0.1" + + def test_agent_version_default_empty(self, monkeypatch): + from azure.ai.agentserver.server._config import resolve_agent_version + monkeypatch.delenv("AGENT_VERSION", raising=False) + assert resolve_agent_version() == "" From 38403c05e782adbc5a0eaacb11f859a5514aed47 Mon Sep 17 00:00:00 2001 From: Zhiyong Yang Date: Thu, 12 Mar 2026 17:44:16 -0700 Subject: [PATCH 07/10] Consolidate tracing into TracingHelper, make internal classes private MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move request-level header extraction and span creation into _TracingHelper convenience methods so protocol heads get tracing with a single call. Rename internal classes (TracingHelper → _TracingHelper, OpenApiValidator → _OpenApiValidator) to reflect private status. Extract span attribute keys into module constants, refactor Azure Monitor setup into focused helpers, and rename the confusing operation/operation_name params to span_operation/operation_name. Co-Authored-By: Claude Opus 4.6 --- .../azure/ai/agentserver/server/_base.py | 12 +- .../azure/ai/agentserver/server/_config.py | 9 + .../azure/ai/agentserver/server/_constants.py | 1 + .../ai/agentserver/server/_invocation.py | 148 +++----- .../agentserver/server/_openapi_validator.py | 20 +- .../ai/agentserver/server/_server_context.py | 4 +- .../azure/ai/agentserver/server/_tracing.py | 346 ++++++++++++++---- .../azure-ai-agentserver-server/cspell.json | 2 + .../tests/test_graceful_shutdown.py | 20 +- .../tests/test_openapi_validation.py | 86 ++--- .../tests/test_tracing.py | 115 ++++-- 11 files changed, 502 insertions(+), 261 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_base.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_base.py index 01d8fad49060..e71d98ab36da 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_base.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_base.py @@ -14,8 +14,8 @@ from ._constants import Constants from ._logger import get_logger -from ._tracing import TracingHelper -from ._openapi_validator import OpenApiValidator +from ._tracing import _TracingHelper +from ._openapi_validator import _OpenApiValidator from ._invocation import _InvocationProtocol from ._server_context import _ServerContext from . import _config @@ -123,8 +123,8 @@ def __init__( _validation_on = _config.resolve_bool_feature( enable_request_validation, Constants.AGENT_ENABLE_REQUEST_VALIDATION ) - validator: Optional[OpenApiValidator] = ( - OpenApiValidator(openapi_spec) + validator: Optional[_OpenApiValidator] = ( + _OpenApiValidator(openapi_spec) if openapi_spec and _validation_on else None ) @@ -134,8 +134,8 @@ def __init__( _conn_str = _config.resolve_appinsights_connection_string( application_insights_connection_string ) if _tracing_on else None - self._tracing: Optional[TracingHelper] = ( - TracingHelper(connection_string=_conn_str) if _tracing_on else None + self._tracing: Optional[_TracingHelper] = ( + _TracingHelper(connection_string=_conn_str) if _tracing_on else None ) # Timeouts --------------------------------------------------------- diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_config.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_config.py index 7e2ffbb75a76..8110414a3b8d 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_config.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_config.py @@ -205,3 +205,12 @@ def resolve_agent_version() -> str: :rtype: str """ return os.environ.get(Constants.AGENT_VERSION, "") + + +def resolve_project_id() -> str: + """Resolve the Foundry project ID from the ``AGENT_PROJECT_NAME`` environment variable. + + :return: The project ID, or an empty string if not set. + :rtype: str + """ + return os.environ.get(Constants.AGENT_PROJECT_NAME, "") diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py index ceb3f42ef302..8be43b96f54d 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py @@ -16,6 +16,7 @@ class Constants: APPLICATIONINSIGHTS_CONNECTION_STRING = "APPLICATIONINSIGHTS_CONNECTION_STRING" AGENT_NAME = "AGENT_NAME" AGENT_VERSION = "AGENT_VERSION" + AGENT_PROJECT_NAME = "AGENT_PROJECT_NAME" DEFAULT_PORT = 8088 DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT = 30 DEFAULT_REQUEST_TIMEOUT = 300 # 5 minutes diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_invocation.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_invocation.py index ea36c8e0644a..a8b2e60925a0 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_invocation.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_invocation.py @@ -24,10 +24,8 @@ from ._constants import Constants from ._errors import error_response from ._logger import get_logger -from ._tracing import extract_baggage_header, extract_w3c_carrier -from ._openapi_validator import OpenApiValidator +from ._openapi_validator import _OpenApiValidator from ._server_context import _ServerContext -from . import _config logger = get_logger() @@ -47,7 +45,7 @@ def __init__( self, ctx: _ServerContext, openapi_spec: Optional[dict[str, Any]], - validator: Optional[OpenApiValidator], + validator: Optional[_OpenApiValidator], ) -> None: """Initialise the invocation protocol. @@ -56,7 +54,7 @@ def __init__( :param openapi_spec: Optional OpenAPI spec dict for documentation. :type openapi_spec: Optional[dict[str, Any]] :param validator: Optional request validator built from the spec. - :type validator: Optional[OpenApiValidator] + :type validator: Optional[_OpenApiValidator] """ self._ctx = ctx self._invoke_fn: Optional[Callable] = None @@ -65,15 +63,6 @@ def __init__( self._openapi_spec = openapi_spec self._validator = validator - # Agent identity — used for span naming and GenAI attributes. - self._agent_name = _config.resolve_agent_name() - self._agent_version = _config.resolve_agent_version() - self._agent_label = ( - f"{self._agent_name}:{self._agent_version}" - if self._agent_name - else "" - ) - # ------------------------------------------------------------------ # Route registration # ------------------------------------------------------------------ @@ -112,52 +101,6 @@ def routes(self) -> list[Route]: ), ] - # ------------------------------------------------------------------ - # Span name helper (protocol-specific) - # ------------------------------------------------------------------ - - def _span_name(self, operation: str) -> str: - """Build a span name using the operation and agent label. - - :param operation: The operation name (e.g. ``"execute_agent"``). - :type operation: str - :return: ``" :"`` or just ``""``. - :rtype: str - """ - if self._agent_label: - return f"{operation} {self._agent_label}" - return operation - - def _build_span_attrs( - self, - invocation_id: str, - session_id: str, - operation_name: Optional[str] = None, - ) -> dict[str, str]: - """Build GenAI semantic convention span attributes. - - :param invocation_id: The invocation ID for this request. - :type invocation_id: str - :param session_id: The session ID header value (empty string if absent). - :type session_id: str - :param operation_name: Optional ``gen_ai.operation.name`` value - (e.g. ``"invoke_agent"``). Omitted from the dict when *None*. - :type operation_name: Optional[str] - :return: Span attribute dict. - :rtype: dict[str, str] - """ - attrs: dict[str, str] = { - "invocation.id": invocation_id, - "gen_ai.agent.id": self._agent_label, - "gen_ai.response.id": invocation_id, - "gen_ai.provider.name": "microsoft.foundry", - } - if operation_name: - attrs["gen_ai.operation.name"] = operation_name - if session_id: - attrs["gen_ai.conversation.id"] = session_id - return attrs - # ------------------------------------------------------------------ # Handler decorators # ------------------------------------------------------------------ @@ -165,21 +108,39 @@ def _build_span_attrs( def invoke_handler( self, fn: Callable[[Request], Awaitable[Response]] ) -> Callable[[Request], Awaitable[Response]]: - """Store *fn* as the invoke handler. See :meth:`AgentServer.invoke_handler`.""" + """Store *fn* as the invoke handler. See :meth:`AgentServer.invoke_handler`. + + :param fn: Async function accepting a Starlette Request and returning a Response. + :type fn: Callable[[Request], Awaitable[Response]] + :return: The original function (unmodified). + :rtype: Callable[[Request], Awaitable[Response]] + """ self._invoke_fn = fn return fn def get_invocation_handler( self, fn: Callable[[Request], Awaitable[Response]] ) -> Callable[[Request], Awaitable[Response]]: - """Store *fn* as the get-invocation handler. See :meth:`AgentServer.get_invocation_handler`.""" + """Store *fn* as the get-invocation handler. See :meth:`AgentServer.get_invocation_handler`. + + :param fn: Async function accepting a Starlette Request and returning a Response. + :type fn: Callable[[Request], Awaitable[Response]] + :return: The original function (unmodified). + :rtype: Callable[[Request], Awaitable[Response]] + """ self._get_invocation_fn = fn return fn def cancel_invocation_handler( self, fn: Callable[[Request], Awaitable[Response]] ) -> Callable[[Request], Awaitable[Response]]: - """Store *fn* as the cancel-invocation handler. See :meth:`AgentServer.cancel_invocation_handler`.""" + """Store *fn* as the cancel-invocation handler. See :meth:`AgentServer.cancel_invocation_handler`. + + :param fn: Async function accepting a Starlette Request and returning a Response. + :type fn: Callable[[Request], Awaitable[Response]] + :return: The original function (unmodified). + :rtype: Callable[[Request], Awaitable[Response]] + """ self._cancel_invocation_fn = fn return fn @@ -227,7 +188,11 @@ async def _dispatch_cancel_invocation(self, request: Request) -> Response: return error_response("not_supported", "cancel_invocation not supported", status_code=501) def get_openapi_spec(self) -> Optional[dict[str, Any]]: - """Return the stored OpenAPI spec. See :meth:`AgentServer.get_openapi_spec`.""" + """Return the stored OpenAPI spec. See :meth:`AgentServer.get_openapi_spec`. + + :return: The registered OpenAPI spec or None. + :rtype: Optional[dict[str, Any]] + """ return self._openapi_spec # ------------------------------------------------------------------ @@ -277,25 +242,16 @@ async def _create_invocation_endpoint(self, request: Request) -> Response: ], ) - carrier = extract_w3c_carrier(request.headers) - baggage = extract_baggage_header(request.headers) - session_id = request.headers.get(Constants.SESSION_ID_HEADER, "") - span_attrs = self._build_span_attrs( - invocation_id, session_id, operation_name="invoke_agent" - ) - # Use manual span management so that streaming responses keep the # span open until the last chunk is yielded (or an error occurs). - otel_span = ( - self._ctx.tracing.start_span( - self._span_name("execute_agent"), - attributes=span_attrs, - carrier=carrier, - baggage_header=baggage, + otel_span = None + if self._ctx.tracing is not None: + otel_span = self._ctx.tracing.start_request_span( + request.headers, + invocation_id, + span_operation="execute_agent", + operation_name="invoke_agent", ) - if self._ctx.tracing is not None - else None - ) try: invoke_awaitable = self._dispatch_invoke(request) timeout = self._ctx.request_timeout or None # 0 → None (no limit) @@ -351,7 +307,7 @@ async def _create_invocation_endpoint(self, request: Request) -> Response: async def _traced_invocation_endpoint( self, request: Request, - span_name: str, + span_operation: str, dispatch: Callable[[Request], Awaitable[Response]], ) -> Response: """Shared implementation for get/cancel invocation endpoints. @@ -361,8 +317,9 @@ async def _traced_invocation_endpoint( :param request: The incoming Starlette request. :type request: Request - :param span_name: OTel span name (e.g. ``"get_invocation"``). - :type span_name: str + :param span_operation: Span operation name (e.g. + ``"get_invocation"``). + :type span_operation: str :param dispatch: The dispatch method to invoke. :type dispatch: Callable[[Request], Awaitable[Response]] :return: The handler response or an error response. @@ -370,21 +327,12 @@ async def _traced_invocation_endpoint( """ invocation_id = request.path_params["invocation_id"] request.state.invocation_id = invocation_id - carrier = extract_w3c_carrier(request.headers) if self._ctx.tracing is not None else {} - baggage = extract_baggage_header(request.headers) if self._ctx.tracing is not None else None - session_id = request.headers.get(Constants.SESSION_ID_HEADER, "") - span_attrs = self._build_span_attrs(invocation_id, session_id) - - span_cm = ( - self._ctx.tracing.span( - span_name, - attributes=span_attrs, - carrier=carrier, - baggage_header=baggage, + + span_cm: Any = contextlib.nullcontext(None) + if self._ctx.tracing is not None: + span_cm = self._ctx.tracing.request_span( + request.headers, invocation_id, span_operation ) - if self._ctx.tracing is not None - else contextlib.nullcontext(None) - ) with span_cm as _otel_span: try: response = await dispatch(request) @@ -393,7 +341,7 @@ async def _traced_invocation_endpoint( except Exception as exc: # pylint: disable=broad-exception-caught if self._ctx.tracing is not None: self._ctx.tracing.record_error(_otel_span, exc) - logger.error("Error in %s %s: %s", span_name, invocation_id, exc, exc_info=True) + logger.error("Error in %s %s: %s", span_operation, invocation_id, exc, exc_info=True) message = str(exc) if self._ctx.debug_errors else "Internal server error" return error_response( "internal_error", @@ -411,7 +359,7 @@ async def _get_invocation_endpoint(self, request: Request) -> Response: :rtype: Response """ return await self._traced_invocation_endpoint( - request, self._span_name("get_invocation"), self._dispatch_get_invocation + request, "get_invocation", self._dispatch_get_invocation ) async def _cancel_invocation_endpoint(self, request: Request) -> Response: @@ -423,5 +371,5 @@ async def _cancel_invocation_endpoint(self, request: Request) -> Response: :rtype: Response """ return await self._traced_invocation_endpoint( - request, self._span_name("cancel_invocation"), self._dispatch_cancel_invocation + request, "cancel_invocation", self._dispatch_cancel_invocation ) diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_openapi_validator.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_openapi_validator.py index 3345b081d5df..1e01a1b1e1d8 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_openapi_validator.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_openapi_validator.py @@ -103,7 +103,7 @@ def _check_email(value: object) -> bool: ) -class OpenApiValidator: +class _OpenApiValidator: """Validates request/response bodies against an OpenAPI spec. Extracts the request and response JSON schemas from the provided OpenAPI spec dict @@ -211,7 +211,7 @@ def _extract_request_schema(spec: dict[str, Any], path: str) -> Optional[dict[st :return: JSON Schema dict or None. :rtype: Optional[dict[str, Any]] """ - return OpenApiValidator._find_schema_in_paths( + return _OpenApiValidator._find_schema_in_paths( spec, path, "post", "requestBody" ) @@ -226,7 +226,7 @@ def _extract_response_schema(spec: dict[str, Any], path: str) -> Optional[dict[s :return: JSON Schema dict or None. :rtype: Optional[dict[str, Any]] """ - return OpenApiValidator._find_schema_in_paths( + return _OpenApiValidator._find_schema_in_paths( spec, path, "post", "responses" ) @@ -341,9 +341,9 @@ def _preprocess_schema( :rtype: dict[str, Any] """ schema = copy.deepcopy(schema) - OpenApiValidator._apply_nullable(schema) - OpenApiValidator._strip_readonly_writeonly(schema, context) - OpenApiValidator._strip_openapi_keywords(schema) + _OpenApiValidator._apply_nullable(schema) + _OpenApiValidator._strip_readonly_writeonly(schema, context) + _OpenApiValidator._strip_openapi_keywords(schema) return schema @staticmethod @@ -365,7 +365,7 @@ def _apply_nullable(schema: dict[str, Any]) -> None: schema["type"] = [original, "null"] elif isinstance(original, list) and "null" not in original: schema["type"] = original + ["null"] - OpenApiValidator._walk_schema(schema, OpenApiValidator._apply_nullable) + _OpenApiValidator._walk_schema(schema, _OpenApiValidator._apply_nullable) @staticmethod def _strip_readonly_writeonly( @@ -403,9 +403,9 @@ def _strip_readonly_writeonly( required.remove(name) def _recurse(child: dict[str, Any]) -> None: - OpenApiValidator._strip_readonly_writeonly(child, context) + _OpenApiValidator._strip_readonly_writeonly(child, context) - OpenApiValidator._walk_schema(schema, _recurse) + _OpenApiValidator._walk_schema(schema, _recurse) @staticmethod def _strip_openapi_keywords(schema: dict[str, Any]) -> None: @@ -421,7 +421,7 @@ def _strip_openapi_keywords(schema: dict[str, Any]) -> None: return for kw in _OPENAPI_ONLY_KEYWORDS: schema.pop(kw, None) - OpenApiValidator._walk_schema(schema, OpenApiValidator._strip_openapi_keywords) + _OpenApiValidator._walk_schema(schema, _OpenApiValidator._strip_openapi_keywords) # ------------------------------------------------------------------ diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_server_context.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_server_context.py index a9078aba6450..b3d6651ab293 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_server_context.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_server_context.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: - from ._tracing import TracingHelper + from ._tracing import _TracingHelper @dataclasses.dataclass(frozen=True) @@ -20,6 +20,6 @@ class _ServerContext: timeout configuration without coupling to the ``AgentServer`` class. """ - tracing: Optional[TracingHelper] + tracing: Optional[_TracingHelper] debug_errors: bool request_timeout: int diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py index e80048645b72..7b3cb0a440a5 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py @@ -28,6 +28,8 @@ from collections.abc import AsyncIterable, AsyncIterator, Mapping # pylint: disable=import-error from typing import TYPE_CHECKING, Any, Iterator, Optional, Union +from . import _config +from ._constants import Constants from ._logger import get_logger #: Starlette's ``Content`` type — the element type for streaming bodies. @@ -39,6 +41,19 @@ #: Baggage key whose value overrides the parent span ID. _LEAF_CUSTOMER_SPAN_ID = "leaf_customer_span_id" +# ------------------------------------------------------------------ +# GenAI semantic convention attribute keys +# ------------------------------------------------------------------ +_ATTR_INVOCATION_ID = "invocation.id" +_ATTR_RESPONSE_ID = "gen_ai.response.id" +_ATTR_PROVIDER_NAME = "gen_ai.provider.name" +_ATTR_AGENT_ID = "gen_ai.agent.id" +_ATTR_PROJECT_ID = "microsoft.foundry.project.id" +_ATTR_OPERATION_NAME = "gen_ai.operation.name" +_ATTR_CONVERSATION_ID = "gen_ai.conversation.id" + +_PROVIDER_NAME_VALUE = "microsoft.foundry" + logger = get_logger() _HAS_OTEL = False @@ -53,7 +68,7 @@ from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator -class TracingHelper: +class _TracingHelper: """Lightweight wrapper around OpenTelemetry. Only instantiate when tracing is enabled. If ``opentelemetry-api`` is @@ -66,11 +81,22 @@ class TracingHelper: ``azure-monitor-opentelemetry-exporter``. """ - def __init__(self, connection_string: Optional[str] = None) -> None: + def __init__( + self, + connection_string: Optional[str] = None, + ) -> None: self._enabled = _HAS_OTEL self._tracer: Any = None self._propagator: Any = None + # Resolve agent identity from environment variables. + agent_name = _config.resolve_agent_name() + agent_version = _config.resolve_agent_version() + self._agent_label = ( + f"{agent_name}:{agent_version}" if agent_name and agent_version else agent_name + ) + self._project_id = _config.resolve_project_id() + if not self._enabled: logger.warning( "Tracing was enabled but opentelemetry-api is not installed. " @@ -136,59 +162,65 @@ def _setup_azure_monitor(connection_string: str) -> None: :param connection_string: Application Insights connection string. :type connection_string: str """ - try: - from opentelemetry.sdk.resources import Resource - from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider - from opentelemetry.sdk.trace.export import BatchSpanProcessor - - from azure.monitor.opentelemetry.exporter import ( # type: ignore[import-untyped] - AzureMonitorTraceExporter, - ) - except ImportError: - logger.warning( - "Application Insights connection string was provided but " - "required packages are not installed. Install them with: " - "pip install azure-ai-agentserver-server[tracing]" - ) + resource = _create_resource() + if resource is None: return + _setup_trace_export(resource, connection_string) + _setup_log_export(resource, connection_string) - resource = Resource.create({"service.name": "azure.ai.agentserver"}) - - # --- Trace export --- - provider = SdkTracerProvider(resource=resource) - exporter = AzureMonitorTraceExporter( - connection_string=connection_string, - ) - provider.add_span_processor(BatchSpanProcessor(exporter)) - trace.set_tracer_provider(provider) - logger.info("Application Insights trace exporter configured.") + # ------------------------------------------------------------------ + # Span naming and attribute helpers (shared by all protocols) + # ------------------------------------------------------------------ - # --- Log export --- - try: - from opentelemetry._logs import set_logger_provider - from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler - from opentelemetry.sdk._logs.export import BatchLogRecordProcessor + def span_name(self, span_operation: str) -> str: + """Build a span name using the operation and agent label. - from azure.monitor.opentelemetry.exporter import ( # type: ignore[import-untyped] - AzureMonitorLogExporter, - ) + :param span_operation: The span operation (e.g. ``"execute_agent"``). + This becomes the first token of the OTel span name. + :type span_operation: str + :return: ``" :"`` or just + ``""``. + :rtype: str + """ + if self._agent_label: + return f"{span_operation} {self._agent_label}" + return span_operation - log_provider = LoggerProvider(resource=resource) - set_logger_provider(log_provider) - log_exporter = AzureMonitorLogExporter( - connection_string=connection_string, - ) - log_provider.add_log_record_processor( - BatchLogRecordProcessor(log_exporter) - ) - handler = LoggingHandler(logger_provider=log_provider) - logging.getLogger("azure.ai.agentserver").addHandler(handler) - logger.info("Application Insights log exporter configured.") - except ImportError: - logger.warning( - "Log export to Application Insights requires " - "opentelemetry-sdk. Logs will not be forwarded." - ) + def build_span_attrs( + self, + invocation_id: str, + session_id: str, + operation_name: Optional[str] = None, + ) -> dict[str, str]: + """Build GenAI semantic convention span attributes. + + These attributes are common across all protocol heads (invocation, + chat, etc.). + + :param invocation_id: The invocation/request ID for this request. + :type invocation_id: str + :param session_id: The session ID header value (empty string if absent). + :type session_id: str + :param operation_name: Optional ``gen_ai.operation.name`` value + (e.g. ``"invoke_agent"``). Omitted from the dict when *None*. + :type operation_name: Optional[str] + :return: Span attribute dict. + :rtype: dict[str, str] + """ + attrs: dict[str, str] = { + _ATTR_INVOCATION_ID: invocation_id, + _ATTR_RESPONSE_ID: invocation_id, + _ATTR_PROVIDER_NAME: _PROVIDER_NAME_VALUE, + } + if self._agent_label: + attrs[_ATTR_AGENT_ID] = self._agent_label + if self._project_id: + attrs[_ATTR_PROJECT_ID] = self._project_id + if operation_name: + attrs[_ATTR_OPERATION_NAME] = operation_name + if session_id: + attrs[_ATTR_CONVERSATION_ID] = session_id + return attrs @contextmanager def span( @@ -267,6 +299,118 @@ def start_span( context=ctx, ) + # ------------------------------------------------------------------ + # Request-level convenience wrappers + # ------------------------------------------------------------------ + + def _prepare_request_span_args( + self, + headers: Mapping[str, str], + invocation_id: str, + span_operation: str, + operation_name: Optional[str] = None, + ) -> tuple[str, dict[str, str], dict[str, str], Optional[str]]: + """Extract headers and build span arguments for a request. + + Shared pipeline used by :meth:`start_request_span` and + :meth:`request_span` to avoid duplicating header extraction, + attribute building, and span naming. + + :param headers: HTTP request headers (any ``Mapping[str, str]``). + :type headers: Mapping[str, str] + :param invocation_id: The invocation/request ID. + :type invocation_id: str + :param span_operation: Span operation (e.g. ``"execute_agent"``). + :type span_operation: str + :param operation_name: Optional ``gen_ai.operation.name`` value. + :type operation_name: Optional[str] + :return: ``(name, attributes, carrier, baggage)`` ready for + :meth:`span` or :meth:`start_span`. + :rtype: tuple[str, dict[str, str], dict[str, str], Optional[str]] + """ + carrier = _extract_w3c_carrier(headers) + baggage = headers.get("baggage") + session_id = headers.get(Constants.SESSION_ID_HEADER, "") + span_attrs = self.build_span_attrs( + invocation_id, session_id, operation_name=operation_name + ) + return self.span_name(span_operation), span_attrs, carrier, baggage + + def start_request_span( + self, + headers: Mapping[str, str], + invocation_id: str, + span_operation: str, + operation_name: Optional[str] = None, + ) -> Any: + """Start a request-scoped span, extracting context from HTTP headers. + + Convenience method that combines header extraction, attribute + building, span naming, and span creation into a single call. + Use for streaming responses where the span must outlive the + initial handler call. The caller **must** call :meth:`end_span` + when work is finished. + + :param headers: HTTP request headers (any ``Mapping[str, str]``). + :type headers: Mapping[str, str] + :param invocation_id: The invocation/request ID. + :type invocation_id: str + :param span_operation: Span operation (e.g. ``"execute_agent"``). + Becomes the first token of the OTel span name via + :meth:`span_name`. + :type span_operation: str + :param operation_name: Optional ``gen_ai.operation.name`` attribute + value (e.g. ``"invoke_agent"``). Omitted when *None*. + :type operation_name: Optional[str] + :return: The OTel span, or *None* when tracing is disabled. + :rtype: Any + """ + name, attrs, carrier, baggage = self._prepare_request_span_args( + headers, invocation_id, span_operation, operation_name + ) + return self.start_span(name, attributes=attrs, carrier=carrier, baggage_header=baggage) + + @contextmanager + def request_span( + self, + headers: Mapping[str, str], + invocation_id: str, + span_operation: str, + operation_name: Optional[str] = None, + ) -> Iterator[Any]: + """Create a request-scoped span as a context manager. + + Convenience method that combines header extraction, attribute + building, span naming, and span creation into a single call. + Use for non-streaming request handlers where the span should + cover the entire handler execution. + + :param headers: HTTP request headers (any ``Mapping[str, str]``). + :type headers: Mapping[str, str] + :param invocation_id: The invocation/request ID. + :type invocation_id: str + :param span_operation: Span operation (e.g. ``"get_invocation"``). + Becomes the first token of the OTel span name via + :meth:`span_name`. + :type span_operation: str + :param operation_name: Optional ``gen_ai.operation.name`` attribute + value. Omitted when *None*. + :type operation_name: Optional[str] + :return: Context manager that yields the OTel span or *None*. + :rtype: Iterator[Any] + """ + name, attrs, carrier, baggage = self._prepare_request_span_args( + headers, invocation_id, span_operation, operation_name + ) + with self.span( + name, attributes=attrs, carrier=carrier, baggage_header=baggage + ) as otel_span: + yield otel_span + + # ------------------------------------------------------------------ + # Span lifecycle helpers + # ------------------------------------------------------------------ + def end_span(self, span: Any, exc: Optional[Exception] = None) -> None: """End a span started with :meth:`start_span`. @@ -330,7 +474,88 @@ async def trace_stream( self.end_span(span, exc=error) -def extract_w3c_carrier(headers: Mapping[str, str]) -> dict[str, str]: +def _create_resource() -> Any: + """Create the OTel resource for Azure Monitor exporters. + + :return: A :class:`~opentelemetry.sdk.resources.Resource`, or *None* + if the required packages are not installed. + :rtype: Any + """ + try: + from opentelemetry.sdk.resources import Resource + except ImportError: + logger.warning( + "Application Insights connection string was provided but " + "required packages are not installed. Install them with: " + "pip install azure-ai-agentserver-server[tracing]" + ) + return None + return Resource.create({"service.name": "azure.ai.agentserver"}) + + +def _setup_trace_export(resource: Any, connection_string: str) -> None: + """Configure a global :class:`TracerProvider` that exports to App Insights. + + :param resource: The OTel resource describing this service. + :type resource: Any + :param connection_string: Application Insights connection string. + :type connection_string: str + """ + try: + from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + + from azure.monitor.opentelemetry.exporter import ( # type: ignore[import-untyped] + AzureMonitorTraceExporter, + ) + except ImportError: + logger.warning( + "Trace export to Application Insights requires " + "opentelemetry-sdk and azure-monitor-opentelemetry-exporter. " + "Traces will not be forwarded." + ) + return + + provider = SdkTracerProvider(resource=resource) + exporter = AzureMonitorTraceExporter(connection_string=connection_string) + provider.add_span_processor(BatchSpanProcessor(exporter)) + trace.set_tracer_provider(provider) + logger.info("Application Insights trace exporter configured.") + + +def _setup_log_export(resource: Any, connection_string: str) -> None: + """Configure a global :class:`LoggerProvider` that exports to App Insights. + + :param resource: The OTel resource describing this service. + :type resource: Any + :param connection_string: Application Insights connection string. + :type connection_string: str + """ + try: + from opentelemetry._logs import set_logger_provider + from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler + from opentelemetry.sdk._logs.export import BatchLogRecordProcessor + + from azure.monitor.opentelemetry.exporter import ( # type: ignore[import-untyped] + AzureMonitorLogExporter, + ) + except ImportError: + logger.warning( + "Log export to Application Insights requires " + "opentelemetry-sdk. Logs will not be forwarded." + ) + return + + log_provider = LoggerProvider(resource=resource) + set_logger_provider(log_provider) + log_exporter = AzureMonitorLogExporter(connection_string=connection_string) + log_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter)) + handler = LoggingHandler(logger_provider=log_provider) + logging.getLogger("azure.ai.agentserver").addHandler(handler) + logger.info("Application Insights log exporter configured.") + + +def _extract_w3c_carrier(headers: Mapping[str, str]) -> dict[str, str]: """Extract W3C trace-context headers from a mapping. Filters the input to only ``traceparent`` and ``tracestate`` — the two @@ -347,27 +572,10 @@ def extract_w3c_carrier(headers: Mapping[str, str]) -> dict[str, str]: in *headers*. :rtype: dict[str, str] """ - result: dict[str, str] = {} - for key in _W3C_HEADERS: - val = headers.get(key) - if val is not None: - result[key] = val + result: dict[str, str] = {k: v for k in _W3C_HEADERS if (v := headers.get(k)) is not None} return result -def extract_baggage_header(headers: Mapping[str, str]) -> Optional[str]: - """Extract the raw ``baggage`` header value from a mapping. - - Returns *None* if the header is not present. - - :param headers: A mapping of header name to value. - :type headers: Mapping[str, str] - :return: The raw baggage header value or None. - :rtype: Optional[str] - """ - return headers.get("baggage") - - def _parse_baggage_key(baggage: str, key: str) -> Optional[str]: """Parse a single key from a W3C Baggage header value. diff --git a/sdk/agentserver/azure-ai-agentserver-server/cspell.json b/sdk/agentserver/azure-ai-agentserver-server/cspell.json index 5af59c8e52e0..4cf0dce8914d 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/cspell.json +++ b/sdk/agentserver/azure-ai-agentserver-server/cspell.json @@ -9,6 +9,7 @@ "behaviour", "caplog", "delenv", + "genai", "hypercorn", "invocations", "langgraph", @@ -18,6 +19,7 @@ "requestschema", "rtype", "serialisation", + "sess", "Specialised", "Standardised", "starlette", diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_graceful_shutdown.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_graceful_shutdown.py index 17f6b6c7e7b9..31da3d022640 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/tests/test_graceful_shutdown.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_graceful_shutdown.py @@ -100,6 +100,9 @@ def test_run_passes_timeout(self, mock_asyncio, _mock_serve): agent = _make_stub_agent(graceful_shutdown_timeout=15) agent.run() mock_asyncio.run.assert_called_once() + # Close the unawaited coroutine passed to asyncio.run to avoid + # "coroutine was never awaited" warnings on Python 3.14+. + mock_asyncio.run.call_args[0][0].close() # Verify the config built internally has the right graceful_timeout config = agent._build_hypercorn_config("127.0.0.1", 8088) assert config.graceful_timeout == 15.0 @@ -109,6 +112,9 @@ def test_run_passes_timeout(self, mock_asyncio, _mock_serve): def test_run_passes_default_timeout(self, mock_asyncio, _mock_serve): agent = _make_stub_agent() agent.run() + # Close the unawaited coroutine passed to asyncio.run to avoid + # "coroutine was never awaited" warnings on Python 3.14+. + mock_asyncio.run.call_args[0][0].close() config = agent._build_hypercorn_config("127.0.0.1", 8088) assert config.graceful_timeout == 30.0 @@ -300,14 +306,16 @@ async def invoke(request: Request) -> Response: async def shutdown(): order.append("callback") - with patch("azure.ai.agentserver.server._base.logger") as mock_logger: - - def tracking_info(*args, **kwargs): - if args and "shutting down" in str(args[0]).lower(): - order.append("log") + def tracking_info(*args, **kwargs): + if args and "shutting down" in str(args[0]).lower(): + order.append("log") - mock_logger.info.side_effect = tracking_info + mock_logger = MagicMock() + mock_logger.info = MagicMock(side_effect=tracking_info) + with patch( + "azure.ai.agentserver.server._base.logger", mock_logger + ): lifespan = server.app.router.lifespan_context async with lifespan(server.app): pass diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_openapi_validation.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_openapi_validation.py index 0a2f10703403..f9fe9647c127 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/tests/test_openapi_validation.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_openapi_validation.py @@ -11,7 +11,7 @@ from starlette.responses import JSONResponse, Response from azure.ai.agentserver.server import AgentServer -from azure.ai.agentserver.server._openapi_validator import OpenApiValidator +from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator # --------------------------------------------------------------------------- @@ -860,9 +860,9 @@ async def test_complex_multiple_errors_reported(): @pytest.mark.asyncio async def test_validate_response_valid(): """Valid response body passes validation.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator - v = OpenApiValidator(RESPONSE_SPEC) + v = _OpenApiValidator(RESPONSE_SPEC) errors = v.validate_response(b'{"result": "ok"}', "application/json") assert errors == [] @@ -870,9 +870,9 @@ async def test_validate_response_valid(): @pytest.mark.asyncio async def test_validate_response_invalid(): """Invalid response body returns errors.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator - v = OpenApiValidator(RESPONSE_SPEC) + v = _OpenApiValidator(RESPONSE_SPEC) errors = v.validate_response(b'{"wrong": 42}', "application/json") assert len(errors) > 0 @@ -880,14 +880,14 @@ async def test_validate_response_invalid(): @pytest.mark.asyncio async def test_validate_response_no_schema(): """When no response schema exists, validation passes (no-op).""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator spec_no_resp = { "openapi": "3.0.0", "info": {"title": "NoResp", "version": "1.0"}, "paths": {"/invocations": {"post": {}}}, } - v = OpenApiValidator(spec_no_resp) + v = _OpenApiValidator(spec_no_resp) errors = v.validate_response(b'{"anything": true}', "application/json") assert errors == [] @@ -900,14 +900,14 @@ async def test_validate_response_no_schema(): @pytest.mark.asyncio async def test_validate_request_no_schema(): """When no request schema exists, validation passes (no-op).""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator spec_no_req = { "openapi": "3.0.0", "info": {"title": "NoReq", "version": "1.0"}, "paths": {"/invocations": {"post": {}}}, } - v = OpenApiValidator(spec_no_req) + v = _OpenApiValidator(spec_no_req) errors = v.validate_request(b'{"anything": true}', "application/json") assert errors == [] @@ -920,7 +920,7 @@ async def test_validate_request_no_schema(): @pytest.mark.asyncio async def test_response_schema_fallback_to_first_available(): """Response schema extraction falls back to first response with JSON.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator spec = { "openapi": "3.0.0", @@ -945,7 +945,7 @@ async def test_response_schema_fallback_to_first_available(): } }, } - v = OpenApiValidator(spec) + v = _OpenApiValidator(spec) # Valid against the 202 schema assert v.validate_response(b'{"status": "accepted"}', "application/json") == [] # Invalid — missing "status" @@ -1211,9 +1211,9 @@ async def test_writeonly_allowed_in_request(): @pytest.mark.asyncio async def test_readonly_schema_introspection_request(): """readOnly properties are removed from the preprocessed request schema.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator - v = OpenApiValidator(READONLY_SPEC) + v = _OpenApiValidator(READONLY_SPEC) # 'id' should be gone from requestschema properties assert "id" not in v._request_schema.get("properties", {}) @@ -1221,9 +1221,9 @@ async def test_readonly_schema_introspection_request(): @pytest.mark.asyncio async def test_readonly_schema_introspection_response(): """readOnly properties remain in the preprocessed response schema.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator - v = OpenApiValidator(READONLY_SPEC) + v = _OpenApiValidator(READONLY_SPEC) # 'id' should still be in response schema assert "id" in v._response_schema.get("properties", {}) @@ -1231,18 +1231,18 @@ async def test_readonly_schema_introspection_response(): @pytest.mark.asyncio async def test_writeonly_stripped_in_response(): """writeOnly properties are removed from the preprocessed response schema.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator - v = OpenApiValidator(READONLY_SPEC) + v = _OpenApiValidator(READONLY_SPEC) assert "password" not in v._response_schema.get("properties", {}) @pytest.mark.asyncio async def test_writeonly_present_in_request(): """writeOnly properties remain in the preprocessed request schema.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator - v = OpenApiValidator(READONLY_SPEC) + v = _OpenApiValidator(READONLY_SPEC) assert "password" in v._request_schema.get("properties", {}) @@ -1306,9 +1306,9 @@ async def test_writeonly_present_in_request(): @pytest.mark.asyncio async def test_optional_body_empty_accepted(): """Empty body is accepted when requestBody.required is false.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator - v = OpenApiValidator(OPTIONAL_BODY_SPEC) + v = _OpenApiValidator(OPTIONAL_BODY_SPEC) errors = v.validate_request(b"", "application/json") assert errors == [] @@ -1316,9 +1316,9 @@ async def test_optional_body_empty_accepted(): @pytest.mark.asyncio async def test_optional_body_whitespace_accepted(): """Whitespace-only body is accepted when requestBody.required is false.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator - v = OpenApiValidator(OPTIONAL_BODY_SPEC) + v = _OpenApiValidator(OPTIONAL_BODY_SPEC) errors = v.validate_request(b" ", "application/json") assert errors == [] @@ -1326,9 +1326,9 @@ async def test_optional_body_whitespace_accepted(): @pytest.mark.asyncio async def test_optional_body_present_still_validated(): """When body IS present with optional requestBody, it still must be valid.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator - v = OpenApiValidator(OPTIONAL_BODY_SPEC) + v = _OpenApiValidator(OPTIONAL_BODY_SPEC) errors = v.validate_request(b'{"wrong": 1}', "application/json") assert len(errors) > 0 @@ -1336,9 +1336,9 @@ async def test_optional_body_present_still_validated(): @pytest.mark.asyncio async def test_required_body_empty_rejected(): """Empty body is rejected when requestBody.required is true.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator - v = OpenApiValidator(REQUIRED_BODY_SPEC) + v = _OpenApiValidator(REQUIRED_BODY_SPEC) errors = v.validate_request(b"", "application/json") assert len(errors) > 0 # "Invalid JSON body" @@ -1346,7 +1346,7 @@ async def test_required_body_empty_rejected(): @pytest.mark.asyncio async def test_default_body_required_behavior(): """When requestBody.required is omitted, body is required by default.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator spec = { "openapi": "3.0.0", @@ -1366,7 +1366,7 @@ async def test_default_body_required_behavior(): } }, } - v = OpenApiValidator(spec) + v = _OpenApiValidator(spec) errors = v.validate_request(b"", "application/json") assert len(errors) > 0 @@ -1379,7 +1379,7 @@ async def test_default_body_required_behavior(): @pytest.mark.asyncio async def test_openapi_keywords_stripped(): """discriminator, xml, externalDocs, example are stripped from schemas.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator schema = { "type": "object", @@ -1389,7 +1389,7 @@ async def test_openapi_keywords_stripped(): "example": {"type": "bar"}, "properties": {"name": {"type": "string"}}, } - result = OpenApiValidator._preprocess_schema(schema) + result = _OpenApiValidator._preprocess_schema(schema) assert "discriminator" not in result assert "xml" not in result assert "externalDocs" not in result @@ -1400,7 +1400,7 @@ async def test_openapi_keywords_stripped(): @pytest.mark.asyncio async def test_openapi_keywords_stripped_nested(): """OpenAPI keywords are stripped from nested properties too.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator schema = { "type": "object", @@ -1413,7 +1413,7 @@ async def test_openapi_keywords_stripped_nested(): } }, } - result = OpenApiValidator._preprocess_schema(schema) + result = _OpenApiValidator._preprocess_schema(schema) child = result["properties"]["child"] assert "example" not in child assert "xml" not in child @@ -1729,7 +1729,7 @@ async def test_anyof_wrong_type(): @pytest.mark.asyncio async def test_nullable_ref_accepts_null(): """A nullable $ref field accepts null after preprocessing.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator spec: dict = { "openapi": "3.0.0", @@ -1766,7 +1766,7 @@ async def test_nullable_ref_accepts_null(): } }, } - v = OpenApiValidator(spec) + v = _OpenApiValidator(spec) errors = v.validate_request(b'{"addr": null}', "application/json") assert errors == [] @@ -1779,20 +1779,20 @@ async def test_nullable_ref_accepts_null(): @pytest.mark.asyncio async def test_apply_nullable_no_duplicate_null(): """_apply_nullable does not add 'null' twice if type is already a list with null.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator schema: dict = {"type": ["string", "null"], "nullable": True} - OpenApiValidator._apply_nullable(schema) + _OpenApiValidator._apply_nullable(schema) assert schema["type"] == ["string", "null"] @pytest.mark.asyncio async def test_apply_nullable_false(): """nullable: false is a no-op (just removes the key).""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator schema: dict = {"type": "string", "nullable": False} - OpenApiValidator._apply_nullable(schema) + _OpenApiValidator._apply_nullable(schema) assert schema["type"] == "string" assert "nullable" not in schema @@ -1800,7 +1800,7 @@ async def test_apply_nullable_false(): @pytest.mark.asyncio async def test_strip_openapi_keywords_nested_deeply(): """OpenAPI keywords are stripped from deeply nested schemas.""" - from azure.ai.agentserver.server._openapi_validator import OpenApiValidator + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator schema: dict = { "type": "object", @@ -1817,7 +1817,7 @@ async def test_strip_openapi_keywords_nested_deeply(): } }, } - OpenApiValidator._strip_openapi_keywords(schema) + _OpenApiValidator._strip_openapi_keywords(schema) items_schema = schema["properties"]["items"]["items"] assert "example" not in items_schema assert "xml" not in items_schema["properties"]["val"] @@ -1830,8 +1830,8 @@ async def test_strip_openapi_keywords_nested_deeply(): # --- Helper: direct validator call for unit-level assertions --- def _validate_request(spec: dict, body: dict) -> list[str]: - """Shortcut: build an OpenApiValidator and return request errors.""" - v = OpenApiValidator(spec) + """Shortcut: build an _OpenApiValidator and return request errors.""" + v = _OpenApiValidator(spec) return v.validate_request( json.dumps(body).encode(), "application/json", diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py index 15f78b7d5bf0..aaa4d25e9382 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py @@ -17,6 +17,7 @@ from starlette.responses import JSONResponse, Response, StreamingResponse from azure.ai.agentserver.server import AgentServer +from azure.ai.agentserver.server._tracing import _TracingHelper # --------------------------------------------------------------------------- @@ -66,9 +67,16 @@ async def handle(request: Request) -> Response: @pytest.fixture() def span_exporter(): - """Return the module-level exporter with a clean slate for each test.""" + """Return the module-level exporter with a clean slate for each test. + + Patches ``_setup_azure_monitor`` so that an ``APPLICATIONINSIGHTS_CONNECTION_STRING`` + env var in the CI environment doesn't replace the test TracerProvider. + """ _MODULE_EXPORTER._spans.clear() - yield _MODULE_EXPORTER + # Ensure the module-level provider is active (a prior test may have replaced it). + trace.set_tracer_provider(_MODULE_PROVIDER) + with patch.object(_TracingHelper, "_setup_azure_monitor"): + yield _MODULE_EXPORTER # Module-level OTel setup — set once to avoid @@ -431,11 +439,11 @@ def test_no_connection_string_returns_none(self, monkeypatch): class TestSetupAzureMonitor: - """Tests for TracingHelper._setup_azure_monitor (mocked exporter imports).""" + """Tests for _TracingHelper._setup_azure_monitor (mocked exporter imports).""" def test_setup_configures_tracer_provider(self): """_setup_azure_monitor sets a global TracerProvider with exporter.""" - from azure.ai.agentserver.server._tracing import TracingHelper + from azure.ai.agentserver.server._tracing import _TracingHelper mock_exporter = MagicMock() mock_exporter_cls = MagicMock(return_value=mock_exporter) @@ -449,7 +457,7 @@ def test_setup_configures_tracer_provider(self): ), }, ), patch("opentelemetry.trace.set_tracer_provider") as mock_set_provider: - TracingHelper._setup_azure_monitor("InstrumentationKey=test") + _TracingHelper._setup_azure_monitor("InstrumentationKey=test") mock_exporter_cls.assert_called_once_with( connection_string="InstrumentationKey=test" @@ -459,7 +467,7 @@ def test_setup_configures_tracer_provider(self): def test_setup_logs_warning_when_packages_missing(self, caplog): """Warns gracefully when azure-monitor exporter is not installed.""" import builtins - from azure.ai.agentserver.server._tracing import TracingHelper + from azure.ai.agentserver.server._tracing import _TracingHelper real_import = builtins.__import__ @@ -472,16 +480,16 @@ def _block_monitor(name, *args, **kwargs): import logging with caplog.at_level(logging.WARNING, logger="azure.ai.agentserver"): - TracingHelper._setup_azure_monitor("InstrumentationKey=test") + _TracingHelper._setup_azure_monitor("InstrumentationKey=test") - assert "required packages are not installed" in caplog.text + assert "Traces will not be forwarded" in caplog.text def test_constructor_passes_connection_string(self, monkeypatch): - """AgentServer passes resolved connection string to TracingHelper.""" + """AgentServer passes resolved connection string to _TracingHelper.""" monkeypatch.delenv("APPLICATIONINSIGHTS_CONNECTION_STRING", raising=False) with patch( - "azure.ai.agentserver.server._tracing.TracingHelper._setup_azure_monitor" + "azure.ai.agentserver.server._tracing._TracingHelper._setup_azure_monitor" ) as mock_setup: _make_echo_traced_agent( enable_tracing=True, @@ -494,7 +502,7 @@ def test_constructor_no_connection_string_skips_setup(self, monkeypatch): monkeypatch.delenv("APPLICATIONINSIGHTS_CONNECTION_STRING", raising=False) with patch( - "azure.ai.agentserver.server._tracing.TracingHelper._setup_azure_monitor" + "azure.ai.agentserver.server._tracing._TracingHelper._setup_azure_monitor" ) as mock_setup: _make_echo_traced_agent(enable_tracing=True) mock_setup.assert_not_called() @@ -507,7 +515,7 @@ def test_constructor_env_var_connection_string(self, monkeypatch): ) with patch( - "azure.ai.agentserver.server._tracing.TracingHelper._setup_azure_monitor" + "azure.ai.agentserver.server._tracing._TracingHelper._setup_azure_monitor" ) as mock_setup: _make_echo_traced_agent(enable_tracing=True) mock_setup.assert_called_once_with("InstrumentationKey=from-env") @@ -715,19 +723,6 @@ def test_whitespace_handling(self): assert _parse_baggage_key(baggage, "leaf_customer_span_id") == "abc123" -class TestExtractBaggageHeader: - """Unit tests for extract_baggage_header.""" - - def test_present(self): - from azure.ai.agentserver.server._tracing import extract_baggage_header - headers = {"baggage": "key=val", "other": "x"} - assert extract_baggage_header(headers) == "key=val" - - def test_absent(self): - from azure.ai.agentserver.server._tracing import extract_baggage_header - assert extract_baggage_header({"other": "x"}) is None - - @pytest.mark.asyncio async def test_baggage_leaf_customer_span_id_overrides_parent(span_exporter): """When baggage contains leaf_customer_span_id, the span's parent span ID @@ -893,3 +888,73 @@ def test_agent_version_default_empty(self, monkeypatch): from azure.ai.agentserver.server._config import resolve_agent_version monkeypatch.delenv("AGENT_VERSION", raising=False) assert resolve_agent_version() == "" + + +# --------------------------------------------------------------------------- +# Tests: project ID resolution and tracing attribute +# --------------------------------------------------------------------------- + + +class TestProjectIdResolution: + """Tests for resolve_project_id and microsoft.foundry.project.id span attribute.""" + + def test_project_id_from_env(self, monkeypatch): + from azure.ai.agentserver.server._config import resolve_project_id + monkeypatch.setenv("AGENT_PROJECT_NAME", "proj-abc-123") + assert resolve_project_id() == "proj-abc-123" + + def test_project_id_default_empty(self, monkeypatch): + from azure.ai.agentserver.server._config import resolve_project_id + monkeypatch.delenv("AGENT_PROJECT_NAME", raising=False) + assert resolve_project_id() == "" + + +@pytest.mark.asyncio +async def test_project_id_attribute_on_invoke_span(monkeypatch, span_exporter): + """microsoft.foundry.project.id is set on invoke span when AGENT_PROJECT_NAME is set.""" + monkeypatch.setenv("AGENT_PROJECT_NAME", "proj-xyz-789") + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b'{}') + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + attrs = dict(invoke_spans[0].attributes) + assert attrs["microsoft.foundry.project.id"] == "proj-xyz-789" + + +@pytest.mark.asyncio +async def test_project_id_attribute_absent_when_not_set(monkeypatch, span_exporter): + """microsoft.foundry.project.id is NOT set when AGENT_PROJECT_NAME env var is absent.""" + monkeypatch.delenv("AGENT_PROJECT_NAME", raising=False) + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b'{}') + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + attrs = dict(invoke_spans[0].attributes) + assert "microsoft.foundry.project.id" not in attrs + + +@pytest.mark.asyncio +async def test_project_id_attribute_on_get_invocation_span(monkeypatch, span_exporter): + """microsoft.foundry.project.id is set on get_invocation span too.""" + monkeypatch.setenv("AGENT_PROJECT_NAME", "proj-get-456") + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.get("/invocations/test-id") + + spans = span_exporter.get_finished_spans() + get_spans = [s for s in spans if "get_invocation" in s.name] + assert len(get_spans) == 1 + + attrs = dict(get_spans[0].attributes) + assert attrs["microsoft.foundry.project.id"] == "proj-get-456" From 71dac326512668f55918a6924ae94ac18bff0b9c Mon Sep 17 00:00:00 2001 From: Zhiyong Yang Date: Thu, 12 Mar 2026 18:26:38 -0700 Subject: [PATCH 08/10] Fix pylint W0135 and CI test failures for span name matching Replace nested @contextmanager + with pattern in request_span with explicit start_span/end_span lifecycle to satisfy pylint contextmanager-generator-missing-cleanup check. Use substring matching for span names in tests to handle CI environments where AGENT_NAME is set. Co-Authored-By: Claude Opus 4.6 --- .../azure/ai/agentserver/server/_tracing.py | 9 ++++++--- .../tests/test_tracing.py | 20 +++++++++---------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py index 7b3cb0a440a5..fd22fb1ba2c8 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py @@ -402,10 +402,13 @@ def request_span( name, attrs, carrier, baggage = self._prepare_request_span_args( headers, invocation_id, span_operation, operation_name ) - with self.span( - name, attributes=attrs, carrier=carrier, baggage_header=baggage - ) as otel_span: + otel_span = self.start_span(name, attributes=attrs, carrier=carrier, baggage_header=baggage) + try: yield otel_span + except Exception as exc: # pylint: disable=broad-exception-caught + self.end_span(otel_span, exc=exc) + raise + self.end_span(otel_span) # ------------------------------------------------------------------ # Span lifecycle helpers diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py index aaa4d25e9382..e630c0581ce1 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py @@ -126,7 +126,7 @@ async def test_tracing_enabled_creates_invoke_span(span_exporter): assert resp.status_code == 200 spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.name == "execute_agent"] + invoke_spans = [s for s in spans if "execute_agent" in s.name] assert len(invoke_spans) == 1 span = invoke_spans[0] @@ -144,7 +144,7 @@ async def test_tracing_invoke_error_records_exception(span_exporter): assert resp.status_code == 500 spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.name == "execute_agent"] + invoke_spans = [s for s in spans if "execute_agent" in s.name] assert len(invoke_spans) == 1 span = invoke_spans[0] @@ -165,7 +165,7 @@ async def test_tracing_get_invocation_creates_span(span_exporter): assert resp.status_code == 501 spans = span_exporter.get_finished_spans() - get_spans = [s for s in spans if s.name == "get_invocation"] + get_spans = [s for s in spans if "get_invocation" in s.name] assert len(get_spans) == 1 assert dict(get_spans[0].attributes)["invocation.id"] == "test-id-123" @@ -180,7 +180,7 @@ async def test_tracing_cancel_invocation_creates_span(span_exporter): assert resp.status_code == 501 spans = span_exporter.get_finished_spans() - cancel_spans = [s for s in spans if s.name == "cancel_invocation"] + cancel_spans = [s for s in spans if "cancel_invocation" in s.name] assert len(cancel_spans) == 1 assert dict(cancel_spans[0].attributes)["invocation.id"] == "test-cancel-456" @@ -202,7 +202,7 @@ async def test_tracing_enabled_via_env_var(monkeypatch, span_exporter): await client.post("/invocations", content=b'{}') spans = span_exporter.get_finished_spans() - assert any(s.name == "execute_agent" for s in spans) + assert any("execute_agent" in s.name for s in spans) @pytest.mark.asyncio @@ -229,7 +229,7 @@ async def test_tracing_propagates_traceparent(span_exporter): assert resp.status_code == 200 spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.name == "execute_agent"] + invoke_spans = [s for s in spans if "execute_agent" in s.name] assert len(invoke_spans) == 1 span = invoke_spans[0] # The span's trace ID should match the traceparent's trace ID @@ -298,7 +298,7 @@ async def test_streaming_response_creates_span(span_exporter): assert resp.status_code == 200 spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.name == "execute_agent"] + invoke_spans = [s for s in spans if "execute_agent" in s.name] assert len(invoke_spans) == 1 assert invoke_spans[0].status.status_code == trace.StatusCode.UNSET @@ -317,7 +317,7 @@ async def test_streaming_span_covers_full_body(span_exporter): assert b"slow-2" in resp.content spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.name == "execute_agent"] + invoke_spans = [s for s in spans if "execute_agent" in s.name] assert len(invoke_spans) == 1 span = invoke_spans[0] @@ -355,7 +355,7 @@ async def test_streaming_error_recorded_in_span(span_exporter): pass # connection reset / partial read is acceptable spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.name == "execute_agent"] + invoke_spans = [s for s in spans if "execute_agent" in s.name] assert len(invoke_spans) == 1 span = invoke_spans[0] @@ -394,7 +394,7 @@ async def test_streaming_propagates_traceparent(span_exporter): assert resp.status_code == 200 spans = span_exporter.get_finished_spans() - invoke_spans = [s for s in spans if s.name == "execute_agent"] + invoke_spans = [s for s in spans if "execute_agent" in s.name] assert len(invoke_spans) == 1 expected_trace_id = int("0af7651916cd43dd8448eb211c80319c", 16) assert invoke_spans[0].context.trace_id == expected_trace_id From 5788b4d3c57f20247a52ff07298dcdc01dfce8dc Mon Sep 17 00:00:00 2001 From: Zhiyong Yang Date: Fri, 13 Mar 2026 16:12:31 -0700 Subject: [PATCH 09/10] Replace x-agent-session-id header with agent_session_id query parameter for session tracking Session ID is now read from the agent_session_id query parameter instead of the x-agent-session-id header. Removes SESSION_ID_HEADER constant and header fallback in tracing. Adds activity_weather_agent sample demonstrating Activity protocol bridging via the invoke handler. Co-Authored-By: Claude Opus 4.6 --- .../azure/ai/agentserver/server/_constants.py | 1 - .../ai/agentserver/server/_invocation.py | 4 +- .../azure/ai/agentserver/server/_tracing.py | 20 +- .../activity_weather_agent/.env.sample | 5 + .../activity_weather_agent.py | 242 ++++++++++++++++++ .../activity_weather_agent/requirements.txt | 5 + .../samples/activity_weather_agent/server.py | 115 +++++++++ .../tests/test_tracing.py | 8 +- 8 files changed, 390 insertions(+), 10 deletions(-) create mode 100644 sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/.env.sample create mode 100644 sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/activity_weather_agent.py create mode 100644 sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/requirements.txt create mode 100644 sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/server.py diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py index 8be43b96f54d..7a32bff8f972 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py @@ -21,4 +21,3 @@ class Constants: DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT = 30 DEFAULT_REQUEST_TIMEOUT = 300 # 5 minutes INVOCATION_ID_HEADER = "x-agent-invocation-id" - SESSION_ID_HEADER = "x-agent-session-id" diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_invocation.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_invocation.py index a8b2e60925a0..1d116f2aee70 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_invocation.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_invocation.py @@ -251,6 +251,7 @@ async def _create_invocation_endpoint(self, request: Request) -> Response: invocation_id, span_operation="execute_agent", operation_name="invoke_agent", + session_id=request.query_params.get("agent_session_id", ""), ) try: invoke_awaitable = self._dispatch_invoke(request) @@ -331,7 +332,8 @@ async def _traced_invocation_endpoint( span_cm: Any = contextlib.nullcontext(None) if self._ctx.tracing is not None: span_cm = self._ctx.tracing.request_span( - request.headers, invocation_id, span_operation + request.headers, invocation_id, span_operation, + session_id=request.query_params.get("agent_session_id", ""), ) with span_cm as _otel_span: try: diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py index fd22fb1ba2c8..2f2c2592f127 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py @@ -29,7 +29,6 @@ from typing import TYPE_CHECKING, Any, Iterator, Optional, Union from . import _config -from ._constants import Constants from ._logger import get_logger #: Starlette's ``Content`` type — the element type for streaming bodies. @@ -309,6 +308,7 @@ def _prepare_request_span_args( invocation_id: str, span_operation: str, operation_name: Optional[str] = None, + session_id: str = "", ) -> tuple[str, dict[str, str], dict[str, str], Optional[str]]: """Extract headers and build span arguments for a request. @@ -324,13 +324,15 @@ def _prepare_request_span_args( :type span_operation: str :param operation_name: Optional ``gen_ai.operation.name`` value. :type operation_name: Optional[str] + :param session_id: Session ID from the ``agent_session_id`` query + parameter. Defaults to ``""`` (no session). + :type session_id: str :return: ``(name, attributes, carrier, baggage)`` ready for :meth:`span` or :meth:`start_span`. :rtype: tuple[str, dict[str, str], dict[str, str], Optional[str]] """ carrier = _extract_w3c_carrier(headers) baggage = headers.get("baggage") - session_id = headers.get(Constants.SESSION_ID_HEADER, "") span_attrs = self.build_span_attrs( invocation_id, session_id, operation_name=operation_name ) @@ -342,6 +344,7 @@ def start_request_span( invocation_id: str, span_operation: str, operation_name: Optional[str] = None, + session_id: str = "", ) -> Any: """Start a request-scoped span, extracting context from HTTP headers. @@ -362,11 +365,15 @@ def start_request_span( :param operation_name: Optional ``gen_ai.operation.name`` attribute value (e.g. ``"invoke_agent"``). Omitted when *None*. :type operation_name: Optional[str] + :param session_id: Session ID from the ``agent_session_id`` query + parameter. Defaults to ``""`` (no session). + :type session_id: str :return: The OTel span, or *None* when tracing is disabled. :rtype: Any """ name, attrs, carrier, baggage = self._prepare_request_span_args( - headers, invocation_id, span_operation, operation_name + headers, invocation_id, span_operation, operation_name, + session_id=session_id, ) return self.start_span(name, attributes=attrs, carrier=carrier, baggage_header=baggage) @@ -377,6 +384,7 @@ def request_span( invocation_id: str, span_operation: str, operation_name: Optional[str] = None, + session_id: str = "", ) -> Iterator[Any]: """Create a request-scoped span as a context manager. @@ -396,11 +404,15 @@ def request_span( :param operation_name: Optional ``gen_ai.operation.name`` attribute value. Omitted when *None*. :type operation_name: Optional[str] + :param session_id: Session ID from the ``agent_session_id`` query + parameter. Defaults to ``""`` (no session). + :type session_id: str :return: Context manager that yields the OTel span or *None*. :rtype: Iterator[Any] """ name, attrs, carrier, baggage = self._prepare_request_span_args( - headers, invocation_id, span_operation, operation_name + headers, invocation_id, span_operation, operation_name, + session_id=session_id, ) otel_span = self.start_span(name, attributes=attrs, carrier=carrier, baggage_header=baggage) try: diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/.env.sample b/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/.env.sample new file mode 100644 index 000000000000..66bb7033c34e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/.env.sample @@ -0,0 +1,5 @@ +# Azure OpenAI credentials (used by AsyncAzureOpenAI) +AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com/ +AZURE_OPENAI_API_KEY=your-api-key-here +AZURE_OPENAI_API_VERSION=2024-12-01-preview +AZURE_OPENAI_MODEL=gpt-4o diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/activity_weather_agent.py b/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/activity_weather_agent.py new file mode 100644 index 000000000000..06227e8c0eb9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/activity_weather_agent.py @@ -0,0 +1,242 @@ +"""Weather agent built with the Activity protocol. + +Uses ``microsoft_agents.activity`` types (Activity, ChannelAccount, etc.) +for activity handling and the OpenAI Agents SDK for LLM orchestration. + +This module has **no dependency** on ``azure-ai-agentserver-server`` — it +is a standalone agent that can be hosted by any server that speaks the +Activity protocol. See ``server.py`` for the AgentServer hosting layer. + +The agent mirrors the +`weather-agent-open-ai `_ +reference sample from Agents-for-python, using the same ``ActivityHandler`` +dispatch pattern (``on_turn`` → ``on_message_activity`` / +``on_members_added_activity``). +""" + +import logging +import os +import random +from datetime import datetime +from typing import Optional + +from dotenv import load_dotenv + +load_dotenv() + +from openai import AsyncAzureOpenAI +from agents import ( + Agent, + Model, + ModelProvider, + OpenAIChatCompletionsModel, + RunConfig, + Runner, + function_tool, +) + +from microsoft_agents.activity import Activity, ActivityTypes + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Tools — registered with the OpenAI Agents SDK +# --------------------------------------------------------------------------- + + +@function_tool +def get_weather(city: str, date: str) -> dict: + """Get the weather forecast for a city on a given date. + + :param city: City name (e.g. "Seattle"). + :param date: Date string (e.g. "2026-03-14"). + :return: Weather data with temperature and conditions. + """ + logger.info("Tool get_weather called: city=%s, date=%s", city, date) + temperature = random.randint(8, 21) + conditions = random.choice( + ["Sunny with light breeze", "Partly cloudy", "Overcast with rain"] + ) + result = { + "city": city, + "temperature": f"{temperature}C", + "conditions": conditions, + "date": date, + } + logger.info("Tool get_weather result: %s", result) + return result + + +@function_tool +def get_date() -> str: + """Get the current date and time. + + :return: ISO-formatted current datetime string. + """ + now = datetime.now().isoformat() + logger.info("Tool get_date called: %s", now) + return now + + +# --------------------------------------------------------------------------- +# WeatherAgent — Activity protocol agent +# --------------------------------------------------------------------------- + + +class WeatherAgent: + """A weather-forecast agent that speaks the Activity protocol. + + Accepts an :class:`~microsoft_agents.activity.Activity`, dispatches by + ``activity.type``, and returns a list of reply + :class:`~microsoft_agents.activity.Activity` objects. + + :param client: An ``AsyncAzureOpenAI`` client for LLM calls. + :type client: openai.AsyncAzureOpenAI + """ + + def __init__(self, client: AsyncAzureOpenAI) -> None: + self._agent = Agent( + name="WeatherAgent", + instructions=( + "You are a friendly assistant that helps people find weather " + "forecasts. Use the get_weather and get_date tools to look up " + "the forecast. Always respond with plain text. If you need " + "more information (city or date), ask the user a follow-up " + "question." + ), + tools=[get_weather, get_date], + ) + + class _AzureModelProvider(ModelProvider): + """Routes model requests to the Azure OpenAI client.""" + + def get_model(self, model_name: Optional[str] = None) -> Model: + return OpenAIChatCompletionsModel( + model=model_name + or os.environ.get("AZURE_OPENAI_MODEL", "gpt-4o"), + openai_client=client, + ) + + self._model_provider = _AzureModelProvider() + logger.info("WeatherAgent initialised with tools: get_weather, get_date") + + # -- Dispatch ------------------------------------------------------------ + + async def on_turn(self, activity: Activity) -> list[Activity]: + """Process an incoming activity and return reply activities. + + Routes by ``activity.type`` to the appropriate handler, mirroring the + ``ActivityHandler.on_turn()`` pattern from the Microsoft Agents SDK. + + :param activity: The incoming activity. + :type activity: microsoft_agents.activity.Activity + :return: Zero or more reply activities. + :rtype: list[Activity] + """ + sender = activity.from_property.name if activity.from_property else "unknown" + conv_id = activity.conversation.id if activity.conversation else "unknown" + logger.info( + "on_turn: type=%s, from=%s, conversation=%s", + activity.type, sender, conv_id, + ) + + if activity.type == ActivityTypes.message: + return await self._on_message_activity(activity) + if activity.type == ActivityTypes.conversation_update: + return await self._on_conversation_update_activity(activity) + + logger.debug("on_turn: unhandled activity type %r, returning empty", activity.type) + return [] + + # -- Handlers ------------------------------------------------------------ + + async def _on_message_activity(self, activity: Activity) -> list[Activity]: + """Handle a ``message`` activity via the OpenAI Agents SDK. + + :param activity: The incoming message activity. + :type activity: Activity + :return: A single-element list with the reply activity. + :rtype: list[Activity] + """ + user_text = activity.text or "" + logger.info("Message received: %r", user_text) + + logger.debug("Running OpenAI Agents SDK Runner.run() ...") + response = await Runner.run( + self._agent, + user_text, + run_config=RunConfig( + model_provider=self._model_provider, + tracing_disabled=True, + ), + ) + + reply = activity.create_reply(text=response.final_output) + logger.info("Reply: %r", reply.text) + return [reply] + + async def _on_conversation_update_activity( + self, activity: Activity + ) -> list[Activity]: + """Handle a ``conversationUpdate`` activity. + + Sends a welcome message when new members are added (excluding the + agent itself). + + :param activity: The incoming conversationUpdate activity. + :type activity: Activity + :return: A list with one welcome reply, or empty. + :rtype: list[Activity] + """ + members_added = activity.members_added or [] + recipient_id = activity.recipient.id if activity.recipient else None + logger.info( + "conversationUpdate: %d member(s) added, recipient_id=%s", + len(members_added), recipient_id, + ) + + for member in members_added: + if member.id != recipient_id: + logger.info( + "New member joined: id=%s, name=%s — sending welcome", + member.id, member.name, + ) + reply = activity.create_reply( + text=( + "Hello and welcome! I can help you with weather " + "forecasts. Just tell me a city and date, and I'll " + "look up the forecast for you." + ), + ) + return [reply] + + logger.debug("conversationUpdate: no new external members, no welcome sent") + return [] + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + + +def create_weather_agent() -> WeatherAgent: + """Create a :class:`WeatherAgent` from environment variables. + + Reads ``AZURE_OPENAI_ENDPOINT`` and ``AZURE_OPENAI_API_VERSION`` from + the environment (or ``.env`` file via ``python-dotenv``). + + :return: A configured WeatherAgent instance. + :rtype: WeatherAgent + """ + endpoint = os.environ["AZURE_OPENAI_ENDPOINT"] + api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2024-12-01-preview") + logger.info( + "Creating WeatherAgent: endpoint=%s, api_version=%s", + endpoint, api_version, + ) + client = AsyncAzureOpenAI( + api_version=api_version, + azure_endpoint=endpoint, + ) + return WeatherAgent(client=client) diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/requirements.txt new file mode 100644 index 000000000000..0bdd8b1ddfc4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/requirements.txt @@ -0,0 +1,5 @@ +azure-ai-agentserver-server +microsoft-agents-activity>=0.8.0 +openai>=1.60.0 +openai-agents>=0.1.0 +python-dotenv>=1.0.0 diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/server.py b/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/server.py new file mode 100644 index 000000000000..a00b54a78409 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/server.py @@ -0,0 +1,115 @@ +"""Host a WeatherAgent via AgentServer with the Activity protocol. + +Bridges the Activity protocol to AgentServer's ``/invocations`` endpoint. +Incoming requests carry Bot Framework Activity JSON; the server +deserialises them into ``microsoft_agents.activity.Activity`` objects, +dispatches to the ``WeatherAgent``, and serialises the reply activities +back to JSON. + +This demonstrates that AgentServer can host agents built with the +Activity protocol — without depending on ``microsoft_agents.hosting.core`` +or ``CloudAdapter``. + +Usage:: + + # Start the server + python server.py + + # Send a message activity + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{ + "type": "message", + "text": "What is the weather in Seattle tomorrow?", + "from": {"id": "user-1", "name": "User"}, + "recipient": {"id": "agent-1", "name": "WeatherAgent"}, + "conversation": {"id": "conv-1"}, + "channelId": "custom", + "serviceUrl": "http://localhost:8088" + }' + + # Send a conversationUpdate to trigger the welcome message + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{ + "type": "conversationUpdate", + "membersAdded": [{"id": "user-1", "name": "User"}], + "from": {"id": "user-1"}, + "recipient": {"id": "agent-1", "name": "WeatherAgent"}, + "conversation": {"id": "conv-1"}, + "channelId": "custom", + "serviceUrl": "http://localhost:8088" + }' +""" + +import logging + +# Configure logging for the sample so all loggers (including the agent's) +# output to the console. This must run before any logger.info() calls. +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", +) + +from microsoft_agents.activity import Activity + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver.server import AgentServer + +from activity_weather_agent import create_weather_agent + +logger = logging.getLogger(__name__) + + +# -- Create the agent and server --------------------------------------------- + +agent = create_weather_agent() +server = AgentServer() +# AgentServer adds its own StreamHandler to the "azure.ai.agentserver" logger. +# Stop propagation to the root logger to avoid duplicate lines from basicConfig. +logging.getLogger("azure.ai.agentserver").propagate = False +logger.info("AgentServer created, WeatherAgent wired to /invocations") + + +# -- Wire the Activity protocol to /invocations ------------------------------ + + +@server.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Bridge ``POST /invocations`` to the Activity protocol. + + Deserialises the JSON body into an + :class:`~microsoft_agents.activity.Activity`, dispatches to the + agent's ``on_turn``, and returns the first reply activity as JSON. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON response containing the reply activity, or 200 with + an empty status if no reply is produced. + :rtype: starlette.responses.Response + """ + data = await request.json() + activity = Activity.model_validate(data) + logger.info( + "Received activity: type=%s, conversation=%s", + activity.type, + activity.conversation.id if activity.conversation else "unknown", + ) + + replies = await agent.on_turn(activity) + + if replies: + # Serialise back to camelCase JSON (matching Bot Framework wire format). + body = replies[0].model_dump( + by_alias=True, exclude_none=True, mode="json" + ) + logger.info( + "Returning reply: type=%s, text=%r", + body.get("type"), body.get("text", "")[:80], + ) + return JSONResponse(body) + + logger.info("No reply activities produced, returning status ok") + return JSONResponse({"status": "ok"}) + + +if __name__ == "__main__": + server.run() diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py index e630c0581ce1..2e09df90ee52 100644 --- a/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py @@ -624,7 +624,7 @@ async def test_genai_attributes_on_invoke_span(span_exporter, monkeypatch): @pytest.mark.asyncio async def test_genai_conversation_id_from_session_header(span_exporter, monkeypatch): - """gen_ai.conversation.id is set from x-agent-session-id header.""" + """gen_ai.conversation.id is set from agent_session_id query parameter.""" monkeypatch.delenv("AGENT_NAME", raising=False) monkeypatch.delenv("AGENT_VERSION", raising=False) agent = _make_echo_traced_agent(enable_tracing=True) @@ -633,7 +633,7 @@ async def test_genai_conversation_id_from_session_header(span_exporter, monkeypa await client.post( "/invocations", content=b'{}', - headers={"x-agent-session-id": "session-abc-123"}, + params={"agent_session_id": "session-abc-123"}, ) spans = span_exporter.get_finished_spans() @@ -646,7 +646,7 @@ async def test_genai_conversation_id_from_session_header(span_exporter, monkeypa @pytest.mark.asyncio async def test_genai_conversation_id_absent_when_no_header(span_exporter, monkeypatch): - """gen_ai.conversation.id is NOT set when x-agent-session-id header is absent.""" + """gen_ai.conversation.id is NOT set when agent_session_id query parameter is absent.""" monkeypatch.delenv("AGENT_NAME", raising=False) monkeypatch.delenv("AGENT_VERSION", raising=False) agent = _make_echo_traced_agent(enable_tracing=True) @@ -672,7 +672,7 @@ async def test_genai_attributes_on_get_invocation_span(span_exporter, monkeypatc async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: await client.get( "/invocations/inv-42", - headers={"x-agent-session-id": "sess-99"}, + params={"agent_session_id": "sess-99"}, ) spans = span_exporter.get_finished_spans() From a9fa33de4d517cfc3c534aba082ea7b05eb22dbe Mon Sep 17 00:00:00 2001 From: Zhiyong Yang Date: Tue, 17 Mar 2026 22:44:51 -0700 Subject: [PATCH 10/10] add no sdk spec and sample --- .../samples/no_sdk_agent/SPEC.md | 120 +++++++ .../samples/no_sdk_agent/full_server.py | 327 ++++++++++++++++++ .../samples/no_sdk_agent/minimal_server.py | 31 ++ 3 files changed, 478 insertions(+) create mode 100644 sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/SPEC.md create mode 100644 sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/full_server.py create mode 100644 sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/minimal_server.py diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/SPEC.md b/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/SPEC.md new file mode 100644 index 000000000000..0403e8c825ed --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/SPEC.md @@ -0,0 +1,120 @@ +# Invocation API Spec + +This document defines the HTTP contract that an agent container must implement +to run on the platform. No SDK is required — any language or framework that +can serve HTTP is supported. + +## Required + +### `POST /invocations` + +Execute the agent. + +- **Port:** `8088`. +- **Request body:** Any content type. The schema and format are defined by your + agent — the platform does not enforce a specific shape or media type. +- **Response body:** Any content type. The schema and format are defined by your + agent. + +## Optional Features + +The following endpoints and capabilities are **not required** by the platform. +Implement them only if your agent needs the corresponding functionality. + +### Async polling and cancel + +The platform only calls `POST /invocations`. If your agent needs async +polling or cancel, you can add your own endpoints under `/invocations/` +with any paths and schemas you choose. + +### OpenAPI spec — `GET /invocations/docs/openapi.json` + +Serve an OpenAPI spec describing your agent's request/response schema. +The platform can use this for documentation and request validation. + +### Health probes — `GET /liveness` and `GET /readiness` + +Standard Kubernetes health probes. Implement these if your container +runs in a Kubernetes environment and you need the orchestrator to detect +unhealthy or unready instances. + +- `/liveness` — return `200` when the process is alive. +- `/readiness` — return `200` when the agent is ready to accept requests + (e.g. models loaded, connections established). + +### Invocation ID tracking + +The platform may send an `x-agent-invocation-id` request header. +If present, echo it back on the response. If absent, generate a UUID +and include it on the response. This enables end-to-end correlation +of requests across services. + +### OpenTelemetry tracing with App Insights export + +Integrate with Foundry distributed tracing by creating spans for each +invocation and exporting them to Application Insights. See the +[Tracing](#tracing) section below for details. + +## Headers + +| Header | Direction | Description | +|--------|-----------|-------------| +| `x-agent-invocation-id` | Request | Platform may send an invocation ID. If absent, agent may generate a UUID. | +| `traceparent` | Request | W3C Trace Context header for distributed tracing. | +| `tracestate` | Request | W3C Trace Context header for distributed tracing. | +| `baggage` | Request | W3C Baggage for cross-service context propagation. | + +## Query Parameters + +| Parameter | Description | +|-----------|-------------| +| `agent_session_id` | Session or conversation ID for tracing correlation. Maps to `gen_ai.conversation.id` span attribute. | + +## Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `AGENT_NAME` | | Agent name | +| `AGENT_VERSION` | | Agent version | +| `AGENT_PROJECT_NAME` | | Azure foundry project name | + +## Tracing + +To integrate wtih foundry tracing, the agent should create an OpenTelemetry span for each +`POST /invocations` request and export traces to Application Insights via +`azure-monitor-opentelemetry-exporter`. + +### W3C Trace Context propagation + +Extract `traceparent` and `tracestate` from the incoming request headers and use +them as the parent context for the span. This connects the agent's spans to the +platform's distributed trace. + +### `leaf_customer_span_id` (baggage) + +The platform may send a `baggage` header containing a `leaf_customer_span_id` +key. When present, the agent **must** override the parent span ID from +`traceparent` with this value. This re-parents the agent's root span under +the caller's leaf span so the trace tree renders correctly in App Insights. + +The value is a 16-character lower-hex span ID. To apply it: + +1. Extract the trace context from `traceparent` normally. +2. Parse `leaf_customer_span_id` from the `baggage` header. +3. Create a new `SpanContext` with the same `trace_id` but the + `span_id` replaced by the baggage value. +4. Use the new context as the parent when starting the span. + +### Span attributes + +Each span should include the following GenAI semantic convention attributes: + +| Attribute | Source | Description | +|-----------|--------|-------------| +| `invocation.id` | `x-agent-invocation-id` header or generated UUID | Unique invocation identifier | +| `gen_ai.response.id` | Same as `invocation.id` | Maps response to invocation | +| `gen_ai.provider.name` | `"microsoft.foundry"` | Fixed provider name | +| `gen_ai.agent.id` | `AGENT_NAME` + `AGENT_VERSION` env vars | Agent identity, e.g. `"my-agent:1.0"` | +| `microsoft.foundry.project.id` | `AGENT_PROJECT_NAME` env var | Project identifier | +| `gen_ai.operation.name` | `"invoke_agent"` | Operation type | +| `gen_ai.conversation.id` | `agent_session_id` query parameter | Session/conversation ID | diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/full_server.py b/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/full_server.py new file mode 100644 index 000000000000..1cbcba1de38e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/full_server.py @@ -0,0 +1,327 @@ +"""Full-featured agent — no SDK required. + +Implements optional Invocation API features: +- Invocation ID tracking (x-agent-invocation-id header) +- Health probes (/liveness, /readiness) +- OpenTelemetry tracing with App Insights export + +See SPEC.md for the full Invocation API contract. + +Usage:: + + pip install fastapi uvicorn opentelemetry-api opentelemetry-sdk \\ + azure-monitor-opentelemetry-exporter + + # Optional: set App Insights connection string for trace export + export APPLICATIONINSIGHTS_CONNECTION_STRING="InstrumentationKey=..." + export AGENT_ENABLE_TRACING=true + export AGENT_NAME=my-agent + export AGENT_VERSION=1.0 + + python full_server.py + + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"name": "Alice"}' + # -> {"greeting": "Hello, Alice!"} + + curl http://localhost:8088/liveness + # -> {"status": "alive"} +""" +import logging +import os +import uuid +from contextlib import contextmanager +from typing import Any, Iterator + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse + +# --------------------------------------------------------------------------- +# Configuration from environment variables +# --------------------------------------------------------------------------- + +LOG_LEVEL = os.environ.get("AGENT_LOG_LEVEL", "INFO").upper() +ENABLE_TRACING = os.environ.get("AGENT_ENABLE_TRACING", "").lower() in ("true", "1", "yes") +APPINSIGHTS_CONN_STR = os.environ.get("APPLICATIONINSIGHTS_CONNECTION_STRING", "") +AGENT_NAME = os.environ.get("AGENT_NAME", "") +AGENT_VERSION = os.environ.get("AGENT_VERSION", "") +AGENT_PROJECT_NAME = os.environ.get("AGENT_PROJECT_NAME", "") + +logging.basicConfig( + level=LOG_LEVEL, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", +) +logger = logging.getLogger(__name__) + +INVOCATION_ID_HEADER = "x-agent-invocation-id" + +app = FastAPI() + + +# --------------------------------------------------------------------------- +# OpenTelemetry tracing (optional) +# --------------------------------------------------------------------------- + +_tracer: Any = None +_propagator: Any = None + +if ENABLE_TRACING: + try: + from opentelemetry import trace + from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator, + ) + + # Set up App Insights export if connection string is available + if APPINSIGHTS_CONN_STR: + try: + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + from azure.monitor.opentelemetry.exporter import ( + AzureMonitorTraceExporter, + ) + + resource = Resource.create({"service.name": "agent"}) + provider = TracerProvider(resource=resource) + exporter = AzureMonitorTraceExporter( + connection_string=APPINSIGHTS_CONN_STR + ) + provider.add_span_processor(BatchSpanProcessor(exporter)) + trace.set_tracer_provider(provider) + logger.info("App Insights trace exporter configured") + except ImportError: + logger.warning( + "App Insights export requires opentelemetry-sdk and " + "azure-monitor-opentelemetry-exporter" + ) + + _tracer = trace.get_tracer("agent") + _propagator = TraceContextTextMapPropagator() + logger.info("OpenTelemetry tracing enabled") + + except ImportError: + logger.warning( + "AGENT_ENABLE_TRACING=true but opentelemetry-api is not installed" + ) + + +def _build_span_attrs( + invocation_id: str, + session_id: str = "", + operation_name: str = "invoke_agent", +) -> dict[str, str]: + """Build GenAI semantic convention span attributes.""" + agent_label = ( + f"{AGENT_NAME}:{AGENT_VERSION}" if AGENT_NAME and AGENT_VERSION else AGENT_NAME + ) + attrs: dict[str, str] = { + "invocation.id": invocation_id, + "gen_ai.response.id": invocation_id, + "gen_ai.provider.name": "microsoft.foundry", + "gen_ai.operation.name": operation_name, + } + if agent_label: + attrs["gen_ai.agent.id"] = agent_label + if AGENT_PROJECT_NAME: + attrs["microsoft.foundry.project.id"] = AGENT_PROJECT_NAME + if session_id: + attrs["gen_ai.conversation.id"] = session_id + return attrs + + +def _parse_baggage_key(baggage_header: str, key: str) -> str: + """Parse a single key from a W3C Baggage header. + + The baggage header is a comma-separated list of key=value pairs. + See https://www.w3.org/TR/baggage/ + + :param baggage_header: Raw baggage header value. + :param key: Key to extract. + :return: The value if found, or empty string. + """ + for member in baggage_header.split(","): + member = member.strip() + if "=" not in member: + continue + k, _, v = member.partition("=") + if k.strip() == key: + # Value may have properties after ';' — take only the value part + return v.split(";")[0].strip() + return "" + + +def _override_parent_span_id( + ctx: Any, + leaf_span_id_hex: str, +) -> Any: + """Re-parent the trace context using leaf_customer_span_id from baggage. + + Creates a new parent context with the same trace_id but the span_id + replaced by *leaf_span_id_hex*. This connects the agent's root span + to the caller's leaf span so the trace tree renders correctly in + App Insights. + + :param ctx: OpenTelemetry context with extracted traceparent. + :param leaf_span_id_hex: 16-character lower-hex span ID from baggage. + :return: New context with overridden parent span ID. + """ + from opentelemetry.trace import ( + NonRecordingSpan, + SpanContext, + set_span_in_context, + ) + + # Get the current span from the context to extract trace_id and trace_flags + current_span = trace.get_current_span(ctx) if ctx else None + if current_span is None: + return ctx + + current_ctx = current_span.get_span_context() + if current_ctx is None or not current_ctx.is_valid: + return ctx + + try: + new_span_id = int(leaf_span_id_hex, 16) + except ValueError: + logger.warning("Invalid leaf_customer_span_id: %s", leaf_span_id_hex) + return ctx + + # Build a new SpanContext with the overridden span_id + new_span_context = SpanContext( + trace_id=current_ctx.trace_id, + span_id=new_span_id, + is_remote=True, + trace_flags=current_ctx.trace_flags, + trace_state=current_ctx.trace_state, + ) + return set_span_in_context(NonRecordingSpan(new_span_context), ctx) + + +_LEAF_CUSTOMER_SPAN_ID = "leaf_customer_span_id" + + +@contextmanager +def _request_span( + request: Request, + invocation_id: str, + operation_name: str = "invoke_agent", +) -> Iterator[Any]: + """Create a traced span for a request, propagating W3C context. + + Handles full W3C Trace Context propagation including the + ``leaf_customer_span_id`` baggage key, which re-parents the agent's + root span under the caller's leaf span for correct trace tree + rendering in App Insights. + + Yields the OTel span or None if tracing is disabled. + """ + if _tracer is None or _propagator is None: + yield None + return + + # Extract W3C trace context from incoming headers + carrier = { + k: v + for k in ("traceparent", "tracestate") + if (v := request.headers.get(k)) is not None + } + ctx = _propagator.extract(carrier=carrier) if carrier else None + + # Override parent span ID with leaf_customer_span_id from baggage + # (see SPEC.md § leaf_customer_span_id) + baggage = request.headers.get("baggage", "") + if ctx is not None and baggage: + leaf_span_id = _parse_baggage_key(baggage, _LEAF_CUSTOMER_SPAN_ID) + if leaf_span_id: + ctx = _override_parent_span_id(ctx, leaf_span_id) + + session_id = request.query_params.get("agent_session_id", "") + attrs = _build_span_attrs(invocation_id, session_id, operation_name) + + agent_label = ( + f"{AGENT_NAME}:{AGENT_VERSION}" if AGENT_NAME and AGENT_VERSION else AGENT_NAME + ) + span_name = f"execute_agent {agent_label}" if agent_label else "execute_agent" + + with _tracer.start_as_current_span( + name=span_name, + attributes=attrs, + kind=trace.SpanKind.SERVER, + context=ctx, + ) as otel_span: + yield otel_span + + +def _record_error(span: Any, exc: Exception) -> None: + """Record an exception on a span if tracing is active.""" + if span is not None and ENABLE_TRACING: + try: + span.set_status(trace.StatusCode.ERROR, str(exc)) + span.record_exception(exc) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Health probes +# --------------------------------------------------------------------------- + + +@app.get("/liveness") +async def liveness(): + """liveness probe.""" + return {"status": "alive"} + + +@app.get("/readiness") +async def readiness(): + """readiness probe.""" + return {"status": "ready"} + + +# --------------------------------------------------------------------------- +# Agent logic — replace this with your own +# --------------------------------------------------------------------------- + + +async def run_agent(data: dict) -> dict: + """Your agent logic goes here. + + :param data: Parsed JSON request body. + :return: Response dict to serialize as JSON. + """ + greeting = f"Hello, {data.get('name', 'World')}!" + return {"greeting": greeting} + + +# --------------------------------------------------------------------------- +# POST /invocations — required +# --------------------------------------------------------------------------- + + +@app.post("/invocations") +async def invoke(request: Request): + """Execute the agent.""" + invocation_id = ( + request.headers.get(INVOCATION_ID_HEADER) or str(uuid.uuid4()) + ) + + data = await request.json() + + with _request_span(request, invocation_id) as span: + try: + result = await run_agent(data) + except Exception as exc: + _record_error(span, exc) + raise + + return JSONResponse( + result, + headers={INVOCATION_ID_HEADER: invocation_id}, + ) + + +if __name__ == "__main__": + import uvicorn + port = int(os.environ.get("AGENT_SERVER_PORT", "8088")) + uvicorn.run(app, host="0.0.0.0", port=port) diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/minimal_server.py b/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/minimal_server.py new file mode 100644 index 000000000000..18ba3999d46e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/minimal_server.py @@ -0,0 +1,31 @@ +"""Minimal agent — no SDK required. + +Implements only the required ``POST /invocations`` endpoint. +See SPEC.md for the full Invocation API contract. + +Usage:: + + pip install fastapi uvicorn + python minimal_server.py + + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"name": "Alice"}' + # -> {"greeting": "Hello, Alice!"} +""" +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse + +app = FastAPI() + + +@app.post("/invocations") +async def invoke(request: Request): + data = await request.json() + greeting = f"Hello, {data.get('name', 'Alice')}!" + return JSONResponse({"greeting": greeting}) + + +if __name__ == "__main__": + import os + import uvicorn + port = int(os.environ.get("AGENT_SERVER_PORT", "8088")) + uvicorn.run(app, host="0.0.0.0", port=port)