diff --git a/.coverage b/.coverage index ce2d005..de20ecf 100644 Binary files a/.coverage and b/.coverage differ diff --git a/coverage.xml b/coverage.xml index 4785cc5..83bdaba 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,5 +1,5 @@ - + @@ -171,5 +171,116 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/gavaconnect/auth/__init__.py b/gavaconnect/auth/__init__.py index 5093349..faaa1dd 100644 --- a/gavaconnect/auth/__init__.py +++ b/gavaconnect/auth/__init__.py @@ -1,10 +1,11 @@ """Authentication module for GavaConnect SDK.""" from .basic import BasicAuthPolicy, BasicCredentials -from .bearer import BearerAuthPolicy, TokenProvider +from .bearer import AuthPolicy, BearerAuthPolicy, TokenProvider from .providers import ClientCredentialsProvider __all__ = [ + "AuthPolicy", "BasicAuthPolicy", "BasicCredentials", "BearerAuthPolicy", diff --git a/gavaconnect/http/__init__.py b/gavaconnect/http/__init__.py new file mode 100644 index 0000000..c9afe49 --- /dev/null +++ b/gavaconnect/http/__init__.py @@ -0,0 +1,13 @@ +"""HTTP transport layer for the GavaConnect SDK.""" + +from .logging import log_request, log_response +from .telemetry import otel_request_span, otel_response_span +from .transport import AsyncTransport + +__all__ = [ + "log_request", + "log_response", + "otel_request_span", + "otel_response_span", + "AsyncTransport", +] diff --git a/gavaconnect/http/logging.py b/gavaconnect/http/logging.py new file mode 100644 index 0000000..362796d --- /dev/null +++ b/gavaconnect/http/logging.py @@ -0,0 +1,35 @@ +"""HTTP request and response logging utilities.""" + +import logging +import time + +import httpx + +logger = logging.getLogger("gavaconnect") + + +async def log_request(req: httpx.Request) -> None: + """Log an HTTP request with sanitized headers. + + Args: + req: The HTTP request to log. + + """ + req.extensions["start_time"] = time.perf_counter() + hdrs = dict(req.headers) + hdrs.pop("authorization", None) + logger.debug(f"HTTP {req.method} {req.url} headers={hdrs}") + + +async def log_response(req: httpx.Request, resp: httpx.Response) -> None: + """Log an HTTP response with timing information. + + Args: + req: The HTTP request. + resp: The HTTP response to log. + + """ + dur = time.perf_counter() - req.extensions.get("start_time", time.perf_counter()) + logger.info( + f"HTTP {req.method} {req.url} -> {resp.status_code} in {dur:.3f}s request_id={resp.headers.get('x-request-id')}" + ) diff --git a/gavaconnect/http/telemetry.py b/gavaconnect/http/telemetry.py new file mode 100644 index 0000000..9f7dd07 --- /dev/null +++ b/gavaconnect/http/telemetry.py @@ -0,0 +1,36 @@ +"""OpenTelemetry tracing utilities for HTTP requests.""" + +import httpx +from opentelemetry import trace + +tracer = trace.get_tracer("gavaconnect") + + +async def otel_request_span(req: httpx.Request) -> None: + """Start an OpenTelemetry span for an HTTP request. + + Args: + req: The HTTP request to trace. + + """ + span = tracer.start_span( + "http.client", attributes={"http.method": req.method, "http.url": str(req.url)} + ) + req.extensions["otel_span"] = span + + +async def otel_response_span(req: httpx.Request, resp: httpx.Response) -> None: + """Complete an OpenTelemetry span for an HTTP response. + + Args: + req: The HTTP request. + resp: The HTTP response. + + """ + span = req.extensions.pop("otel_span", None) + if span: + span.set_attribute("http.status_code", resp.status_code) + rid = resp.headers.get("x-request-id") + if rid: + span.set_attribute("http.response.request_id", rid) + span.end() diff --git a/gavaconnect/http/transport.py b/gavaconnect/http/transport.py new file mode 100644 index 0000000..34fbea0 --- /dev/null +++ b/gavaconnect/http/transport.py @@ -0,0 +1,138 @@ +"""HTTP transport implementation with retry logic and error handling.""" + +from __future__ import annotations + +import asyncio +import json +import random +from typing import Any + +import httpx + +from gavaconnect.auth import AuthPolicy +from gavaconnect.config import SDKConfig +from gavaconnect.errors import APIError, RateLimitError, TransportError + + +def _jitter(base: float, attempt: int) -> float: + return float(base * (2 ** (attempt - 1)) * (1 + random.random() * 0.2)) # nosec B311 + + +class AsyncTransport: + """Async HTTP transport with retry logic and authentication support.""" + + def __init__(self, cfg: SDKConfig) -> None: + """Initialize the async transport. + + Args: + cfg: SDK configuration containing timeout and retry settings. + + """ + self.cfg = cfg + self.client = httpx.AsyncClient( + base_url=cfg.base_url, + http2=True, + timeout=httpx.Timeout( + cfg.total_timeout_s, + read=cfg.read_timeout_s, + connect=cfg.connect_timeout_s, + ), + headers={"user-agent": cfg.user_agent, "x-client-version": cfg.user_agent}, + ) + + async def close(self) -> None: + """Close the underlying HTTP client.""" + await self.client.aclose() + + async def request( + self, + method: str, + url: str, + *, + auth: AuthPolicy | None = None, + **kw: Any, # noqa: ANN401 + ) -> httpx.Response: + """Make an HTTP request with retry logic and authentication. + + Args: + method: HTTP method (GET, POST, etc.). + url: Request URL. + auth: Optional authentication policy. + **kw: Additional keyword arguments for the request. + + Returns: + The HTTP response. + + Raises: + TransportError: If the request fails after all retries. + + """ + req = self.client.build_request(method, url, **kw) + if auth: + await auth.authorize(req) + attempt = 1 + while True: + try: + resp = await self.client.send(req, stream=False) + except httpx.HTTPError as e: + if attempt > self.cfg.retry.max_attempts: + raise TransportError(str(e)) from e + await asyncio.sleep(_jitter(self.cfg.retry.base_backoff_s, attempt)) + attempt += 1 + continue + if resp.status_code == 401 and auth and await auth.on_unauthorized(): + req = self.client.build_request(method, url, **kw) + await auth.authorize(req) + resp = await self.client.send(req, stream=False) + if ( + resp.status_code in self.cfg.retry.retry_on_status + and attempt <= self.cfg.retry.max_attempts + ): + ra = resp.headers.get("retry-after") + backoff = ( + float(ra) + if ra and ra.isdigit() + else _jitter(self.cfg.retry.base_backoff_s, attempt) + ) + await asyncio.sleep(backoff) + attempt += 1 + continue + return resp + + @staticmethod + def raise_for_api_error(resp: httpx.Response) -> None: + """Raise appropriate API error based on response status and content. + + Args: + resp: HTTP response to check for errors. + + Raises: + APIError: For general API errors. + RateLimitError: For rate limit errors (status 429). + + """ + if resp.status_code < 400: + return + try: + b = resp.json() + err = b.get("error", {}) + except (json.JSONDecodeError, ValueError) as e: + raise APIError( + resp.status_code, + "api_error", + resp.text, + None, + resp.headers.get("x-request-id"), + None, + resp.content, + ) from e + type_ = err.get("type") or "api_error" + msg = err.get("message") or resp.text + code = err.get("code") + rid = resp.headers.get("x-request-id") + ra = err.get("retry_after") + if resp.status_code == 429: + raise RateLimitError( + resp.status_code, type_, msg, code, rid, ra, resp.content + ) + raise APIError(resp.status_code, type_, msg, code, rid, ra, resp.content) diff --git a/pyproject.toml b/pyproject.toml index 1bc3681..c8ab5f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,4 +133,7 @@ dev = [ "pytest-cov>=6.2.1", "pytest-asyncio>=0.25.0", "respx>=0.22.0", + "opentelemetry-api>=1.36.0", + "opentelemetry-sdk>=1.36.0", + "h2>=4.2.0", ] diff --git a/tests/test_auth_module.py b/tests/test_auth_module.py index 540e1d4..9b83238 100644 --- a/tests/test_auth_module.py +++ b/tests/test_auth_module.py @@ -28,6 +28,7 @@ def test_module_has_all_attribute(self): assert isinstance(auth.__all__, list) expected_exports = { + "AuthPolicy", "BasicAuthPolicy", "BasicCredentials", "BearerAuthPolicy", diff --git a/tests/test_http_logging.py b/tests/test_http_logging.py new file mode 100644 index 0000000..5f49317 --- /dev/null +++ b/tests/test_http_logging.py @@ -0,0 +1,165 @@ +"""Tests for HTTP logging utilities.""" + +import logging +import time +from unittest.mock import patch + +import httpx +import pytest +from pytest import LogCaptureFixture + +from gavaconnect.http.logging import log_request, log_response + + +class TestLogRequest: + """Test log_request function.""" + + @pytest.mark.asyncio + async def test_log_request_basic(self, caplog: LogCaptureFixture): + """Test basic request logging.""" + # Create a mock request + req = httpx.Request("GET", "https://api.example.com/test") + req.extensions = {} + + with caplog.at_level(logging.DEBUG, logger="gavaconnect"): + await log_request(req) + + # Check that start_time was set + assert "start_time" in req.extensions + assert isinstance(req.extensions["start_time"], float) + + # Check the logged message + assert len(caplog.records) == 1 + record = caplog.records[0] + assert record.levelname == "DEBUG" + assert "HTTP GET https://api.example.com/test" in record.message + assert "headers=" in record.message + + @pytest.mark.asyncio + async def test_log_request_with_authorization_header( + self, caplog: LogCaptureFixture + ): + """Test that authorization headers are removed from logs.""" + headers = { + "authorization": "Bearer secret-token", + "content-type": "application/json", + "x-custom": "value", + } + req = httpx.Request("POST", "https://api.example.com/test", headers=headers) + req.extensions = {} + + with caplog.at_level(logging.DEBUG, logger="gavaconnect"): + await log_request(req) + + # Check that authorization header is not in the log + record = caplog.records[0] + assert "secret-token" not in record.message + assert "Bearer" not in record.message + # But other headers should be present + assert "content-type" in record.message + assert "x-custom" in record.message + + @pytest.mark.asyncio + async def test_log_request_timing(self): + """Test that timing is properly recorded.""" + req = httpx.Request("GET", "https://api.example.com/test") + req.extensions = {} + + before_time = time.perf_counter() + await log_request(req) + after_time = time.perf_counter() + + # Check that start_time is within reasonable bounds + start_time = req.extensions["start_time"] + assert before_time <= start_time <= after_time + + +class TestLogResponse: + """Test log_response function.""" + + @pytest.mark.asyncio + async def test_log_response_basic(self, caplog: LogCaptureFixture): + """Test basic response logging.""" + # Create a mock request with start_time + req = httpx.Request("GET", "https://api.example.com/test") + req.extensions = {"start_time": time.perf_counter() - 0.1} # 100ms ago + + # Create a mock response + resp = httpx.Response( + status_code=200, + headers={"x-request-id": "req-123"}, + content=b'{"result": "success"}', + ) + + with caplog.at_level(logging.INFO, logger="gavaconnect"): + await log_response(req, resp) + + # Check the logged message + assert len(caplog.records) == 1 + record = caplog.records[0] + assert record.levelname == "INFO" + assert "HTTP GET https://api.example.com/test -> 200" in record.message + assert "request_id=req-123" in record.message + assert "in " in record.message and "s" in record.message # timing info + + @pytest.mark.asyncio + async def test_log_response_without_start_time(self, caplog: LogCaptureFixture): + """Test response logging when start_time is missing.""" + # Create a mock request without start_time + req = httpx.Request("POST", "https://api.example.com/test") + req.extensions = {} + + # Create a mock response + resp = httpx.Response(status_code=201) + + with caplog.at_level(logging.INFO, logger="gavaconnect"): + await log_response(req, resp) + + # Should still log without error + assert len(caplog.records) == 1 + record = caplog.records[0] + assert "HTTP POST https://api.example.com/test -> 201" in record.message + + @pytest.mark.asyncio + async def test_log_response_without_request_id(self, caplog: LogCaptureFixture): + """Test response logging when request ID is missing.""" + # Create a mock request with start_time + req = httpx.Request("GET", "https://api.example.com/test") + req.extensions = {"start_time": time.perf_counter()} + + # Create a mock response without request ID + resp = httpx.Response(status_code=404) + + with caplog.at_level(logging.INFO, logger="gavaconnect"): + await log_response(req, resp) + + # Check the logged message + record = caplog.records[0] + assert "request_id=None" in record.message + + @pytest.mark.asyncio + async def test_log_response_timing_calculation(self): + """Test that timing calculation works correctly.""" + # Create a mock request with a specific start_time + start_time = time.perf_counter() - 0.5 # 500ms ago + req = httpx.Request("GET", "https://api.example.com/test") + req.extensions = {"start_time": start_time} + + # Create a mock response + resp = httpx.Response(status_code=200) + + with patch("gavaconnect.http.logging.logger") as mock_logger: + await log_response(req, resp) + + # Check that the timing was calculated + mock_logger.info.assert_called_once() + call_args = mock_logger.info.call_args[0][0] + + # The duration should be approximately 0.5 seconds + # Extract the duration from the log message + import re + + match = re.search(r"in (\d+\.\d+)s", call_args) + assert match is not None + duration = float(match.group(1)) + assert 0.4 <= duration <= 0.6 # Allow some tolerance diff --git a/tests/test_http_module.py b/tests/test_http_module.py new file mode 100644 index 0000000..c5398bd --- /dev/null +++ b/tests/test_http_module.py @@ -0,0 +1,69 @@ +"""Tests for HTTP module imports and initialization.""" + + +class TestHttpModuleImports: + """Test that the HTTP module imports work correctly.""" + + def test_import_logging_functions(self): + """Test importing logging functions.""" + from gavaconnect.http import log_request, log_response + + assert callable(log_request) + assert callable(log_response) + + def test_import_telemetry_functions(self): + """Test importing telemetry functions.""" + from gavaconnect.http import otel_request_span, otel_response_span + + assert callable(otel_request_span) + assert callable(otel_response_span) + + def test_import_transport_class(self): + """Test importing transport class.""" + from gavaconnect.http import AsyncTransport + + assert AsyncTransport is not None + # Verify it's a class + assert isinstance(AsyncTransport, type) + + def test_all_exports(self): + """Test that __all__ contains expected exports.""" + import gavaconnect.http as http_module + + expected_exports = [ + "log_request", + "log_response", + "otel_request_span", + "otel_response_span", + "AsyncTransport", + ] + + assert hasattr(http_module, "__all__") + + # Check that all expected items are in __all__ + for export in expected_exports: + assert export in http_module.__all__ + + # Check that all items in __all__ are actually available + for export in http_module.__all__: + assert hasattr(http_module, export) + + def test_direct_module_import(self): + """Test importing the module directly.""" + import gavaconnect.http + + # Should have a docstring + assert gavaconnect.http.__doc__ is not None + assert "HTTP transport layer" in gavaconnect.http.__doc__ + + def test_individual_submodule_imports(self): + """Test that individual submodules can be imported.""" + # Test individual imports don't raise errors + import gavaconnect.http.logging + import gavaconnect.http.telemetry + import gavaconnect.http.transport + + # Verify they have the expected content + assert hasattr(gavaconnect.http.logging, "log_request") + assert hasattr(gavaconnect.http.telemetry, "otel_request_span") + assert hasattr(gavaconnect.http.transport, "AsyncTransport") diff --git a/tests/test_http_telemetry.py b/tests/test_http_telemetry.py new file mode 100644 index 0000000..87450c6 --- /dev/null +++ b/tests/test_http_telemetry.py @@ -0,0 +1,203 @@ +"""Tests for HTTP telemetry utilities.""" + +from unittest.mock import Mock, patch + +import httpx +import pytest + +from gavaconnect.http.telemetry import otel_request_span, otel_response_span + + +class TestOtelRequestSpan: + """Test otel_request_span function.""" + + @pytest.mark.asyncio + async def test_otel_request_span_basic(self): + """Test basic OpenTelemetry span creation.""" + # Create a mock request + req = httpx.Request("GET", "https://api.example.com/test") + req.extensions = {} + + # Mock the tracer + with patch("gavaconnect.http.telemetry.tracer") as mock_tracer: + mock_span = Mock() + mock_tracer.start_span.return_value = mock_span + + await otel_request_span(req) + + # Verify span creation + mock_tracer.start_span.assert_called_once_with( + "http.client", + attributes={ + "http.method": "GET", + "http.url": "https://api.example.com/test", + }, + ) + + # Verify span is stored in extensions + assert req.extensions["otel_span"] == mock_span + + @pytest.mark.asyncio + async def test_otel_request_span_different_methods(self): + """Test span creation with different HTTP methods.""" + methods_and_urls = [ + ("POST", "https://api.example.com/create"), + ("PUT", "https://api.example.com/update/123"), + ("DELETE", "https://api.example.com/delete/456"), + ("PATCH", "https://api.example.com/patch/789"), + ] + + for method, url in methods_and_urls: + req = httpx.Request(method, url) + req.extensions = {} + + with patch("gavaconnect.http.telemetry.tracer") as mock_tracer: + mock_span = Mock() + mock_tracer.start_span.return_value = mock_span + + await otel_request_span(req) + + # Verify correct attributes + mock_tracer.start_span.assert_called_once_with( + "http.client", attributes={"http.method": method, "http.url": url} + ) + + +class TestOtelResponseSpan: + """Test otel_response_span function.""" + + @pytest.mark.asyncio + async def test_otel_response_span_basic(self): + """Test basic OpenTelemetry span completion.""" + # Create a mock request with an otel span + req = httpx.Request("GET", "https://api.example.com/test") + mock_span = Mock() + req.extensions = {"otel_span": mock_span} + + # Create a mock response + resp = httpx.Response(status_code=200, headers={"x-request-id": "req-123"}) + + await otel_response_span(req, resp) + + # Verify span attributes were set + mock_span.set_attribute.assert_any_call("http.status_code", 200) + mock_span.set_attribute.assert_any_call("http.response.request_id", "req-123") + + # Verify span was ended + mock_span.end.assert_called_once() + + # Verify span was removed from extensions + assert "otel_span" not in req.extensions + + @pytest.mark.asyncio + async def test_otel_response_span_without_request_id(self): + """Test span completion when response has no request ID.""" + # Create a mock request with an otel span + req = httpx.Request("POST", "https://api.example.com/test") + mock_span = Mock() + req.extensions = {"otel_span": mock_span} + + # Create a mock response without request ID + resp = httpx.Response(status_code=404) + + await otel_response_span(req, resp) + + # Verify only status code was set (no request ID) + mock_span.set_attribute.assert_called_once_with("http.status_code", 404) + + # Verify span was still ended + mock_span.end.assert_called_once() + + @pytest.mark.asyncio + async def test_otel_response_span_no_span_in_request(self): + """Test span completion when no span exists in request.""" + # Create a mock request without an otel span + req = httpx.Request("GET", "https://api.example.com/test") + req.extensions = {} + + # Create a mock response + resp = httpx.Response(status_code=200) + + # Should not raise an error + await otel_response_span(req, resp) + + # Extensions should still be empty + assert req.extensions == {} + + @pytest.mark.asyncio + async def test_otel_response_span_different_status_codes(self): + """Test span completion with different status codes.""" + status_codes = [200, 201, 400, 401, 404, 500, 502] + + for status_code in status_codes: + req = httpx.Request("GET", "https://api.example.com/test") + mock_span = Mock() + req.extensions = {"otel_span": mock_span} + + resp = httpx.Response(status_code=status_code) + + await otel_response_span(req, resp) + + # Verify correct status code was set + mock_span.set_attribute.assert_called_with("http.status_code", status_code) + mock_span.end.assert_called_once() + + # Reset for next iteration + mock_span.reset_mock() + + @pytest.mark.asyncio + async def test_otel_response_span_with_existing_extensions(self): + """Test that other extensions are preserved.""" + # Create a mock request with multiple extensions + req = httpx.Request("GET", "https://api.example.com/test") + mock_span = Mock() + req.extensions = { + "otel_span": mock_span, + "start_time": 12345.0, + "custom_data": "test_value", + } + + # Create a mock response + resp = httpx.Response(status_code=200) + + await otel_response_span(req, resp) + + # Verify span was removed but other extensions remain + assert "otel_span" not in req.extensions + assert req.extensions["start_time"] == 12345.0 + assert req.extensions["custom_data"] == "test_value" + + @pytest.mark.asyncio + async def test_integration_request_and_response_spans(self): + """Test integration between request and response span functions.""" + # Create a mock request + req = httpx.Request("POST", "https://api.example.com/test") + req.extensions = {} + + # Mock the tracer for request span + with patch("gavaconnect.http.telemetry.tracer") as mock_tracer: + mock_span = Mock() + mock_tracer.start_span.return_value = mock_span + + # Start request span + await otel_request_span(req) + + # Verify span is in extensions + assert req.extensions["otel_span"] == mock_span + + # Create response and complete span + resp = httpx.Response( + status_code=201, headers={"x-request-id": "integration-test-123"} + ) + + await otel_response_span(req, resp) + + # Verify span completion + mock_span.set_attribute.assert_any_call("http.status_code", 201) + mock_span.set_attribute.assert_any_call( + "http.response.request_id", "integration-test-123" + ) + mock_span.end.assert_called_once() + + # Verify span was removed + assert "otel_span" not in req.extensions diff --git a/tests/test_http_transport.py b/tests/test_http_transport.py new file mode 100644 index 0000000..7c22b9d --- /dev/null +++ b/tests/test_http_transport.py @@ -0,0 +1,486 @@ +"""Tests for HTTP transport layer.""" + +import json +from unittest.mock import AsyncMock, Mock, patch + +import httpx +import pytest + +from gavaconnect.auth import AuthPolicy +from gavaconnect.config import RetryPolicy, SDKConfig +from gavaconnect.errors import APIError, RateLimitError, TransportError +from gavaconnect.http.transport import AsyncTransport, _jitter + + +class TestJitter: + """Test the _jitter function.""" + + def test_jitter_calculation(self): + """Test jitter calculation with different inputs.""" + # Test with base=1.0, attempt=1 + result = _jitter(1.0, 1) + # Should be base * (2^0) * (1 + random * 0.2) = 1.0 * 1 * (1.0 to 1.2) + assert 1.0 <= result <= 1.2 + + # Test with base=0.5, attempt=2 + result = _jitter(0.5, 2) + # Should be 0.5 * (2^1) * (1.0 to 1.2) = 1.0 to 1.2 + assert 1.0 <= result <= 1.2 + + # Test with base=0.2, attempt=3 + result = _jitter(0.2, 3) + # Should be 0.2 * (2^2) * (1.0 to 1.2) = 0.8 to 0.96 + assert 0.8 <= result <= 0.96 + + def test_jitter_randomness(self): + """Test that jitter produces different results.""" + results = [_jitter(1.0, 1) for _ in range(10)] + # Results should not all be the same (very unlikely) + assert len(set(results)) > 1 + + +class TestAsyncTransport: + """Test AsyncTransport class.""" + + def test_init(self): + """Test AsyncTransport initialization.""" + config = SDKConfig( + base_url="https://api.example.com", + connect_timeout_s=10.0, + read_timeout_s=60.0, + total_timeout_s=70.0, + user_agent="test-agent/1.0.0", + ) + + transport = AsyncTransport(config) + + assert transport.cfg == config + assert isinstance(transport.client, httpx.AsyncClient) + assert str(transport.client.base_url).rstrip("/") == "https://api.example.com" + + @pytest.mark.asyncio + async def test_close(self): + """Test transport close method.""" + config = SDKConfig(base_url="https://api.example.com") + transport = AsyncTransport(config) + + # Test that close works without error + await transport.close() + + @pytest.mark.asyncio + async def test_successful_request(self): + """Test successful HTTP request.""" + config = SDKConfig(base_url="https://api.example.com") + transport = AsyncTransport(config) + + # Mock the client methods + mock_request = Mock() + mock_response = Mock() + mock_response.status_code = 200 + + with ( + patch.object(transport.client, "build_request", return_value=mock_request), + patch.object( + transport.client, + "send", + new_callable=AsyncMock, + return_value=mock_response, + ), + ): + result = await transport.request("GET", "/test") + + assert result == mock_response + transport.client.build_request.assert_called_once_with("GET", "/test") + transport.client.send.assert_called_once_with(mock_request, stream=False) + + await transport.close() + + @pytest.mark.asyncio + async def test_request_with_auth(self): + """Test request with authentication.""" + config = SDKConfig(base_url="https://api.example.com") + transport = AsyncTransport(config) + + # Mock auth policy + auth = Mock(spec=AuthPolicy) + auth.authorize = AsyncMock() + auth.on_unauthorized = AsyncMock(return_value=False) + + # Mock the client methods + mock_request = Mock() + mock_response = Mock() + mock_response.status_code = 200 + + with ( + patch.object(transport.client, "build_request", return_value=mock_request), + patch.object( + transport.client, + "send", + new_callable=AsyncMock, + return_value=mock_response, + ), + ): + result = await transport.request( + "POST", "/test", auth=auth, json={"data": "test"} + ) + + assert result == mock_response + auth.authorize.assert_called_once_with(mock_request) + + await transport.close() + + @pytest.mark.asyncio + async def test_request_with_401_and_retry(self): + """Test request handling 401 with auth retry.""" + config = SDKConfig(base_url="https://api.example.com") + transport = AsyncTransport(config) + + # Mock auth policy + auth = Mock(spec=AuthPolicy) + auth.authorize = AsyncMock() + auth.on_unauthorized = AsyncMock(return_value=True) # Retry auth + + # Mock responses: first 401, then 200 + mock_request = Mock() + mock_response_401 = Mock() + mock_response_401.status_code = 401 + mock_response_200 = Mock() + mock_response_200.status_code = 200 + + with ( + patch.object(transport.client, "build_request", return_value=mock_request), + patch.object( + transport.client, + "send", + new_callable=AsyncMock, + side_effect=[mock_response_401, mock_response_200], + ), + ): + result = await transport.request("GET", "/test", auth=auth) + + assert result == mock_response_200 + # Auth should be called twice (initial and retry) + assert auth.authorize.call_count == 2 + auth.on_unauthorized.assert_called_once() + + await transport.close() + + @pytest.mark.asyncio + async def test_request_with_http_error_retry(self): + """Test request retry on HTTP errors.""" + config = SDKConfig( + base_url="https://api.example.com", + retry=RetryPolicy( + max_attempts=2, base_backoff_s=0.01 + ), # Fast retry for testing + ) + transport = AsyncTransport(config) + + mock_request = Mock() + http_error = httpx.ConnectError("Connection failed") + mock_response = Mock() + mock_response.status_code = 200 + + with ( + patch.object(transport.client, "build_request", return_value=mock_request), + patch.object( + transport.client, + "send", + new_callable=AsyncMock, + side_effect=[http_error, mock_response], + ), + patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep, + ): + result = await transport.request("GET", "/test") + + assert result == mock_response + mock_sleep.assert_called_once() # Should have slept once for retry + + await transport.close() + + @pytest.mark.asyncio + async def test_request_max_retries_exceeded(self): + """Test request fails after max retries.""" + config = SDKConfig( + base_url="https://api.example.com", + retry=RetryPolicy(max_attempts=2, base_backoff_s=0.01), + ) + transport = AsyncTransport(config) + + mock_request = Mock() + http_error = httpx.ConnectError("Connection failed") + + with ( + patch.object(transport.client, "build_request", return_value=mock_request), + patch.object( + transport.client, "send", new_callable=AsyncMock, side_effect=http_error + ), + pytest.raises(TransportError, match="Connection failed"), + ): + await transport.request("GET", "/test") + + await transport.close() + + @pytest.mark.asyncio + async def test_request_with_status_code_retry(self): + """Test request retry on specific status codes.""" + config = SDKConfig( + base_url="https://api.example.com", + retry=RetryPolicy( + max_attempts=2, base_backoff_s=0.01, retry_on_status=(503,) + ), + ) + transport = AsyncTransport(config) + + mock_request = Mock() + mock_response_503 = Mock() + mock_response_503.status_code = 503 + mock_response_503.headers = {} + mock_response_200 = Mock() + mock_response_200.status_code = 200 + + with ( + patch.object(transport.client, "build_request", return_value=mock_request), + patch.object( + transport.client, + "send", + new_callable=AsyncMock, + side_effect=[mock_response_503, mock_response_200], + ), + patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep, + ): + result = await transport.request("GET", "/test") + + assert result == mock_response_200 + mock_sleep.assert_called_once() + + await transport.close() + + @pytest.mark.asyncio + async def test_request_with_retry_after_header(self): + """Test request respects retry-after header.""" + config = SDKConfig( + base_url="https://api.example.com", + retry=RetryPolicy(max_attempts=2, retry_on_status=(429,)), + ) + transport = AsyncTransport(config) + + mock_request = Mock() + mock_response_429 = Mock() + mock_response_429.status_code = 429 + mock_response_429.headers = {"retry-after": "2"} + mock_response_200 = Mock() + mock_response_200.status_code = 200 + + with ( + patch.object(transport.client, "build_request", return_value=mock_request), + patch.object( + transport.client, + "send", + new_callable=AsyncMock, + side_effect=[mock_response_429, mock_response_200], + ), + patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep, + ): + result = await transport.request("GET", "/test") + + assert result == mock_response_200 + mock_sleep.assert_called_once_with(2.0) + + await transport.close() + + +class TestRaiseForApiError: + """Test the raise_for_api_error static method.""" + + def test_no_error_for_success_status(self): + """Test no exception for successful status codes.""" + for status_code in [200, 201, 202, 204]: + resp = Mock() + resp.status_code = status_code + + # Should not raise any exception + AsyncTransport.raise_for_api_error(resp) + + def test_api_error_with_json_response(self): + """Test APIError with proper JSON error response.""" + resp = Mock() + resp.status_code = 400 + resp.json.return_value = { + "error": { + "type": "validation_error", + "message": "Invalid input", + "code": "INVALID_INPUT", + } + } + resp.headers = {"x-request-id": "req-123"} + resp.content = b'{"error": {"type": "validation_error"}}' + + with pytest.raises(APIError) as exc_info: + AsyncTransport.raise_for_api_error(resp) + + error = exc_info.value + assert error.status == 400 + assert error.type == "validation_error" + assert str(error) == "Invalid input" # message is in the exception string + assert error.code == "INVALID_INPUT" + assert error.request_id == "req-123" + + def test_rate_limit_error(self): + """Test RateLimitError for 429 status.""" + resp = Mock() + resp.status_code = 429 + resp.json.return_value = { + "error": { + "type": "rate_limit_exceeded", + "message": "Too many requests", + "retry_after": 30.0, + } + } + resp.headers = {"x-request-id": "req-456"} + resp.content = b'{"error": {"type": "rate_limit_exceeded"}}' + + with pytest.raises(RateLimitError) as exc_info: + AsyncTransport.raise_for_api_error(resp) + + error = exc_info.value + assert error.status == 429 + assert error.type == "rate_limit_exceeded" + assert str(error) == "Too many requests" + assert error.retry_after_s == 30.0 + + def test_api_error_with_invalid_json(self): + """Test APIError when response JSON is invalid.""" + resp = Mock() + resp.status_code = 500 + resp.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) + resp.text = "Internal Server Error" + resp.headers = {"x-request-id": "req-789"} + resp.content = b"Internal Server Error" + + with pytest.raises(APIError) as exc_info: + AsyncTransport.raise_for_api_error(resp) + + error = exc_info.value + assert error.status == 500 + assert error.type == "api_error" + assert str(error) == "Internal Server Error" + assert error.request_id == "req-789" + + def test_api_error_with_missing_error_field(self): + """Test APIError when error field is missing from JSON.""" + resp = Mock() + resp.status_code = 404 + resp.json.return_value = {"message": "Not found"} # No "error" field + resp.text = "Not Found" + resp.headers = {} + resp.content = b'{"message": "Not found"}' + + with pytest.raises(APIError) as exc_info: + AsyncTransport.raise_for_api_error(resp) + + error = exc_info.value + assert error.status == 404 + assert error.type == "api_error" + assert str(error) == "Not Found" # Falls back to resp.text + + def test_api_error_defaults(self): + """Test APIError with minimal error information.""" + resp = Mock() + resp.status_code = 422 + resp.json.return_value = {"error": {}} # Empty error object + resp.text = "Unprocessable Entity" + resp.headers = {} + resp.content = b'{"error": {}}' + + with pytest.raises(APIError) as exc_info: + AsyncTransport.raise_for_api_error(resp) + + error = exc_info.value + assert error.status == 422 + assert error.type == "api_error" # Default type + assert str(error) == "Unprocessable Entity" # Falls back to resp.text + assert error.code is None + assert error.request_id is None + assert error.retry_after_s is None + + +@pytest.fixture +async def transport(): + """Fixture providing a configured AsyncTransport instance.""" + config = SDKConfig( + base_url="https://api.example.com", + connect_timeout_s=5.0, + read_timeout_s=30.0, + total_timeout_s=40.0, + ) + transport = AsyncTransport(config) + yield transport + await transport.close() + + +class TestAsyncTransportIntegration: + """Integration tests for AsyncTransport.""" + + @pytest.mark.asyncio + async def test_complete_request_flow(self, transport: AsyncTransport): + """Test complete request flow with mocked httpx client.""" + # Mock a complete successful request + mock_request = Mock() + mock_response = Mock() + mock_response.status_code = 200 + + with ( + patch.object(transport.client, "build_request", return_value=mock_request), + patch.object( + transport.client, + "send", + new_callable=AsyncMock, + return_value=mock_response, + ), + ): + result = await transport.request( + "POST", "/api/test", json={"test": "data"}, headers={"custom": "header"} + ) + + assert result == mock_response + + # Verify build_request was called with correct parameters + transport.client.build_request.assert_called_once_with( + "POST", "/api/test", json={"test": "data"}, headers={"custom": "header"} + ) + + @pytest.mark.asyncio + async def test_request_with_keyword_arguments(self, transport: AsyncTransport): + """Test that keyword arguments are properly passed through.""" + mock_request = Mock() + mock_response = Mock() + mock_response.status_code = 201 + + with ( + patch.object(transport.client, "build_request", return_value=mock_request), + patch.object( + transport.client, + "send", + new_callable=AsyncMock, + return_value=mock_response, + ), + ): + result = await transport.request( + "PUT", + "/api/update/123", + json={"name": "updated"}, + headers={"authorization": "Bearer token"}, + params={"version": "v1"}, + timeout=60.0, + ) + + assert result == mock_response + + # Verify all kwargs were passed to build_request + call_args = transport.client.build_request.call_args + assert call_args[0] == ("PUT", "/api/update/123") + assert call_args[1]["json"] == {"name": "updated"} + assert call_args[1]["headers"]["authorization"] == "Bearer token" + assert call_args[1]["params"] == {"version": "v1"} + assert call_args[1]["timeout"] == 60.0 diff --git a/uv.lock b/uv.lock index 23d6ad7..a68613b 100644 --- a/uv.lock +++ b/uv.lock @@ -122,6 +122,9 @@ dev = [ [package.dev-dependencies] dev = [ + { name = "h2" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-sdk" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, @@ -143,6 +146,9 @@ provides-extras = ["dev"] [package.metadata.requires-dev] dev = [ + { name = "h2", specifier = ">=4.2.0" }, + { name = "opentelemetry-api", specifier = ">=1.36.0" }, + { name = "opentelemetry-sdk", specifier = ">=1.36.0" }, { name = "pytest", specifier = ">=8.4.1" }, { name = "pytest-asyncio", specifier = ">=0.25.0" }, { name = "pytest-cov", specifier = ">=6.2.1" }, @@ -158,6 +164,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, ] +[[package]] +name = "h2" +version = "4.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "hpack" }, + { name = "hyperframe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1b/38/d7f80fd13e6582fb8e0df8c9a653dcc02b03ca34f4d72f34869298c5baf8/h2-4.2.0.tar.gz", hash = "sha256:c8a52129695e88b1a0578d8d2cc6842bbd79128ac685463b887ee278126ad01f", size = 2150682, upload-time = "2025-02-02T07:43:51.815Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/9e/984486f2d0a0bd2b024bf4bc1c62688fcafa9e61991f041fb0e2def4a982/h2-4.2.0-py3-none-any.whl", hash = "sha256:479a53ad425bb29af087f3458a61d30780bc818e4ebcf01f0b536ba916462ed0", size = 60957, upload-time = "2025-02-01T11:02:26.481Z" }, +] + +[[package]] +name = "hpack" +version = "4.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/48/71de9ed269fdae9c8057e5a4c0aa7402e8bb16f2c6e90b3aa53327b113f8/hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca", size = 51276, upload-time = "2025-01-22T21:44:58.347Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/c6/80c95b1b2b94682a72cbdbfb85b81ae2daffa4291fbfa1b1464502ede10d/hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496", size = 34357, upload-time = "2025-01-22T21:44:56.92Z" }, +] + [[package]] name = "httpcore" version = "1.0.9" @@ -186,6 +214,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] +[[package]] +name = "hyperframe" +version = "6.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/02/e7/94f8232d4a74cc99514c13a9f995811485a6903d48e5d952771ef6322e30/hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08", size = 26566, upload-time = "2025-01-22T21:41:49.302Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007, upload-time = "2025-01-22T21:41:47.295Z" }, +] + [[package]] name = "idna" version = "3.10" @@ -195,6 +232,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, ] +[[package]] +name = "importlib-metadata" +version = "8.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "zipp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/76/66/650a33bd90f786193e4de4b3ad86ea60b53c89b669a5c7be931fac31cdb0/importlib_metadata-8.7.0.tar.gz", hash = "sha256:d13b81ad223b890aa16c5471f2ac3056cf76c5f10f82d6f9292f0b415f389000", size = 56641, upload-time = "2025-04-27T15:29:01.736Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656, upload-time = "2025-04-27T15:29:00.214Z" }, +] + [[package]] name = "iniconfig" version = "2.1.0" @@ -260,6 +309,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, ] +[[package]] +name = "opentelemetry-api" +version = "1.36.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-metadata" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/27/d2/c782c88b8afbf961d6972428821c302bd1e9e7bc361352172f0ca31296e2/opentelemetry_api-1.36.0.tar.gz", hash = "sha256:9a72572b9c416d004d492cbc6e61962c0501eaf945ece9b5a0f56597d8348aa0", size = 64780, upload-time = "2025-07-29T15:12:06.02Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/ee/6b08dde0a022c463b88f55ae81149584b125a42183407dc1045c486cc870/opentelemetry_api-1.36.0-py3-none-any.whl", hash = "sha256:02f20bcacf666e1333b6b1f04e647dc1d5111f86b8e510238fcc56d7762cda8c", size = 65564, upload-time = "2025-07-29T15:11:47.998Z" }, +] + +[[package]] +name = "opentelemetry-sdk" +version = "1.36.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4c/85/8567a966b85a2d3f971c4d42f781c305b2b91c043724fa08fd37d158e9dc/opentelemetry_sdk-1.36.0.tar.gz", hash = "sha256:19c8c81599f51b71670661ff7495c905d8fdf6976e41622d5245b791b06fa581", size = 162557, upload-time = "2025-07-29T15:12:16.76Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/59/7bed362ad1137ba5886dac8439e84cd2df6d087be7c09574ece47ae9b22c/opentelemetry_sdk-1.36.0-py3-none-any.whl", hash = "sha256:19fe048b42e98c5c1ffe85b569b7073576ad4ce0bcb6e9b4c6a39e890a6c45fb", size = 119995, upload-time = "2025-07-29T15:12:03.181Z" }, +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.57b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7e/31/67dfa252ee88476a29200b0255bda8dfc2cf07b56ad66dc9a6221f7dc787/opentelemetry_semantic_conventions-0.57b0.tar.gz", hash = "sha256:609a4a79c7891b4620d64c7aac6898f872d790d75f22019913a660756f27ff32", size = 124225, upload-time = "2025-07-29T15:12:17.873Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/05/75/7d591371c6c39c73de5ce5da5a2cc7b72d1d1cd3f8f4638f553c01c37b11/opentelemetry_semantic_conventions-0.57b0-py3-none-any.whl", hash = "sha256:757f7e76293294f124c827e514c2a3144f191ef175b069ce8d1211e1e38e9e78", size = 201627, upload-time = "2025-07-29T15:12:04.174Z" }, +] + [[package]] name = "packaging" version = "25.0" @@ -456,3 +545,12 @@ sdist = { url = "https://files.pythonhosted.org/packages/98/5a/da40306b885cc8c09 wheels = [ { url = "https://files.pythonhosted.org/packages/b5/00/d631e67a838026495268c2f6884f3711a15a9a2a96cd244fdaea53b823fb/typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76", size = 43906, upload-time = "2025-07-04T13:28:32.743Z" }, ] + +[[package]] +name = "zipp" +version = "3.23.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547, upload-time = "2025-06-08T17:06:39.4Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, +]