diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/agent_framework.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/agent_framework.py index 7177b522d2a9..09e6904f8a63 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/agent_framework.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/agent_framework.py @@ -77,7 +77,9 @@ def _resolve_stream_timeout(self, request_body: CreateResponse) -> float: def init_tracing(self): exporter = os.environ.get(AdapterConstants.OTEL_EXPORTER_ENDPOINT) - app_insights_conn_str = os.environ.get(AdapterConstants.APPLICATION_INSIGHTS_CONNECTION_STRING) + app_insights_conn_str = os.environ.get( + AdapterConstants.APPLICATION_INSIGHTS_CONNECTION_STRING # pylint: disable=no-member + ) project_endpoint = os.environ.get(AdapterConstants.AZURE_AI_PROJECT_ENDPOINT) if project_endpoint: diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py index 805a5eeb9dec..ce2010566ce0 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py @@ -213,7 +213,7 @@ def _coerce_result_text(self, value: Any) -> str | dict: def _construct_response_data(self, output_items: List[dict]) -> dict: agent_id = AgentIdGenerator.generate(self._context) - response_data = { + response_data: dict[str, Any] = { "object": "response", "metadata": {}, "agent": agent_id, diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/foundry_id_generator.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/foundry_id_generator.py index 910a7c481daa..3f571503620a 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/foundry_id_generator.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/foundry_id_generator.py @@ -34,12 +34,11 @@ def __init__(self, response_id: Optional[str], conversation_id: Optional[str]): def from_request(cls, payload: dict) -> "FoundryIdGenerator": response_id = payload.get("metadata", {}).get("response_id", None) conv_id_raw = payload.get("conversation", None) + conv_id: Optional[str] = None if isinstance(conv_id_raw, str): conv_id = conv_id_raw elif isinstance(conv_id_raw, dict): conv_id = conv_id_raw.get("id", None) - else: - conv_id = None return cls(response_id, conv_id) def generate(self, category: Optional[str] = None) -> str: diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml index 5552ff8233d2..41dc302054a3 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml @@ -64,4 +64,5 @@ pyright = false verifytypes = false # incompatible python version for -core verify_keywords = false mindependency = false # depends on -core package -whl_no_aio = false \ No newline at end of file +whl_no_aio = false +apistub = false # pip 24.0 crashes resolving langchain/langgraph transitive deps during stub generation \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-server/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-server/CHANGELOG.md new file mode 100644 index 000000000000..3b472f26b730 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/CHANGELOG.md @@ -0,0 +1,12 @@ +# Release History + +## 1.0.0b1 (Unreleased) + +### Features Added + +- Initial release of `azure-ai-agentserver-server`. +- Generic `AgentServer` base class with pluggable protocol heads. +- `/invoke` protocol head with all 4 operations: create, get, cancel, and OpenAPI spec. +- OpenAPI spec-based request/response validation via `jsonschema`. +- Health check endpoints (`/liveness`, `/readiness`). +- Streaming and non-streaming invocation support. diff --git a/sdk/agentserver/azure-ai-agentserver-server/LICENSE b/sdk/agentserver/azure-ai-agentserver-server/LICENSE new file mode 100644 index 000000000000..b2f52a2bad4e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/LICENSE @@ -0,0 +1,21 @@ +Copyright (c) Microsoft Corporation. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/sdk/agentserver/azure-ai-agentserver-server/MANIFEST.in b/sdk/agentserver/azure-ai-agentserver-server/MANIFEST.in new file mode 100644 index 000000000000..49a1b88738e9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/MANIFEST.in @@ -0,0 +1,8 @@ +include *.md +include LICENSE +recursive-include tests *.py +recursive-include samples *.py *.md +include azure/__init__.py +include azure/ai/__init__.py +include azure/ai/agentserver/__init__.py +include azure/ai/agentserver/server/py.typed diff --git a/sdk/agentserver/azure-ai-agentserver-server/README.md b/sdk/agentserver/azure-ai-agentserver-server/README.md new file mode 100644 index 000000000000..581da4b56339 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/README.md @@ -0,0 +1,312 @@ +# Azure AI Agent Server client library for Python + +A standalone, **protocol-agnostic agent server** package for Azure AI. Provides a +Starlette-based `AgentServer` class with pluggable protocol heads, OpenAPI spec +serving, optional request validation, optional tracing, and health +endpoints — with **zero framework coupling**. + +## Getting started + +### Install the package + +```bash +pip install azure-ai-agentserver-server +``` + +**Requires Python >= 3.10.** + +### Quick start + +```python +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver.server import AgentServer + +server = AgentServer() + + +@server.invoke_handler +async def handle(request: Request) -> Response: + data = await request.json() + greeting = f"Hello, {data['name']}!" + return JSONResponse({"greeting": greeting}) + + +if __name__ == "__main__": + server.run() +``` + +```bash +# Start the agent +python my_agent.py + +# Call it +curl -X POST http://localhost:8088/invocations \ + -H "Content-Type: application/json" \ + -d '{"name": "World"}' +# → {"greeting": "Hello, World!"} +``` + +## Key concepts + +`AgentServer` is a class for building agent endpoints that plug into the +Azure Agent Service. Register handlers with decorators on an `AgentServer` +instance. Multiple protocol heads (`/invoke` today, `/responses` and others +in the future) are supported through a pluggable handler architecture. + +The Azure Agent Service expects specific route paths (`/invocations`, `/liveness`, +`/readiness`, etc.) for deployment. `AgentServer` wires these automatically so +your agent is compatible with the hosting platform — no manual route setup required. + +**Key properties:** + +- **Platform-compatible routes** — automatically registers the exact endpoints the Azure + Agent Service expects, so your agent deploys without configuration changes. +- **Starlette + Hypercorn** — lightweight ASGI server with native HTTP/1.1 and HTTP/2 + support. +- **Raw protocol access** — receive raw Starlette `Request` objects, return raw + Starlette `Response` objects. Full control over content types, streaming, headers, + SSE, and status codes. +- **Automatic invocation ID tracking** — every request gets a unique ID injected into + `request.state` and the `x-agent-invocation-id` response header. +- **OpenAPI spec serving** — pass a spec and it is served at + `GET /invocations/docs/openapi.json` for documentation / tooling. +- **Optional request validation** — opt in to validate incoming request bodies + against the OpenAPI spec before reaching your code. +- **Request timeout and graceful shutdown** — configurable invoke timeouts (504) and graceful shutdown — all via + constructor args or environment variables. +- **Optional OpenTelemetry tracing** — opt-in span instrumentation that covers the full + request lifecycle, including streaming responses. +- **Health endpoints** — `/liveness` and `/readiness` out of the box. + +### Architecture + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ Azure Agent Service (Cloud Infrastructure) │ +│ Protocols: /invoke, /responses, /mcp, /a2a, /activity │ +└────────────────────────────┬────────────────────────────────────────┘ + │ + ▼ + ┌───────────────────────────────┐ + │ azure-ai-agentserver-server │ + │ AgentServer │ + │ │ + │ Protocol heads: │ + │ • /invoke │ + │ • /responses (planned) │ + │ │ + │ OpenAPI spec serving │ + │ Optional request validation │ + │ Invocation ID tracking │ + │ Invoke timeout / graceful │ + │ shutdown │ + │ Optional OTel tracing │ + │ Health & readiness endpoints │ + └───────────────────────────────┘ + │ + Your integration code: + ┌─────────┼─────────┐ + ▼ ▼ ▼ + LangGraph Agent Semantic + Framework Kernel +``` + +**Single package, multiple protocol heads, no framework coupling.** + +### Handler registration + +Instantiate `AgentServer` and register handlers with decorators: + +| Decorator | Required | Description | +|-----------|----------|-------------| +| `@server.invoke_handler` | **Yes** | Register the invoke handler function. | +| `@server.get_invocation_handler` | No | Register a get-invocation handler. Default returns 404. | +| `@server.cancel_invocation_handler` | No | Register a cancel-invocation handler. Default returns 404. | +| `@server.shutdown_handler` | No | Register a shutdown handler. Default is a no-op. | + +The invocation ID is available via `request.state.invocation_id` (auto-generated for +`invoke`, extracted from the URL path for `get_invocation` / `cancel_invocation`). +The server auto-injects the `x-agent-invocation-id` response header if not already set. + +### Routes + +| Route | Method | Description | +|-------|--------|-------------| +| `/invocations` | POST | Create and process an invocation | +| `/invocations/{id}` | GET | Retrieve a previous invocation result | +| `/invocations/{id}/cancel` | POST | Cancel a running invocation | +| `/invocations/docs/openapi.json` | GET | Return the registered OpenAPI spec | +| `/liveness` | GET | Health check | +| `/readiness` | GET | Readiness check | + +### Configuration + +All settings follow the same resolution order: **constructor argument > environment +variable > default**. Set a value to `0` to disable the corresponding feature. + +| Constructor param | Environment variable | Default | Description | +|---|---|---|---| +| `port` (on `run()`) | `AGENT_SERVER_PORT` | `8088` | Port to bind | +| `graceful_shutdown_timeout` | `AGENT_GRACEFUL_SHUTDOWN_TIMEOUT` | `30` (seconds) | Drain period after SIGTERM before forced exit | +| `request_timeout` | `AGENT_REQUEST_TIMEOUT` | `300` (seconds) | Max time for `invoke()` before 504 | +| `enable_tracing` | `AGENT_ENABLE_TRACING` | `false` | Enable OpenTelemetry tracing | +| `enable_request_validation` | `AGENT_ENABLE_REQUEST_VALIDATION` | `false` | Validate request bodies against `openapi_spec` | +| `log_level` | `AGENT_LOG_LEVEL` | `INFO` | Library log level (`DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`) | +| `debug_errors` | `AGENT_DEBUG_ERRORS` | `false` | Include exception details in error responses | + +```python +server = AgentServer( + request_timeout=60, # 1 minute + graceful_shutdown_timeout=15, # 15 s drain +) +server.run() +``` + +Or configure entirely via environment variables — no code changes needed for deployment +tuning. + +## Examples + +### OpenAPI spec & validation + +Pass an OpenAPI spec to serve it at `/invocations/docs/openapi.json`. +Opt in to runtime request validation with `enable_request_validation=True`: + +```python +spec = { + "openapi": "3.0.0", + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"greeting": {"type": "string"}}, + } + } + } + } + }, + } + } + }, +} + +# Spec served for documentation; validation off (default) +server = AgentServer(openapi_spec=spec) + +# Spec served AND requests validated at runtime +server = AgentServer(openapi_spec=spec, enable_request_validation=True) +server.run() +``` + +- `GET /invocations/docs/openapi.json` serves the registered spec (or 404 if none). +- When validation is enabled, non-conforming **requests** return 400 with details. + +### Tracing + +Tracing is **disabled by default**. Enable it via constructor or environment variable: + +```python +server = AgentServer(enable_tracing=True) +``` + +or: + +```bash +export AGENT_ENABLE_TRACING=true +``` + +Install the tracing extras (includes OpenTelemetry and the Azure Monitor exporter): + +```bash +pip install azure-ai-agentserver-server[tracing] +``` + +When enabled, spans are created for `invoke`, `get_invocation`, and `cancel_invocation` +endpoints. For streaming responses, the span stays open until the last chunk is sent, +accurately capturing the full transfer duration. Errors during streaming are recorded on +the span. Incoming `traceparent` / `tracestate` headers are propagated via W3C +TraceContext. + +#### Application Insights integration + +When tracing is enabled **and** an Application Insights connection string is available, +traces and logs are automatically exported to Azure Monitor. The connection string is +resolved in the following order: + +1. The `application_insights_connection_string` constructor parameter. +2. The `APPLICATIONINSIGHTS_CONNECTION_STRING` environment variable. + +```python +# Explicit connection string +server = AgentServer( + enable_tracing=True, + application_insights_connection_string="InstrumentationKey=...", +) + +# Or via environment variable (connection string auto-discovered) +# export APPLICATIONINSIGHTS_CONNECTION_STRING="InstrumentationKey=..." +server = AgentServer(enable_tracing=True) +``` + +If no connection string is found, tracing still works — spans are created but not exported +to Azure Monitor (you can bring your own `TracerProvider`). + +### More samples + +| Sample | Description | +|--------|-------------| +| `samples/simple_invoke_agent/` | Minimal from-scratch agent | +| `samples/openapi_validated_agent/` | OpenAPI spec with request/response validation | +| `samples/async_invoke_agent/` | Long-running tasks with get & cancel support | +| `samples/human_in_the_loop_agent/` | Synchronous human-in-the-loop interaction | +| `samples/langgraph_invoke_agent/` | Customer-managed LangGraph adapter | +| `samples/agentframework_invoke_agent/` | Customer-managed Agent Framework adapter | + +## Troubleshooting + +### Reporting issues + +To report an issue with the client library, or request additional features, please open a +GitHub issue [here](https://github.com/Azure/azure-sdk-for-python/issues). + +## Next steps + +Please visit [Samples](https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/agentserver/azure-ai-agentserver-server/samples) +for more usage examples. + +## Contributing + +This project welcomes contributions and suggestions. Most contributions require +you to agree to a Contributor License Agreement (CLA) declaring that you have +the right to, and actually do, grant us the rights to use your contribution. +For details, visit https://cla.microsoft.com. + +When you submit a pull request, a CLA-bot will automatically determine whether +you need to provide a CLA and decorate the PR appropriately (e.g., label, +comment). Simply follow the instructions provided by the bot. You will only +need to do this once across all repos using our CLA. + +This project has adopted the +[Microsoft Open Source Code of Conduct][code_of_conduct]. For more information, +see the Code of Conduct FAQ or contact opencode@microsoft.com with any +additional questions or comments. + +[code_of_conduct]: https://opensource.microsoft.com/codeofconduct/ diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/__init__.py b/sdk/agentserver/azure-ai-agentserver-server/azure/__init__.py new file mode 100644 index 000000000000..d55ccad1f573 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/__init__.py @@ -0,0 +1 @@ +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/__init__.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/__init__.py new file mode 100644 index 000000000000..d55ccad1f573 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/__init__.py @@ -0,0 +1 @@ +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/__init__.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/__init__.py new file mode 100644 index 000000000000..d55ccad1f573 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/__init__.py @@ -0,0 +1 @@ +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/__init__.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/__init__.py new file mode 100644 index 000000000000..a82226f72189 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/__init__.py @@ -0,0 +1,10 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + +from ._version import VERSION +from ._base import AgentServer + +__all__ = ["AgentServer"] +__version__ = VERSION diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_base.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_base.py new file mode 100644 index 000000000000..e71d98ab36da --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_base.py @@ -0,0 +1,349 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import asyncio # pylint: disable=do-not-import-asyncio +import contextlib +import logging +from collections.abc import AsyncGenerator, Awaitable, Callable # pylint: disable=import-error +from typing import Any, Optional + +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Route + +from ._constants import Constants +from ._logger import get_logger +from ._tracing import _TracingHelper +from ._openapi_validator import _OpenApiValidator +from ._invocation import _InvocationProtocol +from ._server_context import _ServerContext +from . import _config + +logger = get_logger() + +# Pre-built health-check responses to avoid per-request allocation. +_LIVENESS_BODY = b'{"status":"alive"}' +_READINESS_BODY = b'{"status":"ready"}' + + +class AgentServer: # pylint: disable=too-many-instance-attributes + """Agent server with pluggable protocol heads. + + Instantiate and register handlers with decorators:: + + server = AgentServer() + + @server.invoke_handler + async def handle(request): + return JSONResponse({"ok": True}) + + Optionally register handlers with :meth:`get_invocation_handler` and + :meth:`cancel_invocation_handler` for additional protocol support. + + Developer receives raw Starlette ``Request`` objects and returns raw + Starlette ``Response`` objects, giving full control over content types, + streaming, headers, and status codes. + + :param openapi_spec: Optional OpenAPI spec dict. When provided, the spec + is served at ``GET /invocations/docs/openapi.json`` for documentation. + Runtime request validation is **not** enabled by default — set + *enable_request_validation* to opt in. + :type openapi_spec: Optional[dict[str, Any]] + :param enable_request_validation: When *True*, incoming ``POST /invocations`` + request bodies are validated against the *openapi_spec* before reaching + :meth:`invoke`. When *None* (default) the + ``AGENT_ENABLE_REQUEST_VALIDATION`` env var is consulted (``"true"`` to + enable). Requires *openapi_spec* to be set. + :type enable_request_validation: Optional[bool] + :param enable_tracing: Enable OpenTelemetry tracing. When *None* (default) + the ``AGENT_ENABLE_TRACING`` env var is consulted (``"true"`` to enable). + Requires ``opentelemetry-api`` — install with + ``pip install azure-ai-agentserver-server[tracing]``. + When an Application Insights connection string is also available, + traces and logs are automatically exported to Azure Monitor. + :type enable_tracing: Optional[bool] + :param application_insights_connection_string: Application Insights + connection string for exporting traces and logs to Azure Monitor. + When *None* (default) the ``APPLICATIONINSIGHTS_CONNECTION_STRING`` + env var is consulted. Only takes effect when *enable_tracing* is + ``True``. Requires ``opentelemetry-sdk`` and + ``azure-monitor-opentelemetry-exporter`` (included in the + ``[tracing]`` extras group). + :type application_insights_connection_string: Optional[str] + :param graceful_shutdown_timeout: Seconds to wait for in-flight requests to + complete after receiving SIGTERM / shutdown signal. When *None* (default) + the ``AGENT_GRACEFUL_SHUTDOWN_TIMEOUT`` env var is consulted; if that is + also unset the default is 30 seconds. Set to ``0`` to disable the + drain period. + :type graceful_shutdown_timeout: Optional[int] + :param request_timeout: Maximum seconds an ``invoke()`` call may run before + being cancelled. When *None* (default) the ``AGENT_REQUEST_TIMEOUT`` + env var is consulted; if that is also unset the default is 300 seconds + (5 minutes). Set to ``0`` to disable the timeout. + :type request_timeout: Optional[int] + :param log_level: Library log level (e.g. ``"DEBUG"``, ``"INFO"``). When + *None* (default) the ``AGENT_LOG_LEVEL`` env var is consulted; if that + is also unset the default is ``"INFO"``. + :type log_level: Optional[str] + :param debug_errors: When *True*, error responses include the original + exception message instead of a generic ``"Internal server error"``. + When *None* (default) the ``AGENT_DEBUG_ERRORS`` env var is consulted + (any truthy value enables it). Defaults to *False*. + :type debug_errors: Optional[bool] + """ + + def __init__( + self, + *, + openapi_spec: Optional[dict[str, Any]] = None, + enable_request_validation: Optional[bool] = None, + enable_tracing: Optional[bool] = None, + application_insights_connection_string: Optional[str] = None, + graceful_shutdown_timeout: Optional[int] = None, + request_timeout: Optional[int] = None, + log_level: Optional[str] = None, + debug_errors: Optional[bool] = None, + ) -> None: + # Shutdown handler slot (server-level lifecycle) ------------------- + self._shutdown_fn: Optional[Callable] = None + + # Logging & debug ------------------------------------------------- + resolved_level = _config.resolve_log_level(log_level) + logger.setLevel(resolved_level) + if not logger.handlers: + _console = logging.StreamHandler() + _console.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")) + logger.addHandler(_console) + self._debug_errors = _config.resolve_bool_feature( + debug_errors, Constants.AGENT_DEBUG_ERRORS + ) + + # OpenAPI validation ----------------------------------------------- + _validation_on = _config.resolve_bool_feature( + enable_request_validation, Constants.AGENT_ENABLE_REQUEST_VALIDATION + ) + validator: Optional[_OpenApiValidator] = ( + _OpenApiValidator(openapi_spec) + if openapi_spec and _validation_on + else None + ) + + # Tracing ---------------------------------------------------------- + _tracing_on = _config.resolve_bool_feature(enable_tracing, Constants.AGENT_ENABLE_TRACING) + _conn_str = _config.resolve_appinsights_connection_string( + application_insights_connection_string + ) if _tracing_on else None + self._tracing: Optional[_TracingHelper] = ( + _TracingHelper(connection_string=_conn_str) if _tracing_on else None + ) + + # Timeouts --------------------------------------------------------- + self._graceful_shutdown_timeout = _config.resolve_graceful_shutdown_timeout( + graceful_shutdown_timeout + ) + self._request_timeout = _config.resolve_request_timeout(request_timeout) + + # Invocation protocol (composed) ------------------------------------- + ctx = _ServerContext( + tracing=self._tracing, + debug_errors=self._debug_errors, + request_timeout=self._request_timeout, + ) + self._invocation = _InvocationProtocol(ctx, openapi_spec, validator) + + self.app: Starlette + self._build_app() + + # ------------------------------------------------------------------ + # Shutdown handler (server-level lifecycle) + # ------------------------------------------------------------------ + + def shutdown_handler(self, fn: Callable[[], Awaitable[None]]) -> Callable[[], Awaitable[None]]: + """Register a function as the shutdown handler. + + :param fn: Async function called during graceful shutdown. + :type fn: Callable[[], Awaitable[None]] + :return: The original function (unmodified). + :rtype: Callable[[], Awaitable[None]] + """ + self._shutdown_fn = fn + return fn + + async def _dispatch_shutdown(self) -> None: + """Dispatch to the registered shutdown handler, or no-op.""" + if self._shutdown_fn is not None: + await self._shutdown_fn() + + # ------------------------------------------------------------------ + # Invocation protocol delegates + # ------------------------------------------------------------------ + + def invoke_handler( + self, fn: Callable[[Request], Awaitable[Response]] + ) -> Callable[[Request], Awaitable[Response]]: + """Register a function as the invoke handler. + + Usage:: + + @server.invoke_handler + async def handle(request: Request) -> Response: + ... + + :param fn: Async function accepting a Starlette Request and returning a Response. + :type fn: Callable[[Request], Awaitable[Response]] + :return: The original function (unmodified). + :rtype: Callable[[Request], Awaitable[Response]] + """ + return self._invocation.invoke_handler(fn) + + def get_invocation_handler( + self, fn: Callable[[Request], Awaitable[Response]] + ) -> Callable[[Request], Awaitable[Response]]: + """Register a function as the get-invocation handler. + + :param fn: Async function accepting a Starlette Request and returning a Response. + :type fn: Callable[[Request], Awaitable[Response]] + :return: The original function (unmodified). + :rtype: Callable[[Request], Awaitable[Response]] + """ + return self._invocation.get_invocation_handler(fn) + + def cancel_invocation_handler( + self, fn: Callable[[Request], Awaitable[Response]] + ) -> Callable[[Request], Awaitable[Response]]: + """Register a function as the cancel-invocation handler. + + :param fn: Async function accepting a Starlette Request and returning a Response. + :type fn: Callable[[Request], Awaitable[Response]] + :return: The original function (unmodified). + :rtype: Callable[[Request], Awaitable[Response]] + """ + return self._invocation.cancel_invocation_handler(fn) + + def get_openapi_spec(self) -> Optional[dict[str, Any]]: + """Return the OpenAPI spec dict for this agent, or None. + + :return: The registered OpenAPI spec or None. + :rtype: Optional[dict[str, Any]] + """ + return self._invocation.get_openapi_spec() + + # ------------------------------------------------------------------ + # Run helpers + # ------------------------------------------------------------------ + + def _build_hypercorn_config(self, host: str, port: int) -> object: + """Create a Hypercorn config with resolved host, port and timeouts. + + :param host: Network interface to bind. + :type host: str + :param port: Port to bind. + :type port: int + :return: Configured Hypercorn config. + :rtype: hypercorn.config.Config + """ + from hypercorn.config import Config as HypercornConfig + + config = HypercornConfig() + config.bind = [f"{host}:{port}"] + config.graceful_timeout = float(self._graceful_shutdown_timeout) + return config + + def run(self, host: str = "127.0.0.1", port: Optional[int] = None) -> None: + """Start the server synchronously. + + Uses Hypercorn as the ASGI server, which supports HTTP/1.1 and HTTP/2. + + :param host: Network interface to bind. Defaults to ``127.0.0.1``. + Use ``"0.0.0.0"`` to listen on all interfaces. + :type host: str + :param port: Port to bind. Defaults to ``AGENT_SERVER_PORT`` env var or 8088. + :type port: Optional[int] + """ + from hypercorn.asyncio import serve as _hypercorn_serve + + resolved_port = _config.resolve_port(port) + logger.info("AgentServer starting on %s:%s", host, resolved_port) + config = self._build_hypercorn_config(host, resolved_port) + asyncio.run(_hypercorn_serve(self.app, config)) # type: ignore[arg-type] # Starlette is ASGI-compatible + + async def run_async(self, host: str = "127.0.0.1", port: Optional[int] = None) -> None: + """Start the server asynchronously (awaitable). + + Uses Hypercorn as the ASGI server, which supports HTTP/1.1 and HTTP/2. + + :param host: Network interface to bind. Defaults to ``127.0.0.1``. + Use ``"0.0.0.0"`` to listen on all interfaces. + :type host: str + :param port: Port to bind. Defaults to ``AGENT_SERVER_PORT`` env var or 8088. + :type port: Optional[int] + """ + from hypercorn.asyncio import serve as _hypercorn_serve + + resolved_port = _config.resolve_port(port) + logger.info("AgentServer starting on %s:%s (async)", host, resolved_port) + config = self._build_hypercorn_config(host, resolved_port) + await _hypercorn_serve(self.app, config) # type: ignore[arg-type] # Starlette is ASGI-compatible + + # ------------------------------------------------------------------ + # Private: app construction + # ------------------------------------------------------------------ + + def _build_app(self) -> None: + """Construct the Starlette ASGI application with all routes.""" + + @contextlib.asynccontextmanager + async def _lifespan(_app: Starlette) -> AsyncGenerator[None, None]: # noqa: RUF029 + logger.info("AgentServer started") + yield + + # --- SHUTDOWN: runs once when the server is stopping --- + logger.info( + "AgentServer shutting down (graceful timeout=%ss)", + self._graceful_shutdown_timeout, + ) + try: + await asyncio.wait_for( + self._dispatch_shutdown(), + timeout=self._graceful_shutdown_timeout or None, + ) + except asyncio.TimeoutError: + logger.warning( + "on_shutdown did not complete within %ss timeout", + self._graceful_shutdown_timeout, + ) + except Exception: # pylint: disable=broad-exception-caught + logger.exception("Error in on_shutdown") + + routes = list(self._invocation.routes) + routes.extend([ + Route("/liveness", self._liveness_endpoint, methods=["GET"], name="liveness"), + Route("/readiness", self._readiness_endpoint, methods=["GET"], name="readiness"), + ]) + + self.app = Starlette(routes=routes, lifespan=_lifespan) + + # ------------------------------------------------------------------ + # Health endpoints + # ------------------------------------------------------------------ + + async def _liveness_endpoint(self, request: Request) -> Response: # pylint: disable=unused-argument + """GET /liveness — health check. + + :param request: The incoming Starlette request. + :type request: Request + :return: 200 OK response. + :rtype: Response + """ + return Response(_LIVENESS_BODY, media_type="application/json") + + async def _readiness_endpoint(self, request: Request) -> Response: # pylint: disable=unused-argument + """GET /readiness — readiness check. + + :param request: The incoming Starlette request. + :type request: Request + :return: 200 OK response. + :rtype: Response + """ + return Response(_READINESS_BODY, media_type="application/json") diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_config.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_config.py new file mode 100644 index 000000000000..8110414a3b8d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_config.py @@ -0,0 +1,216 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Configuration resolution helpers for AgentServer. + +Each ``resolve_*`` function follows the same hierarchy: +1. Explicit argument (if not *None*) +2. Environment variable +3. Built-in default + +A value of ``0`` conventionally disables the corresponding feature. + +Invalid environment variable values raise ``ValueError`` immediately so +misconfiguration is surfaced at startup rather than silently masked. +""" +import os +from typing import Optional + +from ._constants import Constants + + +def _parse_int_env(var_name: str) -> Optional[int]: + """Parse an integer environment variable, raising on invalid values. + + :param var_name: Name of the environment variable. + :type var_name: str + :return: The parsed integer or None if the variable is not set. + :rtype: Optional[int] + :raises ValueError: If the variable is set but cannot be parsed as an integer. + """ + raw = os.environ.get(var_name) + if raw is None: + return None + try: + return int(raw) + except ValueError: + raise ValueError( + f"Invalid value for {var_name}: {raw!r} (expected an integer)" + ) from None + + +def _require_int(name: str, value: object) -> int: + """Validate that *value* is an integer. + + :param name: Human-readable parameter/env-var name for the error message. + :type name: str + :param value: The value to validate. + :type value: object + :return: The value cast to int. + :rtype: int + :raises ValueError: If *value* is not an integer. + """ + if not isinstance(value, int): + raise ValueError( + f"Invalid value for {name}: {value!r} (expected an integer)" + ) + return value + + +def _validate_port(value: int, source: str) -> int: + """Validate that a port number is within the valid range. + + :param value: The port number to validate. + :type value: int + :param source: Human-readable source name for the error message. + :type source: str + :return: The validated port number. + :rtype: int + :raises ValueError: If the port is outside 1-65535. + """ + if not 1 <= value <= 65535: + raise ValueError( + f"Invalid value for {source}: {value} (expected 1-65535)" + ) + return value + + +def resolve_port(port: Optional[int]) -> int: + """Resolve the server port from argument, env var, or default. + + :param port: Explicitly requested port or None. + :type port: Optional[int] + :return: The resolved port number. + :rtype: int + :raises ValueError: If the port value is not a valid integer or is outside 1-65535. + """ + if port is not None: + return _validate_port(_require_int("port", port), "port") + env_port = _parse_int_env(Constants.AGENT_SERVER_PORT) + if env_port is not None: + return _validate_port(env_port, Constants.AGENT_SERVER_PORT) + return Constants.DEFAULT_PORT + + +def resolve_graceful_shutdown_timeout(timeout: Optional[int]) -> int: + """Resolve the graceful shutdown timeout from argument, env var, or default. + + :param timeout: Explicitly requested timeout or None. + :type timeout: Optional[int] + :return: The resolved timeout in seconds. + :rtype: int + :raises ValueError: If the env var is not a valid integer. + """ + if timeout is not None: + return max(0, _require_int("graceful_shutdown_timeout", timeout)) + env_timeout = _parse_int_env(Constants.AGENT_GRACEFUL_SHUTDOWN_TIMEOUT) + if env_timeout is not None: + return max(0, env_timeout) + return Constants.DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT + + +def resolve_request_timeout(timeout: Optional[int]) -> int: + """Resolve the request timeout from argument, env var, or default. + + :param timeout: Explicitly requested timeout in seconds or None. + :type timeout: Optional[int] + :return: The resolved timeout in seconds. + :rtype: int + :raises ValueError: If the env var is not a valid integer. + """ + if timeout is not None: + return max(0, _require_int("request_timeout", timeout)) + env_timeout = _parse_int_env(Constants.AGENT_REQUEST_TIMEOUT) + if env_timeout is not None: + return max(0, env_timeout) + return Constants.DEFAULT_REQUEST_TIMEOUT + + +def resolve_bool_feature(value: Optional[bool], env_var: str) -> bool: + """Resolve an opt-in boolean feature from argument, env var, or default (False). + + :param value: Explicitly requested value or None. + :type value: Optional[bool] + :param env_var: Name of the environment variable to consult. + :type env_var: str + :return: Whether the feature is enabled. + :rtype: bool + """ + if value is not None: + return bool(value) + return os.environ.get(env_var, "").lower() == "true" + + +_VALID_LOG_LEVELS = ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL") + + +def resolve_appinsights_connection_string( + connection_string: Optional[str], +) -> Optional[str]: + """Resolve the Application Insights connection string. + + Resolution order: + + 1. Explicit *connection_string* argument (if not *None*). + 2. ``APPLICATIONINSIGHTS_CONNECTION_STRING`` env var (standard Azure + Monitor convention). + 3. *None* — no connection string available. + + :param connection_string: Explicitly provided connection string or None. + :type connection_string: Optional[str] + :return: The resolved connection string, or None. + :rtype: Optional[str] + """ + if connection_string is not None: + return connection_string + return os.environ.get( + Constants.APPLICATIONINSIGHTS_CONNECTION_STRING + ) + + +def resolve_log_level(level: Optional[str]) -> str: + """Resolve the library log level from argument, env var, or default (``INFO``). + + :param level: Explicitly requested level (e.g. ``"DEBUG"``) or None. + :type level: Optional[str] + :return: Validated, upper-cased log level string. + :rtype: str + :raises ValueError: If the value is not one of DEBUG/INFO/WARNING/ERROR/CRITICAL. + """ + if level is not None: + normalized = level.upper() + else: + normalized = os.environ.get(Constants.AGENT_LOG_LEVEL, "INFO").upper() + if normalized not in _VALID_LOG_LEVELS: + raise ValueError( + f"Invalid log level: {normalized!r} " + f"(expected one of {', '.join(_VALID_LOG_LEVELS)})" + ) + return normalized + + +def resolve_agent_name() -> str: + """Resolve the agent name from the ``AGENT_NAME`` environment variable. + + :return: The agent name, or an empty string if not set. + :rtype: str + """ + return os.environ.get(Constants.AGENT_NAME, "") + + +def resolve_agent_version() -> str: + """Resolve the agent version from the ``AGENT_VERSION`` environment variable. + + :return: The agent version, or an empty string if not set. + :rtype: str + """ + return os.environ.get(Constants.AGENT_VERSION, "") + + +def resolve_project_id() -> str: + """Resolve the Foundry project ID from the ``AGENT_PROJECT_NAME`` environment variable. + + :return: The project ID, or an empty string if not set. + :rtype: str + """ + return os.environ.get(Constants.AGENT_PROJECT_NAME, "") diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py new file mode 100644 index 000000000000..7a32bff8f972 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_constants.py @@ -0,0 +1,23 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + + +class Constants: + """Well-known environment variables and defaults for AgentServer.""" + + AGENT_LOG_LEVEL = "AGENT_LOG_LEVEL" + AGENT_DEBUG_ERRORS = "AGENT_DEBUG_ERRORS" + AGENT_ENABLE_TRACING = "AGENT_ENABLE_TRACING" + AGENT_SERVER_PORT = "AGENT_SERVER_PORT" + AGENT_GRACEFUL_SHUTDOWN_TIMEOUT = "AGENT_GRACEFUL_SHUTDOWN_TIMEOUT" + AGENT_REQUEST_TIMEOUT = "AGENT_REQUEST_TIMEOUT" + AGENT_ENABLE_REQUEST_VALIDATION = "AGENT_ENABLE_REQUEST_VALIDATION" + APPLICATIONINSIGHTS_CONNECTION_STRING = "APPLICATIONINSIGHTS_CONNECTION_STRING" + AGENT_NAME = "AGENT_NAME" + AGENT_VERSION = "AGENT_VERSION" + AGENT_PROJECT_NAME = "AGENT_PROJECT_NAME" + DEFAULT_PORT = 8088 + DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT = 30 + DEFAULT_REQUEST_TIMEOUT = 300 # 5 minutes + INVOCATION_ID_HEADER = "x-agent-invocation-id" diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_errors.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_errors.py new file mode 100644 index 000000000000..a7774fd09a92 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_errors.py @@ -0,0 +1,52 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Standardised error response builder for AgentServer. + +Every error returned by the framework uses the shape:: + + { + "error": { + "code": "...", // required – machine-readable error code + "message": "...", // required – human-readable description + "details": [ ... ] // optional – child errors + } + } +""" +from __future__ import annotations + +from typing import Any, Optional + +from starlette.responses import JSONResponse + + +def error_response( + code: str, + message: str, + *, + status_code: int, + details: Optional[list[dict[str, Any]]] = None, + headers: Optional[dict[str, str]] = None, +) -> JSONResponse: + """Build a ``JSONResponse`` with the standard error envelope. + + :param code: Machine-readable error code (e.g. ``"internal_error"``). + :type code: str + :param message: Human-readable error message. + :type message: str + :keyword status_code: HTTP status code for the response. + :paramtype status_code: int + :keyword details: Child error objects, each with at least ``code`` and + ``message`` keys. + :paramtype details: Optional[list[dict[str, Any]]] + :keyword headers: Extra HTTP headers to include on the response. + :paramtype headers: Optional[dict[str, str]] + :return: A ready-to-send JSON error response. + :rtype: JSONResponse + """ + body: dict[str, Any] = {"code": code, "message": message} + if details is not None: + body["details"] = details + return JSONResponse( + {"error": body}, status_code=status_code, headers=headers + ) diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_invocation.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_invocation.py new file mode 100644 index 000000000000..1d116f2aee70 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_invocation.py @@ -0,0 +1,377 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Invocation protocol for AgentServer. + +Encapsulates the ``/invocations`` REST endpoints, handler decorators, +dispatch methods, and OpenAPI spec serving. Designed as a standalone +composed object so that ``AgentServer`` can compose multiple protocol +heads (invocation, chat, etc.) without inheritance conflicts. + +Shared server state (tracing, error handling, timeouts) is received +via a :class:`~._server_context._ServerContext` instance. +""" +import asyncio # pylint: disable=do-not-import-asyncio +import contextlib +import uuid +from collections.abc import Awaitable, Callable # pylint: disable=import-error +from typing import Any, Optional + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse +from starlette.routing import Route + +from ._constants import Constants +from ._errors import error_response +from ._logger import get_logger +from ._openapi_validator import _OpenApiValidator +from ._server_context import _ServerContext + +logger = get_logger() + + +class _InvocationProtocol: + """Invocation protocol implementation. + + Receives shared server state via a :class:`_ServerContext` and manages + its own handler slots, agent identity, and endpoint handlers. + + **Not intended for direct instantiation.** Use via + :class:`~azure.ai.agentserver.server.AgentServer`, which creates and + delegates to this object. + """ + + def __init__( + self, + ctx: _ServerContext, + openapi_spec: Optional[dict[str, Any]], + validator: Optional[_OpenApiValidator], + ) -> None: + """Initialise the invocation protocol. + + :param ctx: Shared server context (tracing, debug_errors, request_timeout). + :type ctx: _ServerContext + :param openapi_spec: Optional OpenAPI spec dict for documentation. + :type openapi_spec: Optional[dict[str, Any]] + :param validator: Optional request validator built from the spec. + :type validator: Optional[_OpenApiValidator] + """ + self._ctx = ctx + self._invoke_fn: Optional[Callable] = None + self._get_invocation_fn: Optional[Callable] = None + self._cancel_invocation_fn: Optional[Callable] = None + self._openapi_spec = openapi_spec + self._validator = validator + + # ------------------------------------------------------------------ + # Route registration + # ------------------------------------------------------------------ + + @property + def routes(self) -> list[Route]: + """Starlette routes for the invocation protocol. + + :return: A list of four Route objects for the invocation endpoints. + :rtype: list[Route] + """ + return [ + Route( + "/invocations/docs/openapi.json", + self._get_openapi_spec_endpoint, + methods=["GET"], + name="get_openapi_spec", + ), + Route( + "/invocations", + self._create_invocation_endpoint, + methods=["POST"], + name="create_invocation", + ), + Route( + "/invocations/{invocation_id}", + self._get_invocation_endpoint, + methods=["GET"], + name="get_invocation", + ), + Route( + "/invocations/{invocation_id}/cancel", + self._cancel_invocation_endpoint, + methods=["POST"], + name="cancel_invocation", + ), + ] + + # ------------------------------------------------------------------ + # Handler decorators + # ------------------------------------------------------------------ + + def invoke_handler( + self, fn: Callable[[Request], Awaitable[Response]] + ) -> Callable[[Request], Awaitable[Response]]: + """Store *fn* as the invoke handler. See :meth:`AgentServer.invoke_handler`. + + :param fn: Async function accepting a Starlette Request and returning a Response. + :type fn: Callable[[Request], Awaitable[Response]] + :return: The original function (unmodified). + :rtype: Callable[[Request], Awaitable[Response]] + """ + self._invoke_fn = fn + return fn + + def get_invocation_handler( + self, fn: Callable[[Request], Awaitable[Response]] + ) -> Callable[[Request], Awaitable[Response]]: + """Store *fn* as the get-invocation handler. See :meth:`AgentServer.get_invocation_handler`. + + :param fn: Async function accepting a Starlette Request and returning a Response. + :type fn: Callable[[Request], Awaitable[Response]] + :return: The original function (unmodified). + :rtype: Callable[[Request], Awaitable[Response]] + """ + self._get_invocation_fn = fn + return fn + + def cancel_invocation_handler( + self, fn: Callable[[Request], Awaitable[Response]] + ) -> Callable[[Request], Awaitable[Response]]: + """Store *fn* as the cancel-invocation handler. See :meth:`AgentServer.cancel_invocation_handler`. + + :param fn: Async function accepting a Starlette Request and returning a Response. + :type fn: Callable[[Request], Awaitable[Response]] + :return: The original function (unmodified). + :rtype: Callable[[Request], Awaitable[Response]] + """ + self._cancel_invocation_fn = fn + return fn + + # ------------------------------------------------------------------ + # Dispatch methods (internal) + # ------------------------------------------------------------------ + + async def _dispatch_invoke(self, request: Request) -> Response: + """Dispatch to the registered invoke handler. + + :param request: The incoming Starlette request. + :type request: Request + :return: The response from the invoke handler. + :rtype: Response + :raises NotImplementedError: If no invoke handler has been registered. + """ + if self._invoke_fn is not None: + return await self._invoke_fn(request) + raise NotImplementedError( + "No invoke handler registered. Use the @server.invoke_handler decorator." + ) + + async def _dispatch_get_invocation(self, request: Request) -> Response: + """Dispatch to the registered get-invocation handler, or return 501. + + :param request: The incoming Starlette request. + :type request: Request + :return: The response from the get-invocation handler. + :rtype: Response + """ + if self._get_invocation_fn is not None: + return await self._get_invocation_fn(request) + return error_response("not_supported", "get_invocation not supported", status_code=501) + + async def _dispatch_cancel_invocation(self, request: Request) -> Response: + """Dispatch to the registered cancel-invocation handler, or return 501. + + :param request: The incoming Starlette request. + :type request: Request + :return: The response from the cancel-invocation handler. + :rtype: Response + """ + if self._cancel_invocation_fn is not None: + return await self._cancel_invocation_fn(request) + return error_response("not_supported", "cancel_invocation not supported", status_code=501) + + def get_openapi_spec(self) -> Optional[dict[str, Any]]: + """Return the stored OpenAPI spec. See :meth:`AgentServer.get_openapi_spec`. + + :return: The registered OpenAPI spec or None. + :rtype: Optional[dict[str, Any]] + """ + return self._openapi_spec + + # ------------------------------------------------------------------ + # Endpoint handlers + # ------------------------------------------------------------------ + + async def _get_openapi_spec_endpoint(self, request: Request) -> Response: # pylint: disable=unused-argument + """GET /invocations/docs/openapi.json — return registered spec or 404. + + :param request: The incoming Starlette request. + :type request: Request + :return: JSON response with the spec or 404. + :rtype: Response + """ + spec = self.get_openapi_spec() + if spec is None: + return error_response("not_found", "No OpenAPI spec registered", status_code=404) + return JSONResponse(spec) + + async def _create_invocation_endpoint(self, request: Request) -> Response: + """POST /invocations — create and process an invocation. + + :param request: The incoming Starlette request. + :type request: Request + :return: The invocation result or error response. + :rtype: Response + """ + invocation_id = ( + request.headers.get(Constants.INVOCATION_ID_HEADER) + or str(uuid.uuid4()) + ) + request.state.invocation_id = invocation_id + + # Validate request body against OpenAPI spec + if self._validator is not None: + content_type = request.headers.get("content-type", "application/json") + body = await request.body() + errors = self._validator.validate_request(body, content_type) + if errors: + return error_response( + "invalid_payload", + "Request validation failed", + status_code=400, + details=[ + {"code": "validation_error", "message": e} + for e in errors + ], + ) + + # Use manual span management so that streaming responses keep the + # span open until the last chunk is yielded (or an error occurs). + otel_span = None + if self._ctx.tracing is not None: + otel_span = self._ctx.tracing.start_request_span( + request.headers, + invocation_id, + span_operation="execute_agent", + operation_name="invoke_agent", + session_id=request.query_params.get("agent_session_id", ""), + ) + try: + invoke_awaitable = self._dispatch_invoke(request) + timeout = self._ctx.request_timeout or None # 0 → None (no limit) + response = await asyncio.wait_for(invoke_awaitable, timeout=timeout) + except NotImplementedError as exc: + if self._ctx.tracing is not None: + self._ctx.tracing.end_span(otel_span, exc=exc) + logger.error("Invocation %s failed: %s", invocation_id, exc) + return error_response( + "not_implemented", + str(exc), + status_code=501, + headers={Constants.INVOCATION_ID_HEADER: invocation_id}, + ) + except asyncio.TimeoutError as exc: + if self._ctx.tracing is not None: + self._ctx.tracing.end_span(otel_span, exc=exc) + logger.error( + "Invocation %s timed out after %ss", + invocation_id, + self._ctx.request_timeout, + ) + return error_response( + "request_timeout", + f"Invocation timed out after {self._ctx.request_timeout}s", + status_code=504, + headers={Constants.INVOCATION_ID_HEADER: invocation_id}, + ) + except Exception as exc: # pylint: disable=broad-exception-caught + if self._ctx.tracing is not None: + self._ctx.tracing.end_span(otel_span, exc=exc) + logger.error("Error processing invocation %s: %s", invocation_id, exc, exc_info=True) + message = str(exc) if self._ctx.debug_errors else "Internal server error" + return error_response( + "internal_error", + message, + status_code=500, + headers={Constants.INVOCATION_ID_HEADER: invocation_id}, + ) + + # For streaming responses, wrap the body iterator so the span stays + # open until all chunks are sent and captures any streaming errors. + if isinstance(response, StreamingResponse) and self._ctx.tracing is not None: + response.body_iterator = self._ctx.tracing.trace_stream(response.body_iterator, otel_span) + elif self._ctx.tracing is not None: + self._ctx.tracing.end_span(otel_span) + + # Always set invocation_id header (overrides any handler-set value) + response.headers[Constants.INVOCATION_ID_HEADER] = invocation_id + + return response + + async def _traced_invocation_endpoint( + self, + request: Request, + span_operation: str, + dispatch: Callable[[Request], Awaitable[Response]], + ) -> Response: + """Shared implementation for get/cancel invocation endpoints. + + Extracts the invocation ID from path params, optionally creates a + tracing span, dispatches to the handler, and handles errors. + + :param request: The incoming Starlette request. + :type request: Request + :param span_operation: Span operation name (e.g. + ``"get_invocation"``). + :type span_operation: str + :param dispatch: The dispatch method to invoke. + :type dispatch: Callable[[Request], Awaitable[Response]] + :return: The handler response or an error response. + :rtype: Response + """ + invocation_id = request.path_params["invocation_id"] + request.state.invocation_id = invocation_id + + span_cm: Any = contextlib.nullcontext(None) + if self._ctx.tracing is not None: + span_cm = self._ctx.tracing.request_span( + request.headers, invocation_id, span_operation, + session_id=request.query_params.get("agent_session_id", ""), + ) + with span_cm as _otel_span: + try: + response = await dispatch(request) + response.headers[Constants.INVOCATION_ID_HEADER] = invocation_id + return response + except Exception as exc: # pylint: disable=broad-exception-caught + if self._ctx.tracing is not None: + self._ctx.tracing.record_error(_otel_span, exc) + logger.error("Error in %s %s: %s", span_operation, invocation_id, exc, exc_info=True) + message = str(exc) if self._ctx.debug_errors else "Internal server error" + return error_response( + "internal_error", + message, + status_code=500, + headers={Constants.INVOCATION_ID_HEADER: invocation_id}, + ) + + async def _get_invocation_endpoint(self, request: Request) -> Response: + """GET /invocations/{invocation_id} — retrieve an invocation result. + + :param request: The incoming Starlette request. + :type request: Request + :return: The stored result or 501. + :rtype: Response + """ + return await self._traced_invocation_endpoint( + request, "get_invocation", self._dispatch_get_invocation + ) + + async def _cancel_invocation_endpoint(self, request: Request) -> Response: + """POST /invocations/{invocation_id}/cancel — cancel an invocation. + + :param request: The incoming Starlette request. + :type request: Request + :return: The cancellation result or 501. + :rtype: Response + """ + return await self._traced_invocation_endpoint( + request, "cancel_invocation", self._dispatch_cancel_invocation + ) diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_logger.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_logger.py new file mode 100644 index 000000000000..4b4e960e971d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_logger.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import logging + + +def get_logger() -> logging.Logger: + """Return the library-scoped logger. + + The log level is configured by the ``log_level`` constructor parameter + of :class:`AgentServer` (or the ``AGENT_LOG_LEVEL`` env var as fallback). + This function simply returns the named logger without forcing a level so + that the level already set by the constructor is preserved. + + :return: Logger instance for azure.ai.agentserver. + :rtype: logging.Logger + """ + return logging.getLogger("azure.ai.agentserver") diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_openapi_validator.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_openapi_validator.py new file mode 100644 index 000000000000..1e01a1b1e1d8 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_openapi_validator.py @@ -0,0 +1,714 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# +# Why not use openapi-spec-validator / openapi-schema-validator / openapi-core? +# +# - openapi-spec-validator validates that an OpenAPI *document* is well-formed +# (meta-validation). It does not validate request/response data. +# +# - openapi-schema-validator (OAS30Validator) handles nullable, readOnly/ +# writeOnly, and OpenAPI keyword stripping natively, but does NOT provide: +# * $ref resolution from the full spec into extracted sub-schemas, +# * discriminator-aware oneOf/anyOf error collection (ported from the +# server-side C# JsonSchemaValidator), +# * JSON-path-prefixed error messages. +# Adopting it would save ~65 lines of preprocessing but adds three +# transitive dependencies (jsonschema-specifications, rfc3339-validator, +# jsonschema-path) while still requiring all custom error-collection code. +# +# - openapi-core is a full request/response middleware framework with its +# own routing and parsing. It conflicts with our Starlette middleware +# approach and is a much heavier dependency. +# +# Keeping only jsonschema as the single validation dependency gives us full +# control over error output and avoids unnecessary transitive packages. +# +import copy +import json +import re +from collections import Counter +from collections.abc import Callable # pylint: disable=import-error +from datetime import datetime +from typing import Any, Optional + +import jsonschema +from jsonschema.exceptions import best_match +from jsonschema import FormatChecker, ValidationError + +from ._logger import get_logger + +logger = get_logger() + +# --------------------------------------------------------------------------- +# Stdlib-only format checkers so we never depend on optional jsonschema extras +# (rfc3339-validator, fqdn, …). Registered on a module-level instance that +# is reused for every validation call. +# --------------------------------------------------------------------------- +_format_checker = FormatChecker(formats=()) # start with no built-in checks + +_EMAIL_RE = re.compile(r"^[^@\s]+@[^@\s]+\.[^@\s]+$") + + +@_format_checker.checks("date-time", raises=ValueError) +def _check_datetime(value: object) -> bool: + """Validate RFC 3339 / ISO 8601 date-time strings using stdlib only. + + :param value: The value to validate. + :type value: object + :return: True if valid. + :rtype: bool + """ + if not isinstance(value, str): + return True # non-string is not a format error + normalized = value.replace("Z", "+00:00") if value.endswith("Z") else value + datetime.fromisoformat(normalized) + return True + + +@_format_checker.checks("date", raises=ValueError) +def _check_date(value: object) -> bool: + """Validate ISO 8601 date strings (YYYY-MM-DD). + + :param value: The value to validate. + :type value: object + :return: True if valid. + :rtype: bool + """ + if not isinstance(value, str): + return True + datetime.strptime(value, "%Y-%m-%d") + return True + + +@_format_checker.checks("email", raises=ValueError) +def _check_email(value: object) -> bool: + """Basic RFC 5322 email format check (no DNS lookup). + + :param value: The value to validate. + :type value: object + :return: True if valid. + :rtype: bool + """ + if not isinstance(value, str): + return True + if not _EMAIL_RE.match(value): + raise ValueError(f"Invalid email: {value!r}") + return True + +# OpenAPI keywords that are not part of JSON Schema and must be stripped +# before handing a schema to a JSON Schema validator. +_OPENAPI_ONLY_KEYWORDS: frozenset[str] = frozenset( + {"discriminator", "xml", "externalDocs", "example"} +) + + +class _OpenApiValidator: + """Validates request/response bodies against an OpenAPI spec. + + Extracts the request and response JSON schemas from the provided OpenAPI spec dict + and uses ``jsonschema`` to validate bodies at runtime. + + :param spec: An OpenAPI spec dictionary. + :type spec: dict[str, Any] + """ + + def __init__(self, spec: dict[str, Any], path: str = "/invocations") -> None: + self._request_body_required = self._is_request_body_required(spec, path) + raw_request = self._extract_request_schema(spec, path) + raw_response = self._extract_response_schema(spec, path) + self._request_schema = ( + self._preprocess_schema(raw_request, context="request") + if raw_request + else None + ) + self._response_schema = ( + self._preprocess_schema(raw_response, context="response") + if raw_response + else None + ) + + def validate_request(self, body: bytes, content_type: str) -> list[str]: + """Validate a request body against the spec's request schema. + + :param body: Raw request body bytes. + :type body: bytes + :param content_type: The Content-Type header value. + :type content_type: str + :return: List of validation error messages. Empty when valid. + :rtype: list[str] + """ + if self._request_schema is None: + return [] + # If requestBody.required is false, allow empty bodies + if not self._request_body_required and body.strip() == b"": + return [] + return self._validate_body(body, content_type, self._request_schema) + + def validate_response(self, body: bytes, content_type: str) -> list[str]: + """Validate a response body against the spec's response schema. + + :param body: Raw response body bytes. + :type body: bytes + :param content_type: The Content-Type header value. + :type content_type: str + :return: List of validation error messages. Empty when valid. + :rtype: list[str] + """ + if self._response_schema is None: + return [] + return self._validate_body(body, content_type, self._response_schema) + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + @staticmethod + def _validate_body(body: bytes, content_type: str, schema: dict[str, Any]) -> list[str]: + """Parse body as JSON and validate against *schema*. + + Uses discriminator-aware error collection for ``oneOf`` / ``anyOf`` + schemas: when a discriminator property is detected across branches the + validator selects the matching branch and reports only *its* errors, + avoiding the noisy dump of every branch. + + Error messages are prefixed with the JSON-path of the failing element + (e.g. ``$.items[0].type: ...``) so callers can pinpoint the problem. + + :param body: Raw bytes to validate. + :type body: bytes + :param content_type: The Content-Type header value. + :type content_type: str + :param schema: JSON Schema dict to validate against. + :type schema: dict[str, Any] + :return: List of validation error strings. Empty when valid. + :rtype: list[str] + """ + if "json" not in content_type.lower(): + return [] # skip validation for non-JSON payloads + + try: + data = json.loads(body) + except (json.JSONDecodeError, UnicodeDecodeError) as exc: + return [f"Invalid JSON body: {exc}"] + + errors: list[str] = [] + validator = jsonschema.Draft7Validator( + schema, format_checker=_format_checker + ) + for error in validator.iter_errors(data): + errors.extend(_collect_errors(error)) + return errors + + @staticmethod + def _extract_request_schema(spec: dict[str, Any], path: str) -> Optional[dict[str, Any]]: + """Extract the request body JSON schema from the POST operation at *path*. + + :param spec: OpenAPI spec dictionary. + :type spec: dict[str, Any] + :param path: The API path (e.g. ``/invocations``). + :type path: str + :return: JSON Schema dict or None. + :rtype: Optional[dict[str, Any]] + """ + return _OpenApiValidator._find_schema_in_paths( + spec, path, "post", "requestBody" + ) + + @staticmethod + def _extract_response_schema(spec: dict[str, Any], path: str) -> Optional[dict[str, Any]]: + """Extract the response body JSON schema from the POST operation at *path*. + + :param spec: OpenAPI spec dictionary. + :type spec: dict[str, Any] + :param path: The API path (e.g. ``/invocations``). + :type path: str + :return: JSON Schema dict or None. + :rtype: Optional[dict[str, Any]] + """ + return _OpenApiValidator._find_schema_in_paths( + spec, path, "post", "responses" + ) + + @staticmethod + def _find_schema_in_paths( + spec: dict[str, Any], + path: str, + method: str, + section: str, + ) -> Optional[dict[str, Any]]: + """Walk the spec to find a JSON schema for the given path/method/section. + + :param spec: OpenAPI spec dictionary. + :type spec: dict[str, Any] + :param path: The API path (e.g. ``/invocations``). + :type path: str + :param method: HTTP method (e.g. ``post``). + :type method: str + :param section: Either ``requestBody`` or ``responses``. + :type section: str + :return: Resolved JSON Schema dict or None. + :rtype: Optional[dict[str, Any]] + """ + paths = spec.get("paths", {}) + operation = paths.get(path, {}).get(method, {}) + + if section == "requestBody": + request_body = operation.get("requestBody", {}) + content = request_body.get("content", {}) + json_content = content.get("application/json", {}) + schema = json_content.get("schema") + return _resolve_refs_deep(spec, schema) if schema else None + + if section == "responses": + responses = operation.get("responses", {}) + # Try 200, then 201, then first available + for code in ("200", "201"): + resp = responses.get(code, {}) + content = resp.get("content", {}) + json_content = content.get("application/json", {}) + schema = json_content.get("schema") + if schema: + return _resolve_refs_deep(spec, schema) + # Fallback: first response with JSON content + for resp in responses.values(): + if isinstance(resp, dict): + content = resp.get("content", {}) + json_content = content.get("application/json", {}) + schema = json_content.get("schema") + if schema: + return _resolve_refs_deep(spec, schema) + return None + + @staticmethod + def _is_request_body_required(spec: dict[str, Any], path: str) -> bool: + """Check whether ``requestBody.required`` is true for the POST operation. + + :param spec: OpenAPI spec dictionary. + :type spec: dict[str, Any] + :param path: The API path (e.g. ``/invocations``). + :type path: str + :return: True if the request body is explicitly required (default True). + :rtype: bool + """ + paths = spec.get("paths", {}) + operation = paths.get(path, {}).get("post", {}) + request_body = operation.get("requestBody", {}) + return request_body.get("required", True) + + @staticmethod + def _walk_schema(schema: dict[str, Any], visitor: Callable[[dict[str, Any]], None]) -> None: + """Recursively apply *visitor* to every sub-schema in the tree. + + *visitor* is called on each dict-typed sub-schema found in + ``items``, ``additionalProperties``, ``properties``, and + ``allOf``/``oneOf``/``anyOf``. + + :param schema: Root schema dict. + :type schema: dict[str, Any] + :param visitor: Callable applied to each sub-schema dict. + :type visitor: Callable[[dict[str, Any]], None] + """ + for key in ("items", "additionalProperties"): + child = schema.get(key) + if isinstance(child, dict): + visitor(child) + for prop in schema.get("properties", {}).values(): + if isinstance(prop, dict): + visitor(prop) + for keyword in ("allOf", "oneOf", "anyOf"): + for sub in schema.get(keyword, []): + if isinstance(sub, dict): + visitor(sub) + + @staticmethod + def _preprocess_schema( + schema: dict[str, Any], context: str = "request" + ) -> dict[str, Any]: + """Convert OpenAPI-specific keywords into pure JSON Schema. + + Performs a deep copy then applies the following transformations: + 1. ``nullable: true`` → ``type: [originalType, "null"]`` + 2. Strip ``readOnly`` properties in request context, + ``writeOnly`` in response context. + 3. Remove OpenAPI-only keywords (``discriminator``, ``xml``, etc.). + + :param schema: Resolved JSON Schema dict (may contain OpenAPI extensions). + :type schema: dict[str, Any] + :param context: Either ``"request"`` or ``"response"``. + :type context: str + :return: A pure JSON Schema dict safe for ``jsonschema`` validation. + :rtype: dict[str, Any] + """ + schema = copy.deepcopy(schema) + _OpenApiValidator._apply_nullable(schema) + _OpenApiValidator._strip_readonly_writeonly(schema, context) + _OpenApiValidator._strip_openapi_keywords(schema) + return schema + + @staticmethod + def _apply_nullable(schema: dict[str, Any]) -> None: + """Convert ``nullable: true`` to JSON Schema union type in-place. + + Walks the schema tree and transforms + ``{"type": "string", "nullable": true}`` into + ``{"type": ["string", "null"]}``. + + :param schema: Schema dict to mutate in-place. + :type schema: dict[str, Any] + """ + if not isinstance(schema, dict): + return + if schema.pop("nullable", False): + original = schema.get("type") + if isinstance(original, str): + schema["type"] = [original, "null"] + elif isinstance(original, list) and "null" not in original: + schema["type"] = original + ["null"] + _OpenApiValidator._walk_schema(schema, _OpenApiValidator._apply_nullable) + + @staticmethod + def _strip_readonly_writeonly( + schema: dict[str, Any], context: str + ) -> None: + """Remove readOnly/writeOnly properties based on context. + + In **request** context, properties marked ``readOnly: true`` are removed + from ``properties`` and from the ``required`` list (the server generates + them; clients should not send them). + + In **response** context, ``writeOnly: true`` properties are removed + (e.g. passwords that the client sends but the server never echoes back). + + :param schema: Schema dict to mutate in-place. + :type schema: dict[str, Any] + :param context: ``"request"`` or ``"response"``. + :type context: str + """ + if not isinstance(schema, dict): + return + props = schema.get("properties", {}) + required = schema.get("required", []) + to_remove: list[str] = [] + for name, prop_schema in props.items(): + if not isinstance(prop_schema, dict): + continue + if context == "request" and prop_schema.get("readOnly"): + to_remove.append(name) + elif context == "response" and prop_schema.get("writeOnly"): + to_remove.append(name) + for name in to_remove: + props.pop(name, None) + if name in required: + required.remove(name) + + def _recurse(child: dict[str, Any]) -> None: + _OpenApiValidator._strip_readonly_writeonly(child, context) + + _OpenApiValidator._walk_schema(schema, _recurse) + + @staticmethod + def _strip_openapi_keywords(schema: dict[str, Any]) -> None: + """Remove OpenAPI-only keywords that confuse JSON Schema validators. + + Strips ``discriminator``, ``xml``, ``externalDocs``, and ``example`` + from the schema tree in-place. + + :param schema: Schema dict to mutate in-place. + :type schema: dict[str, Any] + """ + if not isinstance(schema, dict): + return + for kw in _OPENAPI_ONLY_KEYWORDS: + schema.pop(kw, None) + _OpenApiValidator._walk_schema(schema, _OpenApiValidator._strip_openapi_keywords) + + +# ------------------------------------------------------------------ +# Discriminator-aware error collection helpers +# ------------------------------------------------------------------ + + +def _format_error(error: ValidationError) -> str: + """Format a single validation error with its JSON path prefix. + + :param error: A ``jsonschema`` validation error. + :type error: ValidationError + :return: Human-readable error string, optionally path-prefixed. + :rtype: str + """ + path = error.json_path + if path and path != "$": + return f"{path}: {error.message}" + return error.message + + +def _collect_errors(error: ValidationError) -> list[str]: + """Collect formatted error messages from a ``ValidationError``. + + For ``oneOf`` / ``anyOf`` errors the helper attempts discriminator-aware + branch selection (mirroring the server-side C# ``JsonSchemaValidator``). + When a discriminator property is detected, only the *matching* branch's + errors are reported, avoiding a noisy dump of every branch. + + :param error: A ``jsonschema`` validation error. + :type error: ValidationError + :return: List of formatted error strings. + :rtype: list[str] + """ + if error.validator in ("oneOf", "anyOf") and error.context: + return _collect_composition_errors(error) + return [_format_error(error)] + + +def _collect_composition_errors(error: ValidationError) -> list[str]: + """Handle ``oneOf`` / ``anyOf`` errors with discriminator-based branch selection. + + Algorithm (ported from the C# ``JsonSchemaValidator``): + + 1. Group errors by branch index (``schema_path[0]``). + 2. Detect a *discriminator path*: a ``const`` / ``type`` / ``enum`` + error that appears at the same ``absolute_path`` in the majority of + branches (threshold: ``max(2, ceil(n/2))``). + 3. The *matching branch* is the one **without** a discriminator error + at that path. + 4. Report only the matching branch's errors. If no branch matches, + report a concise ``"Invalid value. Expected one of: ..."`` message. + + Falls back to ``best_match`` when no discriminator can be identified. + + :param error: A ``oneOf`` / ``anyOf`` validation error. + :type error: ValidationError + :return: List of formatted error strings. + :rtype: list[str] + """ + # Group sub-errors by branch index (first element of schema_path) + branch_groups: dict[int, list[ValidationError]] = {} + for sub in error.context or []: + if sub.schema_path: + idx = sub.schema_path[0] + if isinstance(idx, int): + branch_groups.setdefault(idx, []).append(sub) + + if len(branch_groups) < 2: + # Cannot do branch analysis — fallback + best = best_match([error]) + if best is not None and best is not error: + return _collect_errors(best) + return [_format_error(error)] + + disc_path = _find_discriminator_path(branch_groups) + + if disc_path is None: + best = best_match([error]) + if best is not None and best is not error: + return _collect_errors(best) + return [_format_error(error)] + + # Find the branch that matches the discriminator value + matching_idx = _find_matching_branch(branch_groups, disc_path) + + if matching_idx is not None: + result: list[str] = [] + for sub in branch_groups[matching_idx]: + result.extend(_collect_errors(sub)) + return result if result else [_format_error(error)] + + # No matching branch — report the discriminator mismatch + return _report_discriminator_error(branch_groups, disc_path, error) + + +_DISCRIMINATOR_VALIDATORS: frozenset[str] = frozenset({"const", "type", "enum"}) + + +def _find_discriminator_path( + branch_groups: dict[int, list[ValidationError]], +) -> Optional[tuple[str | int, ...]]: + """Detect a discriminator property across ``oneOf`` / ``anyOf`` branches. + + A discriminator is a property whose ``const``, ``enum``, or ``type`` + constraint fails at the same ``absolute_path`` in the majority of + branches. + + :param branch_groups: Errors grouped by branch index. + :type branch_groups: dict[int, list[ValidationError]] + :return: The ``absolute_path`` of the discriminator as a tuple, or *None*. + :rtype: Optional[tuple[str | int, ...]] + """ + n_branches = len(branch_groups) + if n_branches < 2: + return None + + # Collect discriminator-error paths per branch + per_branch_paths: list[set[tuple[str | int, ...]]] = [] + for errors in branch_groups.values(): + paths: set[tuple[str | int, ...]] = set() + for err in errors: + if err.validator in _DISCRIMINATOR_VALIDATORS: + paths.add(tuple(err.absolute_path)) + per_branch_paths.append(paths) + + path_counts: Counter[tuple[str | int, ...]] = Counter() + for paths in per_branch_paths: + for p in paths: + path_counts[p] += 1 + + min_threshold = (n_branches + 1) // 2 # ceil(n/2) — at least half the branches + for path, count in path_counts.most_common(): + if count >= min_threshold: + return path + return None + + +def _find_matching_branch( + branch_groups: dict[int, list[ValidationError]], + disc_path: tuple[str | int, ...], +) -> Optional[int]: + """Return the branch index that has **no** discriminator error at *disc_path*. + + :param branch_groups: Errors grouped by branch index. + :type branch_groups: dict[int, list[ValidationError]] + :param disc_path: The discriminator property path. + :type disc_path: tuple[str | int, ...] + :return: The matching branch index or *None*. + :rtype: Optional[int] + """ + for idx, errors in branch_groups.items(): + has_disc_error = any( + err.validator in _DISCRIMINATOR_VALIDATORS + and tuple(err.absolute_path) == disc_path + for err in errors + ) + if not has_disc_error: + return idx + return None + + +def _report_discriminator_error( + branch_groups: dict[int, list[ValidationError]], + disc_path: tuple[str | int, ...], + parent: ValidationError, +) -> list[str]: + """Produce a concise discriminator-mismatch message. + + Collects all expected ``const`` / ``enum`` values from the branches and + reports them as ``"Invalid value. Expected one of: X, Y, Z"``. + + :param branch_groups: Errors grouped by branch index. + :type branch_groups: dict[int, list[ValidationError]] + :param disc_path: The discriminator property path. + :type disc_path: tuple[str | int, ...] + :param parent: The parent ``oneOf`` / ``anyOf`` error. + :type parent: ValidationError + :return: List containing a single formatted error string. + :rtype: list[str] + """ + # Build the JSON-path string for the discriminator property + path_str = "$" + for segment in disc_path: + if isinstance(segment, int): + path_str += f"[{segment}]" + else: + path_str += f".{segment}" + + # Check for type errors first (more fundamental than const/enum) + for errors in branch_groups.values(): + for err in errors: + if err.validator == "type" and tuple(err.absolute_path) == disc_path: + return [f"{path_str}: {err.message}"] + + # Collect expected values from const and enum errors + expected: list[str] = [] + for errors in branch_groups.values(): + for err in errors: + if tuple(err.absolute_path) != disc_path: + continue + if err.validator == "const": + val = err.schema.get("const") if isinstance(err.schema, dict) else None + if val is not None: + formatted = json.dumps(val) if not isinstance(val, str) else f'"{val}"' + if formatted not in expected: + expected.append(formatted) + elif err.validator == "enum": + enum_vals = err.schema.get("enum", []) if isinstance(err.schema, dict) else [] + for val in enum_vals: + formatted = json.dumps(val) if not isinstance(val, str) else f'"{val}"' + if formatted not in expected: + expected.append(formatted) + + if expected: + if len(expected) == 1: + return [f"{path_str}: Invalid value. Expected: {expected[0]}"] + return [f"{path_str}: Invalid value. Expected one of: {', '.join(expected)}"] + + # Fallback + return [_format_error(parent)] + + +def _resolve_ref(spec: dict[str, Any], schema: dict[str, Any]) -> dict[str, Any]: + """Resolve a single ``$ref`` pointer within the spec. + + :param spec: The full OpenAPI spec dictionary. + :type spec: dict[str, Any] + :param schema: A schema dict that may contain a ``$ref`` key. + :type schema: dict[str, Any] + :return: The resolved schema. + :rtype: dict[str, Any] + """ + if "$ref" not in schema: + return schema + ref_path = schema["$ref"] # e.g. "#/components/schemas/MyModel" + parts = ref_path.lstrip("#/").split("/") + current: Any = spec + for part in parts: + if isinstance(current, dict): + current = current.get(part) + else: + return schema # can't resolve, return as-is + return current if isinstance(current, dict) else schema + + +def _resolve_refs_deep( # pylint: disable=too-many-return-statements + spec: dict[str, Any], node: Any, _seen: Optional[set[str]] = None +) -> Any: + """Recursively resolve all ``$ref`` pointers in a schema tree. + + Walks the schema, replacing every ``{"$ref": "..."}`` with the referenced + definition inlined from *spec*. A *_seen* set guards against infinite + recursion from circular references. + + :param spec: The full OpenAPI spec dictionary. + :type spec: dict[str, Any] + :param node: The current schema node (dict, list, or scalar). + :type node: Any + :param _seen: Set of already-visited ``$ref`` paths (cycle guard). + :type _seen: Optional[set[str]] + :return: The schema tree with all ``$ref`` pointers inlined. + :rtype: Any + """ + if _seen is None: + _seen = set() + + if isinstance(node, dict): + if "$ref" in node: + ref_path = node["$ref"] + if ref_path in _seen: + return node # circular – leave the $ref as-is + _seen = _seen | {ref_path} + resolved = _resolve_ref(spec, node) + if resolved is node: + return node # couldn't resolve + # Preserve sibling keywords (e.g. nullable: true alongside $ref) + siblings = {k: v for k, v in node.items() if k != "$ref"} + resolved = _resolve_refs_deep(spec, resolved, _seen) + if siblings and isinstance(resolved, dict): + merged = dict(resolved) + merged.update(siblings) + return merged + return resolved + return {k: _resolve_refs_deep(spec, v, _seen) for k, v in node.items()} + + if isinstance(node, list): + return [_resolve_refs_deep(spec, item, _seen) for item in node] + + return node diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_server_context.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_server_context.py new file mode 100644 index 000000000000..b3d6651ab293 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_server_context.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Internal server context shared across protocol implementations.""" +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from ._tracing import _TracingHelper + + +@dataclasses.dataclass(frozen=True) +class _ServerContext: + """Shared server state passed to protocol implementations. + + Internal — not part of the public API. Each protocol receives this + at construction time so it can access tracing, error handling, and + timeout configuration without coupling to the ``AgentServer`` class. + """ + + tracing: Optional[_TracingHelper] + debug_errors: bool + request_timeout: int diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py new file mode 100644 index 000000000000..2f2c2592f127 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_tracing.py @@ -0,0 +1,674 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Optional OpenTelemetry tracing for AgentServer. + +Tracing is **disabled by default**. Enable it in one of two ways: + +1. Set the environment variable ``AGENT_ENABLE_TRACING=true``. +2. Pass ``enable_tracing=True`` to the :class:`AgentServer` constructor + (constructor argument takes precedence over the env var). + +When enabled, the module requires ``opentelemetry-api`` to be installed:: + + pip install azure-ai-agentserver-server[tracing] + +If the package is not installed, tracing silently becomes a no-op. + +When an Application Insights connection string is available (via constructor +or ``APPLICATIONINSIGHTS_CONNECTION_STRING`` env var), traces **and** logs are +automatically exported to Azure Monitor. This requires the additional +``opentelemetry-sdk`` and ``azure-monitor-opentelemetry-exporter`` packages +(both included in the ``[tracing]`` extras group). +""" +from __future__ import annotations + +import logging +from contextlib import contextmanager +from collections.abc import AsyncIterable, AsyncIterator, Mapping # pylint: disable=import-error +from typing import TYPE_CHECKING, Any, Iterator, Optional, Union + +from . import _config +from ._logger import get_logger + +#: Starlette's ``Content`` type — the element type for streaming bodies. +_Content = Union[str, bytes, memoryview] + +#: W3C Trace Context header names used for distributed trace propagation. +_W3C_HEADERS = ("traceparent", "tracestate") + +#: Baggage key whose value overrides the parent span ID. +_LEAF_CUSTOMER_SPAN_ID = "leaf_customer_span_id" + +# ------------------------------------------------------------------ +# GenAI semantic convention attribute keys +# ------------------------------------------------------------------ +_ATTR_INVOCATION_ID = "invocation.id" +_ATTR_RESPONSE_ID = "gen_ai.response.id" +_ATTR_PROVIDER_NAME = "gen_ai.provider.name" +_ATTR_AGENT_ID = "gen_ai.agent.id" +_ATTR_PROJECT_ID = "microsoft.foundry.project.id" +_ATTR_OPERATION_NAME = "gen_ai.operation.name" +_ATTR_CONVERSATION_ID = "gen_ai.conversation.id" + +_PROVIDER_NAME_VALUE = "microsoft.foundry" + +logger = get_logger() + +_HAS_OTEL = False +try: + from opentelemetry import trace + from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + + _HAS_OTEL = True +except ImportError: + if TYPE_CHECKING: + from opentelemetry import trace + from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + + +class _TracingHelper: + """Lightweight wrapper around OpenTelemetry. + + Only instantiate when tracing is enabled. If ``opentelemetry-api`` is + not installed, a warning is logged and all methods become no-ops. + + When *connection_string* is provided, a :class:`TracerProvider` with an + Azure Monitor exporter is configured globally and log records from the + ``azure.ai.agentserver`` logger are forwarded to Application Insights. + This requires ``opentelemetry-sdk`` and + ``azure-monitor-opentelemetry-exporter``. + """ + + def __init__( + self, + connection_string: Optional[str] = None, + ) -> None: + self._enabled = _HAS_OTEL + self._tracer: Any = None + self._propagator: Any = None + + # Resolve agent identity from environment variables. + agent_name = _config.resolve_agent_name() + agent_version = _config.resolve_agent_version() + self._agent_label = ( + f"{agent_name}:{agent_version}" if agent_name and agent_version else agent_name + ) + self._project_id = _config.resolve_project_id() + + if not self._enabled: + logger.warning( + "Tracing was enabled but opentelemetry-api is not installed. " + "Install it with: pip install azure-ai-agentserver-server[tracing]" + ) + return + + if connection_string: + self._setup_azure_monitor(connection_string) + + self._tracer = trace.get_tracer("azure.ai.agentserver") + self._propagator = TraceContextTextMapPropagator() + + # ------------------------------------------------------------------ + # Azure Monitor auto-configuration + # ------------------------------------------------------------------ + + def _extract_context( + self, + carrier: Optional[dict[str, str]], + baggage_header: Optional[str] = None, + ) -> Any: + """Extract parent trace context from a W3C carrier dict. + + When a ``baggage`` header is provided and contains a + ``leaf_customer_span_id`` key, the parent span ID is overridden + so that the server's root span is parented under the leaf customer + span rather than the span referenced in the ``traceparent`` header. + + :param carrier: W3C trace-context headers or None. + :type carrier: Optional[dict[str, str]] + :param baggage_header: Raw ``baggage`` header value or None. + :type baggage_header: Optional[str] + :return: The extracted OTel context, or None. + :rtype: Any + """ + if not carrier or self._propagator is None: + return None + + ctx = self._propagator.extract(carrier=carrier) + + if not baggage_header: + return ctx + + leaf_span_id = _parse_baggage_key(baggage_header, _LEAF_CUSTOMER_SPAN_ID) + if not leaf_span_id: + return ctx + + return _override_parent_span_id(ctx, leaf_span_id) + + @staticmethod + def _setup_azure_monitor(connection_string: str) -> None: + """Configure global TracerProvider and LoggerProvider for App Insights. + + Sets up ``AzureMonitorTraceExporter`` so spans are exported to + Application Insights, and ``AzureMonitorLogExporter`` so log records + from the ``azure.ai.agentserver`` namespace are forwarded. + + If the required packages are not installed, a warning is logged and + export is silently skipped — span creation still works via the + default no-op or user-configured provider. + + :param connection_string: Application Insights connection string. + :type connection_string: str + """ + resource = _create_resource() + if resource is None: + return + _setup_trace_export(resource, connection_string) + _setup_log_export(resource, connection_string) + + # ------------------------------------------------------------------ + # Span naming and attribute helpers (shared by all protocols) + # ------------------------------------------------------------------ + + def span_name(self, span_operation: str) -> str: + """Build a span name using the operation and agent label. + + :param span_operation: The span operation (e.g. ``"execute_agent"``). + This becomes the first token of the OTel span name. + :type span_operation: str + :return: ``" :"`` or just + ``""``. + :rtype: str + """ + if self._agent_label: + return f"{span_operation} {self._agent_label}" + return span_operation + + def build_span_attrs( + self, + invocation_id: str, + session_id: str, + operation_name: Optional[str] = None, + ) -> dict[str, str]: + """Build GenAI semantic convention span attributes. + + These attributes are common across all protocol heads (invocation, + chat, etc.). + + :param invocation_id: The invocation/request ID for this request. + :type invocation_id: str + :param session_id: The session ID header value (empty string if absent). + :type session_id: str + :param operation_name: Optional ``gen_ai.operation.name`` value + (e.g. ``"invoke_agent"``). Omitted from the dict when *None*. + :type operation_name: Optional[str] + :return: Span attribute dict. + :rtype: dict[str, str] + """ + attrs: dict[str, str] = { + _ATTR_INVOCATION_ID: invocation_id, + _ATTR_RESPONSE_ID: invocation_id, + _ATTR_PROVIDER_NAME: _PROVIDER_NAME_VALUE, + } + if self._agent_label: + attrs[_ATTR_AGENT_ID] = self._agent_label + if self._project_id: + attrs[_ATTR_PROJECT_ID] = self._project_id + if operation_name: + attrs[_ATTR_OPERATION_NAME] = operation_name + if session_id: + attrs[_ATTR_CONVERSATION_ID] = session_id + return attrs + + @contextmanager + def span( + self, + name: str, + attributes: Optional[dict[str, str]] = None, + carrier: Optional[dict[str, str]] = None, + baggage_header: Optional[str] = None, + ) -> Iterator[Any]: + """Create a traced span if tracing is enabled, otherwise no-op. + + Yields the OpenTelemetry span object when tracing is active, or + ``None`` when tracing is disabled. Callers may use the yielded span + together with :meth:`record_error` to attach error information. + + :param name: Span name, e.g. ``"execute_agent my_agent:1.0"``. + :type name: str + :param attributes: Key-value span attributes. + :type attributes: Optional[dict[str, str]] + :param carrier: Incoming HTTP headers for W3C trace-context propagation. + :type carrier: Optional[dict[str, str]] + :param baggage_header: Raw ``baggage`` header value for + ``leaf_customer_span_id`` extraction. + :type baggage_header: Optional[str] + :return: Context manager that yields the OTel span or *None*. + :rtype: Iterator[Any] + """ + if not self._enabled or self._tracer is None: + yield None + return + + ctx = self._extract_context(carrier, baggage_header) + + with self._tracer.start_as_current_span( + name=name, + attributes=attributes or {}, + kind=trace.SpanKind.SERVER, + context=ctx, + ) as otel_span: + yield otel_span + + def start_span( + self, + name: str, + attributes: Optional[dict[str, str]] = None, + carrier: Optional[dict[str, str]] = None, + baggage_header: Optional[str] = None, + ) -> Any: + """Start a span without a context manager. + + Use this for streaming responses where the span must outlive the + initial ``invoke()`` call. The caller **must** call :meth:`end_span` + when the work is finished. + + :param name: Span name, e.g. ``"execute_agent my_agent:1.0"``. + :type name: str + :param attributes: Key-value span attributes. + :type attributes: Optional[dict[str, str]] + :param carrier: Incoming HTTP headers for W3C trace-context propagation. + :type carrier: Optional[dict[str, str]] + :param baggage_header: Raw ``baggage`` header value for + ``leaf_customer_span_id`` extraction. + :type baggage_header: Optional[str] + :return: The OTel span, or *None* when tracing is disabled. + :rtype: Any + """ + if not self._enabled or self._tracer is None: + return None + + ctx = self._extract_context(carrier, baggage_header) + + return self._tracer.start_span( + name=name, + attributes=attributes or {}, + kind=trace.SpanKind.SERVER, + context=ctx, + ) + + # ------------------------------------------------------------------ + # Request-level convenience wrappers + # ------------------------------------------------------------------ + + def _prepare_request_span_args( + self, + headers: Mapping[str, str], + invocation_id: str, + span_operation: str, + operation_name: Optional[str] = None, + session_id: str = "", + ) -> tuple[str, dict[str, str], dict[str, str], Optional[str]]: + """Extract headers and build span arguments for a request. + + Shared pipeline used by :meth:`start_request_span` and + :meth:`request_span` to avoid duplicating header extraction, + attribute building, and span naming. + + :param headers: HTTP request headers (any ``Mapping[str, str]``). + :type headers: Mapping[str, str] + :param invocation_id: The invocation/request ID. + :type invocation_id: str + :param span_operation: Span operation (e.g. ``"execute_agent"``). + :type span_operation: str + :param operation_name: Optional ``gen_ai.operation.name`` value. + :type operation_name: Optional[str] + :param session_id: Session ID from the ``agent_session_id`` query + parameter. Defaults to ``""`` (no session). + :type session_id: str + :return: ``(name, attributes, carrier, baggage)`` ready for + :meth:`span` or :meth:`start_span`. + :rtype: tuple[str, dict[str, str], dict[str, str], Optional[str]] + """ + carrier = _extract_w3c_carrier(headers) + baggage = headers.get("baggage") + span_attrs = self.build_span_attrs( + invocation_id, session_id, operation_name=operation_name + ) + return self.span_name(span_operation), span_attrs, carrier, baggage + + def start_request_span( + self, + headers: Mapping[str, str], + invocation_id: str, + span_operation: str, + operation_name: Optional[str] = None, + session_id: str = "", + ) -> Any: + """Start a request-scoped span, extracting context from HTTP headers. + + Convenience method that combines header extraction, attribute + building, span naming, and span creation into a single call. + Use for streaming responses where the span must outlive the + initial handler call. The caller **must** call :meth:`end_span` + when work is finished. + + :param headers: HTTP request headers (any ``Mapping[str, str]``). + :type headers: Mapping[str, str] + :param invocation_id: The invocation/request ID. + :type invocation_id: str + :param span_operation: Span operation (e.g. ``"execute_agent"``). + Becomes the first token of the OTel span name via + :meth:`span_name`. + :type span_operation: str + :param operation_name: Optional ``gen_ai.operation.name`` attribute + value (e.g. ``"invoke_agent"``). Omitted when *None*. + :type operation_name: Optional[str] + :param session_id: Session ID from the ``agent_session_id`` query + parameter. Defaults to ``""`` (no session). + :type session_id: str + :return: The OTel span, or *None* when tracing is disabled. + :rtype: Any + """ + name, attrs, carrier, baggage = self._prepare_request_span_args( + headers, invocation_id, span_operation, operation_name, + session_id=session_id, + ) + return self.start_span(name, attributes=attrs, carrier=carrier, baggage_header=baggage) + + @contextmanager + def request_span( + self, + headers: Mapping[str, str], + invocation_id: str, + span_operation: str, + operation_name: Optional[str] = None, + session_id: str = "", + ) -> Iterator[Any]: + """Create a request-scoped span as a context manager. + + Convenience method that combines header extraction, attribute + building, span naming, and span creation into a single call. + Use for non-streaming request handlers where the span should + cover the entire handler execution. + + :param headers: HTTP request headers (any ``Mapping[str, str]``). + :type headers: Mapping[str, str] + :param invocation_id: The invocation/request ID. + :type invocation_id: str + :param span_operation: Span operation (e.g. ``"get_invocation"``). + Becomes the first token of the OTel span name via + :meth:`span_name`. + :type span_operation: str + :param operation_name: Optional ``gen_ai.operation.name`` attribute + value. Omitted when *None*. + :type operation_name: Optional[str] + :param session_id: Session ID from the ``agent_session_id`` query + parameter. Defaults to ``""`` (no session). + :type session_id: str + :return: Context manager that yields the OTel span or *None*. + :rtype: Iterator[Any] + """ + name, attrs, carrier, baggage = self._prepare_request_span_args( + headers, invocation_id, span_operation, operation_name, + session_id=session_id, + ) + otel_span = self.start_span(name, attributes=attrs, carrier=carrier, baggage_header=baggage) + try: + yield otel_span + except Exception as exc: # pylint: disable=broad-exception-caught + self.end_span(otel_span, exc=exc) + raise + self.end_span(otel_span) + + # ------------------------------------------------------------------ + # Span lifecycle helpers + # ------------------------------------------------------------------ + + def end_span(self, span: Any, exc: Optional[Exception] = None) -> None: + """End a span started with :meth:`start_span`. + + Optionally records an error before ending. No-op when *span* is + ``None`` (tracing disabled). + + :param span: The OTel span, or *None*. + :type span: Any + :param exc: Optional exception to record before ending. + :type exc: Optional[Exception] + """ + if span is None: + return + if exc is not None: + self.record_error(span, exc) + span.end() + + @staticmethod + def record_error(span: Any, exc: Exception) -> None: + """Record an exception and ERROR status on a span. + + No-op when *span* is ``None`` (tracing disabled) or when + ``opentelemetry-api`` is not installed. + + :param span: The OTel span returned by :meth:`span`, or *None*. + :type span: Any + :param exc: The exception to record. + :type exc: Exception + """ + if span is not None and _HAS_OTEL: + span.set_status(trace.StatusCode.ERROR, str(exc)) + span.record_exception(exc) + + async def trace_stream( + self, iterator: AsyncIterable[_Content], span: Any + ) -> AsyncIterator[_Content]: + """Wrap a streaming body iterator so the tracing span covers the full + duration of data transmission. + + Yields chunks from *iterator* unchanged. When the iterator is + exhausted or raises an exception the span is ended (with error status + if applicable). Safe to call when tracing is disabled (*span* is + ``None``). + + :param iterator: The original async body iterator from + :class:`~starlette.responses.StreamingResponse`. + :type iterator: AsyncIterable[Union[str, bytes, memoryview]] + :param span: The OTel span (or *None* when tracing is disabled). + :type span: Any + :return: An async iterator that yields chunks unchanged. + :rtype: AsyncIterator[Union[str, bytes, memoryview]] + """ + error: Optional[Exception] = None + try: + async for chunk in iterator: + yield chunk + except Exception as exc: + error = exc + raise + finally: + self.end_span(span, exc=error) + + +def _create_resource() -> Any: + """Create the OTel resource for Azure Monitor exporters. + + :return: A :class:`~opentelemetry.sdk.resources.Resource`, or *None* + if the required packages are not installed. + :rtype: Any + """ + try: + from opentelemetry.sdk.resources import Resource + except ImportError: + logger.warning( + "Application Insights connection string was provided but " + "required packages are not installed. Install them with: " + "pip install azure-ai-agentserver-server[tracing]" + ) + return None + return Resource.create({"service.name": "azure.ai.agentserver"}) + + +def _setup_trace_export(resource: Any, connection_string: str) -> None: + """Configure a global :class:`TracerProvider` that exports to App Insights. + + :param resource: The OTel resource describing this service. + :type resource: Any + :param connection_string: Application Insights connection string. + :type connection_string: str + """ + try: + from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + + from azure.monitor.opentelemetry.exporter import ( # type: ignore[import-untyped] + AzureMonitorTraceExporter, + ) + except ImportError: + logger.warning( + "Trace export to Application Insights requires " + "opentelemetry-sdk and azure-monitor-opentelemetry-exporter. " + "Traces will not be forwarded." + ) + return + + provider = SdkTracerProvider(resource=resource) + exporter = AzureMonitorTraceExporter(connection_string=connection_string) + provider.add_span_processor(BatchSpanProcessor(exporter)) + trace.set_tracer_provider(provider) + logger.info("Application Insights trace exporter configured.") + + +def _setup_log_export(resource: Any, connection_string: str) -> None: + """Configure a global :class:`LoggerProvider` that exports to App Insights. + + :param resource: The OTel resource describing this service. + :type resource: Any + :param connection_string: Application Insights connection string. + :type connection_string: str + """ + try: + from opentelemetry._logs import set_logger_provider + from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler + from opentelemetry.sdk._logs.export import BatchLogRecordProcessor + + from azure.monitor.opentelemetry.exporter import ( # type: ignore[import-untyped] + AzureMonitorLogExporter, + ) + except ImportError: + logger.warning( + "Log export to Application Insights requires " + "opentelemetry-sdk. Logs will not be forwarded." + ) + return + + log_provider = LoggerProvider(resource=resource) + set_logger_provider(log_provider) + log_exporter = AzureMonitorLogExporter(connection_string=connection_string) + log_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter)) + handler = LoggingHandler(logger_provider=log_provider) + logging.getLogger("azure.ai.agentserver").addHandler(handler) + logger.info("Application Insights log exporter configured.") + + +def _extract_w3c_carrier(headers: Mapping[str, str]) -> dict[str, str]: + """Extract W3C trace-context headers from a mapping. + + Filters the input to only ``traceparent`` and ``tracestate`` — the two + headers defined by the `W3C Trace Context`_ standard. This avoids + passing unrelated headers (e.g. ``authorization``, ``cookie``) into the + OpenTelemetry propagator. + + .. _W3C Trace Context: https://www.w3.org/TR/trace-context/ + + :param headers: A mapping of header name to value (e.g. + ``request.headers``). + :type headers: Mapping[str, str] + :return: A dict containing only the W3C propagation headers present + in *headers*. + :rtype: dict[str, str] + """ + result: dict[str, str] = {k: v for k in _W3C_HEADERS if (v := headers.get(k)) is not None} + return result + + +def _parse_baggage_key(baggage: str, key: str) -> Optional[str]: + """Parse a single key from a W3C Baggage header value. + + The `W3C Baggage`_ format is a comma-separated list of + ``key=value`` pairs with optional properties after a ``;``. + + Example:: + + leaf_customer_span_id=abc123,other=val + + .. _W3C Baggage: https://www.w3.org/TR/baggage/ + + :param baggage: The raw header value. + :type baggage: str + :param key: The baggage key to look up. + :type key: str + :return: The value for *key*, or *None* if not found. + :rtype: Optional[str] + """ + for member in baggage.split(","): + member = member.strip() + if not member: + continue + # Split on first '=' only; value may contain '=' + kv_part = member.split(";", 1)[0] # strip optional properties + eq_idx = kv_part.find("=") + if eq_idx < 0: + continue + k = kv_part[:eq_idx].strip() + v = kv_part[eq_idx + 1:].strip() + if k == key: + return v + return None + + +def _override_parent_span_id(ctx: Any, hex_span_id: str) -> Any: + """Create a new context with the same trace ID but a different parent span ID. + + Constructs a :class:`~opentelemetry.trace.SpanContext` with the trace ID + taken from the existing context and the span ID replaced by + *hex_span_id*. The resulting context can be used as the ``context`` + argument to ``start_span`` / ``start_as_current_span``. + + Returns the original *ctx* unchanged if *hex_span_id* is invalid or + ``opentelemetry-api`` is not installed. + + :param ctx: An OTel context produced by ``TraceContextTextMapPropagator.extract()``. + :type ctx: Any + :param hex_span_id: 16-character lower-case hex string representing the + desired parent span ID. + :type hex_span_id: str + :return: A context with the overridden parent span ID, or the original. + :rtype: Any + """ + if not _HAS_OTEL: + return ctx + + try: + new_span_id = int(hex_span_id, 16) + except (ValueError, TypeError): + logger.warning("Invalid leaf_customer_span_id in baggage: %r", hex_span_id) + return ctx + + if new_span_id == 0: + return ctx + + # Grab the trace ID from the current parent span in ctx. + current_span = trace.get_current_span(ctx) + current_ctx = current_span.get_span_context() + if current_ctx is None or not current_ctx.is_valid: + return ctx + + custom_span_ctx = trace.SpanContext( + trace_id=current_ctx.trace_id, + span_id=new_span_id, + is_remote=True, + trace_flags=current_ctx.trace_flags, + trace_state=current_ctx.trace_state, + ) + custom_parent = trace.NonRecordingSpan(custom_span_ctx) + return trace.set_span_in_context(custom_parent, ctx) diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_version.py b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_version.py new file mode 100644 index 000000000000..67d209a8cafd --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/_version.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +VERSION = "1.0.0b1" diff --git a/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/py.typed b/sdk/agentserver/azure-ai-agentserver-server/azure/ai/agentserver/server/py.typed new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sdk/agentserver/azure-ai-agentserver-server/cspell.json b/sdk/agentserver/azure-ai-agentserver-server/cspell.json new file mode 100644 index 000000000000..4cf0dce8914d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/cspell.json @@ -0,0 +1,38 @@ +{ + "ignoreWords": [ + "agentframework", + "agentserver", + "appinsights", + "ASGI", + "azureai", + "ainvoke", + "behaviour", + "caplog", + "delenv", + "genai", + "hypercorn", + "invocations", + "langgraph", + "msgpack", + "openapi", + "paramtype", + "requestschema", + "rtype", + "serialisation", + "sess", + "Specialised", + "Standardised", + "starlette", + "traceparent", + "tracestate", + "tracecontext", + "varint", + "writeonly" + ], + "ignorePaths": [ + "*.csv", + "*.json", + "*.rst", + "samples/**" + ] +} \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-server/dev_requirements.txt b/sdk/agentserver/azure-ai-agentserver-server/dev_requirements.txt new file mode 100644 index 000000000000..a4d2cb770dbc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/dev_requirements.txt @@ -0,0 +1,9 @@ +-e ../../../eng/tools/azure-sdk-tools +pytest +httpx +httpx[http2] +pytest-asyncio +opentelemetry-api>=1.20.0 +opentelemetry-sdk>=1.20.0 +azure-monitor-opentelemetry-exporter>=1.0.0b21 +cryptography diff --git a/sdk/agentserver/azure-ai-agentserver-server/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-server/pyproject.toml new file mode 100644 index 000000000000..663bd57fdb10 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/pyproject.toml @@ -0,0 +1,76 @@ +[project] +name = "azure-ai-agentserver-server" +dynamic = ["version", "readme"] +description = "Generic agent server for Azure AI with pluggable protocol heads" +requires-python = ">=3.10" +authors = [ + { name = "Microsoft Corporation", email = "azpysdkhelp@microsoft.com" }, +] +license = "MIT" +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +keywords = ["azure", "azure sdk", "agent", "agentserver"] + +dependencies = [ + "azure-core>=1.35.0", + "starlette>=0.45.0", + "hypercorn>=0.17.0", + "jsonschema>=4.0.0", +] + +[project.optional-dependencies] +tracing = [ + "opentelemetry-api>=1.20.0", + "opentelemetry-sdk>=1.20.0", + "azure-monitor-opentelemetry-exporter>=1.0.0b21", +] + +[build-system] +requires = ["setuptools>=69", "wheel"] +build-backend = "setuptools.build_meta" + +[project.urls] +repository = "https://github.com/Azure/azure-sdk-for-python" + +[tool.setuptools.packages.find] +exclude = [ + "tests*", + "samples*", + "doc*", + "azure", + "azure.ai", +] + +[tool.setuptools.dynamic] +version = { attr = "azure.ai.agentserver.server._version.VERSION" } +readme = { file = ["README.md"], content-type = "text/markdown" } + +[tool.setuptools.package-data] +pytyped = ["py.typed"] + +[tool.ruff] +line-length = 120 +target-version = "py310" +lint.select = ["E", "F", "B", "I"] +lint.ignore = [] +fix = false + +[tool.ruff.lint.isort] +known-first-party = ["azure.ai.agentserver.server"] +combine-as-imports = true + +[tool.azure-sdk-build] +breaking = false +mypy = true +pyright = true +verifytypes = true +pylint = true +type_check_samples = false # samples require env-specific deps (langchain, agent-framework, etc.) diff --git a/sdk/agentserver/azure-ai-agentserver-server/pyrightconfig.json b/sdk/agentserver/azure-ai-agentserver-server/pyrightconfig.json new file mode 100644 index 000000000000..5f81af3c9da7 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/pyrightconfig.json @@ -0,0 +1,11 @@ +{ + "reportOptionalMemberAccess": "warning", + "reportArgumentType": "warning", + "reportAttributeAccessIssue": "warning", + "reportMissingImports": "warning", + "reportGeneralTypeIssues": "warning", + "reportReturnType": "warning", + "exclude": [ + "**/samples/**" + ] +} \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/.env.sample b/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/.env.sample new file mode 100644 index 000000000000..66bb7033c34e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/.env.sample @@ -0,0 +1,5 @@ +# Azure OpenAI credentials (used by AsyncAzureOpenAI) +AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com/ +AZURE_OPENAI_API_KEY=your-api-key-here +AZURE_OPENAI_API_VERSION=2024-12-01-preview +AZURE_OPENAI_MODEL=gpt-4o diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/activity_weather_agent.py b/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/activity_weather_agent.py new file mode 100644 index 000000000000..06227e8c0eb9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/activity_weather_agent.py @@ -0,0 +1,242 @@ +"""Weather agent built with the Activity protocol. + +Uses ``microsoft_agents.activity`` types (Activity, ChannelAccount, etc.) +for activity handling and the OpenAI Agents SDK for LLM orchestration. + +This module has **no dependency** on ``azure-ai-agentserver-server`` — it +is a standalone agent that can be hosted by any server that speaks the +Activity protocol. See ``server.py`` for the AgentServer hosting layer. + +The agent mirrors the +`weather-agent-open-ai `_ +reference sample from Agents-for-python, using the same ``ActivityHandler`` +dispatch pattern (``on_turn`` → ``on_message_activity`` / +``on_members_added_activity``). +""" + +import logging +import os +import random +from datetime import datetime +from typing import Optional + +from dotenv import load_dotenv + +load_dotenv() + +from openai import AsyncAzureOpenAI +from agents import ( + Agent, + Model, + ModelProvider, + OpenAIChatCompletionsModel, + RunConfig, + Runner, + function_tool, +) + +from microsoft_agents.activity import Activity, ActivityTypes + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Tools — registered with the OpenAI Agents SDK +# --------------------------------------------------------------------------- + + +@function_tool +def get_weather(city: str, date: str) -> dict: + """Get the weather forecast for a city on a given date. + + :param city: City name (e.g. "Seattle"). + :param date: Date string (e.g. "2026-03-14"). + :return: Weather data with temperature and conditions. + """ + logger.info("Tool get_weather called: city=%s, date=%s", city, date) + temperature = random.randint(8, 21) + conditions = random.choice( + ["Sunny with light breeze", "Partly cloudy", "Overcast with rain"] + ) + result = { + "city": city, + "temperature": f"{temperature}C", + "conditions": conditions, + "date": date, + } + logger.info("Tool get_weather result: %s", result) + return result + + +@function_tool +def get_date() -> str: + """Get the current date and time. + + :return: ISO-formatted current datetime string. + """ + now = datetime.now().isoformat() + logger.info("Tool get_date called: %s", now) + return now + + +# --------------------------------------------------------------------------- +# WeatherAgent — Activity protocol agent +# --------------------------------------------------------------------------- + + +class WeatherAgent: + """A weather-forecast agent that speaks the Activity protocol. + + Accepts an :class:`~microsoft_agents.activity.Activity`, dispatches by + ``activity.type``, and returns a list of reply + :class:`~microsoft_agents.activity.Activity` objects. + + :param client: An ``AsyncAzureOpenAI`` client for LLM calls. + :type client: openai.AsyncAzureOpenAI + """ + + def __init__(self, client: AsyncAzureOpenAI) -> None: + self._agent = Agent( + name="WeatherAgent", + instructions=( + "You are a friendly assistant that helps people find weather " + "forecasts. Use the get_weather and get_date tools to look up " + "the forecast. Always respond with plain text. If you need " + "more information (city or date), ask the user a follow-up " + "question." + ), + tools=[get_weather, get_date], + ) + + class _AzureModelProvider(ModelProvider): + """Routes model requests to the Azure OpenAI client.""" + + def get_model(self, model_name: Optional[str] = None) -> Model: + return OpenAIChatCompletionsModel( + model=model_name + or os.environ.get("AZURE_OPENAI_MODEL", "gpt-4o"), + openai_client=client, + ) + + self._model_provider = _AzureModelProvider() + logger.info("WeatherAgent initialised with tools: get_weather, get_date") + + # -- Dispatch ------------------------------------------------------------ + + async def on_turn(self, activity: Activity) -> list[Activity]: + """Process an incoming activity and return reply activities. + + Routes by ``activity.type`` to the appropriate handler, mirroring the + ``ActivityHandler.on_turn()`` pattern from the Microsoft Agents SDK. + + :param activity: The incoming activity. + :type activity: microsoft_agents.activity.Activity + :return: Zero or more reply activities. + :rtype: list[Activity] + """ + sender = activity.from_property.name if activity.from_property else "unknown" + conv_id = activity.conversation.id if activity.conversation else "unknown" + logger.info( + "on_turn: type=%s, from=%s, conversation=%s", + activity.type, sender, conv_id, + ) + + if activity.type == ActivityTypes.message: + return await self._on_message_activity(activity) + if activity.type == ActivityTypes.conversation_update: + return await self._on_conversation_update_activity(activity) + + logger.debug("on_turn: unhandled activity type %r, returning empty", activity.type) + return [] + + # -- Handlers ------------------------------------------------------------ + + async def _on_message_activity(self, activity: Activity) -> list[Activity]: + """Handle a ``message`` activity via the OpenAI Agents SDK. + + :param activity: The incoming message activity. + :type activity: Activity + :return: A single-element list with the reply activity. + :rtype: list[Activity] + """ + user_text = activity.text or "" + logger.info("Message received: %r", user_text) + + logger.debug("Running OpenAI Agents SDK Runner.run() ...") + response = await Runner.run( + self._agent, + user_text, + run_config=RunConfig( + model_provider=self._model_provider, + tracing_disabled=True, + ), + ) + + reply = activity.create_reply(text=response.final_output) + logger.info("Reply: %r", reply.text) + return [reply] + + async def _on_conversation_update_activity( + self, activity: Activity + ) -> list[Activity]: + """Handle a ``conversationUpdate`` activity. + + Sends a welcome message when new members are added (excluding the + agent itself). + + :param activity: The incoming conversationUpdate activity. + :type activity: Activity + :return: A list with one welcome reply, or empty. + :rtype: list[Activity] + """ + members_added = activity.members_added or [] + recipient_id = activity.recipient.id if activity.recipient else None + logger.info( + "conversationUpdate: %d member(s) added, recipient_id=%s", + len(members_added), recipient_id, + ) + + for member in members_added: + if member.id != recipient_id: + logger.info( + "New member joined: id=%s, name=%s — sending welcome", + member.id, member.name, + ) + reply = activity.create_reply( + text=( + "Hello and welcome! I can help you with weather " + "forecasts. Just tell me a city and date, and I'll " + "look up the forecast for you." + ), + ) + return [reply] + + logger.debug("conversationUpdate: no new external members, no welcome sent") + return [] + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + + +def create_weather_agent() -> WeatherAgent: + """Create a :class:`WeatherAgent` from environment variables. + + Reads ``AZURE_OPENAI_ENDPOINT`` and ``AZURE_OPENAI_API_VERSION`` from + the environment (or ``.env`` file via ``python-dotenv``). + + :return: A configured WeatherAgent instance. + :rtype: WeatherAgent + """ + endpoint = os.environ["AZURE_OPENAI_ENDPOINT"] + api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2024-12-01-preview") + logger.info( + "Creating WeatherAgent: endpoint=%s, api_version=%s", + endpoint, api_version, + ) + client = AsyncAzureOpenAI( + api_version=api_version, + azure_endpoint=endpoint, + ) + return WeatherAgent(client=client) diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/requirements.txt new file mode 100644 index 000000000000..0bdd8b1ddfc4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/requirements.txt @@ -0,0 +1,5 @@ +azure-ai-agentserver-server +microsoft-agents-activity>=0.8.0 +openai>=1.60.0 +openai-agents>=0.1.0 +python-dotenv>=1.0.0 diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/server.py b/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/server.py new file mode 100644 index 000000000000..a00b54a78409 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/activity_weather_agent/server.py @@ -0,0 +1,115 @@ +"""Host a WeatherAgent via AgentServer with the Activity protocol. + +Bridges the Activity protocol to AgentServer's ``/invocations`` endpoint. +Incoming requests carry Bot Framework Activity JSON; the server +deserialises them into ``microsoft_agents.activity.Activity`` objects, +dispatches to the ``WeatherAgent``, and serialises the reply activities +back to JSON. + +This demonstrates that AgentServer can host agents built with the +Activity protocol — without depending on ``microsoft_agents.hosting.core`` +or ``CloudAdapter``. + +Usage:: + + # Start the server + python server.py + + # Send a message activity + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{ + "type": "message", + "text": "What is the weather in Seattle tomorrow?", + "from": {"id": "user-1", "name": "User"}, + "recipient": {"id": "agent-1", "name": "WeatherAgent"}, + "conversation": {"id": "conv-1"}, + "channelId": "custom", + "serviceUrl": "http://localhost:8088" + }' + + # Send a conversationUpdate to trigger the welcome message + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{ + "type": "conversationUpdate", + "membersAdded": [{"id": "user-1", "name": "User"}], + "from": {"id": "user-1"}, + "recipient": {"id": "agent-1", "name": "WeatherAgent"}, + "conversation": {"id": "conv-1"}, + "channelId": "custom", + "serviceUrl": "http://localhost:8088" + }' +""" + +import logging + +# Configure logging for the sample so all loggers (including the agent's) +# output to the console. This must run before any logger.info() calls. +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", +) + +from microsoft_agents.activity import Activity + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver.server import AgentServer + +from activity_weather_agent import create_weather_agent + +logger = logging.getLogger(__name__) + + +# -- Create the agent and server --------------------------------------------- + +agent = create_weather_agent() +server = AgentServer() +# AgentServer adds its own StreamHandler to the "azure.ai.agentserver" logger. +# Stop propagation to the root logger to avoid duplicate lines from basicConfig. +logging.getLogger("azure.ai.agentserver").propagate = False +logger.info("AgentServer created, WeatherAgent wired to /invocations") + + +# -- Wire the Activity protocol to /invocations ------------------------------ + + +@server.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Bridge ``POST /invocations`` to the Activity protocol. + + Deserialises the JSON body into an + :class:`~microsoft_agents.activity.Activity`, dispatches to the + agent's ``on_turn``, and returns the first reply activity as JSON. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON response containing the reply activity, or 200 with + an empty status if no reply is produced. + :rtype: starlette.responses.Response + """ + data = await request.json() + activity = Activity.model_validate(data) + logger.info( + "Received activity: type=%s, conversation=%s", + activity.type, + activity.conversation.id if activity.conversation else "unknown", + ) + + replies = await agent.on_turn(activity) + + if replies: + # Serialise back to camelCase JSON (matching Bot Framework wire format). + body = replies[0].model_dump( + by_alias=True, exclude_none=True, mode="json" + ) + logger.info( + "Returning reply: type=%s, text=%r", + body.get("type"), body.get("text", "")[:80], + ) + return JSONResponse(body) + + logger.info("No reply activities produced, returning status ok") + return JSONResponse({"status": "ok"}) + + +if __name__ == "__main__": + server.run() diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/.env.sample b/sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/.env.sample new file mode 100644 index 000000000000..b2381c6f1b1b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/.env.sample @@ -0,0 +1,4 @@ +# Azure AI credentials (used by DefaultAzureCredential + AzureOpenAIChatClient) +AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com/ +AZURE_OPENAI_API_VERSION=2024-12-01-preview +AZURE_OPENAI_CHAT_DEPLOYMENT_NAME=gpt-4.1 \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/agentframework_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/agentframework_invoke_agent.py new file mode 100644 index 000000000000..b49f330f52c4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/agentframework_invoke_agent.py @@ -0,0 +1,81 @@ +"""Agent Framework agent served via /invoke. + +Customer owns the AgentFramework <-> invoke conversion logic. +This replaces the need for azure-ai-agentserver-agentframework. + +Usage:: + + # Start the agent + python agentframework_invoke_agent.py + + # Send a request + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"input": "What is the weather in Seattle?"}' + # -> {"result": "The weather in Seattle is sunny with a high of 25°C."} +""" +import asyncio +import os +from random import randint +from typing import Annotated + +from dotenv import load_dotenv + +load_dotenv() + +from agent_framework import Agent +from agent_framework.azure import AzureOpenAIChatClient +from azure.identity import DefaultAzureCredential + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver.server import AgentServer + + +# -- Customer defines their tools -- + +def get_weather( + location: Annotated[str, "The location to get the weather for."], +) -> str: + """Get the weather for a given location.""" + conditions = ["sunny", "cloudy", "rainy", "stormy"] + return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C." + + +# -- Customer builds their Agent Framework agent -- + +def build_agent() -> Agent: + """Create an Agent Framework Agent with tools.""" + client = AzureOpenAIChatClient(credential=DefaultAzureCredential()) + return client.as_agent( + instructions="You are a helpful weather assistant.", + tools=get_weather, + ) + + +# -- Customer-managed adapter: Agent Framework <-> /invoke -- + +agent = build_agent() +server = AgentServer() + + +@server.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Process an invocation via Agent Framework. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON response with the agent result. + :rtype: starlette.responses.JSONResponse + """ + data = await request.json() + user_input = data.get("input", "") + + # Run the agent + response = await agent.run(user_input) + result = response.content if hasattr(response, "content") else str(response) + + return JSONResponse({"result": result}) + + +if __name__ == "__main__": + server.run() diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/requirements.txt new file mode 100644 index 000000000000..a02b64e685a5 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/agentframework_invoke_agent/requirements.txt @@ -0,0 +1,4 @@ +azure-ai-agentserver-server +agent-framework>=1.0.0rc2 +azure-identity>=1.25.0 +python-dotenv>=1.0.0 diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/async_invoke_agent/async_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-server/samples/async_invoke_agent/async_invoke_agent.py new file mode 100644 index 000000000000..c929fd898107 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/async_invoke_agent/async_invoke_agent.py @@ -0,0 +1,168 @@ +"""Async invoke agent example. + +Demonstrates get_invocation and cancel_invocation for long-running work. +Invocations run in background tasks; callers poll or cancel by ID. + +.. warning:: + + **In-memory demo only.** This sample stores all invocation state + (``self._tasks``, ``self._results``) in process memory. Both in-flight + ``asyncio.Task`` objects and completed results are lost on process restart + — which *will* happen during platform rolling updates, health-check + failures, and scaling events. + + For production long-running invocations: + + * Persist results to durable storage (Redis, Cosmos DB, etc.) inside + ``_do_work`` **before** the method returns. + * On startup, rehydrate any incomplete work or mark it as failed. + * Consider an external task queue (Celery, Azure Queue, etc.) instead + of ``asyncio.create_task`` for work that must survive restarts. + +Usage:: + + # Start the agent + python async_invoke_agent.py + + # Start a long-running invocation + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"query": "analyze dataset"}' + # -> x-agent-invocation-id: abc-123 + # -> {"invocation_id": "abc-123", "status": "running"} + + # Poll for result + curl http://localhost:8088/invocations/abc-123 + # -> {"invocation_id": "abc-123", "status": "running"} (still working) + # -> {"invocation_id": "abc-123", "status": "completed"} (done) + + # Or cancel + curl -X POST http://localhost:8088/invocations/abc-123/cancel + # -> {"invocation_id": "abc-123", "status": "cancelled"} +""" +import asyncio +import json + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver.server import AgentServer + + +# In-memory state for demo purposes (see module docstring for production caveats) +_tasks: dict[str, asyncio.Task] = {} +_results: dict[str, bytes] = {} + +server = AgentServer() + + +async def _do_work(invocation_id: str, data: dict) -> bytes: + """Simulate long-running work. + + :param invocation_id: The invocation ID for this task. + :type invocation_id: str + :param data: The parsed request data. + :type data: dict + :return: JSON result bytes. + :rtype: bytes + """ + await asyncio.sleep(10) + result = json.dumps({ + "invocation_id": invocation_id, + "status": "completed", + "output": f"Processed: {data}", + }).encode() + _results[invocation_id] = result + return result + + +@server.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Start a long-running invocation in a background task. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON status indicating the task is running. + :rtype: starlette.responses.JSONResponse + """ + data = await request.json() + invocation_id = request.state.invocation_id + + task = asyncio.create_task(_do_work(invocation_id, data)) + _tasks[invocation_id] = task + + return JSONResponse({ + "invocation_id": invocation_id, + "status": "running", + }) + + +@server.get_invocation_handler +async def handle_get_invocation(request: Request) -> Response: + """Retrieve a previous invocation result. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON status or result. + :rtype: starlette.responses.JSONResponse + """ + invocation_id = request.state.invocation_id + + if invocation_id in _results: + return Response(content=_results[invocation_id], media_type="application/json") + + if invocation_id in _tasks: + task = _tasks[invocation_id] + if not task.done(): + return JSONResponse({ + "invocation_id": invocation_id, + "status": "running", + }) + result = task.result() + _results[invocation_id] = result + del _tasks[invocation_id] + return Response(content=result, media_type="application/json") + + return JSONResponse({"error": "not found"}, status_code=404) + + +@server.cancel_invocation_handler +async def handle_cancel_invocation(request: Request) -> Response: + """Cancel a running invocation. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON cancellation status. + :rtype: starlette.responses.JSONResponse + """ + invocation_id = request.state.invocation_id + + # Already completed — cannot cancel + if invocation_id in _results: + return JSONResponse({ + "invocation_id": invocation_id, + "status": "completed", + "error": "invocation already completed", + }) + + if invocation_id in _tasks: + task = _tasks[invocation_id] + if task.done(): + # Task finished between check — treat as completed + _results[invocation_id] = task.result() + del _tasks[invocation_id] + return JSONResponse({ + "invocation_id": invocation_id, + "status": "completed", + "error": "invocation already completed", + }) + task.cancel() + del _tasks[invocation_id] + return JSONResponse({ + "invocation_id": invocation_id, + "status": "cancelled", + }) + + return JSONResponse({"error": "not found"}, status_code=404) + + +if __name__ == "__main__": + server.run() diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/async_invoke_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-server/samples/async_invoke_agent/requirements.txt new file mode 100644 index 000000000000..16f731287fdc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/async_invoke_agent/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-server diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/human_in_the_loop_agent/human_in_the_loop_agent.py b/sdk/agentserver/azure-ai-agentserver-server/samples/human_in_the_loop_agent/human_in_the_loop_agent.py new file mode 100644 index 000000000000..4526500e6bb1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/human_in_the_loop_agent/human_in_the_loop_agent.py @@ -0,0 +1,74 @@ +"""Human-in-the-loop invoke agent example. + +Demonstrates a synchronous human-in-the-loop pattern using only +POST /invocations. The agent asks a clarifying question, and the client +replies in a second request. + +Flow: + 1. Client sends a message -> agent returns a question + invocation_id + 2. Client sends a reply -> agent returns the final result + +Usage:: + + # Start the agent + python human_in_the_loop_agent.py + + # Step 1: Send a request — agent asks a clarifying question + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"message": "Book me a flight"}' + # -> {"invocation_id": "", "status": "needs_input", "question": "Where would you like to fly to?"} + + # Step 2: Reply with the answer — agent completes + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"reply_to": "", "message": "Seattle"}' + # -> {"invocation_id": "", "status": "completed", "response": "Flight to Seattle booked."} +""" +from typing import Any + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver.server import AgentServer + + +# Holds questions waiting for a human reply, keyed by invocation_id +_waiting: dict[str, dict[str, Any]] = {} + +server = AgentServer() + + +@server.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Handle messages and replies. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON response indicating status. + :rtype: starlette.responses.JSONResponse + """ + data = await request.json() + invocation_id = request.state.invocation_id + + # --- Reply to a previous question --- + reply_to = data.get("reply_to") + if reply_to: + if reply_to not in _waiting: + return JSONResponse({"error": f"No pending question for {reply_to}"}) + + return JSONResponse({ + "invocation_id": reply_to, + "status": "completed", + "response": f"Flight to {data.get('message', '?')} booked.", + }) + + # --- New request: ask a clarifying question --- + _waiting[invocation_id] = { + "message": data.get("message", ""), + } + return JSONResponse({ + "invocation_id": invocation_id, + "status": "needs_input", + "question": "Where would you like to fly to?", + }) + + +if __name__ == "__main__": + server.run() diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/human_in_the_loop_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-server/samples/human_in_the_loop_agent/requirements.txt new file mode 100644 index 000000000000..16f731287fdc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/human_in_the_loop_agent/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-server diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/.env.sample b/sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/.env.sample new file mode 100644 index 000000000000..a75e7aec8869 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/.env.sample @@ -0,0 +1,5 @@ +# Azure OpenAI credentials +AZURE_OPENAI_API_KEY=your-api-key-here +AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com/ +AZURE_OPENAI_API_VERSION=2024-12-01-preview +AZURE_OPENAI_MODEL=gpt-4o diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/langgraph_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/langgraph_invoke_agent.py new file mode 100644 index 000000000000..f717551643f8 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/langgraph_invoke_agent.py @@ -0,0 +1,100 @@ +"""LangGraph agent served via /invoke. + +Customer owns the LangGraph <-> invoke conversion logic. +This replaces the need for azure-ai-agentserver-langgraph. + +Usage:: + + # Start the agent + python langgraph_invoke_agent.py + + # Non-streaming request + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"message": "What is the capital of France?"}' + # -> {"reply": "The capital of France is Paris."} + + # Streaming request + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"message": "Tell me a joke", "stream": true}' + # -> {"delta": "Why did..."} + # {"delta": " the chicken..."} +""" +import json +import os +from typing import AsyncGenerator + +from dotenv import load_dotenv + +load_dotenv() + +from langgraph.graph import END, START, MessagesState, StateGraph +from langchain_openai import AzureChatOpenAI + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver.server import AgentServer + + +def build_graph() -> StateGraph: + """Customer builds their LangGraph agent as usual.""" + llm = AzureChatOpenAI( + model=os.environ["AZURE_OPENAI_MODEL"], + api_version=os.environ.get("AZURE_OPENAI_API_VERSION", "2024-12-01-preview"), + ) + + def chatbot(state: MessagesState): + return {"messages": [llm.invoke(state["messages"])]} + + graph = StateGraph(MessagesState) + graph.add_node("chatbot", chatbot) + graph.add_edge(START, "chatbot") + graph.add_edge("chatbot", END) + return graph.compile() + + +graph = build_graph() +server = AgentServer() + + +async def _stream_response(user_message: str) -> AsyncGenerator[bytes, None]: + """Async generator that yields response chunks. + + :param user_message: The user message to process. + :type user_message: str + :return: An async generator yielding JSON-encoded byte chunks. + :rtype: AsyncGenerator[bytes, None] + """ + async for event in graph.astream_events( + {"messages": [{"role": "user", "content": user_message}]}, + version="v2", + ): + if event["event"] == "on_chat_model_stream": + chunk = event["data"]["chunk"].content + if chunk: + yield json.dumps({"delta": chunk}).encode() + b"\n" + + +@server.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Process the invocation via LangGraph. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON response or streaming response. + :rtype: starlette.responses.Response + """ + data = await request.json() + user_message = data["message"] + stream = data.get("stream", False) + + if stream: + return StreamingResponse(_stream_response(user_message)) + + result = await graph.ainvoke( + {"messages": [{"role": "user", "content": user_message}]} + ) + last_message = result["messages"][-1] + return JSONResponse({"reply": last_message.content}) + + +if __name__ == "__main__": + server.run() diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/requirements.txt new file mode 100644 index 000000000000..681ea34a9682 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/langgraph_invoke_agent/requirements.txt @@ -0,0 +1,4 @@ +azure-ai-agentserver-server +langgraph>=1.0.0 +langchain-openai>=1.0.0 +python-dotenv>=1.0.0 diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/SPEC.md b/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/SPEC.md new file mode 100644 index 000000000000..0403e8c825ed --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/SPEC.md @@ -0,0 +1,120 @@ +# Invocation API Spec + +This document defines the HTTP contract that an agent container must implement +to run on the platform. No SDK is required — any language or framework that +can serve HTTP is supported. + +## Required + +### `POST /invocations` + +Execute the agent. + +- **Port:** `8088`. +- **Request body:** Any content type. The schema and format are defined by your + agent — the platform does not enforce a specific shape or media type. +- **Response body:** Any content type. The schema and format are defined by your + agent. + +## Optional Features + +The following endpoints and capabilities are **not required** by the platform. +Implement them only if your agent needs the corresponding functionality. + +### Async polling and cancel + +The platform only calls `POST /invocations`. If your agent needs async +polling or cancel, you can add your own endpoints under `/invocations/` +with any paths and schemas you choose. + +### OpenAPI spec — `GET /invocations/docs/openapi.json` + +Serve an OpenAPI spec describing your agent's request/response schema. +The platform can use this for documentation and request validation. + +### Health probes — `GET /liveness` and `GET /readiness` + +Standard Kubernetes health probes. Implement these if your container +runs in a Kubernetes environment and you need the orchestrator to detect +unhealthy or unready instances. + +- `/liveness` — return `200` when the process is alive. +- `/readiness` — return `200` when the agent is ready to accept requests + (e.g. models loaded, connections established). + +### Invocation ID tracking + +The platform may send an `x-agent-invocation-id` request header. +If present, echo it back on the response. If absent, generate a UUID +and include it on the response. This enables end-to-end correlation +of requests across services. + +### OpenTelemetry tracing with App Insights export + +Integrate with Foundry distributed tracing by creating spans for each +invocation and exporting them to Application Insights. See the +[Tracing](#tracing) section below for details. + +## Headers + +| Header | Direction | Description | +|--------|-----------|-------------| +| `x-agent-invocation-id` | Request | Platform may send an invocation ID. If absent, agent may generate a UUID. | +| `traceparent` | Request | W3C Trace Context header for distributed tracing. | +| `tracestate` | Request | W3C Trace Context header for distributed tracing. | +| `baggage` | Request | W3C Baggage for cross-service context propagation. | + +## Query Parameters + +| Parameter | Description | +|-----------|-------------| +| `agent_session_id` | Session or conversation ID for tracing correlation. Maps to `gen_ai.conversation.id` span attribute. | + +## Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `AGENT_NAME` | | Agent name | +| `AGENT_VERSION` | | Agent version | +| `AGENT_PROJECT_NAME` | | Azure foundry project name | + +## Tracing + +To integrate wtih foundry tracing, the agent should create an OpenTelemetry span for each +`POST /invocations` request and export traces to Application Insights via +`azure-monitor-opentelemetry-exporter`. + +### W3C Trace Context propagation + +Extract `traceparent` and `tracestate` from the incoming request headers and use +them as the parent context for the span. This connects the agent's spans to the +platform's distributed trace. + +### `leaf_customer_span_id` (baggage) + +The platform may send a `baggage` header containing a `leaf_customer_span_id` +key. When present, the agent **must** override the parent span ID from +`traceparent` with this value. This re-parents the agent's root span under +the caller's leaf span so the trace tree renders correctly in App Insights. + +The value is a 16-character lower-hex span ID. To apply it: + +1. Extract the trace context from `traceparent` normally. +2. Parse `leaf_customer_span_id` from the `baggage` header. +3. Create a new `SpanContext` with the same `trace_id` but the + `span_id` replaced by the baggage value. +4. Use the new context as the parent when starting the span. + +### Span attributes + +Each span should include the following GenAI semantic convention attributes: + +| Attribute | Source | Description | +|-----------|--------|-------------| +| `invocation.id` | `x-agent-invocation-id` header or generated UUID | Unique invocation identifier | +| `gen_ai.response.id` | Same as `invocation.id` | Maps response to invocation | +| `gen_ai.provider.name` | `"microsoft.foundry"` | Fixed provider name | +| `gen_ai.agent.id` | `AGENT_NAME` + `AGENT_VERSION` env vars | Agent identity, e.g. `"my-agent:1.0"` | +| `microsoft.foundry.project.id` | `AGENT_PROJECT_NAME` env var | Project identifier | +| `gen_ai.operation.name` | `"invoke_agent"` | Operation type | +| `gen_ai.conversation.id` | `agent_session_id` query parameter | Session/conversation ID | diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/full_server.py b/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/full_server.py new file mode 100644 index 000000000000..1cbcba1de38e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/full_server.py @@ -0,0 +1,327 @@ +"""Full-featured agent — no SDK required. + +Implements optional Invocation API features: +- Invocation ID tracking (x-agent-invocation-id header) +- Health probes (/liveness, /readiness) +- OpenTelemetry tracing with App Insights export + +See SPEC.md for the full Invocation API contract. + +Usage:: + + pip install fastapi uvicorn opentelemetry-api opentelemetry-sdk \\ + azure-monitor-opentelemetry-exporter + + # Optional: set App Insights connection string for trace export + export APPLICATIONINSIGHTS_CONNECTION_STRING="InstrumentationKey=..." + export AGENT_ENABLE_TRACING=true + export AGENT_NAME=my-agent + export AGENT_VERSION=1.0 + + python full_server.py + + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"name": "Alice"}' + # -> {"greeting": "Hello, Alice!"} + + curl http://localhost:8088/liveness + # -> {"status": "alive"} +""" +import logging +import os +import uuid +from contextlib import contextmanager +from typing import Any, Iterator + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse + +# --------------------------------------------------------------------------- +# Configuration from environment variables +# --------------------------------------------------------------------------- + +LOG_LEVEL = os.environ.get("AGENT_LOG_LEVEL", "INFO").upper() +ENABLE_TRACING = os.environ.get("AGENT_ENABLE_TRACING", "").lower() in ("true", "1", "yes") +APPINSIGHTS_CONN_STR = os.environ.get("APPLICATIONINSIGHTS_CONNECTION_STRING", "") +AGENT_NAME = os.environ.get("AGENT_NAME", "") +AGENT_VERSION = os.environ.get("AGENT_VERSION", "") +AGENT_PROJECT_NAME = os.environ.get("AGENT_PROJECT_NAME", "") + +logging.basicConfig( + level=LOG_LEVEL, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", +) +logger = logging.getLogger(__name__) + +INVOCATION_ID_HEADER = "x-agent-invocation-id" + +app = FastAPI() + + +# --------------------------------------------------------------------------- +# OpenTelemetry tracing (optional) +# --------------------------------------------------------------------------- + +_tracer: Any = None +_propagator: Any = None + +if ENABLE_TRACING: + try: + from opentelemetry import trace + from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator, + ) + + # Set up App Insights export if connection string is available + if APPINSIGHTS_CONN_STR: + try: + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + from azure.monitor.opentelemetry.exporter import ( + AzureMonitorTraceExporter, + ) + + resource = Resource.create({"service.name": "agent"}) + provider = TracerProvider(resource=resource) + exporter = AzureMonitorTraceExporter( + connection_string=APPINSIGHTS_CONN_STR + ) + provider.add_span_processor(BatchSpanProcessor(exporter)) + trace.set_tracer_provider(provider) + logger.info("App Insights trace exporter configured") + except ImportError: + logger.warning( + "App Insights export requires opentelemetry-sdk and " + "azure-monitor-opentelemetry-exporter" + ) + + _tracer = trace.get_tracer("agent") + _propagator = TraceContextTextMapPropagator() + logger.info("OpenTelemetry tracing enabled") + + except ImportError: + logger.warning( + "AGENT_ENABLE_TRACING=true but opentelemetry-api is not installed" + ) + + +def _build_span_attrs( + invocation_id: str, + session_id: str = "", + operation_name: str = "invoke_agent", +) -> dict[str, str]: + """Build GenAI semantic convention span attributes.""" + agent_label = ( + f"{AGENT_NAME}:{AGENT_VERSION}" if AGENT_NAME and AGENT_VERSION else AGENT_NAME + ) + attrs: dict[str, str] = { + "invocation.id": invocation_id, + "gen_ai.response.id": invocation_id, + "gen_ai.provider.name": "microsoft.foundry", + "gen_ai.operation.name": operation_name, + } + if agent_label: + attrs["gen_ai.agent.id"] = agent_label + if AGENT_PROJECT_NAME: + attrs["microsoft.foundry.project.id"] = AGENT_PROJECT_NAME + if session_id: + attrs["gen_ai.conversation.id"] = session_id + return attrs + + +def _parse_baggage_key(baggage_header: str, key: str) -> str: + """Parse a single key from a W3C Baggage header. + + The baggage header is a comma-separated list of key=value pairs. + See https://www.w3.org/TR/baggage/ + + :param baggage_header: Raw baggage header value. + :param key: Key to extract. + :return: The value if found, or empty string. + """ + for member in baggage_header.split(","): + member = member.strip() + if "=" not in member: + continue + k, _, v = member.partition("=") + if k.strip() == key: + # Value may have properties after ';' — take only the value part + return v.split(";")[0].strip() + return "" + + +def _override_parent_span_id( + ctx: Any, + leaf_span_id_hex: str, +) -> Any: + """Re-parent the trace context using leaf_customer_span_id from baggage. + + Creates a new parent context with the same trace_id but the span_id + replaced by *leaf_span_id_hex*. This connects the agent's root span + to the caller's leaf span so the trace tree renders correctly in + App Insights. + + :param ctx: OpenTelemetry context with extracted traceparent. + :param leaf_span_id_hex: 16-character lower-hex span ID from baggage. + :return: New context with overridden parent span ID. + """ + from opentelemetry.trace import ( + NonRecordingSpan, + SpanContext, + set_span_in_context, + ) + + # Get the current span from the context to extract trace_id and trace_flags + current_span = trace.get_current_span(ctx) if ctx else None + if current_span is None: + return ctx + + current_ctx = current_span.get_span_context() + if current_ctx is None or not current_ctx.is_valid: + return ctx + + try: + new_span_id = int(leaf_span_id_hex, 16) + except ValueError: + logger.warning("Invalid leaf_customer_span_id: %s", leaf_span_id_hex) + return ctx + + # Build a new SpanContext with the overridden span_id + new_span_context = SpanContext( + trace_id=current_ctx.trace_id, + span_id=new_span_id, + is_remote=True, + trace_flags=current_ctx.trace_flags, + trace_state=current_ctx.trace_state, + ) + return set_span_in_context(NonRecordingSpan(new_span_context), ctx) + + +_LEAF_CUSTOMER_SPAN_ID = "leaf_customer_span_id" + + +@contextmanager +def _request_span( + request: Request, + invocation_id: str, + operation_name: str = "invoke_agent", +) -> Iterator[Any]: + """Create a traced span for a request, propagating W3C context. + + Handles full W3C Trace Context propagation including the + ``leaf_customer_span_id`` baggage key, which re-parents the agent's + root span under the caller's leaf span for correct trace tree + rendering in App Insights. + + Yields the OTel span or None if tracing is disabled. + """ + if _tracer is None or _propagator is None: + yield None + return + + # Extract W3C trace context from incoming headers + carrier = { + k: v + for k in ("traceparent", "tracestate") + if (v := request.headers.get(k)) is not None + } + ctx = _propagator.extract(carrier=carrier) if carrier else None + + # Override parent span ID with leaf_customer_span_id from baggage + # (see SPEC.md § leaf_customer_span_id) + baggage = request.headers.get("baggage", "") + if ctx is not None and baggage: + leaf_span_id = _parse_baggage_key(baggage, _LEAF_CUSTOMER_SPAN_ID) + if leaf_span_id: + ctx = _override_parent_span_id(ctx, leaf_span_id) + + session_id = request.query_params.get("agent_session_id", "") + attrs = _build_span_attrs(invocation_id, session_id, operation_name) + + agent_label = ( + f"{AGENT_NAME}:{AGENT_VERSION}" if AGENT_NAME and AGENT_VERSION else AGENT_NAME + ) + span_name = f"execute_agent {agent_label}" if agent_label else "execute_agent" + + with _tracer.start_as_current_span( + name=span_name, + attributes=attrs, + kind=trace.SpanKind.SERVER, + context=ctx, + ) as otel_span: + yield otel_span + + +def _record_error(span: Any, exc: Exception) -> None: + """Record an exception on a span if tracing is active.""" + if span is not None and ENABLE_TRACING: + try: + span.set_status(trace.StatusCode.ERROR, str(exc)) + span.record_exception(exc) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Health probes +# --------------------------------------------------------------------------- + + +@app.get("/liveness") +async def liveness(): + """liveness probe.""" + return {"status": "alive"} + + +@app.get("/readiness") +async def readiness(): + """readiness probe.""" + return {"status": "ready"} + + +# --------------------------------------------------------------------------- +# Agent logic — replace this with your own +# --------------------------------------------------------------------------- + + +async def run_agent(data: dict) -> dict: + """Your agent logic goes here. + + :param data: Parsed JSON request body. + :return: Response dict to serialize as JSON. + """ + greeting = f"Hello, {data.get('name', 'World')}!" + return {"greeting": greeting} + + +# --------------------------------------------------------------------------- +# POST /invocations — required +# --------------------------------------------------------------------------- + + +@app.post("/invocations") +async def invoke(request: Request): + """Execute the agent.""" + invocation_id = ( + request.headers.get(INVOCATION_ID_HEADER) or str(uuid.uuid4()) + ) + + data = await request.json() + + with _request_span(request, invocation_id) as span: + try: + result = await run_agent(data) + except Exception as exc: + _record_error(span, exc) + raise + + return JSONResponse( + result, + headers={INVOCATION_ID_HEADER: invocation_id}, + ) + + +if __name__ == "__main__": + import uvicorn + port = int(os.environ.get("AGENT_SERVER_PORT", "8088")) + uvicorn.run(app, host="0.0.0.0", port=port) diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/minimal_server.py b/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/minimal_server.py new file mode 100644 index 000000000000..18ba3999d46e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/no_sdk_agent/minimal_server.py @@ -0,0 +1,31 @@ +"""Minimal agent — no SDK required. + +Implements only the required ``POST /invocations`` endpoint. +See SPEC.md for the full Invocation API contract. + +Usage:: + + pip install fastapi uvicorn + python minimal_server.py + + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"name": "Alice"}' + # -> {"greeting": "Hello, Alice!"} +""" +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse + +app = FastAPI() + + +@app.post("/invocations") +async def invoke(request: Request): + data = await request.json() + greeting = f"Hello, {data.get('name', 'Alice')}!" + return JSONResponse({"greeting": greeting}) + + +if __name__ == "__main__": + import os + import uvicorn + port = int(os.environ.get("AGENT_SERVER_PORT", "8088")) + uvicorn.run(app, host="0.0.0.0", port=port) diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/openapi_validated_agent/openapi_validated_agent.py b/sdk/agentserver/azure-ai-agentserver-server/samples/openapi_validated_agent/openapi_validated_agent.py new file mode 100644 index 000000000000..60ffb28516de --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/openapi_validated_agent/openapi_validated_agent.py @@ -0,0 +1,117 @@ +"""OpenAPI-validated agent example. + +Demonstrates how to supply an OpenAPI spec to AgentServer so that +incoming requests are validated automatically. Invalid requests receive +a 400 response before ``invoke`` is called. + +The spec is also served at ``GET /invocations/docs/openapi.json`` so +that callers can discover the agent's contract at runtime. + +Usage:: + + # Start the agent + python openapi_validated_agent.py + + # Fetch the OpenAPI spec + curl http://localhost:8088/invocations/docs/openapi.json + + # Valid request (200) + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"name": "Alice", "language": "fr"}' + # -> {"greeting": "Bonjour, Alice!"} + + # Invalid request — missing required "name" field (400) + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"language": "en"}' + # -> {"error": ["'name' is a required property"]} +""" +from typing import Any + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver.server import AgentServer + +# Define a simple OpenAPI 3.0 spec inline. In production this could be +# loaded from a YAML/JSON file. +OPENAPI_SPEC: dict[str, Any] = { + "openapi": "3.0.0", + "info": {"title": "Greeting Agent", "version": "1.0.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["name"], + "properties": { + "name": { + "type": "string", + "description": "Name of the person to greet.", + }, + "language": { + "type": "string", + "enum": ["en", "es", "fr"], + "description": "Language for the greeting.", + }, + }, + "additionalProperties": False, + } + } + }, + }, + "responses": { + "200": { + "description": "Greeting response", + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["greeting"], + "properties": { + "greeting": {"type": "string"}, + }, + } + } + }, + } + }, + } + } + }, +} + +GREETINGS = { + "en": "Hello", + "es": "Hola", + "fr": "Bonjour", +} + + +server = AgentServer(openapi_spec=OPENAPI_SPEC, enable_request_validation=True) + + +@server.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Return a localised greeting. + + The OpenAPI spec enforces that "name" is required, "language" must be + one of ``en``, ``es``, or ``fr``, and no extra fields are allowed. + Requests that violate the schema are rejected with 400 before reaching + the invoke handler. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON greeting response. + :rtype: starlette.responses.JSONResponse + """ + data = await request.json() + language = data.get("language", "en") + prefix = GREETINGS.get(language, "Hello") + greeting = f"{prefix}, {data['name']}!" + return JSONResponse({"greeting": greeting}) + + +if __name__ == "__main__": + server.run() diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/openapi_validated_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-server/samples/openapi_validated_agent/requirements.txt new file mode 100644 index 000000000000..16f731287fdc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/openapi_validated_agent/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-server diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/simple_invoke_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-server/samples/simple_invoke_agent/requirements.txt new file mode 100644 index 000000000000..16f731287fdc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/simple_invoke_agent/requirements.txt @@ -0,0 +1 @@ +azure-ai-agentserver-server diff --git a/sdk/agentserver/azure-ai-agentserver-server/samples/simple_invoke_agent/simple_invoke_agent.py b/sdk/agentserver/azure-ai-agentserver-server/samples/simple_invoke_agent/simple_invoke_agent.py new file mode 100644 index 000000000000..368511106848 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/samples/simple_invoke_agent/simple_invoke_agent.py @@ -0,0 +1,38 @@ +"""Simple invoke agent example. + +Accepts JSON requests, echoes back with a greeting. + +Usage:: + + # Start the agent + python simple_invoke_agent.py + + # Send a greeting request + curl -X POST http://localhost:8088/invocations -H "Content-Type: application/json" -d '{"name": "Alice"}' + # -> {"greeting": "Hello, Alice!"} +""" +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver.server import AgentServer + + +server = AgentServer() + + +@server.invoke_handler +async def handle_invoke(request: Request) -> Response: + """Process the invocation by echoing a greeting. + + :param request: The raw Starlette request. + :type request: starlette.requests.Request + :return: JSON greeting response. + :rtype: starlette.responses.JSONResponse + """ + data = await request.json() + greeting = f"Hello, {data['name']}!" + return JSONResponse({"greeting": greeting}) + + +if __name__ == "__main__": + server.run() diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/conftest.py b/sdk/agentserver/azure-ai-agentserver-server/tests/conftest.py new file mode 100644 index 000000000000..e9e6fcda8c3d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/conftest.py @@ -0,0 +1,205 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for azure-ai-agentserver-server tests.""" +import json + +import pytest +import pytest_asyncio +import httpx + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver.server import AgentServer + + +# --------------------------------------------------------------------------- +# Agent factory functions (decorator pattern) +# --------------------------------------------------------------------------- + + +def _make_echo_agent(**kwargs) -> AgentServer: + """Create an echo agent that returns the request body as-is.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + body = await request.body() + return Response(content=body, media_type="application/octet-stream") + + return server + + +def _make_streaming_agent(**kwargs) -> AgentServer: + """Create an agent that returns a multi-chunk streaming response.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> StreamingResponse: + async def generate(): + for i in range(3): + yield json.dumps({"chunk": i}).encode() + b"\n" + + return StreamingResponse(generate()) + + return server + + +def _make_async_storage_agent(**kwargs) -> AgentServer: + """Create an agent with get/cancel support via in-memory storage.""" + store: dict[str, bytes] = {} + server = AgentServer(**kwargs) + server._store = store # expose for test access + + @server.invoke_handler + async def invoke(request: Request) -> Response: + body = await request.body() + invocation_id = request.state.invocation_id + result = json.dumps({"echo": body.decode()}).encode() + store[invocation_id] = result + return Response(content=result, media_type="application/json") + + @server.get_invocation_handler + async def get_invocation(request: Request) -> Response: + invocation_id = request.state.invocation_id + if invocation_id in store: + return Response(content=store[invocation_id], media_type="application/json") + return JSONResponse({"error": "not found"}, status_code=404) + + @server.cancel_invocation_handler + async def cancel_invocation(request: Request) -> Response: + invocation_id = request.state.invocation_id + if invocation_id in store: + del store[invocation_id] + return JSONResponse({"status": "cancelled"}) + return JSONResponse({"error": "not found"}, status_code=404) + + return server + + +SAMPLE_OPENAPI_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "Test", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + }, + "required": ["name"], + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "greeting": {"type": "string"}, + }, + "required": ["greeting"], + } + } + } + } + }, + } + } + }, +} + + +def _make_validated_agent() -> AgentServer: + """Create an agent with OpenAPI validation that returns a greeting.""" + server = AgentServer(openapi_spec=SAMPLE_OPENAPI_SPEC, enable_request_validation=True) + + @server.invoke_handler + async def handle(request: Request) -> Response: + data = await request.json() + return JSONResponse({"greeting": f"Hello, {data['name']}!"}) + + return server + + +def _make_failing_agent(**kwargs) -> AgentServer: + """Create an agent whose invoke handler always raises.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + raise ValueError("something went wrong") + + return server + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def echo_client(): + """httpx.AsyncClient wired to an echo agent's ASGI app.""" + server = _make_echo_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +@pytest_asyncio.fixture +async def streaming_client(): + """httpx.AsyncClient wired to a streaming agent's ASGI app.""" + server = _make_streaming_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +@pytest.fixture +def async_storage_server(): + """An async storage agent instance.""" + return _make_async_storage_agent() + + +@pytest_asyncio.fixture +async def async_storage_client(async_storage_server): + """httpx.AsyncClient wired to an async storage agent's ASGI app.""" + transport = httpx.ASGITransport(app=async_storage_server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +@pytest_asyncio.fixture +async def validated_client(): + """httpx.AsyncClient wired to a validated agent's ASGI app.""" + server = _make_validated_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +@pytest_asyncio.fixture +async def no_spec_client(): + """httpx.AsyncClient wired to an echo agent (no OpenAPI spec).""" + server = _make_echo_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +@pytest_asyncio.fixture +async def failing_client(): + """httpx.AsyncClient wired to a failing agent's ASGI app.""" + server = _make_failing_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_decorator_pattern.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_decorator_pattern.py new file mode 100644 index 000000000000..e03d01efe567 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_decorator_pattern.py @@ -0,0 +1,343 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for the decorator-based handler registration pattern.""" +from __future__ import annotations + +import httpx +import pytest +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver.server import AgentServer + + +# --------------------------------------------------------------------------- +# Decorator registration +# --------------------------------------------------------------------------- + + +class TestDecoratorRegistration: + """Verify that decorators store the function and return it unchanged.""" + + def test_invoke_handler_stores_function(self): + server = AgentServer() + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + assert server._invocation._invoke_fn is handle + + def test_invoke_handler_returns_original_function(self): + server = AgentServer() + original = None + + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + original = handle + result = server.invoke_handler(handle) + assert result is original + + def test_get_invocation_handler_stores_function(self): + server = AgentServer() + + @server.get_invocation_handler + async def handle(request: Request) -> Response: + return JSONResponse({"found": True}) + + assert server._invocation._get_invocation_fn is handle + + def test_cancel_invocation_handler_stores_function(self): + server = AgentServer() + + @server.cancel_invocation_handler + async def handle(request: Request) -> Response: + return JSONResponse({"cancelled": True}) + + assert server._invocation._cancel_invocation_fn is handle + + def test_shutdown_handler_stores_function(self): + server = AgentServer() + + @server.shutdown_handler + async def handle(): + pass + + assert server._shutdown_fn is handle + + +# --------------------------------------------------------------------------- +# Invoke handler — full request flow +# --------------------------------------------------------------------------- + + +class TestInvokeHandlerFlow: + """Verify that POST /invocations delegates to @invoke_handler.""" + + @pytest.mark.asyncio + async def test_invoke_returns_handler_response(self): + server = AgentServer() + + @server.invoke_handler + async def handle(request: Request) -> Response: + data = await request.json() + return JSONResponse({"echo": data["msg"]}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", json={"msg": "hello"}) + assert resp.status_code == 200 + assert resp.json()["echo"] == "hello" + + @pytest.mark.asyncio + async def test_invoke_includes_invocation_id_header(self): + server = AgentServer() + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", json={}) + assert "x-agent-invocation-id" in resp.headers + + @pytest.mark.asyncio + async def test_invoke_request_has_invocation_id_in_state(self): + """The handler receives request.state.invocation_id.""" + server = AgentServer() + captured_id = None + + @server.invoke_handler + async def handle(request: Request) -> Response: + nonlocal captured_id + captured_id = request.state.invocation_id + return JSONResponse({"id": captured_id}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", json={}) + assert resp.status_code == 200 + assert captured_id is not None + assert resp.json()["id"] == captured_id + + +# --------------------------------------------------------------------------- +# Missing invoke handler +# --------------------------------------------------------------------------- + + +class TestMissingInvokeHandler: + """When no invoke handler is registered and invoke() is not overridden, 501.""" + + @pytest.mark.asyncio + async def test_no_handler_returns_501(self): + server = AgentServer() + # No @invoke_handler registered + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", json={}) + assert resp.status_code == 501 + data = resp.json() + assert data["error"]["code"] == "not_implemented" + assert "invoke handler" in data["error"]["message"].lower() + assert "x-agent-invocation-id" in resp.headers + + +# --------------------------------------------------------------------------- +# Optional handler defaults +# --------------------------------------------------------------------------- + + +class TestOptionalHandlerDefaults: + """get_invocation and cancel_invocation return 404 by default.""" + + @pytest.mark.asyncio + async def test_get_invocation_returns_501_by_default(self): + server = AgentServer() + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/invocations/some-id") + assert resp.status_code == 501 + assert resp.headers.get("x-agent-invocation-id") == "some-id" + + @pytest.mark.asyncio + async def test_cancel_invocation_returns_501_by_default(self): + server = AgentServer() + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations/some-id/cancel") + assert resp.status_code == 501 + assert resp.headers.get("x-agent-invocation-id") == "some-id" + + +# --------------------------------------------------------------------------- +# Optional handler overrides via decorator +# --------------------------------------------------------------------------- + + +class TestOptionalHandlerOverrides: + """Registered optional handlers are called instead of the defaults.""" + + @pytest.mark.asyncio + async def test_get_invocation_handler_called(self): + server = AgentServer() + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return JSONResponse({"ok": True}) + + @server.get_invocation_handler + async def get_inv(request: Request) -> Response: + inv_id = request.state.invocation_id + return JSONResponse({"id": inv_id, "status": "completed"}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/invocations/test-123") + assert resp.status_code == 200 + assert resp.json()["id"] == "test-123" + assert resp.json()["status"] == "completed" + + @pytest.mark.asyncio + async def test_cancel_invocation_handler_called(self): + server = AgentServer() + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return JSONResponse({"ok": True}) + + @server.cancel_invocation_handler + async def cancel(request: Request) -> Response: + inv_id = request.state.invocation_id + return JSONResponse({"id": inv_id, "cancelled": True}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations/test-456/cancel") + assert resp.status_code == 200 + assert resp.json()["cancelled"] is True + + +# --------------------------------------------------------------------------- +# Shutdown handler via decorator +# --------------------------------------------------------------------------- + + +class TestShutdownHandler: + """@shutdown_handler is called during lifespan teardown.""" + + @pytest.mark.asyncio + async def test_shutdown_handler_called(self): + server = AgentServer() + shutdown_called = False + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return JSONResponse({"ok": True}) + + @server.shutdown_handler + async def cleanup(): + nonlocal shutdown_called + shutdown_called = True + + # Exercise the lifespan directly + lifespan = server.app.router.lifespan_context + async with lifespan(server.app): + pass # startup; on exit → shutdown + + assert shutdown_called is True + + @pytest.mark.asyncio + async def test_no_shutdown_handler_is_noop(self): + """Without @shutdown_handler, on_shutdown is a silent no-op.""" + server = AgentServer() + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return JSONResponse({"ok": True}) + + # Should not raise + lifespan = server.app.router.lifespan_context + async with lifespan(server.app): + pass + + +# --------------------------------------------------------------------------- +# Config passthrough +# --------------------------------------------------------------------------- + + +class TestConfigPassthrough: + """All AgentServer kwargs still work in decorator mode.""" + + def test_request_timeout_resolved(self): + server = AgentServer(request_timeout=42) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + assert server._request_timeout == 42 + + def test_graceful_shutdown_timeout_resolved(self): + server = AgentServer(graceful_shutdown_timeout=10) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + assert server._graceful_shutdown_timeout == 10 + + def test_debug_errors_resolved(self): + server = AgentServer(debug_errors=True) + assert server._debug_errors is True + + +# --------------------------------------------------------------------------- +# Health endpoints work in decorator mode +# --------------------------------------------------------------------------- + + +class TestHealthEndpointsDecoratorMode: + """Health endpoints respond even when using the decorator pattern.""" + + @pytest.mark.asyncio + async def test_liveness(self): + server = AgentServer() + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/liveness") + assert resp.status_code == 200 + assert resp.json()["status"] == "alive" + + @pytest.mark.asyncio + async def test_readiness(self): + server = AgentServer() + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/readiness") + assert resp.status_code == 200 + assert resp.json()["status"] == "ready" diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_edge_cases.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_edge_cases.py new file mode 100644 index 000000000000..c5db4d49ca66 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_edge_cases.py @@ -0,0 +1,519 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""End-to-end edge-case tests for AgentServer.""" +import json +import uuid + +import httpx +import pytest +import pytest_asyncio + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver.server import AgentServer + + +# --------------------------------------------------------------------------- +# Agent factory functions for edge cases +# --------------------------------------------------------------------------- + + +def _make_custom_header_agent(**kwargs) -> AgentServer: + """Create an agent that sets its own x-agent-invocation-id header.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse( + {"ok": True}, + headers={"x-agent-invocation-id": "custom-id-from-agent"}, + ) + + return server + + +def _make_empty_streaming_agent(**kwargs) -> AgentServer: + """Create an agent that returns an empty streaming response (0 chunks).""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> StreamingResponse: + async def generate(): + return + yield # noqa: RUF028 — makes this an async generator + + return StreamingResponse(generate(), media_type="text/event-stream") + + return server + + +def _make_slow_failing_get_agent(**kwargs) -> AgentServer: + """Create an agent whose get_invocation raises so we can test debug errors on GET.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return JSONResponse({"ok": True}) + + @server.get_invocation_handler + async def get_invocation(request: Request) -> Response: + raise ValueError("get-debug-detail") + + return server + + +def _make_slow_failing_cancel_agent(**kwargs) -> AgentServer: + """Create an agent whose cancel_invocation raises so we can test debug errors on cancel.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return JSONResponse({"ok": True}) + + @server.cancel_invocation_handler + async def cancel_invocation(request: Request) -> Response: + raise ValueError("cancel-debug-detail") + + return server + + +def _make_large_payload_agent(**kwargs) -> AgentServer: + """Create an agent that echoes the request body length as JSON.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + body = await request.body() + return JSONResponse({"length": len(body)}) + + return server + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def custom_header_client(): + server = _make_custom_header_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +@pytest_asyncio.fixture +async def empty_streaming_client(): + server = _make_empty_streaming_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +@pytest_asyncio.fixture +async def large_payload_client(): + server = _make_large_payload_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + yield client + + +# --------------------------------------------------------------------------- +# Tests: wrong HTTP methods +# --------------------------------------------------------------------------- + + +class TestMethodNotAllowed: + """Verify that wrong HTTP methods return 405.""" + + @pytest.mark.asyncio + async def test_get_on_invocations_returns_405(self, echo_client): + """GET /invocations is not allowed (POST only).""" + resp = await echo_client.get("/invocations") + assert resp.status_code == 405 + + @pytest.mark.asyncio + async def test_post_on_liveness_returns_405(self, echo_client): + """POST /liveness is not allowed (GET only).""" + resp = await echo_client.post("/liveness") + assert resp.status_code == 405 + + @pytest.mark.asyncio + async def test_post_on_readiness_returns_405(self, echo_client): + """POST /readiness is not allowed (GET only).""" + resp = await echo_client.post("/readiness") + assert resp.status_code == 405 + + @pytest.mark.asyncio + async def test_put_on_invocations_returns_405(self, echo_client): + """PUT /invocations is not allowed.""" + resp = await echo_client.put("/invocations", content=b"{}") + assert resp.status_code == 405 + + @pytest.mark.asyncio + async def test_delete_on_invocations_id_returns_405(self, echo_client): + """DELETE /invocations/{id} is not allowed.""" + resp = await echo_client.delete(f"/invocations/{uuid.uuid4()}") + assert resp.status_code == 405 + + @pytest.mark.asyncio + async def test_post_on_openapi_spec_returns_405(self, echo_client): + """POST /invocations/docs/openapi.json is not allowed (GET only).""" + resp = await echo_client.post("/invocations/docs/openapi.json", content=b"{}") + assert resp.status_code == 405 + + +# --------------------------------------------------------------------------- +# Tests: OpenAPI validation edge cases +# --------------------------------------------------------------------------- + + +class TestOpenAPIValidation: + """End-to-end OpenAPI validation via HTTP.""" + + @pytest.mark.asyncio + async def test_valid_request_returns_200(self, validated_client): + """POST with valid body passes validation and returns greeting.""" + resp = await validated_client.post( + "/invocations", + content=json.dumps({"name": "Alice"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + assert resp.json()["greeting"] == "Hello, Alice!" + + @pytest.mark.asyncio + async def test_missing_required_field_returns_400(self, validated_client): + """POST missing required 'name' field returns 400 with validation details.""" + resp = await validated_client.post( + "/invocations", + content=json.dumps({"wrong_field": "foo"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + body = resp.json() + assert body["error"]["code"] == "invalid_payload" + assert "details" in body["error"] + + @pytest.mark.asyncio + async def test_malformed_json_returns_400(self, validated_client): + """POST with non-JSON body when spec expects JSON returns 400.""" + resp = await validated_client.post( + "/invocations", + content=b"this is not json", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + assert resp.json()["error"]["code"] == "invalid_payload" + + @pytest.mark.asyncio + async def test_extra_field_accepted(self, validated_client): + """Extra fields are accepted when additionalProperties is not false.""" + resp = await validated_client.post( + "/invocations", + content=json.dumps({"name": "Bob", "bonus": 42}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + assert resp.json()["greeting"] == "Hello, Bob!" + + @pytest.mark.asyncio + async def test_no_spec_skips_validation(self, no_spec_client): + """Agent with no OpenAPI spec accepts any payload without validation.""" + resp = await no_spec_client.post( + "/invocations", + content=b"arbitrary bytes not json at all!", + ) + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# Tests: response header behaviour +# --------------------------------------------------------------------------- + + +class TestResponseHeaders: + """Verify invocation-id header auto-injection and passthrough.""" + + @pytest.mark.asyncio + async def test_agent_custom_invocation_id_not_overwritten(self, custom_header_client): + """If agent sets x-agent-invocation-id, the server overwrites it with the canonical value.""" + resp = await custom_header_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + # Server always controls the invocation-id header; handler cannot override it. + invocation_id = resp.headers["x-agent-invocation-id"] + assert invocation_id != "custom-id-from-agent" + uuid.UUID(invocation_id) # must be a valid server-generated UUID + + @pytest.mark.asyncio + async def test_invocation_id_injected_when_missing(self, echo_client): + """EchoAgent doesn't set the header; server auto-injects it.""" + resp = await echo_client.post("/invocations", content=b"hello") + assert resp.status_code == 200 + invocation_id = resp.headers.get("x-agent-invocation-id") + assert invocation_id is not None + uuid.UUID(invocation_id) # valid UUID + + @pytest.mark.asyncio + async def test_invocation_id_accepted_from_request_header(self, echo_client): + """If caller sends x-agent-invocation-id in the request, server uses it.""" + custom_id = "caller-provided-id-12345" + resp = await echo_client.post( + "/invocations", + content=b"hello", + headers={"x-agent-invocation-id": custom_id}, + ) + assert resp.status_code == 200 + assert resp.headers["x-agent-invocation-id"] == custom_id + + @pytest.mark.asyncio + async def test_invocation_id_generated_when_request_header_empty(self, echo_client): + """If caller sends empty x-agent-invocation-id, server generates a new one.""" + resp = await echo_client.post( + "/invocations", + content=b"hello", + headers={"x-agent-invocation-id": ""}, + ) + assert resp.status_code == 200 + invocation_id = resp.headers["x-agent-invocation-id"] + assert invocation_id != "" + uuid.UUID(invocation_id) # valid UUID + + +# --------------------------------------------------------------------------- +# Tests: payload edge cases +# --------------------------------------------------------------------------- + + +class TestPayloadEdgeCases: + """Body content edge cases.""" + + @pytest.mark.asyncio + async def test_large_payload(self, large_payload_client): + """Server handles a 1 MB payload correctly.""" + big = b"A" * (1024 * 1024) + resp = await large_payload_client.post("/invocations", content=big) + assert resp.status_code == 200 + assert resp.json()["length"] == 1024 * 1024 + + @pytest.mark.asyncio + async def test_unicode_body(self, echo_client): + """Unicode characters round-trip correctly.""" + data = "\u3053\u3093\u306b\u3061\u306f\u4e16\u754c \U0001f30d \u0645\u0631\u062d\u0628\u0627" + resp = await echo_client.post("/invocations", content=data.encode("utf-8")) + assert resp.status_code == 200 + assert resp.content.decode("utf-8") == data + + @pytest.mark.asyncio + async def test_binary_body(self, echo_client): + """Binary (non-UTF-8) bytes round-trip correctly.""" + data = bytes(range(256)) + resp = await echo_client.post("/invocations", content=data) + assert resp.status_code == 200 + assert resp.content == data + + +# --------------------------------------------------------------------------- +# Tests: streaming edge cases +# --------------------------------------------------------------------------- + + +class TestStreamingEdgeCases: + """Streaming response edge cases.""" + + @pytest.mark.asyncio + async def test_empty_streaming_response(self, empty_streaming_client): + """Empty streaming response returns 200 with no body.""" + resp = await empty_streaming_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert resp.content == b"" + + @pytest.mark.asyncio + async def test_streaming_response_has_invocation_id(self, streaming_client): + """Streaming response still gets the auto-injected invocation-id header.""" + resp = await streaming_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + invocation_id = resp.headers.get("x-agent-invocation-id") + assert invocation_id is not None + uuid.UUID(invocation_id) # valid UUID + + +# --------------------------------------------------------------------------- +# Tests: invocation lifecycle (storage agent) +# --------------------------------------------------------------------------- + + +class TestInvocationLifecycle: + """Multi-step invocation workflows.""" + + @pytest.mark.asyncio + async def test_multiple_gets_on_same_invocation(self, async_storage_client): + """GET /invocations/{id} returns the same data on repeated calls.""" + post_resp = await async_storage_client.post("/invocations", content=b'{"data":"hello"}') + invocation_id = post_resp.headers["x-agent-invocation-id"] + + get1 = await async_storage_client.get(f"/invocations/{invocation_id}") + get2 = await async_storage_client.get(f"/invocations/{invocation_id}") + assert get1.status_code == 200 + assert get2.status_code == 200 + assert get1.content == get2.content + + @pytest.mark.asyncio + async def test_double_cancel_returns_404_second_time(self, async_storage_client): + """Cancelling the same invocation twice: first succeeds, second 404.""" + post_resp = await async_storage_client.post("/invocations", content=b'{"data":"x"}') + invocation_id = post_resp.headers["x-agent-invocation-id"] + + cancel1 = await async_storage_client.post(f"/invocations/{invocation_id}/cancel") + assert cancel1.status_code == 200 + + cancel2 = await async_storage_client.post(f"/invocations/{invocation_id}/cancel") + assert cancel2.status_code == 404 + + @pytest.mark.asyncio + async def test_invoke_then_cancel_then_get_returns_404(self, async_storage_client): + """Full lifecycle: invoke → cancel → get returns 404.""" + post_resp = await async_storage_client.post("/invocations", content=b'{"val":1}') + invocation_id = post_resp.headers["x-agent-invocation-id"] + + await async_storage_client.post(f"/invocations/{invocation_id}/cancel") + + get_resp = await async_storage_client.get(f"/invocations/{invocation_id}") + assert get_resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# Tests: debug errors on get/cancel endpoints +# --------------------------------------------------------------------------- + + +class TestDebugErrorsOnGetCancel: + """AGENT_DEBUG_ERRORS exposes details on GET and CANCEL error responses too.""" + + @pytest.mark.asyncio + async def test_get_hides_details_by_default(self): + """GET error hides exception detail without AGENT_DEBUG_ERRORS.""" + agent = _make_slow_failing_get_agent() + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/invocations/some-id") + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "Internal server error" + + @pytest.mark.asyncio + async def test_get_exposes_details_with_debug(self, monkeypatch): + """GET error exposes exception detail with AGENT_DEBUG_ERRORS set.""" + monkeypatch.setenv("AGENT_DEBUG_ERRORS", "true") + agent = _make_slow_failing_get_agent() + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/invocations/some-id") + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "get-debug-detail" + + @pytest.mark.asyncio + async def test_cancel_hides_details_by_default(self): + """CANCEL error hides exception detail without AGENT_DEBUG_ERRORS.""" + agent = _make_slow_failing_cancel_agent() + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations/some-id/cancel") + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "Internal server error" + + @pytest.mark.asyncio + async def test_cancel_exposes_details_with_debug(self, monkeypatch): + """CANCEL error exposes exception detail with AGENT_DEBUG_ERRORS set.""" + monkeypatch.setenv("AGENT_DEBUG_ERRORS", "true") + agent = _make_slow_failing_cancel_agent() + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations/some-id/cancel") + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "cancel-debug-detail" + + @pytest.mark.asyncio + async def test_get_exposes_details_with_constructor_param(self): + """debug_errors=True in constructor exposes GET exception detail.""" + agent = _make_slow_failing_get_agent(debug_errors=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/invocations/some-id") + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "get-debug-detail" + + @pytest.mark.asyncio + async def test_cancel_exposes_details_with_constructor_param(self): + """debug_errors=True in constructor exposes CANCEL exception detail.""" + agent = _make_slow_failing_cancel_agent(debug_errors=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations/some-id/cancel") + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "cancel-debug-detail" + + @pytest.mark.asyncio + async def test_constructor_overrides_env_var(self, monkeypatch): + """debug_errors=False in constructor overrides AGENT_DEBUG_ERRORS=true.""" + monkeypatch.setenv("AGENT_DEBUG_ERRORS", "true") + agent = _make_slow_failing_get_agent(debug_errors=False) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/invocations/some-id") + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "Internal server error" + + +# --------------------------------------------------------------------------- +# Tests: log_level constructor parameter +# --------------------------------------------------------------------------- + + +class TestLogLevel: + """log_level parameter controls the library logger level.""" + + def test_log_level_via_constructor(self): + """log_level='DEBUG' sets library logger to DEBUG.""" + import logging + agent = _make_custom_header_agent(log_level="DEBUG") + lib_logger = logging.getLogger("azure.ai.agentserver") + assert lib_logger.level == logging.DEBUG + + def test_log_level_via_env_var(self, monkeypatch): + """AGENT_LOG_LEVEL=info sets library logger to INFO.""" + import logging + monkeypatch.setenv("AGENT_LOG_LEVEL", "info") + agent = _make_custom_header_agent() + lib_logger = logging.getLogger("azure.ai.agentserver") + assert lib_logger.level == logging.INFO + + def test_log_level_constructor_overrides_env_var(self, monkeypatch): + """Constructor log_level overrides AGENT_LOG_LEVEL env var.""" + import logging + monkeypatch.setenv("AGENT_LOG_LEVEL", "DEBUG") + agent = _make_custom_header_agent(log_level="ERROR") + lib_logger = logging.getLogger("azure.ai.agentserver") + assert lib_logger.level == logging.ERROR + + def test_invalid_log_level_raises(self): + """Invalid log_level raises ValueError.""" + with pytest.raises(ValueError, match="Invalid log level"): + _make_custom_header_agent(log_level="BOGUS") + + +class TestConcurrency: + """Verify multiple concurrent requests produce unique invocation IDs.""" + + @pytest.mark.asyncio + async def test_concurrent_invocations_unique_ids(self, echo_client): + """10 concurrent POSTs each get a unique invocation ID.""" + import asyncio + + async def do_post(): + return await echo_client.post("/invocations", content=b"concurrent") + + responses = await asyncio.gather(*[do_post() for _ in range(10)]) + ids = {r.headers["x-agent-invocation-id"] for r in responses} + assert len(ids) == 10 # all unique diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_get_cancel.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_get_cancel.py new file mode 100644 index 000000000000..edc5f1d47b48 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_get_cancel.py @@ -0,0 +1,115 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for get_invocation and cancel_invocation.""" +import json +import uuid + +import pytest + + +@pytest.mark.asyncio +async def test_get_invocation_after_invoke(async_storage_client): + """Invoke, then GET /invocations/{id} returns stored result.""" + resp = await async_storage_client.post("/invocations", content=b'{"key":"value"}') + invocation_id = resp.headers["x-agent-invocation-id"] + + get_resp = await async_storage_client.get(f"/invocations/{invocation_id}") + assert get_resp.status_code == 200 + assert get_resp.headers.get("x-agent-invocation-id") == invocation_id + data = json.loads(get_resp.content) + assert "echo" in data + + +@pytest.mark.asyncio +async def test_get_invocation_unknown_id_returns_404(async_storage_client): + """GET /invocations/{unknown} returns 404.""" + resp = await async_storage_client.get(f"/invocations/{uuid.uuid4()}") + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_cancel_invocation_after_invoke(async_storage_client): + """Invoke, then POST /invocations/{id}/cancel returns cancelled status.""" + resp = await async_storage_client.post("/invocations", content=b'{"key":"value"}') + invocation_id = resp.headers["x-agent-invocation-id"] + + cancel_resp = await async_storage_client.post(f"/invocations/{invocation_id}/cancel") + assert cancel_resp.status_code == 200 + assert cancel_resp.headers.get("x-agent-invocation-id") == invocation_id + data = json.loads(cancel_resp.content) + assert data["status"] == "cancelled" + + +@pytest.mark.asyncio +async def test_cancel_invocation_unknown_id_returns_404(async_storage_client): + """POST /invocations/{unknown}/cancel returns 404.""" + resp = await async_storage_client.post(f"/invocations/{uuid.uuid4()}/cancel") + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_get_after_cancel_returns_404(async_storage_client): + """Cancel, then get same ID returns 404.""" + resp = await async_storage_client.post("/invocations", content=b'{"key":"value"}') + invocation_id = resp.headers["x-agent-invocation-id"] + + await async_storage_client.post(f"/invocations/{invocation_id}/cancel") + get_resp = await async_storage_client.get(f"/invocations/{invocation_id}") + assert get_resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_get_invocation_error_returns_500(): + """GET /invocations/{id} returns 500 when customer code raises an error.""" + import httpx + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + + from azure.ai.agentserver.server import AgentServer + + server = AgentServer() + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return JSONResponse({}) + + @server.get_invocation_handler + async def get_inv(request: Request) -> Response: + raise RuntimeError("storage unavailable") + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/invocations/some-id") + assert resp.status_code == 500 + assert resp.json()["error"]["code"] == "internal_error" + assert resp.json()["error"]["message"] == "Internal server error" + assert resp.headers.get("x-agent-invocation-id") == "some-id" + + +@pytest.mark.asyncio +async def test_cancel_invocation_error_returns_500(): + """POST /invocations/{id}/cancel returns 500 when customer code raises an error.""" + import httpx + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + + from azure.ai.agentserver.server import AgentServer + + server = AgentServer() + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return JSONResponse({}) + + @server.cancel_invocation_handler + async def cancel(request: Request) -> Response: + raise RuntimeError("cancel failed") + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations/some-id/cancel") + assert resp.status_code == 500 + assert resp.json()["error"]["code"] == "internal_error" + assert resp.json()["error"]["message"] == "Internal server error" + assert resp.headers.get("x-agent-invocation-id") == "some-id" diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_graceful_shutdown.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_graceful_shutdown.py new file mode 100644 index 000000000000..31da3d022640 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_graceful_shutdown.py @@ -0,0 +1,451 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for graceful shutdown configuration and lifecycle behaviour.""" +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from starlette.requests import Request +from starlette.responses import Response + +from azure.ai.agentserver.server import AgentServer +from azure.ai.agentserver.server._constants import Constants + + +# --------------------------------------------------------------------------- +# Agent factory functions +# --------------------------------------------------------------------------- + + +def _make_stub_agent(**kwargs) -> AgentServer: + """Create a no-op agent used to inspect internal state.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return Response(content=b"ok") + + return server + + +# --------------------------------------------------------------------------- +# _resolve_graceful_shutdown_timeout +# --------------------------------------------------------------------------- + + +class TestResolveGracefulShutdownTimeout: + """Unit tests for the timeout resolution hierarchy: explicit > env > default.""" + + def test_explicit_value_takes_precedence(self): + agent = _make_stub_agent(graceful_shutdown_timeout=10) + assert agent._graceful_shutdown_timeout == 10 + + def test_explicit_zero_disables_drain(self): + agent = _make_stub_agent(graceful_shutdown_timeout=0) + assert agent._graceful_shutdown_timeout == 0 + + def test_env_var_zero_disables_drain(self): + with patch.dict(os.environ, {Constants.AGENT_GRACEFUL_SHUTDOWN_TIMEOUT: "0"}): + agent = _make_stub_agent() + assert agent._graceful_shutdown_timeout == 0 + + def test_env_var_used_when_no_explicit(self): + with patch.dict(os.environ, {Constants.AGENT_GRACEFUL_SHUTDOWN_TIMEOUT: "45"}): + agent = _make_stub_agent() + assert agent._graceful_shutdown_timeout == 45 + + def test_invalid_env_var_raises(self): + with patch.dict(os.environ, {Constants.AGENT_GRACEFUL_SHUTDOWN_TIMEOUT: "not-a-number"}): + with pytest.raises(ValueError, match="AGENT_GRACEFUL_SHUTDOWN_TIMEOUT"): + _make_stub_agent() + + def test_non_int_explicit_value_raises(self): + with pytest.raises(ValueError, match="expected an integer"): + _make_stub_agent(graceful_shutdown_timeout="ten") # type: ignore[arg-type] + + def test_default_when_nothing_set(self): + with patch.dict(os.environ, {}, clear=True): + agent = _make_stub_agent() + assert agent._graceful_shutdown_timeout == Constants.DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT + + def test_default_is_30_seconds(self): + assert Constants.DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT == 30 + + +# --------------------------------------------------------------------------- +# Constant exists +# --------------------------------------------------------------------------- + + +class TestConstants: + """Verify the new constant is wired correctly.""" + + def test_env_var_name(self): + assert Constants.AGENT_GRACEFUL_SHUTDOWN_TIMEOUT == "AGENT_GRACEFUL_SHUTDOWN_TIMEOUT" + + +# --------------------------------------------------------------------------- +# Hypercorn config receives graceful_timeout (sync run) +# --------------------------------------------------------------------------- + + +class TestRunPassesTimeout: + """Ensure run() forwards the timeout to Hypercorn config.""" + + @patch("hypercorn.asyncio.serve", new_callable=AsyncMock) + @patch("azure.ai.agentserver.server._base.asyncio") + def test_run_passes_timeout(self, mock_asyncio, _mock_serve): + agent = _make_stub_agent(graceful_shutdown_timeout=15) + agent.run() + mock_asyncio.run.assert_called_once() + # Close the unawaited coroutine passed to asyncio.run to avoid + # "coroutine was never awaited" warnings on Python 3.14+. + mock_asyncio.run.call_args[0][0].close() + # Verify the config built internally has the right graceful_timeout + config = agent._build_hypercorn_config("127.0.0.1", 8088) + assert config.graceful_timeout == 15.0 + + @patch("hypercorn.asyncio.serve", new_callable=AsyncMock) + @patch("azure.ai.agentserver.server._base.asyncio") + def test_run_passes_default_timeout(self, mock_asyncio, _mock_serve): + agent = _make_stub_agent() + agent.run() + # Close the unawaited coroutine passed to asyncio.run to avoid + # "coroutine was never awaited" warnings on Python 3.14+. + mock_asyncio.run.call_args[0][0].close() + config = agent._build_hypercorn_config("127.0.0.1", 8088) + assert config.graceful_timeout == 30.0 + + +# --------------------------------------------------------------------------- +# Hypercorn config receives graceful_timeout (async run) +# --------------------------------------------------------------------------- + + +class TestRunAsyncPassesTimeout: + """Ensure run_async() forwards the timeout to Hypercorn config.""" + + @pytest.mark.asyncio + @patch("hypercorn.asyncio.serve", new_callable=AsyncMock) + async def test_run_async_passes_timeout(self, mock_serve): + agent = _make_stub_agent(graceful_shutdown_timeout=20) + await agent.run_async() + mock_serve.assert_awaited_once() + # Check the config passed to serve + call_args = mock_serve.call_args + config = call_args[0][1] if len(call_args[0]) > 1 else call_args.kwargs.get("config") + assert config.graceful_timeout == 20.0 + + @pytest.mark.asyncio + @patch("hypercorn.asyncio.serve", new_callable=AsyncMock) + async def test_run_async_passes_default_timeout(self, mock_serve): + agent = _make_stub_agent() + await agent.run_async() + call_args = mock_serve.call_args + config = call_args[0][1] if len(call_args[0]) > 1 else call_args.kwargs.get("config") + assert config.graceful_timeout == 30.0 + + +# --------------------------------------------------------------------------- +# Lifespan shutdown log +# --------------------------------------------------------------------------- + + +class TestLifespanShutdown: + """Verify the lifespan emits a shutdown log message.""" + + @pytest.mark.asyncio + async def test_shutdown_log_emitted(self): + agent = _make_stub_agent(graceful_shutdown_timeout=42) + # Exercise the full lifespan by sending a request through the ASGI app + import httpx + + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/liveness") + assert resp.status_code == 200 + + # The shutdown log is emitted after the lifespan exits. + # We verify that the app is configured with the correct timeout so the + # log message references it. A direct lifespan exercise is below. + + @pytest.mark.asyncio + async def test_lifespan_shutdown_logs(self): + """Directly exercise the lifespan context manager and verify shutdown log.""" + import logging + + agent = _make_stub_agent(graceful_shutdown_timeout=99) + + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + # Grab the lifespan from the Starlette app + lifespan = agent.app.router.lifespan_context + async with lifespan(agent.app): + pass # startup yields here + + # After exiting the context manager, shutdown log should fire + shutdown_calls = [ + c + for c in mock_logger.info.call_args_list + if "shutting down" in str(c).lower() + ] + assert len(shutdown_calls) == 1 + assert "99" in str(shutdown_calls[0]) + + +# --------------------------------------------------------------------------- +# on_shutdown overridable method +# --------------------------------------------------------------------------- + + +def _make_shutdown_recording_agent(**kwargs) -> AgentServer: + """Create an agent that records on_shutdown calls.""" + server = AgentServer(**kwargs) + server.shutdown_log: list[str] = [] + + @server.invoke_handler + async def handle(request: Request) -> Response: + return Response(content=b"ok") + + @server.shutdown_handler + async def shutdown(): + server.shutdown_log.append("shutdown") + + return server + + +def _make_async_checkpoint_agent(**kwargs) -> AgentServer: + """Create an agent whose shutdown handler performs async work.""" + server = AgentServer(**kwargs) + server.flushed: list[str] = [] + + @server.invoke_handler + async def handle(request: Request) -> Response: + return Response(content=b"ok") + + @server.shutdown_handler + async def shutdown(): + import asyncio + await asyncio.sleep(0) + server.flushed.append("async-checkpoint") + + return server + + +def _make_failing_shutdown_agent(**kwargs) -> AgentServer: + """Create an agent whose shutdown handler raises an exception.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return Response(content=b"ok") + + @server.shutdown_handler + async def shutdown(): + raise RuntimeError("disk full") + + return server + + +class TestOnShutdownMethod: + """Verify the overridable on_shutdown() method is called during lifespan teardown.""" + + @pytest.mark.asyncio + async def test_on_shutdown_called(self): + agent = _make_shutdown_recording_agent() + lifespan = agent.app.router.lifespan_context + async with lifespan(agent.app): + pass + assert agent.shutdown_log == ["shutdown"] + + @pytest.mark.asyncio + async def test_async_work_in_on_shutdown(self): + agent = _make_async_checkpoint_agent() + lifespan = agent.app.router.lifespan_context + async with lifespan(agent.app): + pass + assert agent.flushed == ["async-checkpoint"] + + @pytest.mark.asyncio + async def test_default_on_shutdown_is_noop(self): + """Base class on_shutdown does nothing and doesn't raise.""" + agent = _make_stub_agent() + lifespan = agent.app.router.lifespan_context + async with lifespan(agent.app): + pass # should complete without error + + @pytest.mark.asyncio + async def test_on_shutdown_exception_is_logged_not_raised(self): + """A failing on_shutdown must not crash the shutdown sequence.""" + agent = _make_failing_shutdown_agent() + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + lifespan = agent.app.router.lifespan_context + async with lifespan(agent.app): + pass # should NOT raise + + exception_calls = [ + c + for c in mock_logger.exception.call_args_list + if "on_shutdown" in str(c).lower() + ] + assert len(exception_calls) == 1 + + @pytest.mark.asyncio + async def test_on_shutdown_runs_after_shutdown_log(self): + """on_shutdown fires after the shutdown log message.""" + order: list[str] = [] + + server = AgentServer() + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return Response(content=b"ok") + + @server.shutdown_handler + async def shutdown(): + order.append("callback") + + def tracking_info(*args, **kwargs): + if args and "shutting down" in str(args[0]).lower(): + order.append("log") + + mock_logger = MagicMock() + mock_logger.info = MagicMock(side_effect=tracking_info) + + with patch( + "azure.ai.agentserver.server._base.logger", mock_logger + ): + lifespan = server.app.router.lifespan_context + async with lifespan(server.app): + pass + + assert order == ["log", "callback"] + + @pytest.mark.asyncio + async def test_on_shutdown_has_access_to_state(self): + """Shutdown handler can access state via closure.""" + connections: list[str] = ["db", "cache"] + closed: list[str] = [] + + server = AgentServer() + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return Response(content=b"ok") + + @server.shutdown_handler + async def shutdown(): + for conn in connections: + closed.append(conn) + + lifespan = server.app.router.lifespan_context + async with lifespan(server.app): + pass + assert closed == ["db", "cache"] + + +# --------------------------------------------------------------------------- +# on_shutdown timeout enforcement +# --------------------------------------------------------------------------- + + +class TestOnShutdownTimeout: + """Verify on_shutdown is bounded by graceful_shutdown_timeout.""" + + @pytest.mark.asyncio + async def test_slow_on_shutdown_is_cancelled_and_warning_logged(self): + """If on_shutdown exceeds the timeout, it is cancelled and a warning is logged.""" + import asyncio + + server = AgentServer(graceful_shutdown_timeout=1) + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return Response(content=b"ok") + + @server.shutdown_handler + async def shutdown(): + await asyncio.sleep(999) # way longer than the 1s timeout + + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + lifespan = server.app.router.lifespan_context + async with lifespan(server.app): + pass # should NOT hang + + warning_calls = [ + c + for c in mock_logger.warning.call_args_list + if "did not complete" in str(c).lower() + ] + assert len(warning_calls) == 1 + assert "1" in str(warning_calls[0]) + + @pytest.mark.asyncio + async def test_fast_on_shutdown_completes_normally(self): + """on_shutdown that finishes within the timeout succeeds without warnings.""" + import asyncio + + done = False + server = AgentServer(graceful_shutdown_timeout=5) + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return Response(content=b"ok") + + @server.shutdown_handler + async def shutdown(): + nonlocal done + await asyncio.sleep(0) + done = True + + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + lifespan = server.app.router.lifespan_context + async with lifespan(server.app): + pass + + warning_calls = [ + c + for c in mock_logger.warning.call_args_list + if "did not complete" in str(c).lower() + ] + assert len(warning_calls) == 0 + assert done is True + + @pytest.mark.asyncio + async def test_zero_timeout_disables_wait(self): + """graceful_shutdown_timeout=0 passes None to wait_for (no timeout).""" + import asyncio + + ran = False + server = AgentServer(graceful_shutdown_timeout=0) + + @server.invoke_handler + async def invoke(request: Request) -> Response: + return Response(content=b"ok") + + @server.shutdown_handler + async def shutdown(): + nonlocal ran + ran = True + + lifespan = server.app.router.lifespan_context + async with lifespan(server.app): + pass + assert ran is True + + @pytest.mark.asyncio + async def test_timeout_does_not_suppress_exceptions(self): + """Exceptions from on_shutdown are still logged even if within timeout.""" + agent = _make_failing_shutdown_agent(graceful_shutdown_timeout=5) + with patch("azure.ai.agentserver.server._base.logger") as mock_logger: + lifespan = agent.app.router.lifespan_context + async with lifespan(agent.app): + pass + + exception_calls = [ + c + for c in mock_logger.exception.call_args_list + if "on_shutdown" in str(c).lower() + ] + assert len(exception_calls) == 1 diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_health.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_health.py new file mode 100644 index 000000000000..3c0786f8eeac --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_health.py @@ -0,0 +1,23 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for health check endpoints.""" +import pytest + + +@pytest.mark.asyncio +async def test_liveness_returns_200(echo_client): + """GET /liveness returns 200.""" + resp = await echo_client.get("/liveness") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "alive" + + +@pytest.mark.asyncio +async def test_readiness_returns_200(echo_client): + """GET /readiness returns 200.""" + resp = await echo_client.get("/readiness") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ready" diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_http2.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_http2.py new file mode 100644 index 000000000000..51da9b840096 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_http2.py @@ -0,0 +1,473 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Integration tests verifying HTTP/2 support via Hypercorn. + +**TLS tests** start a real Hypercorn server with a self-signed cert and +connect via HTTP/2 using ``httpx`` with ALPN negotiation. # cspell:ignore ALPN + +**h2c tests** (HTTP/2 cleartext) start a plain-TCP Hypercorn server and +connect via HTTP/2 *prior knowledge* (``http1=False, http2=True``), +proving that HTTP/2 works without any TLS certificates. + +Requirements (auto-skipped when missing): +- ``cryptography`` — self-signed certificate generation (TLS tests only) +- ``h2`` — HTTP/2 client support in httpx (installed with hypercorn) +""" +from __future__ import annotations + +import asyncio +import datetime +import ipaddress +import socket +import tempfile +from pathlib import Path +from typing import AsyncGenerator + +import pytest +import pytest_asyncio + +cryptography = pytest.importorskip("cryptography", reason="cryptography needed for TLS cert generation") +pytest.importorskip("h2", reason="h2 needed for HTTP/2 client support") + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID + +import httpx +from hypercorn.asyncio import serve as _hypercorn_serve +from hypercorn.config import Config as HypercornConfig +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver.server import AgentServer + + +# --------------------------------------------------------------------------- +# Test agents +# --------------------------------------------------------------------------- + + +def _make_echo_agent(**kwargs): + """Agent that echoes the request body in the response.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + body = await request.body() + return JSONResponse({"echo": body.decode()}) + + return server + + +def _make_stream_agent(**kwargs): + """Agent that returns a streaming response.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + async def generate() -> AsyncGenerator[bytes, None]: + for chunk in [b"chunk1", b"chunk2", b"chunk3"]: + yield chunk + + return StreamingResponse(generate(), media_type="application/octet-stream") + + return server + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_free_port() -> int: + """Find an available TCP port.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def _generate_self_signed_cert(cert_path: Path, key_path: Path) -> None: + """Generate a self-signed TLS certificate for localhost / 127.0.0.1.""" + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), + ]) + + now = datetime.datetime.now(datetime.timezone.utc) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + datetime.timedelta(hours=1)) + .add_extension( + x509.SubjectAlternativeName([ + x509.DNSName("localhost"), + x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), + ]), + critical=False, + ) + .sign(key, hashes.SHA256()) + ) + + key_path.write_bytes( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM)) + + +async def _wait_for_server(host: str, port: int, timeout: float = 5.0) -> None: + """Poll until the server is accepting connections.""" + deadline = asyncio.get_event_loop().time() + timeout + while asyncio.get_event_loop().time() < deadline: + try: + _, writer = await asyncio.open_connection(host, port) + writer.close() + await writer.wait_closed() + return + except OSError: + await asyncio.sleep(0.05) + raise TimeoutError(f"Server on {host}:{port} did not start within {timeout}s") + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def tls_cert_pair(): + """Generate a self-signed cert + key in a temp directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + cert_path = Path(tmpdir) / "cert.pem" + key_path = Path(tmpdir) / "key.pem" + _generate_self_signed_cert(cert_path, key_path) + yield str(cert_path), str(key_path) + + +@pytest_asyncio.fixture() +async def echo_h2_server(tls_cert_pair): + """Start _EchoAgent on a random port with TLS, yield the port.""" + cert_path, key_path = tls_cert_pair + port = _get_free_port() + + agent = _make_echo_agent() + config = HypercornConfig() + config.bind = [f"127.0.0.1:{port}"] + config.certfile = cert_path + config.keyfile = key_path + config.graceful_timeout = 1.0 + + shutdown_event = asyncio.Event() + task = asyncio.create_task( + _hypercorn_serve(agent.app, config, shutdown_trigger=shutdown_event.wait) + ) + await _wait_for_server("127.0.0.1", port) + yield port + + shutdown_event.set() + await task + + +@pytest_asyncio.fixture() +async def stream_h2_server(tls_cert_pair): + """Start _StreamAgent on a random port with TLS, yield the port.""" + cert_path, key_path = tls_cert_pair + port = _get_free_port() + + agent = _make_stream_agent() + config = HypercornConfig() + config.bind = [f"127.0.0.1:{port}"] + config.certfile = cert_path + config.keyfile = key_path + config.graceful_timeout = 1.0 + + shutdown_event = asyncio.Event() + task = asyncio.create_task( + _hypercorn_serve(agent.app, config, shutdown_trigger=shutdown_event.wait) + ) + await _wait_for_server("127.0.0.1", port) + yield port + + shutdown_event.set() + await task + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestHttp2Health: + """Verify health endpoints work over HTTP/2.""" + + @pytest.mark.asyncio + async def test_h2_liveness(self, echo_h2_server): + port = echo_h2_server + async with httpx.AsyncClient(http2=True, verify=False) as client: + resp = await client.get(f"https://127.0.0.1:{port}/liveness") + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + assert resp.json() == {"status": "alive"} + + @pytest.mark.asyncio + async def test_h2_readiness(self, echo_h2_server): + port = echo_h2_server + async with httpx.AsyncClient(http2=True, verify=False) as client: + resp = await client.get(f"https://127.0.0.1:{port}/readiness") + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + assert resp.json() == {"status": "ready"} + + +class TestHttp2Invoke: + """Verify invocation works over HTTP/2.""" + + @pytest.mark.asyncio + async def test_h2_invoke_returns_200(self, echo_h2_server): + port = echo_h2_server + async with httpx.AsyncClient(http2=True, verify=False) as client: + resp = await client.post( + f"https://127.0.0.1:{port}/invocations", + json={"message": "hello"}, + ) + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + data = resp.json() + assert "echo" in data + + @pytest.mark.asyncio + async def test_h2_invocation_id_header(self, echo_h2_server): + port = echo_h2_server + async with httpx.AsyncClient(http2=True, verify=False) as client: + resp = await client.post( + f"https://127.0.0.1:{port}/invocations", + json={"message": "test"}, + ) + assert resp.http_version == "HTTP/2" + assert "x-agent-invocation-id" in resp.headers + + @pytest.mark.asyncio + async def test_h2_multiple_requests_on_same_connection(self, echo_h2_server): + """HTTP/2 multiplexing: multiple requests on a single connection.""" + port = echo_h2_server + async with httpx.AsyncClient(http2=True, verify=False) as client: + responses = await asyncio.gather( + client.post(f"https://127.0.0.1:{port}/invocations", json={"n": 1}), + client.post(f"https://127.0.0.1:{port}/invocations", json={"n": 2}), + client.post(f"https://127.0.0.1:{port}/invocations", json={"n": 3}), + ) + for resp in responses: + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + # Each invocation gets a unique ID + ids = {resp.headers["x-agent-invocation-id"] for resp in responses} + assert len(ids) == 3 + + +class TestHttp2Streaming: + """Verify streaming responses work over HTTP/2.""" + + @pytest.mark.asyncio + async def test_h2_streaming_response(self, stream_h2_server): + port = stream_h2_server + async with httpx.AsyncClient(http2=True, verify=False) as client: + resp = await client.post( + f"https://127.0.0.1:{port}/invocations", + json={}, + ) + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + assert resp.content == b"chunk1chunk2chunk3" + + @pytest.mark.asyncio + async def test_h2_streaming_has_invocation_id(self, stream_h2_server): + port = stream_h2_server + async with httpx.AsyncClient(http2=True, verify=False) as client: + resp = await client.post( + f"https://127.0.0.1:{port}/invocations", + json={}, + ) + assert resp.http_version == "HTTP/2" + assert "x-agent-invocation-id" in resp.headers + + +class TestHttp1Fallback: + """Verify HTTP/1.1 still works (client without h2 negotiation).""" + + @pytest.mark.asyncio + async def test_http1_invoke(self, echo_h2_server): + """A client that does not request HTTP/2 gets HTTP/1.1.""" + port = echo_h2_server + async with httpx.AsyncClient(http2=False, verify=False) as client: + resp = await client.post( + f"https://127.0.0.1:{port}/invocations", + json={"message": "http1"}, + ) + assert resp.status_code == 200 + assert resp.http_version == "HTTP/1.1" + + +# =========================================================================== +# h2c (HTTP/2 cleartext) — no TLS certificates required +# =========================================================================== + + +@pytest_asyncio.fixture() +async def echo_h2c_server(): + """Start _EchoAgent on a random port *without* TLS, yield the port.""" + port = _get_free_port() + + agent = _make_echo_agent() + config = HypercornConfig() + config.bind = [f"127.0.0.1:{port}"] + config.graceful_timeout = 1.0 + + shutdown_event = asyncio.Event() + task = asyncio.create_task( + _hypercorn_serve(agent.app, config, shutdown_trigger=shutdown_event.wait) + ) + await _wait_for_server("127.0.0.1", port) + yield port + + shutdown_event.set() + await task + + +@pytest_asyncio.fixture() +async def stream_h2c_server(): + """Start _StreamAgent on a random port *without* TLS, yield the port.""" + port = _get_free_port() + + agent = _make_stream_agent() + config = HypercornConfig() + config.bind = [f"127.0.0.1:{port}"] + config.graceful_timeout = 1.0 + + shutdown_event = asyncio.Event() + task = asyncio.create_task( + _hypercorn_serve(agent.app, config, shutdown_trigger=shutdown_event.wait) + ) + await _wait_for_server("127.0.0.1", port) + yield port + + shutdown_event.set() + await task + + +class TestH2cHealth: + """Verify health endpoints work over h2c (HTTP/2 cleartext).""" + + @pytest.mark.asyncio + async def test_h2c_liveness(self, echo_h2c_server): + port = echo_h2c_server + async with httpx.AsyncClient(http1=False, http2=True) as client: + resp = await client.get(f"http://127.0.0.1:{port}/liveness") + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + assert resp.json() == {"status": "alive"} + + @pytest.mark.asyncio + async def test_h2c_readiness(self, echo_h2c_server): + port = echo_h2c_server + async with httpx.AsyncClient(http1=False, http2=True) as client: + resp = await client.get(f"http://127.0.0.1:{port}/readiness") + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + assert resp.json() == {"status": "ready"} + + +class TestH2cInvoke: + """Verify invocation works over h2c (HTTP/2 cleartext).""" + + @pytest.mark.asyncio + async def test_h2c_invoke_returns_200(self, echo_h2c_server): + port = echo_h2c_server + async with httpx.AsyncClient(http1=False, http2=True) as client: + resp = await client.post( + f"http://127.0.0.1:{port}/invocations", + json={"message": "hello"}, + ) + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + data = resp.json() + assert "echo" in data + + @pytest.mark.asyncio + async def test_h2c_invocation_id_header(self, echo_h2c_server): + port = echo_h2c_server + async with httpx.AsyncClient(http1=False, http2=True) as client: + resp = await client.post( + f"http://127.0.0.1:{port}/invocations", + json={"message": "test"}, + ) + assert resp.http_version == "HTTP/2" + assert "x-agent-invocation-id" in resp.headers + + @pytest.mark.asyncio + async def test_h2c_multiplexing(self, echo_h2c_server): + """HTTP/2 multiplexing over cleartext: concurrent requests, unique IDs.""" + port = echo_h2c_server + async with httpx.AsyncClient(http1=False, http2=True) as client: + responses = await asyncio.gather( + client.post(f"http://127.0.0.1:{port}/invocations", json={"n": 1}), + client.post(f"http://127.0.0.1:{port}/invocations", json={"n": 2}), + client.post(f"http://127.0.0.1:{port}/invocations", json={"n": 3}), + ) + for resp in responses: + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + ids = {resp.headers["x-agent-invocation-id"] for resp in responses} + assert len(ids) == 3 + + +class TestH2cStreaming: + """Verify streaming responses work over h2c (HTTP/2 cleartext).""" + + @pytest.mark.asyncio + async def test_h2c_streaming_response(self, stream_h2c_server): + port = stream_h2c_server + async with httpx.AsyncClient(http1=False, http2=True) as client: + resp = await client.post( + f"http://127.0.0.1:{port}/invocations", + json={}, + ) + assert resp.status_code == 200 + assert resp.http_version == "HTTP/2" + assert resp.content == b"chunk1chunk2chunk3" + + @pytest.mark.asyncio + async def test_h2c_streaming_has_invocation_id(self, stream_h2c_server): + port = stream_h2c_server + async with httpx.AsyncClient(http1=False, http2=True) as client: + resp = await client.post( + f"http://127.0.0.1:{port}/invocations", + json={}, + ) + assert resp.http_version == "HTTP/2" + assert "x-agent-invocation-id" in resp.headers + + +class TestH2cFallback: + """Verify HTTP/1.1 still works on the plain-TCP server.""" + + @pytest.mark.asyncio + async def test_h2c_server_accepts_http1(self, echo_h2c_server): + """A client that only speaks HTTP/1.1 is served normally.""" + port = echo_h2c_server + async with httpx.AsyncClient(http1=True, http2=False) as client: + resp = await client.post( + f"http://127.0.0.1:{port}/invocations", + json={"message": "http1"}, + ) + assert resp.status_code == 200 + assert resp.http_version == "HTTP/1.1" diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_invoke.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_invoke.py new file mode 100644 index 000000000000..2192e96c477b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_invoke.py @@ -0,0 +1,119 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for invoke dispatch (streaming and non-streaming).""" +import json +import uuid + +import httpx +import pytest + +from conftest import _make_failing_agent + + +@pytest.mark.asyncio +async def test_invoke_echoes_body(echo_client): + """POST /invocations body is passed to invoke() and echoed back.""" + payload = b'{"message":"ping"}' + resp = await echo_client.post("/invocations", content=payload) + assert resp.status_code == 200 + assert resp.content == payload + + +@pytest.mark.asyncio +async def test_invoke_receives_headers(echo_client): + """request.headers contains sent HTTP headers.""" + # EchoAgent echoes body; we just confirm the request succeeds with custom headers. + resp = await echo_client.post( + "/invocations", + content=b"{}", + headers={"X-Custom-Header": "test-value", "Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_invoke_receives_invocation_id(echo_client): + """Response x-agent-invocation-id is a non-empty UUID string.""" + resp = await echo_client.post("/invocations", content=b"{}") + invocation_id = resp.headers["x-agent-invocation-id"] + assert invocation_id + uuid.UUID(invocation_id) # raises if not valid UUID + + +@pytest.mark.asyncio +async def test_invoke_invocation_id_unique(echo_client): + """Two consecutive POST /invocations return different x-agent-invocation-id values.""" + resp1 = await echo_client.post("/invocations", content=b"{}") + resp2 = await echo_client.post("/invocations", content=b"{}") + id1 = resp1.headers["x-agent-invocation-id"] + id2 = resp2.headers["x-agent-invocation-id"] + assert id1 != id2 + + +@pytest.mark.asyncio +async def test_invoke_streaming_returns_chunked(streaming_client): + """Streaming agent returns a StreamingResponse.""" + resp = await streaming_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + # The response should contain all chunks + assert len(resp.content) > 0 + + +@pytest.mark.asyncio +async def test_invoke_streaming_yields_all_chunks(streaming_client): + """All chunks from the async generator are received by the client.""" + resp = await streaming_client.post("/invocations", content=b"{}") + lines = resp.content.decode().strip().split("\n") + chunks = [json.loads(line) for line in lines] + assert len(chunks) == 3 + assert chunks[0]["chunk"] == 0 + assert chunks[1]["chunk"] == 1 + assert chunks[2]["chunk"] == 2 + + +@pytest.mark.asyncio +async def test_invoke_streaming_has_invocation_id_header(streaming_client): + """Streaming response also includes x-agent-invocation-id header.""" + resp = await streaming_client.post("/invocations", content=b"{}") + invocation_id = resp.headers.get("x-agent-invocation-id") + assert invocation_id is not None + uuid.UUID(invocation_id) + + +@pytest.mark.asyncio +async def test_invoke_empty_body(echo_client): + """POST /invocations with empty body doesn't crash.""" + resp = await echo_client.post("/invocations", content=b"") + assert resp.status_code == 200 + assert resp.content == b"" + + +@pytest.mark.asyncio +async def test_invoke_error_returns_500(failing_client): + """When invoke() raises, server returns 500 with generic error message and invocation id.""" + resp = await failing_client.post("/invocations", content=b'{"key":"value"}') + assert resp.status_code == 500 + data = resp.json() + assert data["error"]["code"] == "internal_error" + assert data["error"]["message"] == "Internal server error" + assert resp.headers.get("x-agent-invocation-id") is not None + + +@pytest.mark.asyncio +async def test_invoke_error_hides_details_by_default(failing_client): + """Without AGENT_DEBUG_ERRORS, the actual exception message is hidden.""" + resp = await failing_client.post("/invocations", content=b'{}') + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "Internal server error" + + +@pytest.mark.asyncio +async def test_invoke_error_exposes_details_with_debug(): + """With debug_errors=True, the actual exception message is returned.""" + agent = _make_failing_agent(debug_errors=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b'{}') + assert resp.status_code == 500 + assert resp.json()["error"]["message"] == "something went wrong" diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_logger.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_logger.py new file mode 100644 index 000000000000..314211cec644 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_logger.py @@ -0,0 +1,22 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for the library-scoped logger.""" +import logging + + +def test_library_logger_exists(): + """The azure.ai.agentserver logger is a standard named logger.""" + lib_logger = logging.getLogger("azure.ai.agentserver") + assert lib_logger.name == "azure.ai.agentserver" + + +def test_log_level_preserved_across_imports(): + """Importing the server module does not reset a level already set.""" + lib_logger = logging.getLogger("azure.ai.agentserver") + lib_logger.setLevel(logging.ERROR) + + # Re-importing the base module should not override the level. + from azure.ai.agentserver.server import _base # noqa: F401 + + assert lib_logger.level == logging.ERROR diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_multimodal_protocol.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_multimodal_protocol.py new file mode 100644 index 000000000000..9e481d1ea117 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_multimodal_protocol.py @@ -0,0 +1,807 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for multi-modality payloads, content types, and protocol edge cases. + +AgentServer is an HTTP/ASGI server. It does NOT define WebSocket or gRPC +routes, so those protocol attempts must be handled gracefully. The +server is content-type agnostic: agents can receive and return any +media type (images, audio, protobuf, SSE, etc.). +""" +import base64 +import io +import json +import uuid + +import httpx +import pytest +import pytest_asyncio + +from starlette.requests import Request +from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse, Response, StreamingResponse + +from azure.ai.agentserver.server import AgentServer + + +# --------------------------------------------------------------------------- +# Minimal binary blobs for realistic content types +# --------------------------------------------------------------------------- + +# 1x1 red pixel PNG (67 bytes — smallest valid PNG) +_MINIMAL_PNG = bytes.fromhex( + "89504e470d0a1a0a" # PNG signature + "0000000d49484452" "00000001000000010802" "000000907753de" # header + "0000000c49444154" "789c63f80f00000101000518d84e" # image data + "0000000049454e44ae426082" # end +) + +# Minimal WAV header (44 bytes) — 1 sample, 16-bit mono 8 kHz +_MINIMAL_WAV = bytes.fromhex( + "52494646" # RIFF + "26000000" # file size - 8 (38) + "57415645" # WAVE + "666d7420" # fmt + "10000000" # chunk size (16) + "0100" # PCM + "0100" # mono + "401f0000" # sample rate 8000 + "803e0000" # byte rate 16000 + "0200" # block align + "1000" # bits per sample (16) + "64617461" # data + "02000000" # data bytes (2) + "0000" # 1 silent sample +) + +# Tiny protobuf-like payload (field 1, varint value 150) +_PROTO_PAYLOAD = b"\x08\x96\x01" + + +# --------------------------------------------------------------------------- +# Specialised agent implementations: multi-modal +# --------------------------------------------------------------------------- + + +def _make_ctype_echo_agent(**kwargs): + """Returns the request Content-Type and body length as JSON.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + ctype = request.headers.get("content-type", "") + body = await request.body() + return JSONResponse({"content_type": ctype, "length": len(body)}) + + return server + + +def _make_image_agent(**kwargs): + """Returns a pre-canned PNG image.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return Response(content=_MINIMAL_PNG, media_type="image/png") + + return server + + +def _make_audio_agent(**kwargs): + """Returns a pre-canned WAV clip.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return Response(content=_MINIMAL_WAV, media_type="audio/wav") + + return server + + +def _make_html_agent(**kwargs): + """Returns an HTML response.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return HTMLResponse("

Hello Agent

") + + return server + + +def _make_plaintext_agent(**kwargs): + """Returns plain text.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return PlainTextResponse("Hello, plain text world!") + + return server + + +def _make_xml_agent(**kwargs): + """Returns XML content.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + xml = 'ok' + return Response(content=xml.encode(), media_type="application/xml") + + return server + + +def _make_sse_agent(**kwargs): + """Returns a Server-Sent Events stream.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> StreamingResponse: + async def event_stream(): + for i in range(3): + yield f"event: message\ndata: {json.dumps({'n': i})}\n\n".encode() + + return StreamingResponse(event_stream(), media_type="text/event-stream") + + return server + + +def _make_custom_status_agent(**kwargs): + """Returns whichever HTTP status code the client requests in body.status_code.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + data = await request.json() + code = int(data.get("status_code", 200)) + return JSONResponse({"status_code": code}, status_code=code) + + return server + + +def _make_multipart_raw_agent(**kwargs): + """Echoes the raw multipart body's content-type and length.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + ctype = request.headers.get("content-type", "") + body = await request.body() + return JSONResponse({"content_type": ctype, "length": len(body), "body_prefix": body[:80].decode("latin-1")}) + + return server + + +def _make_query_string_agent(**kwargs): + """Echoes query parameters as JSON.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"query": dict(request.query_params)}) + + return server + + +def _make_b64image_agent(**kwargs): + """Accepts JSON with base64-encoded image data and decodes it.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + data = await request.json() + image_b64 = data.get("image", "") + decoded = base64.b64decode(image_b64) + return JSONResponse({"decoded_size": len(decoded), "starts_with_png": decoded[:4] == b"\x89PNG"}) + + return server + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def ctype_echo_client(): + server = _make_ctype_echo_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def image_client(): + server = _make_image_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def audio_client(): + server = _make_audio_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def html_client(): + server = _make_html_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def plaintext_client(): + server = _make_plaintext_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def xml_client(): + server = _make_xml_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def sse_client(): + server = _make_sse_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def custom_status_client(): + server = _make_custom_status_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def multipart_raw_client(): + server = _make_multipart_raw_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def query_client(): + server = _make_query_string_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +@pytest_asyncio.fixture +async def b64image_client(): + server = _make_b64image_agent() + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as c: + yield c + + +# ======================================================================== +# MULTI-MODALITY: requests with different content types +# ======================================================================== + + +class TestMultiModalRequests: + """Verify the server accepts and forwards any content type to the agent.""" + + @pytest.mark.asyncio + async def test_image_png_payload(self, ctype_echo_client): + """POST with image/png content type and binary PNG data.""" + resp = await ctype_echo_client.post( + "/invocations", + content=_MINIMAL_PNG, + headers={"Content-Type": "image/png"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["content_type"] == "image/png" + assert body["length"] == len(_MINIMAL_PNG) + + @pytest.mark.asyncio + async def test_image_jpeg_payload(self, ctype_echo_client): + """POST with image/jpeg content type.""" + fake_jpeg = b"\xff\xd8\xff\xe0" + b"\x00" * 100 # JPEG SOI marker + padding + resp = await ctype_echo_client.post( + "/invocations", + content=fake_jpeg, + headers={"Content-Type": "image/jpeg"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "image/jpeg" + assert resp.json()["length"] == len(fake_jpeg) + + @pytest.mark.asyncio + async def test_audio_wav_payload(self, ctype_echo_client): + """POST with audio/wav content type and WAV header.""" + resp = await ctype_echo_client.post( + "/invocations", + content=_MINIMAL_WAV, + headers={"Content-Type": "audio/wav"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "audio/wav" + assert resp.json()["length"] == len(_MINIMAL_WAV) + + @pytest.mark.asyncio + async def test_video_mp4_payload(self, ctype_echo_client): + """POST with video/mp4 content type and opaque bytes.""" + fake_mp4 = bytes.fromhex("00000018667479706d703432") + b"\x00" * 200 + resp = await ctype_echo_client.post( + "/invocations", + content=fake_mp4, + headers={"Content-Type": "video/mp4"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "video/mp4" + + @pytest.mark.asyncio + async def test_protobuf_payload(self, ctype_echo_client): + """POST with application/x-protobuf content type.""" + resp = await ctype_echo_client.post( + "/invocations", + content=_PROTO_PAYLOAD, + headers={"Content-Type": "application/x-protobuf"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "application/x-protobuf" + assert resp.json()["length"] == len(_PROTO_PAYLOAD) + + @pytest.mark.asyncio + async def test_octet_stream_payload(self, ctype_echo_client): + """POST with application/octet-stream content type.""" + resp = await ctype_echo_client.post( + "/invocations", + content=bytes(range(256)), + headers={"Content-Type": "application/octet-stream"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "application/octet-stream" + assert resp.json()["length"] == 256 + + @pytest.mark.asyncio + async def test_msgpack_payload(self, ctype_echo_client): + """POST with application/msgpack content type (binary serialisation).""" + resp = await ctype_echo_client.post( + "/invocations", + content=b"\x82\xa3key\xa5value\xa3num\x2a", # {"key":"value","num":42} + headers={"Content-Type": "application/msgpack"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "application/msgpack" + + @pytest.mark.asyncio + async def test_base64_encoded_image_in_json(self, b64image_client): + """JSON body carrying a base64-encoded PNG (multi-modal pattern).""" + encoded = base64.b64encode(_MINIMAL_PNG).decode("ascii") + resp = await b64image_client.post( + "/invocations", + content=json.dumps({"image": encoded}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["decoded_size"] == len(_MINIMAL_PNG) + assert body["starts_with_png"] is True + + +# ======================================================================== +# MULTI-MODALITY: responses with different content types +# ======================================================================== + + +class TestMultiModalResponses: + """Agents can return any media type — verify end-to-end.""" + + @pytest.mark.asyncio + async def test_agent_returns_png(self, image_client): + """Agent returning image/png is received correctly by the client.""" + resp = await image_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "image/png" + assert resp.content[:4] == b"\x89PNG" + assert resp.headers.get("x-agent-invocation-id") is not None + + @pytest.mark.asyncio + async def test_agent_returns_wav(self, audio_client): + """Agent returning audio/wav is received correctly.""" + resp = await audio_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "audio/wav" + assert resp.content[:4] == b"RIFF" + + @pytest.mark.asyncio + async def test_agent_returns_html(self, html_client): + """Agent returning HTML is received with correct content type.""" + resp = await html_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert "text/html" in resp.headers["content-type"] + assert "

Hello Agent

" in resp.text + + @pytest.mark.asyncio + async def test_agent_returns_plain_text(self, plaintext_client): + """Agent returning plain text is received correctly.""" + resp = await plaintext_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert "text/plain" in resp.headers["content-type"] + assert resp.text == "Hello, plain text world!" + + @pytest.mark.asyncio + async def test_agent_returns_xml(self, xml_client): + """Agent returning application/xml is received correctly.""" + resp = await xml_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert resp.headers["content-type"] == "application/xml" + assert "ok" in resp.text + + +# ======================================================================== +# STREAMING: SSE (Server-Sent Events) protocol +# ======================================================================== + + +class TestServerSentEvents: + """Verify SSE streaming works end-to-end.""" + + @pytest.mark.asyncio + async def test_sse_stream_yields_events(self, sse_client): + """Agent producing SSE events delivers them to the client.""" + resp = await sse_client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert "text/event-stream" in resp.headers.get("content-type", "") + + # Parse the SSE stream + raw = resp.text + events = [line for line in raw.split("\n") if line.startswith("data:")] + assert len(events) == 3 + for i, ev in enumerate(events): + payload = json.loads(ev.removeprefix("data:").strip()) + assert payload["n"] == i + + @pytest.mark.asyncio + async def test_sse_has_invocation_id(self, sse_client): + """SSE response still carries x-agent-invocation-id.""" + resp = await sse_client.post("/invocations", content=b"{}") + assert resp.headers.get("x-agent-invocation-id") is not None + + +# ======================================================================== +# MULTIPART form data (file uploads) +# ======================================================================== + + +class TestMultipartFormData: + """Verify the server passes multipart/form-data payloads to the agent.""" + + @pytest.mark.asyncio + async def test_single_file_upload(self, multipart_raw_client): + """Upload a single file via multipart form data.""" + resp = await multipart_raw_client.post( + "/invocations", + files={"image": ("photo.png", _MINIMAL_PNG, "image/png")}, + ) + assert resp.status_code == 200 + body = resp.json() + assert "multipart/form-data" in body["content_type"] + assert body["length"] > len(_MINIMAL_PNG) # includes boundary overhead + + @pytest.mark.asyncio + async def test_multiple_files_upload(self, multipart_raw_client): + """Upload multiple files in one request.""" + resp = await multipart_raw_client.post( + "/invocations", + files=[ + ("file1", ("a.png", _MINIMAL_PNG, "image/png")), + ("file2", ("b.wav", _MINIMAL_WAV, "audio/wav")), + ], + ) + assert resp.status_code == 200 + body = resp.json() + assert "multipart/form-data" in body["content_type"] + # Total body must be larger than both payloads combined + assert body["length"] > len(_MINIMAL_PNG) + len(_MINIMAL_WAV) + + @pytest.mark.asyncio + async def test_mixed_form_fields_and_files(self, multipart_raw_client): + """Multipart with both plain text fields and file uploads.""" + resp = await multipart_raw_client.post( + "/invocations", + data={"prompt": "describe this image"}, + files={"image": ("pic.png", _MINIMAL_PNG, "image/png")}, + ) + assert resp.status_code == 200 + body = resp.json() + assert "multipart/form-data" in body["content_type"] + assert body["length"] > 0 + + +# ======================================================================== +# CUSTOM HTTP STATUS CODES +# ======================================================================== + + +class TestCustomStatusCodes: + """AgentServer preserves whatever HTTP status the agent returns.""" + + @pytest.mark.asyncio + async def test_201_created(self, custom_status_client): + """Agent returning 201 Created.""" + resp = await custom_status_client.post( + "/invocations", + content=json.dumps({"status_code": 201}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 201 + assert resp.headers.get("x-agent-invocation-id") is not None + + @pytest.mark.asyncio + async def test_202_accepted(self, custom_status_client): + """Agent returning 202 Accepted (async processing).""" + resp = await custom_status_client.post( + "/invocations", + content=json.dumps({"status_code": 202}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 202 + + @pytest.mark.asyncio + async def test_204_no_content(self, custom_status_client): + """Agent returning 204 No Content.""" + resp = await custom_status_client.post( + "/invocations", + content=json.dumps({"status_code": 204}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 204 + + @pytest.mark.asyncio + async def test_400_bad_request(self, custom_status_client): + """Agent can signal a 400 back to caller.""" + resp = await custom_status_client.post( + "/invocations", + content=json.dumps({"status_code": 400}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + @pytest.mark.asyncio + async def test_422_unprocessable(self, custom_status_client): + """Agent returning 422 Unprocessable Entity.""" + resp = await custom_status_client.post( + "/invocations", + content=json.dumps({"status_code": 422}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 422 + + @pytest.mark.asyncio + async def test_503_service_unavailable(self, custom_status_client): + """Agent signalling downstream unavailability.""" + resp = await custom_status_client.post( + "/invocations", + content=json.dumps({"status_code": 503}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 503 + + +# ======================================================================== +# HTTP PROTOCOL: query strings, HEAD, path params +# ======================================================================== + + +class TestHttpProtocol: + """HTTP protocol-level edge cases.""" + + @pytest.mark.asyncio + async def test_query_string_forwarded_to_agent(self, query_client): + """Query parameters on /invocations are accessible to the agent.""" + resp = await query_client.post( + "/invocations?model=gpt-4&temperature=0.7", + content=b"{}", + ) + assert resp.status_code == 200 + q = resp.json()["query"] + assert q["model"] == "gpt-4" + assert q["temperature"] == "0.7" + + @pytest.mark.asyncio + async def test_head_liveness(self, echo_client): + """HEAD /liveness returns 200 with no body (standard HTTP HEAD).""" + resp = await echo_client.head("/liveness") + # Starlette may return 200 or 405 — both are valid server behavior + assert resp.status_code in (200, 405) + + @pytest.mark.asyncio + async def test_path_param_special_characters(self, echo_client): + """GET /invocations/{id} with URL-encoded special chars in the ID.""" + weird_id = "abc%20def%2F123" + resp = await echo_client.get(f"/invocations/{weird_id}") + # default agent returns 404 — ensure it doesn't crash + assert resp.status_code == 404 + + @pytest.mark.asyncio + async def test_accept_header_passthrough(self, ctype_echo_client): + """Accept header doesn't interfere with agent processing.""" + resp = await ctype_echo_client.post( + "/invocations", + content=b"hello", + headers={ + "Accept": "application/xml", + "Content-Type": "text/plain", + }, + ) + assert resp.status_code == 200 + # Agent returns JSON regardless — server doesn't enforce Accept + assert resp.json()["content_type"] == "text/plain" + + @pytest.mark.asyncio + async def test_multiple_custom_headers_forwarded(self, ctype_echo_client): + """Custom X- headers reach the agent without interference.""" + resp = await ctype_echo_client.post( + "/invocations", + content=b"{}", + headers={ + "Content-Type": "application/json", + "X-Request-Id": "req-42", + "X-Session-Token": "tok-abc", + }, + ) + assert resp.status_code == 200 + + +# ======================================================================== +# WebSocket: upgrade attempt to an HTTP-only server +# ======================================================================== + + +class TestWebSocketUpgradeRejected: + """The server has no WebSocket routes. WS upgrade attempts must fail + gracefully rather than crashing.""" + + @pytest.mark.asyncio + async def test_websocket_upgrade_on_invocations(self): + """WebSocket handshake on /invocations fails cleanly.""" + agent = _make_ctype_echo_agent() + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + # Simulate a WebSocket upgrade via HTTP headers + resp = await client.get( + "/invocations", + headers={ + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version": "13", + }, + ) + # Server should reject: no WS route → 4xx response + assert resp.status_code in (400, 403, 404, 405, 426) + + @pytest.mark.asyncio + async def test_websocket_upgrade_on_unknown_path(self): + """WebSocket handshake on an undefined path fails cleanly.""" + agent = _make_ctype_echo_agent() + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get( + "/ws/invocations", + headers={ + "Upgrade": "websocket", + "Connection": "Upgrade", + "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version": "13", + }, + ) + assert resp.status_code in (400, 403, 404, 405, 426) + + +# ======================================================================== +# gRPC-like: binary framed payload over HTTP/2 style headers +# ======================================================================== + + +class TestGrpcLikePayloads: + """The server is HTTP/1.1, but clients may send gRPC-style payloads + (application/grpc, length-prefixed protobuf). Server should accept + bytes and let the agent deal with them — no crash.""" + + @pytest.mark.asyncio + async def test_grpc_content_type_accepted(self, ctype_echo_client): + """POST with application/grpc content type is forwarded to the agent.""" + # gRPC frames: 1-byte compressed flag + 4-byte length + payload + grpc_frame = b"\x00" + len(_PROTO_PAYLOAD).to_bytes(4, "big") + _PROTO_PAYLOAD + resp = await ctype_echo_client.post( + "/invocations", + content=grpc_frame, + headers={"Content-Type": "application/grpc"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "application/grpc" + assert resp.json()["length"] == len(grpc_frame) + + @pytest.mark.asyncio + async def test_grpc_web_content_type_accepted(self, ctype_echo_client): + """POST with application/grpc-web content type (for gRPC-Web proxy).""" + resp = await ctype_echo_client.post( + "/invocations", + content=_PROTO_PAYLOAD, + headers={"Content-Type": "application/grpc-web"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "application/grpc-web" + + @pytest.mark.asyncio + async def test_grpc_plus_proto_content_type(self, ctype_echo_client): + """POST with application/grpc+proto content type.""" + resp = await ctype_echo_client.post( + "/invocations", + content=_PROTO_PAYLOAD, + headers={"Content-Type": "application/grpc+proto"}, + ) + assert resp.status_code == 200 + assert resp.json()["content_type"] == "application/grpc+proto" + + +# ======================================================================== +# Multiple modalities in one request / response +# ======================================================================== + + +class TestMixedModality: + """Scenarios combining text + binary in a single invocation.""" + + @pytest.mark.asyncio + async def test_json_with_base64_image_roundtrip(self, b64image_client): + """JSON containing base64-encoded image data round-trips correctly.""" + encoded = base64.b64encode(_MINIMAL_PNG).decode("ascii") + payload = {"image": encoded, "prompt": "What is in this image?"} + resp = await b64image_client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["decoded_size"] == len(_MINIMAL_PNG) + assert body["starts_with_png"] is True + + @pytest.mark.asyncio + async def test_multipart_image_plus_audio(self, multipart_raw_client): + """Multipart upload combining image and audio files.""" + resp = await multipart_raw_client.post( + "/invocations", + files=[ + ("image", ("pic.png", _MINIMAL_PNG, "image/png")), + ("audio", ("clip.wav", _MINIMAL_WAV, "audio/wav")), + ], + data={"instruction": "transcribe the audio and describe the image"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert "multipart/form-data" in body["content_type"] + assert body["length"] > len(_MINIMAL_PNG) + len(_MINIMAL_WAV) + + @pytest.mark.asyncio + async def test_multipart_large_file(self, multipart_raw_client): + """Multipart upload with a large binary file (512 KB).""" + big_file = b"\x00" * (512 * 1024) + resp = await multipart_raw_client.post( + "/invocations", + files={"bigfile": ("large.bin", big_file, "application/octet-stream")}, + ) + assert resp.status_code == 200 + assert resp.json()["length"] > 512 * 1024 diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_openapi_validation.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_openapi_validation.py new file mode 100644 index 000000000000..f9fe9647c127 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_openapi_validation.py @@ -0,0 +1,2167 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for OpenAPI spec validation.""" +import json + +import httpx +import pytest + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver.server import AgentServer +from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + +# --------------------------------------------------------------------------- +# Helpers – inline agent + client builder +# --------------------------------------------------------------------------- + + +def _make_echo_agent(spec: dict) -> AgentServer: + """Create an agent that echoes the parsed JSON body with validation.""" + server = AgentServer(openapi_spec=spec, enable_request_validation=True) + + @server.invoke_handler + async def handle(request: Request) -> Response: + data = await request.json() + return JSONResponse(data) + + return server + + +async def _client_for(agent: AgentServer): + transport = httpx.ASGITransport(app=agent.app) + return httpx.AsyncClient(transport=transport, base_url="http://testserver") + + +# --------------------------------------------------------------------------- +# Complex OpenAPI spec with nested objects, arrays, enums, $ref, constraints +# --------------------------------------------------------------------------- + +COMPLEX_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "Complex Agent", "version": "2.0"}, + "components": { + "schemas": { + "Address": { + "type": "object", + "properties": { + "street": {"type": "string", "minLength": 1}, + "city": {"type": "string"}, + "zip": {"type": "string", "pattern": "^[0-9]{5}$"}, + "country": {"type": "string", "enum": ["US", "CA", "MX"]}, + }, + "required": ["street", "city", "zip", "country"], + }, + "OrderItem": { + "type": "object", + "properties": { + "sku": {"type": "string", "minLength": 3, "maxLength": 12}, + "quantity": {"type": "integer", "minimum": 1, "maximum": 999}, + "unit_price": {"type": "number", "exclusiveMinimum": 0}, + }, + "required": ["sku", "quantity", "unit_price"], + }, + } + }, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "customer_name": { + "type": "string", + "minLength": 1, + "maxLength": 100, + }, + "email": { + "type": "string", + "format": "email", + }, + "age": { + "type": "integer", + "minimum": 18, + "maximum": 150, + }, + "tier": { + "type": "string", + "enum": ["bronze", "silver", "gold", "platinum"], + }, + "shipping_address": { + "$ref": "#/components/schemas/Address" + }, + "items": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OrderItem" + }, + "minItems": 1, + "maxItems": 50, + }, + "notes": { + "type": "string", + }, + "tags": { + "type": "array", + "items": {"type": "string"}, + "uniqueItems": True, + }, + }, + "required": [ + "customer_name", + "age", + "tier", + "shipping_address", + "items", + ], + "additionalProperties": False, + } + } + } + }, + "responses": { + "200": { + "description": "Order accepted", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "order_id": {"type": "string"}, + }, + } + } + }, + } + }, + } + } + }, +} + + +def _valid_order(**overrides) -> dict: + """Return a fully valid order payload, with optional field overrides.""" + base = { + "customer_name": "Alice Smith", + "age": 30, + "tier": "gold", + "shipping_address": { + "street": "123 Main St", + "city": "Redmond", + "zip": "98052", + "country": "US", + }, + "items": [{"sku": "ABC123", "quantity": 2, "unit_price": 19.99}], + } + base.update(overrides) + return base + + +# =================================================================== +# Original tests (kept as-is) +# =================================================================== + + +@pytest.mark.asyncio +async def test_valid_request_passes(validated_client): + """Request matching schema returns 200.""" + resp = await validated_client.post( + "/invocations", + content=json.dumps({"name": "World"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + data = json.loads(resp.content) + assert data["greeting"] == "Hello, World!" + + +@pytest.mark.asyncio +async def test_invalid_request_returns_400(validated_client): + """Request missing required field returns 400 with error details.""" + resp = await validated_client.post( + "/invocations", + content=json.dumps({"wrong_field": "oops"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + data = resp.json() + assert data["error"]["code"] == "invalid_payload" + assert "details" in data["error"] + + +@pytest.mark.asyncio +async def test_invalid_request_wrong_type_returns_400(validated_client): + """Request with wrong field type returns 400.""" + resp = await validated_client.post( + "/invocations", + content=json.dumps({"name": 12345}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + data = resp.json() + assert "details" in data["error"] + + +@pytest.mark.asyncio +async def test_no_spec_skips_validation(no_spec_client): + """Agent with no spec accepts any request body.""" + resp = await no_spec_client.post( + "/invocations", + content=b"this is not json at all", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_spec_endpoint_returns_spec(validated_client): + """GET /invocations/docs/openapi.json returns the registered spec.""" + resp = await validated_client.get("/invocations/docs/openapi.json") + assert resp.status_code == 200 + data = resp.json() + assert "paths" in data + + +@pytest.mark.asyncio +async def test_non_json_body_skips_validation(no_spec_client): + """Non-JSON content type bypasses JSON schema validation.""" + server = AgentServer(openapi_spec={ + "openapi": "3.0.0", + "info": {"title": "Test", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": {"type": "object", "required": ["name"]} + } + } + }, + "responses": {"200": {"description": "OK"}} + } + } + } + }, enable_request_validation=True) + + @server.invoke_handler + async def handle(request: Request) -> Response: + body = await request.body() + return Response(content=body, media_type="text/plain") + + transport = httpx.ASGITransport(app=server.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post( + "/invocations", + content=b"plain text body", + headers={"Content-Type": "text/plain"}, + ) + assert resp.status_code == 200 + + +# =================================================================== +# Complex spec – valid payloads +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_valid_full_payload(): + """A fully-populated valid order is accepted.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + payload = _valid_order( + email="alice@example.com", + notes="Leave at the door", + tags=["priority", "fragile"], + ) + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_complex_valid_minimal_payload(): + """Only required fields are present — should pass.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order()).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_complex_valid_multiple_items(): + """Order with several line items passes.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + items = [ + {"sku": "AAA", "quantity": 1, "unit_price": 5.0}, + {"sku": "BBB", "quantity": 10, "unit_price": 12.50}, + {"sku": "CCC", "quantity": 999, "unit_price": 0.01}, + ] + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=items)).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +# =================================================================== +# Complex spec – missing required fields +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_missing_top_level_required(): + """Missing top-level required field 'tier' triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + payload = _valid_order() + del payload["tier"] + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + details = resp.json()["error"]["details"] + assert any("tier" in d["message"] for d in details) + + +@pytest.mark.asyncio +async def test_complex_missing_nested_required(): + """Missing required 'city' in shipping_address triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + payload = _valid_order() + del payload["shipping_address"]["city"] + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + details = resp.json()["error"]["details"] + assert any("city" in d["message"] for d in details) + + +@pytest.mark.asyncio +async def test_complex_missing_array_item_field(): + """Order item missing required 'sku' field triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + bad_item = {"quantity": 1, "unit_price": 5.0} # no sku + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[bad_item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + details = resp.json()["error"]["details"] + assert any("sku" in d["message"] for d in details) + + +# =================================================================== +# Complex spec – type mismatches +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_wrong_type_age_string(): + """'age' must be integer, not string.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(age="thirty")).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_wrong_type_items_not_array(): + """'items' must be an array, not an object.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps( + _valid_order(items={"sku": "X", "quantity": 1, "unit_price": 1.0}) + ).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_wrong_type_quantity_float(): + """'quantity' must be integer, not float.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + bad_item = {"sku": "ABC", "quantity": 2.5, "unit_price": 10.0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[bad_item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +# =================================================================== +# Complex spec – enum validation +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_invalid_tier_enum(): + """'tier' value not in enum list triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(tier="diamond")).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + details = resp.json()["error"]["details"] + assert any("diamond" in d["message"] for d in details) + + +@pytest.mark.asyncio +async def test_complex_invalid_country_enum(): + """'country' outside allowed enum triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + payload = _valid_order() + payload["shipping_address"]["country"] = "FR" + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + details = resp.json()["error"]["details"] + assert any("FR" in d["message"] for d in details) + + +# =================================================================== +# Complex spec – numeric constraints (minimum / maximum) +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_age_below_minimum(): + """'age' below minimum (18) triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(age=10)).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_age_above_maximum(): + """'age' above maximum (150) triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(age=200)).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_quantity_below_minimum(): + """Order item quantity below 1 triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + bad_item = {"sku": "ABC", "quantity": 0, "unit_price": 10.0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[bad_item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_quantity_above_maximum(): + """Order item quantity above 999 triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + bad_item = {"sku": "ABC", "quantity": 1000, "unit_price": 10.0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[bad_item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_unit_price_zero_exclusive(): + """unit_price with exclusiveMinimum: 0 rejects exactly 0.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + bad_item = {"sku": "ABC", "quantity": 1, "unit_price": 0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[bad_item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +# =================================================================== +# Complex spec – string constraints (minLength / maxLength / pattern) +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_customer_name_empty(): + """Empty customer_name violates minLength: 1.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(customer_name="")).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_customer_name_too_long(): + """customer_name exceeding maxLength: 100 triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(customer_name="A" * 101)).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_sku_too_short(): + """SKU shorter than minLength: 3 triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + bad_item = {"sku": "AB", "quantity": 1, "unit_price": 5.0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[bad_item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_sku_too_long(): + """SKU longer than maxLength: 12 triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + bad_item = {"sku": "A" * 13, "quantity": 1, "unit_price": 5.0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[bad_item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_zip_pattern_invalid(): + """Zip code not matching '^[0-9]{5}$' triggers 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + payload = _valid_order() + payload["shipping_address"]["zip"] = "ABCDE" + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_zip_pattern_too_short(): + """Zip code with fewer than 5 digits fails pattern.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + payload = _valid_order() + payload["shipping_address"]["zip"] = "1234" + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +# =================================================================== +# Complex spec – array constraints (minItems / maxItems / uniqueItems) +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_empty_items_array(): + """Empty items array violates minItems: 1.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_duplicate_tags(): + """Duplicate entries in 'tags' violates uniqueItems.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps( + _valid_order(tags=["urgent", "urgent"]) + ).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +# =================================================================== +# Complex spec – additionalProperties: false +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_additional_properties_rejected(): + """Extra top-level field rejected by additionalProperties: false.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + payload = _valid_order() + payload["surprise_field"] = "not allowed" + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + details = resp.json()["error"]["details"] + assert any("surprise_field" in d["message"] for d in details) + + +# =================================================================== +# Complex spec – $ref resolution +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_ref_address_validated(): + """$ref to Address schema is resolved and validated.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + payload = _valid_order() + # Replace structured address with a scalar — violates $ref schema + payload["shipping_address"] = "123 Main St" + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_complex_ref_order_item_validated(): + """$ref to OrderItem schema is resolved; invalid item caught.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + # unit_price negative — violates exclusiveMinimum + bad_item = {"sku": "ABC", "quantity": 1, "unit_price": -5.0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[bad_item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +# =================================================================== +# Complex spec – malformed JSON body +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_malformed_json(): + """Malformed JSON body returns 400.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=b"{not json at all", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + data = resp.json() + assert data["error"]["code"] == "invalid_payload" + + +# =================================================================== +# Complex spec – boundary / edge-case valid values +# =================================================================== + + +@pytest.mark.asyncio +async def test_complex_boundary_age_exactly_minimum(): + """Age exactly at minimum (18) should pass.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(age=18)).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_complex_boundary_age_exactly_maximum(): + """Age exactly at maximum (150) should pass.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(age=150)).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_complex_boundary_sku_exactly_min_length(): + """SKU at exactly minLength: 3 should pass.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + item = {"sku": "XYZ", "quantity": 1, "unit_price": 1.0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_complex_boundary_sku_exactly_max_length(): + """SKU at exactly maxLength: 12 should pass.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + item = {"sku": "A" * 12, "quantity": 1, "unit_price": 1.0} + resp = await client.post( + "/invocations", + content=json.dumps(_valid_order(items=[item])).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_complex_multiple_errors_reported(): + """Multiple validation errors are all reported in 'details'.""" + agent = _make_echo_agent(COMPLEX_SPEC) + async with await _client_for(agent) as client: + # Wrong type for age, invalid enum for tier, empty items array + payload = { + "customer_name": "Bob", + "age": "old", + "tier": "diamond", + "shipping_address": { + "street": "1 Elm", + "city": "X", + "zip": "00000", + "country": "US", + }, + "items": [], + } + resp = await client.post( + "/invocations", + content=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + details = resp.json()["error"]["details"] + assert len(details) >= 3 # at least age, tier, items errors + + +# --------------------------------------------------------------------------- +# Tests: validate_response +# --------------------------------------------------------------------------- + +RESPONSE_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "Resp", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": {"type": "object"}, + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "result": {"type": "string"}, + }, + "required": ["result"], + } + } + } + } + }, + } + } + }, +} + + +@pytest.mark.asyncio +async def test_validate_response_valid(): + """Valid response body passes validation.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + v = _OpenApiValidator(RESPONSE_SPEC) + errors = v.validate_response(b'{"result": "ok"}', "application/json") + assert errors == [] + + +@pytest.mark.asyncio +async def test_validate_response_invalid(): + """Invalid response body returns errors.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + v = _OpenApiValidator(RESPONSE_SPEC) + errors = v.validate_response(b'{"wrong": 42}', "application/json") + assert len(errors) > 0 + + +@pytest.mark.asyncio +async def test_validate_response_no_schema(): + """When no response schema exists, validation passes (no-op).""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + spec_no_resp = { + "openapi": "3.0.0", + "info": {"title": "NoResp", "version": "1.0"}, + "paths": {"/invocations": {"post": {}}}, + } + v = _OpenApiValidator(spec_no_resp) + errors = v.validate_response(b'{"anything": true}', "application/json") + assert errors == [] + + +# --------------------------------------------------------------------------- +# Tests: no request body schema +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_validate_request_no_schema(): + """When no request schema exists, validation passes (no-op).""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + spec_no_req = { + "openapi": "3.0.0", + "info": {"title": "NoReq", "version": "1.0"}, + "paths": {"/invocations": {"post": {}}}, + } + v = _OpenApiValidator(spec_no_req) + errors = v.validate_request(b'{"anything": true}', "application/json") + assert errors == [] + + +# --------------------------------------------------------------------------- +# Tests: response schema from non-200/201 status +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_response_schema_fallback_to_first_available(): + """Response schema extraction falls back to first response with JSON.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + spec = { + "openapi": "3.0.0", + "info": {"title": "Fallback", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "responses": { + "202": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"status": {"type": "string"}}, + "required": ["status"], + } + } + } + } + } + } + } + }, + } + v = _OpenApiValidator(spec) + # Valid against the 202 schema + assert v.validate_response(b'{"status": "accepted"}', "application/json") == [] + # Invalid — missing "status" + errors = v.validate_response(b'{"other": 1}', "application/json") + assert len(errors) > 0 + + +# --------------------------------------------------------------------------- +# Tests: $ref edge cases +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_unresolvable_ref(): + """An unresolvable $ref leaves the node as-is (no crash).""" + from azure.ai.agentserver.server._openapi_validator import _resolve_ref + + spec: dict = {"components": {"schemas": {}}} + node = {"$ref": "#/components/schemas/DoesNotExist"} + result = _resolve_ref(spec, node) + # Can't resolve — returns the original node + assert result is node + + +@pytest.mark.asyncio +async def test_ref_path_hits_non_dict(): + """A $ref path that traverses a non-dict returns the original node.""" + from azure.ai.agentserver.server._openapi_validator import _resolve_ref + + spec: dict = {"components": {"schemas": "not-a-dict"}} + node = {"$ref": "#/components/schemas/Foo"} + result = _resolve_ref(spec, node) + assert result is node + + +@pytest.mark.asyncio +async def test_circular_ref_stops_recursion(): + """Circular $ref does not cause infinite recursion.""" + from azure.ai.agentserver.server._openapi_validator import _resolve_refs_deep + + spec: dict = { + "components": { + "schemas": { + "Node": { + "type": "object", + "properties": { + "child": {"$ref": "#/components/schemas/Node"}, + }, + } + } + } + } + node = {"$ref": "#/components/schemas/Node"} + result = _resolve_refs_deep(spec, node) + # Should resolve at least the first level, and leave the circular ref as-is + assert result["type"] == "object" + child = result["properties"]["child"] + # The circular reference should be left unresolved + assert "$ref" in child + + +@pytest.mark.asyncio +async def test_ref_resolves_to_non_dict(): + """A $ref that resolves to a non-dict value returns the original node.""" + from azure.ai.agentserver.server._openapi_validator import _resolve_ref + + spec: dict = {"components": {"schemas": {"Bad": 42}}} + node = {"$ref": "#/components/schemas/Bad"} + result = _resolve_ref(spec, node) + assert result is node + + +# =================================================================== +# OpenAPI nullable support +# =================================================================== + +NULLABLE_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "Nullable", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "nickname": { + "type": "string", + "nullable": True, + }, + "age": { + "type": "integer", + "nullable": True, + }, + }, + "required": ["name", "nickname"], + } + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + + +@pytest.mark.asyncio +async def test_nullable_accepts_null(): + """A nullable field accepts a JSON null value.""" + agent = _make_echo_agent(NULLABLE_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"name": "Alice", "nickname": None}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_nullable_accepts_value(): + """A nullable field also accepts a normal string value.""" + agent = _make_echo_agent(NULLABLE_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"name": "Alice", "nickname": "Ali"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_nullable_rejects_wrong_type(): + """A nullable string still rejects integers.""" + agent = _make_echo_agent(NULLABLE_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"name": "Alice", "nickname": 42}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_non_nullable_rejects_null(): + """A non-nullable field (name) rejects null.""" + agent = _make_echo_agent(NULLABLE_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"name": None, "nickname": "Ali"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_nullable_integer_accepts_null(): + """A nullable integer accepts null.""" + agent = _make_echo_agent(NULLABLE_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"name": "A", "nickname": "B", "age": None}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +# =================================================================== +# readOnly / writeOnly support +# =================================================================== + +READONLY_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "ReadOnly", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "id": { + "type": "string", + "readOnly": True, + }, + "password": { + "type": "string", + "writeOnly": True, + }, + }, + "required": ["name", "id"], + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "id": { + "type": "string", + "readOnly": True, + }, + "password": { + "type": "string", + "writeOnly": True, + }, + }, + "required": ["name", "id"], + } + } + } + } + }, + } + } + }, +} + + +@pytest.mark.asyncio +async def test_readonly_not_required_in_request(): + """readOnly field 'id' is stripped from required in request context.""" + agent = _make_echo_agent(READONLY_SPEC) + async with await _client_for(agent) as client: + # Don't send 'id' — it's readOnly, should not be required + resp = await client.post( + "/invocations", + content=json.dumps({"name": "Alice"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_writeonly_allowed_in_request(): + """writeOnly field 'password' is allowed in request.""" + agent = _make_echo_agent(READONLY_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"name": "Alice", "password": "secret"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_readonly_schema_introspection_request(): + """readOnly properties are removed from the preprocessed request schema.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + v = _OpenApiValidator(READONLY_SPEC) + # 'id' should be gone from requestschema properties + assert "id" not in v._request_schema.get("properties", {}) + + +@pytest.mark.asyncio +async def test_readonly_schema_introspection_response(): + """readOnly properties remain in the preprocessed response schema.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + v = _OpenApiValidator(READONLY_SPEC) + # 'id' should still be in response schema + assert "id" in v._response_schema.get("properties", {}) + + +@pytest.mark.asyncio +async def test_writeonly_stripped_in_response(): + """writeOnly properties are removed from the preprocessed response schema.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + v = _OpenApiValidator(READONLY_SPEC) + assert "password" not in v._response_schema.get("properties", {}) + + +@pytest.mark.asyncio +async def test_writeonly_present_in_request(): + """writeOnly properties remain in the preprocessed request schema.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + v = _OpenApiValidator(READONLY_SPEC) + assert "password" in v._request_schema.get("properties", {}) + + +# =================================================================== +# requestBody.required: false +# =================================================================== + +OPTIONAL_BODY_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "OptionalBody", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "required": False, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "query": {"type": "string"}, + }, + "required": ["query"], + } + } + }, + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + +REQUIRED_BODY_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "RequiredBody", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "query": {"type": "string"}, + }, + "required": ["query"], + } + } + }, + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + + +@pytest.mark.asyncio +async def test_optional_body_empty_accepted(): + """Empty body is accepted when requestBody.required is false.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + v = _OpenApiValidator(OPTIONAL_BODY_SPEC) + errors = v.validate_request(b"", "application/json") + assert errors == [] + + +@pytest.mark.asyncio +async def test_optional_body_whitespace_accepted(): + """Whitespace-only body is accepted when requestBody.required is false.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + v = _OpenApiValidator(OPTIONAL_BODY_SPEC) + errors = v.validate_request(b" ", "application/json") + assert errors == [] + + +@pytest.mark.asyncio +async def test_optional_body_present_still_validated(): + """When body IS present with optional requestBody, it still must be valid.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + v = _OpenApiValidator(OPTIONAL_BODY_SPEC) + errors = v.validate_request(b'{"wrong": 1}', "application/json") + assert len(errors) > 0 + + +@pytest.mark.asyncio +async def test_required_body_empty_rejected(): + """Empty body is rejected when requestBody.required is true.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + v = _OpenApiValidator(REQUIRED_BODY_SPEC) + errors = v.validate_request(b"", "application/json") + assert len(errors) > 0 # "Invalid JSON body" + + +@pytest.mark.asyncio +async def test_default_body_required_behavior(): + """When requestBody.required is omitted, body is required by default.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + spec = { + "openapi": "3.0.0", + "info": {"title": "Default", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": {"type": "object", "required": ["q"]}, + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, + } + v = _OpenApiValidator(spec) + errors = v.validate_request(b"", "application/json") + assert len(errors) > 0 + + +# =================================================================== +# OpenAPI keyword stripping +# =================================================================== + + +@pytest.mark.asyncio +async def test_openapi_keywords_stripped(): + """discriminator, xml, externalDocs, example are stripped from schemas.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + schema = { + "type": "object", + "discriminator": {"propertyName": "type"}, + "xml": {"name": "Foo"}, + "externalDocs": {"url": "https://example.com"}, + "example": {"type": "bar"}, + "properties": {"name": {"type": "string"}}, + } + result = _OpenApiValidator._preprocess_schema(schema) + assert "discriminator" not in result + assert "xml" not in result + assert "externalDocs" not in result + assert "example" not in result + assert result["properties"]["name"]["type"] == "string" + + +@pytest.mark.asyncio +async def test_openapi_keywords_stripped_nested(): + """OpenAPI keywords are stripped from nested properties too.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + schema = { + "type": "object", + "properties": { + "child": { + "type": "object", + "example": {"key": "value"}, + "xml": {"wrapped": True}, + "properties": {"inner": {"type": "string"}}, + } + }, + } + result = _OpenApiValidator._preprocess_schema(schema) + child = result["properties"]["child"] + assert "example" not in child + assert "xml" not in child + + +# =================================================================== +# Format validation (enabled via FormatChecker) +# =================================================================== + +FORMAT_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "Format", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "email": { + "type": "string", + "format": "email", + }, + "created_at": { + "type": "string", + "format": "date-time", + }, + }, + "required": ["email"], + } + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + + +@pytest.mark.asyncio +async def test_format_email_valid(): + """Valid email passes format validation.""" + agent = _make_echo_agent(FORMAT_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"email": "alice@example.com"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_format_email_invalid(): + """Invalid email fails format validation.""" + agent = _make_echo_agent(FORMAT_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"email": "not-an-email"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_format_datetime_valid(): + """Valid ISO 8601 date-time passes format validation.""" + agent = _make_echo_agent(FORMAT_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps( + {"email": "a@b.com", "created_at": "2025-01-01T00:00:00Z"} + ).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_format_datetime_invalid(): + """Invalid date-time fails format validation.""" + agent = _make_echo_agent(FORMAT_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps( + {"email": "a@b.com", "created_at": "not-a-date"} + ).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +# =================================================================== +# allOf / oneOf / anyOf composition +# =================================================================== + +COMPOSITION_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "Composition", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "allOf": [ + { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + { + "type": "object", + "properties": {"age": {"type": "integer"}}, + "required": ["age"], + }, + ] + } + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + +ONEOF_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "OneOf", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "oneOf": [ + { + "type": "object", + "properties": { + "kind": { + "type": "string", + "enum": ["cat"], + }, + "purrs": {"type": "boolean"}, + }, + "required": ["kind", "purrs"], + }, + { + "type": "object", + "properties": { + "kind": { + "type": "string", + "enum": ["dog"], + }, + "barks": {"type": "boolean"}, + }, + "required": ["kind", "barks"], + }, + ] + } + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + +ANYOF_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "AnyOf", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + ] + } + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + + +@pytest.mark.asyncio +async def test_allof_valid(): + """allOf: both sub-schemas satisfied → 200.""" + agent = _make_echo_agent(COMPOSITION_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"name": "Alice", "age": 30}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_allof_missing_field(): + """allOf: missing field from second schema → 400.""" + agent = _make_echo_agent(COMPOSITION_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"name": "Alice"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_oneof_first_branch(): + """oneOf: matching first branch (cat) → 200.""" + agent = _make_echo_agent(ONEOF_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"kind": "cat", "purrs": True}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_oneof_second_branch(): + """oneOf: matching second branch (dog) → 200.""" + agent = _make_echo_agent(ONEOF_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"kind": "dog", "barks": True}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_oneof_no_match(): + """oneOf: matching neither branch → 400.""" + agent = _make_echo_agent(ONEOF_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps({"kind": "fish"}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_anyof_string(): + """anyOf: string value matches first branch → 200.""" + agent = _make_echo_agent(ANYOF_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps("hello").encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_anyof_integer(): + """anyOf: integer value matches second branch → 200.""" + agent = _make_echo_agent(ANYOF_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(42).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_anyof_wrong_type(): + """anyOf: boolean matches neither string nor integer → 400.""" + agent = _make_echo_agent(ANYOF_SPEC) + async with await _client_for(agent) as client: + resp = await client.post( + "/invocations", + content=json.dumps(True).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +# =================================================================== +# nullable + $ref combination +# =================================================================== + + +@pytest.mark.asyncio +async def test_nullable_ref_accepts_null(): + """A nullable $ref field accepts null after preprocessing.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + spec: dict = { + "openapi": "3.0.0", + "info": {"title": "NullRef", "version": "1.0"}, + "components": { + "schemas": { + "Address": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } + } + }, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "addr": { + "$ref": "#/components/schemas/Address", + "nullable": True, + } + }, + } + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, + } + v = _OpenApiValidator(spec) + errors = v.validate_request(b'{"addr": null}', "application/json") + assert errors == [] + + +# =================================================================== +# Preprocessing unit tests +# =================================================================== + + +@pytest.mark.asyncio +async def test_apply_nullable_no_duplicate_null(): + """_apply_nullable does not add 'null' twice if type is already a list with null.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + schema: dict = {"type": ["string", "null"], "nullable": True} + _OpenApiValidator._apply_nullable(schema) + assert schema["type"] == ["string", "null"] + + +@pytest.mark.asyncio +async def test_apply_nullable_false(): + """nullable: false is a no-op (just removes the key).""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + schema: dict = {"type": "string", "nullable": False} + _OpenApiValidator._apply_nullable(schema) + assert schema["type"] == "string" + assert "nullable" not in schema + + +@pytest.mark.asyncio +async def test_strip_openapi_keywords_nested_deeply(): + """OpenAPI keywords are stripped from deeply nested schemas.""" + from azure.ai.agentserver.server._openapi_validator import _OpenApiValidator + + schema: dict = { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "example": {"deep": True}, + "properties": { + "val": {"type": "string", "xml": {"attr": True}}, + }, + }, + } + }, + } + _OpenApiValidator._strip_openapi_keywords(schema) + items_schema = schema["properties"]["items"]["items"] + assert "example" not in items_schema + assert "xml" not in items_schema["properties"]["val"] + + +# =================================================================== +# Discriminator-aware error collection (ported from C# JsonSchemaValidator) +# =================================================================== + +# --- Helper: direct validator call for unit-level assertions --- + +def _validate_request(spec: dict, body: dict) -> list[str]: + """Shortcut: build an _OpenApiValidator and return request errors.""" + v = _OpenApiValidator(spec) + return v.validate_request( + json.dumps(body).encode(), + "application/json", + ) + + +# Spec with const-based discriminator (Azure-style polymorphic schema) +CONST_DISCRIMINATOR_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "Disc", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": {"type": "string", "const": "workflow"}, + "steps": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["type", "steps"], + }, + { + "type": "object", + "properties": { + "type": {"type": "string", "const": "prompt"}, + "text": {"type": "string"}, + }, + "required": ["type", "text"], + }, + { + "type": "object", + "properties": { + "type": {"type": "string", "const": "tool_call"}, + "tool_name": {"type": "string"}, + "args": {"type": "object"}, + }, + "required": ["type", "tool_name"], + }, + ] + } + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + + +@pytest.mark.asyncio +async def test_discriminator_valid_workflow(): + """const discriminator: valid workflow branch → passes.""" + errors = _validate_request( + CONST_DISCRIMINATOR_SPEC, + {"type": "workflow", "steps": ["a", "b"]}, + ) + assert errors == [] + + +@pytest.mark.asyncio +async def test_discriminator_valid_prompt(): + """const discriminator: valid prompt branch → passes.""" + errors = _validate_request( + CONST_DISCRIMINATOR_SPEC, + {"type": "prompt", "text": "hello"}, + ) + assert errors == [] + + +@pytest.mark.asyncio +async def test_discriminator_valid_tool_call(): + """const discriminator: valid tool_call branch → passes.""" + errors = _validate_request( + CONST_DISCRIMINATOR_SPEC, + {"type": "tool_call", "tool_name": "search"}, + ) + assert errors == [] + + +@pytest.mark.asyncio +async def test_discriminator_reports_matching_branch_errors(): + """const discriminator: correct type but missing required field → reports only that branch.""" + errors = _validate_request( + CONST_DISCRIMINATOR_SPEC, + {"type": "workflow"}, # missing "steps" + ) + assert len(errors) >= 1 + # Should report the missing "steps" field, not errors from prompt/tool_call branches + assert any("steps" in e for e in errors) + # Should NOT mention fields from other branches + assert not any("text" in e for e in errors) + assert not any("tool_name" in e for e in errors) + + +@pytest.mark.asyncio +async def test_discriminator_unknown_value_reports_expected(): + """const discriminator: unknown type value → concise 'Expected one of' message.""" + errors = _validate_request( + CONST_DISCRIMINATOR_SPEC, + {"type": "unknown_thing", "steps": ["a"]}, + ) + assert len(errors) >= 1 + # Should mention the expected values, not dump all branch errors + combined = " ".join(errors) + assert "Expected" in combined or "type" in combined + + +@pytest.mark.asyncio +async def test_discriminator_wrong_type_value(): + """const discriminator: type field is integer instead of string → type error reported.""" + errors = _validate_request( + CONST_DISCRIMINATOR_SPEC, + {"type": 123, "steps": ["a"]}, + ) + assert len(errors) >= 1 + combined = " ".join(errors) + assert "type" in combined + + +@pytest.mark.asyncio +async def test_discriminator_missing_type_field(): + """const discriminator: missing discriminator field entirely → error.""" + errors = _validate_request( + CONST_DISCRIMINATOR_SPEC, + {"steps": ["a"]}, + ) + assert len(errors) >= 1 + + +# Spec with enum-based discriminator (the existing ONEOF_SPEC pattern) +@pytest.mark.asyncio +async def test_enum_discriminator_matching_branch_errors(): + """enum discriminator: correct kind but missing required field → branch-specific errors.""" + errors = _validate_request( + ONEOF_SPEC, + {"kind": "cat"}, # missing "purrs" + ) + assert len(errors) >= 1 + assert any("purrs" in e for e in errors) + # Should not mention "barks" from the dog branch + assert not any("barks" in e for e in errors) + + +# --- JSON path in error messages --- + +@pytest.mark.asyncio +async def test_error_includes_json_path_for_nested_property(): + """Validation errors for nested properties include the JSON path prefix.""" + errors = _validate_request( + COMPLEX_SPEC, + { + "customer_name": "Bob", + "tier": "gold", + "shipping_address": { + "street": "1 Elm", + # missing "city" + "zip": "00000", + "country": "US", + }, + "items": [{"sku": "ABCD", "quantity": 1, "unit_price": 9.99}], + }, + ) + assert len(errors) >= 1 + # At least one error should contain a JSON-path-like prefix + assert any("$." in e or "city" in e for e in errors) + + +@pytest.mark.asyncio +async def test_error_includes_json_path_for_array_item(): + """Validation errors inside array items include the element index in the path.""" + errors = _validate_request( + COMPLEX_SPEC, + { + "customer_name": "Bob", + "tier": "gold", + "shipping_address": { + "street": "1 Elm", + "city": "X", + "zip": "00000", + "country": "US", + }, + "items": [{"quantity": 1, "unit_price": 9.99}], # missing "sku" + }, + ) + assert len(errors) >= 1 + assert any("items" in e for e in errors) + + +# --- Spec with nested oneOf inside properties --- + +NESTED_ONEOF_SPEC: dict = { + "openapi": "3.0.0", + "info": {"title": "NestedOneOf", "version": "1.0"}, + "paths": { + "/invocations": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "payload": { + "oneOf": [ + { + "type": "object", + "properties": { + "kind": {"type": "string", "const": "text"}, + "content": {"type": "string"}, + }, + "required": ["kind", "content"], + }, + { + "type": "object", + "properties": { + "kind": {"type": "string", "const": "image"}, + "url": {"type": "string", "format": "uri"}, + }, + "required": ["kind", "url"], + }, + ], + }, + }, + "required": ["name", "payload"], + } + } + } + }, + "responses": {"200": {"description": "OK"}}, + } + } + }, +} + + +@pytest.mark.asyncio +async def test_nested_oneof_valid_text(): + """Nested oneOf: valid text payload → passes.""" + errors = _validate_request( + NESTED_ONEOF_SPEC, + {"name": "test", "payload": {"kind": "text", "content": "hello"}}, + ) + assert errors == [] + + +@pytest.mark.asyncio +async def test_nested_oneof_valid_image(): + """Nested oneOf: valid image payload → passes.""" + errors = _validate_request( + NESTED_ONEOF_SPEC, + {"name": "test", "payload": {"kind": "image", "url": "https://example.com/img.png"}}, + ) + assert errors == [] + + +@pytest.mark.asyncio +async def test_nested_oneof_wrong_discriminator(): + """Nested oneOf: unknown discriminator value → error mentioning payload path.""" + errors = _validate_request( + NESTED_ONEOF_SPEC, + {"name": "test", "payload": {"kind": "video", "url": "https://example.com/v.mp4"}}, + ) + assert len(errors) >= 1 + combined = " ".join(errors) + # Should mention "payload" in the path + assert "payload" in combined + + +@pytest.mark.asyncio +async def test_nested_oneof_matching_branch_missing_field(): + """Nested oneOf: correct kind=text but missing content → branch-specific error.""" + errors = _validate_request( + NESTED_ONEOF_SPEC, + {"name": "test", "payload": {"kind": "text"}}, + ) + assert len(errors) >= 1 + assert any("content" in e for e in errors) + # Should NOT mention "url" from the image branch + assert not any("url" in e for e in errors) + + +# --- Unit-level tests for helper functions --- + +@pytest.mark.asyncio +async def test_format_error_includes_path(): + """_format_error prefixes with JSON path when not root.""" + from azure.ai.agentserver.server._openapi_validator import _format_error + import jsonschema + + schema = {"type": "object", "properties": {"age": {"type": "integer"}}, "required": ["age"]} + validator = jsonschema.Draft7Validator(schema) + errors = list(validator.iter_errors({"age": "not_int"})) + assert len(errors) == 1 + formatted = _format_error(errors[0]) + assert "$.age" in formatted + + +@pytest.mark.asyncio +async def test_format_error_root_path_no_prefix(): + """_format_error does not prefix when error is at root ($).""" + from azure.ai.agentserver.server._openapi_validator import _format_error + import jsonschema + + schema = {"type": "object", "required": ["name"]} + validator = jsonschema.Draft7Validator(schema) + errors = list(validator.iter_errors({})) + assert len(errors) == 1 + formatted = _format_error(errors[0]) + # Root-level error — should be the message without "$.xxx:" prefix + assert formatted == errors[0].message + + +@pytest.mark.asyncio +async def test_collect_errors_flat(): + """_collect_errors for a simple (non-composition) error returns path-prefixed message.""" + from azure.ai.agentserver.server._openapi_validator import _collect_errors + import jsonschema + + schema = {"type": "object", "properties": {"x": {"type": "integer"}}} + validator = jsonschema.Draft7Validator(schema) + errors = list(validator.iter_errors({"x": "abc"})) + collected = [] + for e in errors: + collected.extend(_collect_errors(e)) + assert len(collected) == 1 + assert "$.x" in collected[0] diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_request_limits.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_request_limits.py new file mode 100644 index 000000000000..c9f0efee996d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_request_limits.py @@ -0,0 +1,152 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for invoke timeout guards.""" +import asyncio +import os +from unittest.mock import patch + +import httpx +import pytest + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from azure.ai.agentserver.server import AgentServer +from azure.ai.agentserver.server._constants import Constants + + +# --------------------------------------------------------------------------- +# Agent factory functions +# --------------------------------------------------------------------------- + + +def _make_slow_agent(**kwargs) -> AgentServer: + """Create an agent that sleeps forever — used to test invoke timeout.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + await asyncio.sleep(999) + return Response(content=b"done") + + return server + + +def _make_fast_agent(**kwargs) -> AgentServer: + """Create an agent that returns immediately — used to verify no false timeout.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + return JSONResponse({"ok": True}) + + return server + + +# =========================================================================== +# Request timeout +# =========================================================================== + + +class TestRequestTimeoutConstants: + """Verify constants are wired correctly.""" + + def test_default_is_300_seconds(self): + assert Constants.DEFAULT_REQUEST_TIMEOUT == 300 + + def test_env_var_name(self): + assert Constants.AGENT_REQUEST_TIMEOUT == "AGENT_REQUEST_TIMEOUT" + + +class TestRequestTimeoutResolution: + """Test the resolution hierarchy: explicit > env > default.""" + + def test_explicit_value(self): + agent = _make_fast_agent(request_timeout=60) + assert agent._request_timeout == 60 + + def test_explicit_zero_disables(self): + agent = _make_fast_agent(request_timeout=0) + assert agent._request_timeout == 0 + + def test_env_var_zero_disables(self): + with patch.dict(os.environ, {Constants.AGENT_REQUEST_TIMEOUT: "0"}): + agent = _make_fast_agent() + assert agent._request_timeout == 0 + + def test_env_var_used_when_no_explicit(self): + with patch.dict(os.environ, {Constants.AGENT_REQUEST_TIMEOUT: "120"}): + agent = _make_fast_agent() + assert agent._request_timeout == 120 + + def test_invalid_env_var_raises(self): + with patch.dict(os.environ, {Constants.AGENT_REQUEST_TIMEOUT: "abc"}): + with pytest.raises(ValueError, match="AGENT_REQUEST_TIMEOUT"): + _make_fast_agent() + + def test_default_when_nothing_set(self): + with patch.dict(os.environ, {}, clear=True): + agent = _make_fast_agent() + assert agent._request_timeout == Constants.DEFAULT_REQUEST_TIMEOUT + + +class TestInvokeTimeoutEnforcement: + """Test that long-running invoke() calls are cancelled.""" + + @pytest.mark.asyncio + async def test_slow_invoke_returns_504(self): + agent = _make_slow_agent(request_timeout=1) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b"{}") + assert resp.status_code == 504 + data = resp.json() + assert data["error"]["code"] == "request_timeout" + assert "timed out" in data["error"]["message"].lower() + assert "1s" in data["error"]["message"] + + @pytest.mark.asyncio + async def test_slow_invoke_includes_invocation_id_header(self): + agent = _make_slow_agent(request_timeout=1) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b"{}") + assert resp.status_code == 504 + assert Constants.INVOCATION_ID_HEADER in resp.headers + + @pytest.mark.asyncio + async def test_fast_invoke_not_affected(self): + agent = _make_fast_agent(request_timeout=5) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + assert resp.json() == {"ok": True} + + @pytest.mark.asyncio + async def test_timeout_disabled_allows_no_limit(self): + """request_timeout=0 means no timeout (passes None to wait_for).""" + agent = _make_fast_agent(request_timeout=0) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b"{}") + assert resp.status_code == 200 + + @pytest.mark.asyncio + async def test_timeout_error_is_logged(self): + agent = _make_slow_agent(request_timeout=1) + transport = httpx.ASGITransport(app=agent.app) + with patch("azure.ai.agentserver.server._invocation.logger") as mock_logger: + async with httpx.AsyncClient( + transport=transport, base_url="http://testserver" + ) as client: + resp = await client.post("/invocations", content=b"{}") + assert resp.status_code == 504 + + timeout_calls = [ + c + for c in mock_logger.error.call_args_list + if "timed out" in str(c).lower() + ] + assert len(timeout_calls) == 1 diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_server_routes.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_server_routes.py new file mode 100644 index 000000000000..80ae3e5338fb --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_server_routes.py @@ -0,0 +1,120 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for route registration and basic endpoint behavior.""" +import uuid + +import pytest + + +@pytest.mark.asyncio +async def test_post_invocations_returns_200(echo_client): + """POST /invocations with valid body returns 200.""" + resp = await echo_client.post("/invocations", content=b'{"hello":"world"}') + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_post_invocations_returns_invocation_id_header(echo_client): + """Response includes x-agent-invocation-id header in UUID format.""" + resp = await echo_client.post("/invocations", content=b'{"hello":"world"}') + invocation_id = resp.headers.get("x-agent-invocation-id") + assert invocation_id is not None + # Validate UUID format + uuid.UUID(invocation_id) + + +@pytest.mark.asyncio +async def test_get_openapi_spec_returns_404_when_not_set(echo_client): + """GET /invocations/docs/openapi.json returns 404 if no spec registered.""" + resp = await echo_client.get("/invocations/docs/openapi.json") + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_get_openapi_spec_returns_spec(validated_client): + """GET /invocations/docs/openapi.json returns registered spec as JSON.""" + resp = await validated_client.get("/invocations/docs/openapi.json") + assert resp.status_code == 200 + data = resp.json() + assert data["openapi"] == "3.0.0" + assert "/invocations" in data["paths"] + + +@pytest.mark.asyncio +async def test_get_invocation_returns_501_default(echo_client): + """GET /invocations/{id} returns 501 when not overridden.""" + inv_id = str(uuid.uuid4()) + resp = await echo_client.get(f"/invocations/{inv_id}") + assert resp.status_code == 501 + assert resp.headers.get("x-agent-invocation-id") == inv_id + + +@pytest.mark.asyncio +async def test_cancel_invocation_returns_501_default(echo_client): + """POST /invocations/{id}/cancel returns 501 when not overridden.""" + inv_id = str(uuid.uuid4()) + resp = await echo_client.post(f"/invocations/{inv_id}/cancel") + assert resp.status_code == 501 + assert resp.headers.get("x-agent-invocation-id") == inv_id + + +@pytest.mark.asyncio +async def test_unknown_route_returns_404(echo_client): + """GET /unknown returns 404.""" + resp = await echo_client.get("/unknown") + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# _resolve_port tests +# --------------------------------------------------------------------------- + + +class TestResolvePort: + """Unit tests for resolve_port.""" + + def test_explicit_port_wins(self): + """Explicit port argument takes precedence over everything.""" + from azure.ai.agentserver.server._config import resolve_port + assert resolve_port(9090) == 9090 + + def test_env_var_used_when_no_explicit(self, monkeypatch): + """AGENT_SERVER_PORT env var is used when no explicit port given.""" + from azure.ai.agentserver.server._config import resolve_port + monkeypatch.setenv("AGENT_SERVER_PORT", "7777") + assert resolve_port(None) == 7777 + + def test_default_port_when_nothing_set(self, monkeypatch): + """Falls back to 8088 when no explicit port and no env var.""" + from azure.ai.agentserver.server._config import resolve_port + monkeypatch.delenv("AGENT_SERVER_PORT", raising=False) + assert resolve_port(None) == 8088 + + def test_invalid_env_var_raises(self, monkeypatch): + """Non-numeric AGENT_SERVER_PORT env var raises ValueError.""" + from azure.ai.agentserver.server._config import resolve_port + monkeypatch.setenv("AGENT_SERVER_PORT", "not_a_number") + with pytest.raises(ValueError, match="AGENT_SERVER_PORT"): + resolve_port(None) + + def test_non_int_explicit_port_raises(self): + """Passing a non-integer port raises ValueError.""" + from azure.ai.agentserver.server._config import resolve_port + with pytest.raises(ValueError, match="expected an integer"): + resolve_port("sss") # type: ignore[arg-type] + + def test_port_out_of_range_raises(self): + """Port outside 1-65535 raises ValueError.""" + from azure.ai.agentserver.server._config import resolve_port + with pytest.raises(ValueError, match="1-65535"): + resolve_port(0) + with pytest.raises(ValueError, match="1-65535"): + resolve_port(70000) + + def test_env_var_port_out_of_range_raises(self, monkeypatch): + """AGENT_SERVER_PORT outside 1-65535 raises ValueError.""" + from azure.ai.agentserver.server._config import resolve_port + monkeypatch.setenv("AGENT_SERVER_PORT", "0") + with pytest.raises(ValueError, match="1-65535"): + resolve_port(None) diff --git a/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py b/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py new file mode 100644 index 000000000000..2e09df90ee52 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-server/tests/test_tracing.py @@ -0,0 +1,960 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Tests for optional OpenTelemetry tracing.""" +import asyncio +import json +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from azure.ai.agentserver.server import AgentServer +from azure.ai.agentserver.server._tracing import _TracingHelper + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _InMemoryExporter(SpanExporter): + """Minimal in-memory span exporter for tests.""" + + def __init__(self): + self._spans: list = [] + + def export(self, spans): # type: ignore[override] + self._spans.extend(spans) + return SpanExportResult.SUCCESS + + def get_finished_spans(self): + return list(self._spans) + + def shutdown(self): + self._spans.clear() + + +def _make_echo_traced_agent(**kwargs) -> AgentServer: + """Create a simple agent for tracing tests.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + body = await request.body() + return Response(content=body, media_type="application/octet-stream") + + return server + + +def _make_failing_traced_agent(**kwargs) -> AgentServer: + """Create an agent whose invoke raises — so the span records the error.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + raise RuntimeError("trace-this-error") + + return server + + +@pytest.fixture() +def span_exporter(): + """Return the module-level exporter with a clean slate for each test. + + Patches ``_setup_azure_monitor`` so that an ``APPLICATIONINSIGHTS_CONNECTION_STRING`` + env var in the CI environment doesn't replace the test TracerProvider. + """ + _MODULE_EXPORTER._spans.clear() + # Ensure the module-level provider is active (a prior test may have replaced it). + trace.set_tracer_provider(_MODULE_PROVIDER) + with patch.object(_TracingHelper, "_setup_azure_monitor"): + yield _MODULE_EXPORTER + + +# Module-level OTel setup — set once to avoid +# "Overriding of current TracerProvider is not allowed" warnings. +_MODULE_EXPORTER = _InMemoryExporter() +_MODULE_PROVIDER = TracerProvider() +_MODULE_PROVIDER.add_span_processor(SimpleSpanProcessor(_MODULE_EXPORTER)) +trace.set_tracer_provider(_MODULE_PROVIDER) + + +# --------------------------------------------------------------------------- +# Tests: tracing disabled (default) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_tracing_disabled_by_default(): + """Agent created without enable_tracing has tracing off.""" + agent = _make_echo_traced_agent() + assert agent._tracing is None # noqa: SLF001 + + +@pytest.mark.asyncio +async def test_tracing_disabled_no_spans(span_exporter): + """When tracing is disabled, no spans are produced.""" + agent = _make_echo_traced_agent() # default: tracing off + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b'{"hi": "there"}') + + spans = span_exporter.get_finished_spans() + assert len(spans) == 0 + + +# --------------------------------------------------------------------------- +# Tests: tracing enabled via constructor +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_tracing_enabled_creates_invoke_span(span_exporter): + """POST /invocations with tracing enabled creates a span.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b'{"data": 1}') + assert resp.status_code == 200 + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + span = invoke_spans[0] + assert "invocation.id" in dict(span.attributes) + assert span.status.status_code == trace.StatusCode.UNSET # success + + +@pytest.mark.asyncio +async def test_tracing_invoke_error_records_exception(span_exporter): + """When invoke() raises, the span records the error status and exception.""" + agent = _make_failing_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b'{}') + assert resp.status_code == 500 + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + span = invoke_spans[0] + assert span.status.status_code == trace.StatusCode.ERROR + # Exception should be recorded in events + events = span.events + assert any("trace-this-error" in str(e.attributes) for e in events) + + +@pytest.mark.asyncio +async def test_tracing_get_invocation_creates_span(span_exporter): + """GET /invocations/{id} with tracing creates a span.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.get("/invocations/test-id-123") + # Default returns 501 — but span should still exist + assert resp.status_code == 501 + + spans = span_exporter.get_finished_spans() + get_spans = [s for s in spans if "get_invocation" in s.name] + assert len(get_spans) == 1 + assert dict(get_spans[0].attributes)["invocation.id"] == "test-id-123" + + +@pytest.mark.asyncio +async def test_tracing_cancel_invocation_creates_span(span_exporter): + """POST /invocations/{id}/cancel with tracing creates a span.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations/test-cancel-456/cancel") + assert resp.status_code == 501 + + spans = span_exporter.get_finished_spans() + cancel_spans = [s for s in spans if "cancel_invocation" in s.name] + assert len(cancel_spans) == 1 + assert dict(cancel_spans[0].attributes)["invocation.id"] == "test-cancel-456" + + +# --------------------------------------------------------------------------- +# Tests: tracing enabled via env var +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_tracing_enabled_via_env_var(monkeypatch, span_exporter): + """AGENT_ENABLE_TRACING=true activates tracing.""" + monkeypatch.setenv("AGENT_ENABLE_TRACING", "true") + agent = _make_echo_traced_agent() + assert agent._tracing is not None # noqa: SLF001 + + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b'{}') + + spans = span_exporter.get_finished_spans() + assert any("execute_agent" in s.name for s in spans) + + +@pytest.mark.asyncio +async def test_tracing_constructor_overrides_env_var(monkeypatch, span_exporter): + """Constructor enable_tracing=False overrides AGENT_ENABLE_TRACING=true.""" + monkeypatch.setenv("AGENT_ENABLE_TRACING", "true") + agent = _make_echo_traced_agent(enable_tracing=False) + assert agent._tracing is None # noqa: SLF001 + + +@pytest.mark.asyncio +async def test_tracing_propagates_traceparent(span_exporter): + """Incoming traceparent header is extracted as parent context.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + # Valid W3C traceparent — version-trace-id-parent-id-flags + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post( + "/invocations", + content=b'{"hello": "world"}', + headers={"traceparent": traceparent}, + ) + assert resp.status_code == 200 + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + span = invoke_spans[0] + # The span's trace ID should match the traceparent's trace ID + expected_trace_id = int("0af7651916cd43dd8448eb211c80319c", 16) + assert span.context.trace_id == expected_trace_id + + +# --------------------------------------------------------------------------- +# Tests: streaming response tracing +# --------------------------------------------------------------------------- + + +def _make_streaming_traced_agent(**kwargs) -> AgentServer: + """Create an agent that returns a StreamingResponse.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + async def _generate(): + for i in range(5): + yield f"chunk-{i}\n".encode() + + return StreamingResponse(_generate(), media_type="application/octet-stream") + + return server + + +def _make_slow_streaming_agent(**kwargs) -> AgentServer: + """Create an agent that streams with deliberate delays per chunk.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + async def _generate(): + for i in range(3): + await asyncio.sleep(0.05) + yield f"slow-{i}\n".encode() + + return StreamingResponse(_generate(), media_type="text/plain") + + return server + + +def _make_failing_stream_agent(**kwargs) -> AgentServer: + """Create an agent whose streaming body raises mid-stream.""" + server = AgentServer(**kwargs) + + @server.invoke_handler + async def handle(request: Request) -> Response: + async def _generate(): + yield b"ok-chunk\n" + raise RuntimeError("stream-exploded") + + return StreamingResponse(_generate(), media_type="text/plain") + + return server + + +@pytest.mark.asyncio +async def test_streaming_response_creates_span(span_exporter): + """Streaming response still produces a span.""" + agent = _make_streaming_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b'{}') + assert resp.status_code == 200 + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + assert invoke_spans[0].status.status_code == trace.StatusCode.UNSET + + +@pytest.mark.asyncio +async def test_streaming_span_covers_full_body(span_exporter): + """Span for streaming response covers the full streaming duration, + not just the invoke() call that creates the StreamingResponse.""" + agent = _make_slow_streaming_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b'{}') + assert resp.status_code == 200 + # All chunks received + assert b"slow-0" in resp.content + assert b"slow-2" in resp.content + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + span = invoke_spans[0] + + # Span duration should cover the streaming time (~150ms for 3×50ms), + # not just the instant invoke() returns the StreamingResponse object. + duration_ns = span.end_time - span.start_time + # At minimum 100ms (conservative) — would be <1ms without the fix. + assert duration_ns > 50_000_000, f"Span duration {duration_ns}ns is too short for streaming" + + +@pytest.mark.asyncio +async def test_streaming_body_fully_received(span_exporter): + """All chunks from a streaming response are delivered to the client.""" + agent = _make_streaming_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b'{}') + assert resp.status_code == 200 + body = resp.content.decode() + for i in range(5): + assert f"chunk-{i}" in body + + +@pytest.mark.asyncio +async def test_streaming_error_recorded_in_span(span_exporter): + """Errors during streaming are recorded on the span.""" + agent = _make_failing_stream_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + # The server will encounter a mid-stream error; httpx may raise or + # return a partial response depending on ASGI transport behaviour. + try: + await client.post("/invocations", content=b'{}') + except Exception: + pass # connection reset / partial read is acceptable + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + span = invoke_spans[0] + assert span.status.status_code == trace.StatusCode.ERROR + events = span.events + assert any("stream-exploded" in str(e.attributes) for e in events) + + +@pytest.mark.asyncio +async def test_streaming_tracing_disabled_no_span(span_exporter): + """When tracing is disabled, streaming responses produce no spans.""" + agent = _make_streaming_traced_agent() # tracing off (default) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post("/invocations", content=b'{}') + assert resp.status_code == 200 + body = resp.content.decode() + assert "chunk-0" in body # body still delivered + + spans = span_exporter.get_finished_spans() + assert len(spans) == 0 + + +@pytest.mark.asyncio +async def test_streaming_propagates_traceparent(span_exporter): + """Incoming traceparent header is propagated for streaming responses.""" + agent = _make_streaming_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post( + "/invocations", + content=b'{}', + headers={"traceparent": traceparent}, + ) + assert resp.status_code == 200 + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + expected_trace_id = int("0af7651916cd43dd8448eb211c80319c", 16) + assert invoke_spans[0].context.trace_id == expected_trace_id + + +# --------------------------------------------------------------------------- +# Tests: Application Insights connection string resolution +# --------------------------------------------------------------------------- + + +class TestAppInsightsConnectionStringResolution: + """Tests for resolve_appinsights_connection_string and constructor wiring.""" + + def test_explicit_param_takes_priority(self, monkeypatch): + """Constructor param beats env var.""" + from azure.ai.agentserver.server._config import resolve_appinsights_connection_string + + monkeypatch.setenv("APPLICATIONINSIGHTS_CONNECTION_STRING", "env-standard") + result = resolve_appinsights_connection_string("explicit-value") + assert result == "explicit-value" + + def test_standard_env_var_fallback(self, monkeypatch): + """Falls back to APPLICATIONINSIGHTS_CONNECTION_STRING.""" + from azure.ai.agentserver.server._config import resolve_appinsights_connection_string + + monkeypatch.setenv("APPLICATIONINSIGHTS_CONNECTION_STRING", "env-standard") + result = resolve_appinsights_connection_string(None) + assert result == "env-standard" + + def test_no_connection_string_returns_none(self, monkeypatch): + """Returns None when no source provides a connection string.""" + from azure.ai.agentserver.server._config import resolve_appinsights_connection_string + + monkeypatch.delenv("APPLICATIONINSIGHTS_CONNECTION_STRING", raising=False) + result = resolve_appinsights_connection_string(None) + assert result is None + + +# --------------------------------------------------------------------------- +# Tests: _setup_azure_monitor +# --------------------------------------------------------------------------- + + +class TestSetupAzureMonitor: + """Tests for _TracingHelper._setup_azure_monitor (mocked exporter imports).""" + + def test_setup_configures_tracer_provider(self): + """_setup_azure_monitor sets a global TracerProvider with exporter.""" + from azure.ai.agentserver.server._tracing import _TracingHelper + + mock_exporter = MagicMock() + mock_exporter_cls = MagicMock(return_value=mock_exporter) + + with patch.dict( + "sys.modules", + { + "azure.monitor.opentelemetry.exporter": MagicMock( + AzureMonitorTraceExporter=mock_exporter_cls, + AzureMonitorLogExporter=MagicMock(return_value=MagicMock()), + ), + }, + ), patch("opentelemetry.trace.set_tracer_provider") as mock_set_provider: + _TracingHelper._setup_azure_monitor("InstrumentationKey=test") + + mock_exporter_cls.assert_called_once_with( + connection_string="InstrumentationKey=test" + ) + mock_set_provider.assert_called_once() + + def test_setup_logs_warning_when_packages_missing(self, caplog): + """Warns gracefully when azure-monitor exporter is not installed.""" + import builtins + from azure.ai.agentserver.server._tracing import _TracingHelper + + real_import = builtins.__import__ + + def _block_monitor(name, *args, **kwargs): + if "azure.monitor" in name: + raise ImportError("no monitor") + return real_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=_block_monitor): + import logging + + with caplog.at_level(logging.WARNING, logger="azure.ai.agentserver"): + _TracingHelper._setup_azure_monitor("InstrumentationKey=test") + + assert "Traces will not be forwarded" in caplog.text + + def test_constructor_passes_connection_string(self, monkeypatch): + """AgentServer passes resolved connection string to _TracingHelper.""" + monkeypatch.delenv("APPLICATIONINSIGHTS_CONNECTION_STRING", raising=False) + + with patch( + "azure.ai.agentserver.server._tracing._TracingHelper._setup_azure_monitor" + ) as mock_setup: + _make_echo_traced_agent( + enable_tracing=True, + application_insights_connection_string="InstrumentationKey=from-param", + ) + mock_setup.assert_called_once_with("InstrumentationKey=from-param") + + def test_constructor_no_connection_string_skips_setup(self, monkeypatch): + """When no connection string is available, _setup_azure_monitor is not called.""" + monkeypatch.delenv("APPLICATIONINSIGHTS_CONNECTION_STRING", raising=False) + + with patch( + "azure.ai.agentserver.server._tracing._TracingHelper._setup_azure_monitor" + ) as mock_setup: + _make_echo_traced_agent(enable_tracing=True) + mock_setup.assert_not_called() + + def test_constructor_env_var_connection_string(self, monkeypatch): + """Connection string from env var is passed to _setup_azure_monitor.""" + monkeypatch.setenv( + "APPLICATIONINSIGHTS_CONNECTION_STRING", + "InstrumentationKey=from-env", + ) + + with patch( + "azure.ai.agentserver.server._tracing._TracingHelper._setup_azure_monitor" + ) as mock_setup: + _make_echo_traced_agent(enable_tracing=True) + mock_setup.assert_called_once_with("InstrumentationKey=from-env") + + def test_tracing_disabled_skips_connection_string_resolution(self, monkeypatch): + """When tracing is disabled, connection string is not resolved.""" + monkeypatch.setenv( + "APPLICATIONINSIGHTS_CONNECTION_STRING", + "InstrumentationKey=should-not-use", + ) + + agent = _make_echo_traced_agent(enable_tracing=False) + assert agent._tracing is None # noqa: SLF001 + + +# --------------------------------------------------------------------------- +# Tests: span naming with agent name and version +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_span_name_includes_agent_label(monkeypatch, span_exporter): + """Span name includes agent_name:agent_version when env vars are set.""" + monkeypatch.setenv("AGENT_NAME", "my-agent") + monkeypatch.setenv("AGENT_VERSION", "2.1") + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b'{}') + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if s.name == "execute_agent my-agent:2.1"] + assert len(invoke_spans) == 1 + + +@pytest.mark.asyncio +async def test_span_name_without_agent_label(span_exporter, monkeypatch): + """Span name is just the operation when AGENT_NAME is not set.""" + monkeypatch.delenv("AGENT_NAME", raising=False) + monkeypatch.delenv("AGENT_VERSION", raising=False) + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b'{}') + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if s.name == "execute_agent"] + assert len(invoke_spans) == 1 + + +@pytest.mark.asyncio +async def test_get_invocation_span_name_with_label(monkeypatch, span_exporter): + """GET /invocations/{id} span includes agent label.""" + monkeypatch.setenv("AGENT_NAME", "agent-x") + monkeypatch.setenv("AGENT_VERSION", "0.5") + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.get("/invocations/test-id") + + spans = span_exporter.get_finished_spans() + get_spans = [s for s in spans if s.name == "get_invocation agent-x:0.5"] + assert len(get_spans) == 1 + + +@pytest.mark.asyncio +async def test_cancel_invocation_span_name_with_label(monkeypatch, span_exporter): + """POST /invocations/{id}/cancel span includes agent label.""" + monkeypatch.setenv("AGENT_NAME", "agent-x") + monkeypatch.setenv("AGENT_VERSION", "0.5") + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations/test-id/cancel") + + spans = span_exporter.get_finished_spans() + cancel_spans = [s for s in spans if s.name == "cancel_invocation agent-x:0.5"] + assert len(cancel_spans) == 1 + + +# --------------------------------------------------------------------------- +# Tests: GenAI semantic convention attributes +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_genai_attributes_on_invoke_span(span_exporter, monkeypatch): + """Invoke span has GenAI semantic convention attributes.""" + monkeypatch.setenv("AGENT_NAME", "test-agent") + monkeypatch.setenv("AGENT_VERSION", "1.0") + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b'{}') + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + attrs = dict(invoke_spans[0].attributes) + assert attrs["gen_ai.operation.name"] == "invoke_agent" + assert attrs["gen_ai.agent.id"] == "test-agent:1.0" + assert attrs["gen_ai.provider.name"] == "microsoft.foundry" + assert "gen_ai.response.id" in attrs # UUID invocation ID + + +@pytest.mark.asyncio +async def test_genai_conversation_id_from_session_header(span_exporter, monkeypatch): + """gen_ai.conversation.id is set from agent_session_id query parameter.""" + monkeypatch.delenv("AGENT_NAME", raising=False) + monkeypatch.delenv("AGENT_VERSION", raising=False) + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post( + "/invocations", + content=b'{}', + params={"agent_session_id": "session-abc-123"}, + ) + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + attrs = dict(invoke_spans[0].attributes) + assert attrs["gen_ai.conversation.id"] == "session-abc-123" + + +@pytest.mark.asyncio +async def test_genai_conversation_id_absent_when_no_header(span_exporter, monkeypatch): + """gen_ai.conversation.id is NOT set when agent_session_id query parameter is absent.""" + monkeypatch.delenv("AGENT_NAME", raising=False) + monkeypatch.delenv("AGENT_VERSION", raising=False) + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b'{}') + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + attrs = dict(invoke_spans[0].attributes) + assert "gen_ai.conversation.id" not in attrs + + +@pytest.mark.asyncio +async def test_genai_attributes_on_get_invocation_span(span_exporter, monkeypatch): + """GET /invocations/{id} span has GenAI attributes (minus operation.name).""" + monkeypatch.setenv("AGENT_NAME", "test-agent") + monkeypatch.setenv("AGENT_VERSION", "1.0") + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.get( + "/invocations/inv-42", + params={"agent_session_id": "sess-99"}, + ) + + spans = span_exporter.get_finished_spans() + get_spans = [s for s in spans if "get_invocation" in s.name] + assert len(get_spans) == 1 + + attrs = dict(get_spans[0].attributes) + assert attrs["gen_ai.agent.id"] == "test-agent:1.0" + assert attrs["gen_ai.provider.name"] == "microsoft.foundry" + assert attrs["gen_ai.response.id"] == "inv-42" + assert attrs["gen_ai.conversation.id"] == "sess-99" + assert attrs["invocation.id"] == "inv-42" + + +# --------------------------------------------------------------------------- +# Tests: baggage extraction and leaf_customer_span_id +# --------------------------------------------------------------------------- + + +class TestBaggageParsing: + """Unit tests for _parse_baggage_key.""" + + def test_single_key(self): + from azure.ai.agentserver.server._tracing import _parse_baggage_key + assert _parse_baggage_key("leaf_customer_span_id=abc123", "leaf_customer_span_id") == "abc123" + + def test_multiple_keys(self): + from azure.ai.agentserver.server._tracing import _parse_baggage_key + baggage = "foo=bar,leaf_customer_span_id=deadbeef01234567,baz=qux" + assert _parse_baggage_key(baggage, "leaf_customer_span_id") == "deadbeef01234567" + + def test_key_not_present(self): + from azure.ai.agentserver.server._tracing import _parse_baggage_key + assert _parse_baggage_key("foo=bar,baz=qux", "leaf_customer_span_id") is None + + def test_empty_baggage(self): + from azure.ai.agentserver.server._tracing import _parse_baggage_key + assert _parse_baggage_key("", "leaf_customer_span_id") is None + + def test_key_with_properties(self): + from azure.ai.agentserver.server._tracing import _parse_baggage_key + baggage = "leaf_customer_span_id=abc123;property1=val1,other=2" + assert _parse_baggage_key(baggage, "leaf_customer_span_id") == "abc123" + + def test_whitespace_handling(self): + from azure.ai.agentserver.server._tracing import _parse_baggage_key + baggage = " leaf_customer_span_id = abc123 , other = val " + assert _parse_baggage_key(baggage, "leaf_customer_span_id") == "abc123" + + +@pytest.mark.asyncio +async def test_baggage_leaf_customer_span_id_overrides_parent(span_exporter): + """When baggage contains leaf_customer_span_id, the span's parent span ID + is overridden to match, while the trace ID stays the same as traceparent.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + + trace_id_hex = "0af7651916cd43dd8448eb211c80319c" + original_parent_hex = "b7ad6b7169203331" + leaf_span_hex = "00f067aa0ba902b7" + + traceparent = f"00-{trace_id_hex}-{original_parent_hex}-01" + baggage = f"leaf_customer_span_id={leaf_span_hex},other=val" + + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post( + "/invocations", + content=b'{"data": 1}', + headers={"traceparent": traceparent, "baggage": baggage}, + ) + assert resp.status_code == 200 + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + span = invoke_spans[0] + # Trace ID should match traceparent + expected_trace_id = int(trace_id_hex, 16) + assert span.context.trace_id == expected_trace_id + + # Parent span ID should be the leaf_customer_span_id, not the original + expected_parent_span_id = int(leaf_span_hex, 16) + assert span.parent.span_id == expected_parent_span_id + + +@pytest.mark.asyncio +async def test_baggage_without_leaf_uses_traceparent_parent(span_exporter): + """When baggage is present but does NOT contain leaf_customer_span_id, + the parent span ID comes from traceparent as usual.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + + trace_id_hex = "0af7651916cd43dd8448eb211c80319c" + parent_hex = "b7ad6b7169203331" + traceparent = f"00-{trace_id_hex}-{parent_hex}-01" + baggage = "some_other_key=value" + + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post( + "/invocations", + content=b'{}', + headers={"traceparent": traceparent, "baggage": baggage}, + ) + assert resp.status_code == 200 + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + span = invoke_spans[0] + expected_parent_span_id = int(parent_hex, 16) + assert span.parent.span_id == expected_parent_span_id + + +@pytest.mark.asyncio +async def test_baggage_no_traceparent_no_crash(span_exporter): + """When baggage is present but no traceparent, no crash occurs.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post( + "/invocations", + content=b'{}', + headers={"baggage": "leaf_customer_span_id=00f067aa0ba902b7"}, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_baggage_invalid_leaf_span_id_falls_back(span_exporter): + """Invalid hex in leaf_customer_span_id falls back to traceparent parent.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + + trace_id_hex = "0af7651916cd43dd8448eb211c80319c" + parent_hex = "b7ad6b7169203331" + traceparent = f"00-{trace_id_hex}-{parent_hex}-01" + baggage = "leaf_customer_span_id=not_valid_hex" + + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + resp = await client.post( + "/invocations", + content=b'{}', + headers={"traceparent": traceparent, "baggage": baggage}, + ) + assert resp.status_code == 200 + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + span = invoke_spans[0] + # Should fall back to traceparent's parent span ID + expected_parent = int(parent_hex, 16) + assert span.parent.span_id == expected_parent + + +@pytest.mark.asyncio +async def test_baggage_leaf_on_get_invocation(span_exporter): + """Baggage leaf_customer_span_id also works on GET /invocations/{id}.""" + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + + trace_id_hex = "0af7651916cd43dd8448eb211c80319c" + original_parent_hex = "b7ad6b7169203331" + leaf_span_hex = "00f067aa0ba902b7" + + traceparent = f"00-{trace_id_hex}-{original_parent_hex}-01" + baggage = f"leaf_customer_span_id={leaf_span_hex}" + + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.get( + "/invocations/test-id", + headers={"traceparent": traceparent, "baggage": baggage}, + ) + + spans = span_exporter.get_finished_spans() + get_spans = [s for s in spans if "get_invocation" in s.name] + assert len(get_spans) == 1 + + expected_trace_id = int(trace_id_hex, 16) + expected_parent = int(leaf_span_hex, 16) + assert get_spans[0].context.trace_id == expected_trace_id + assert get_spans[0].parent.span_id == expected_parent + + +# --------------------------------------------------------------------------- +# Tests: agent name / version resolution +# --------------------------------------------------------------------------- + + +class TestAgentNameVersionResolution: + """Tests for resolve_agent_name and resolve_agent_version.""" + + def test_agent_name_from_env(self, monkeypatch): + from azure.ai.agentserver.server._config import resolve_agent_name + monkeypatch.setenv("AGENT_NAME", "my-agent") + assert resolve_agent_name() == "my-agent" + + def test_agent_name_default_empty(self, monkeypatch): + from azure.ai.agentserver.server._config import resolve_agent_name + monkeypatch.delenv("AGENT_NAME", raising=False) + assert resolve_agent_name() == "" + + def test_agent_version_from_env(self, monkeypatch): + from azure.ai.agentserver.server._config import resolve_agent_version + monkeypatch.setenv("AGENT_VERSION", "3.0.1") + assert resolve_agent_version() == "3.0.1" + + def test_agent_version_default_empty(self, monkeypatch): + from azure.ai.agentserver.server._config import resolve_agent_version + monkeypatch.delenv("AGENT_VERSION", raising=False) + assert resolve_agent_version() == "" + + +# --------------------------------------------------------------------------- +# Tests: project ID resolution and tracing attribute +# --------------------------------------------------------------------------- + + +class TestProjectIdResolution: + """Tests for resolve_project_id and microsoft.foundry.project.id span attribute.""" + + def test_project_id_from_env(self, monkeypatch): + from azure.ai.agentserver.server._config import resolve_project_id + monkeypatch.setenv("AGENT_PROJECT_NAME", "proj-abc-123") + assert resolve_project_id() == "proj-abc-123" + + def test_project_id_default_empty(self, monkeypatch): + from azure.ai.agentserver.server._config import resolve_project_id + monkeypatch.delenv("AGENT_PROJECT_NAME", raising=False) + assert resolve_project_id() == "" + + +@pytest.mark.asyncio +async def test_project_id_attribute_on_invoke_span(monkeypatch, span_exporter): + """microsoft.foundry.project.id is set on invoke span when AGENT_PROJECT_NAME is set.""" + monkeypatch.setenv("AGENT_PROJECT_NAME", "proj-xyz-789") + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b'{}') + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + attrs = dict(invoke_spans[0].attributes) + assert attrs["microsoft.foundry.project.id"] == "proj-xyz-789" + + +@pytest.mark.asyncio +async def test_project_id_attribute_absent_when_not_set(monkeypatch, span_exporter): + """microsoft.foundry.project.id is NOT set when AGENT_PROJECT_NAME env var is absent.""" + monkeypatch.delenv("AGENT_PROJECT_NAME", raising=False) + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.post("/invocations", content=b'{}') + + spans = span_exporter.get_finished_spans() + invoke_spans = [s for s in spans if "execute_agent" in s.name] + assert len(invoke_spans) == 1 + + attrs = dict(invoke_spans[0].attributes) + assert "microsoft.foundry.project.id" not in attrs + + +@pytest.mark.asyncio +async def test_project_id_attribute_on_get_invocation_span(monkeypatch, span_exporter): + """microsoft.foundry.project.id is set on get_invocation span too.""" + monkeypatch.setenv("AGENT_PROJECT_NAME", "proj-get-456") + agent = _make_echo_traced_agent(enable_tracing=True) + transport = httpx.ASGITransport(app=agent.app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + await client.get("/invocations/test-id") + + spans = span_exporter.get_finished_spans() + get_spans = [s for s in spans if "get_invocation" in s.name] + assert len(get_spans) == 1 + + attrs = dict(get_spans[0].attributes) + assert attrs["microsoft.foundry.project.id"] == "proj-get-456" diff --git a/sdk/agentserver/ci.yml b/sdk/agentserver/ci.yml index bb2d6f479b00..9c0128b8089e 100644 --- a/sdk/agentserver/ci.yml +++ b/sdk/agentserver/ci.yml @@ -40,6 +40,8 @@ extends: Selection: sparse GenerateVMJobs: true Artifacts: + - name: azure-ai-agentserver-server + safeName: azureaiagentserverserver - name: azure-ai-agentserver-core safeName: azureaiagentservercore - name: azure-ai-agentserver-agentframework diff --git a/shared_requirements.txt b/shared_requirements.txt index b5e1b85f184a..518ceb2272f0 100644 --- a/shared_requirements.txt +++ b/shared_requirements.txt @@ -87,6 +87,7 @@ azure-confidentialledger-certificate azure-ai-projects starlette uvicorn +hypercorn opentelemetry-exporter-otlp-proto-http agent-framework-azure-ai langgraph