diff --git a/.coverage b/.coverage index c6d8598..ce2d005 100644 Binary files a/.coverage and b/.coverage differ diff --git a/coverage.xml b/coverage.xml index 002187c..4785cc5 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,5 +1,5 @@ - + @@ -11,9 +11,162 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/gavaconnect/auth/README.md b/gavaconnect/auth/README.md new file mode 100644 index 0000000..d204d08 --- /dev/null +++ b/gavaconnect/auth/README.md @@ -0,0 +1,48 @@ +# Authentication Design — `gavaconnect` Python SDK + +## Introduction + +The SDK implements authentication as a **pluggable policy** so each endpoint family (`checkers`, `tax`, `payments`, `authorization`) can use the scheme it requires while sharing a common transport layer. The SDK supports: + +* **Basic** (static header from `client_id:client_secret`) +* **Bearer** (OAuth2 Client Credentials) with **concurrency-safe caching**, **early refresh**, and **401-triggered single retry** + +Design goals: **credential isolation per resource**, **safe token lifecycle**, **consistent retries/timeouts**, and **extensibility** (e.g., API-Key, HMAC, mTLS) without changing call sites. + +--- + +## High-Level Architecture + +* Each resource client is constructed with an **`AuthPolicy`**: `BasicAuthPolicy` or `BearerAuthPolicy(TokenProvider)`. +* The shared **AsyncTransport**: + + * Calls `authorize(request)` before send. + * On **401**, calls `on_unauthorized()` (Bearer refresh), then **retries once**. + * Applies **timeouts** and **retry/backoff** for **429/5xx** (honors `Retry-After`). +* Hooks provide **logging** (with redaction) and **OpenTelemetry** spans. + +```mermaid +flowchart LR + A[Your code] -->|calls| R[Resource Client (e.g., payments)] + R -->|build request| T[AsyncTransport] + T -->|authorize(request)| AP[AuthPolicy
Basic or Bearer] + AP -->|add Authorization header| T + T -->|HTTP send| API[(Service API)] + API -- 200/2xx --> T + T -- return --> R --> A + + API -- 401 Unauthorized --> T + T -->|on_unauthorized()| AP + AP -->|refresh token (Bearer only)| T + T -->|retry once| API +``` + +--- + +## Why Per-Resource Auth? + +* **Safety by construction:** Credentials for `payments` cannot be sent to `tax` endpoints (and vice versa). This prevents cross-tenant or scope leakage. +* **Clarity & DX:** The chosen auth scheme is explicit at the resource constructor—no hidden URL regex routing or magic defaults. +* **Heterogeneous schemes:** Some families can remain on **Basic** while others adopt **Bearer** with scopes/rotation, without affecting call sites. +* **Testability:** You can unit-test each resource with its auth policy, mock token refresh, and assert no credential cross-talk. +* **Compliance & least privilege:** Bind the **minimal** credentials/scopes to only the endpoints that require them, simplifying audits and rotation. diff --git a/gavaconnect/auth/__init__.py b/gavaconnect/auth/__init__.py new file mode 100644 index 0000000..5093349 --- /dev/null +++ b/gavaconnect/auth/__init__.py @@ -0,0 +1,13 @@ +"""Authentication module for GavaConnect SDK.""" + +from .basic import BasicAuthPolicy, BasicCredentials +from .bearer import BearerAuthPolicy, TokenProvider +from .providers import ClientCredentialsProvider + +__all__ = [ + "BasicAuthPolicy", + "BasicCredentials", + "BearerAuthPolicy", + "TokenProvider", + "ClientCredentialsProvider", +] diff --git a/gavaconnect/auth/basic.py b/gavaconnect/auth/basic.py new file mode 100644 index 0000000..741705f --- /dev/null +++ b/gavaconnect/auth/basic.py @@ -0,0 +1,48 @@ +"""Basic authentication implementation for GavaConnect SDK.""" + +import base64 +from dataclasses import dataclass + +import httpx + + +@dataclass(frozen=True, slots=True) +class BasicCredentials: + """Basic authentication credentials.""" + + client_id: str + client_secret: str + + +class BasicAuthPolicy: + """HTTP Basic authentication policy.""" + + def __init__(self, creds: BasicCredentials) -> None: + """Initialize the basic auth policy. + + Args: + creds: Basic authentication credentials. + + """ + token = base64.b64encode( + f"{creds.client_id}:{creds.client_secret}".encode() + ).decode() + self._header = f"Basic {token}" + + async def authorize(self, request: httpx.Request) -> None: + """Add basic authentication header to the request. + + Args: + request: The HTTP request to authorize. + + """ + request.headers["authorization"] = self._header + + async def on_unauthorized(self) -> bool: + """Handle unauthorized response. + + Returns: + False, as basic auth cannot refresh credentials. + + """ + return False diff --git a/gavaconnect/auth/bearer.py b/gavaconnect/auth/bearer.py new file mode 100644 index 0000000..612b7d9 --- /dev/null +++ b/gavaconnect/auth/bearer.py @@ -0,0 +1,87 @@ +"""Bearer token authentication implementation for GavaConnect SDK.""" + +from __future__ import annotations + +from typing import Protocol + +import httpx + + +class AuthPolicy(Protocol): + """Protocol for authentication policies.""" + + async def authorize(self, request: httpx.Request) -> None: + """Add authentication to the request. + + Args: + request: The HTTP request to authorize. + + """ + ... + + async def on_unauthorized(self) -> bool: + """Handle unauthorized response. + + Returns: + True if authentication was refreshed, False otherwise. + + """ + return False + + +class TokenProvider(Protocol): + """Protocol for token providers.""" + + async def get_token(self) -> str: + """Get the current access token. + + Returns: + The access token. + + """ + ... + + async def refresh(self) -> str: + """Refresh and return a new access token. + + Returns: + The new access token. + + """ + ... + + +class BearerAuthPolicy: + """Bearer token authentication policy.""" + + def __init__(self, provider: TokenProvider) -> None: + """Initialize the bearer auth policy. + + Args: + provider: Token provider for obtaining access tokens. + + """ + self._p, self._last = provider, "" + + async def authorize(self, request: httpx.Request) -> None: + """Add bearer token to the request. + + Args: + request: The HTTP request to authorize. + + """ + token = await self._p.get_token() + self._last = token + request.headers["authorization"] = f"Bearer {token}" + + async def on_unauthorized(self) -> bool: + """Handle unauthorized response by refreshing the token. + + Returns: + True if the token was refreshed, False otherwise. + + """ + new_token = await self._p.refresh() + changed = new_token != self._last + self._last = new_token + return changed diff --git a/gavaconnect/auth/providers.py b/gavaconnect/auth/providers.py new file mode 100644 index 0000000..9117fa5 --- /dev/null +++ b/gavaconnect/auth/providers.py @@ -0,0 +1,82 @@ +"""Token provider implementations for GavaConnect SDK.""" + +import asyncio +import time + +import httpx + + +class ClientCredentialsProvider: + """OAuth2 client credentials token provider.""" + + def __init__( + self, + token_url: str, + client_id: str, + client_secret: str, + scope: str | None = None, + early_refresh_s: int = 60, + client: httpx.AsyncClient | None = None, + ) -> None: + """Initialize the client credentials provider. + + Args: + token_url: OAuth2 token endpoint URL. + client_id: OAuth2 client ID. + client_secret: OAuth2 client secret. + scope: Optional scope for the token. + early_refresh_s: Seconds before expiry to refresh token. + client: Optional HTTP client to use. + + """ + self._url, self._cid, self._sec, self._scope = ( + token_url, + client_id, + client_secret, + scope, + ) + self._early, self._client = ( + early_refresh_s, + (client or httpx.AsyncClient(timeout=10)), + ) + self._lock = asyncio.Lock() + self._token, self._exp = "", 0.0 + + async def _fetch(self) -> tuple[str, float]: + data = {"grant_type": "client_credentials"} | ( + {"scope": self._scope} if self._scope else {} + ) + r = await self._client.post( + self._url, + auth=(self._cid, self._sec), + data=data, + headers={"content-type": "application/x-www-form-urlencoded"}, + ) + r.raise_for_status() + p = r.json() + ttl = float(p.get("expires_in", 3600)) + return p["access_token"], time.time() + max(30.0, ttl - self._early) + + async def get_token(self) -> str: + """Get the current access token, refreshing if necessary. + + Returns: + The access token. + + """ + async with self._lock: + if self._token and time.time() < self._exp: + return self._token + self._token, self._exp = await self._fetch() + return self._token + + async def refresh(self) -> str: + """Force refresh the access token. + + Returns: + The new access token. + + """ + async with self._lock: + self._token, self._exp = await self._fetch() + return self._token diff --git a/pyproject.toml b/pyproject.toml index d1c7ac3..1bc3681 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,8 @@ classifiers = [ dev = [ "pytest>=8.0.0", "pytest-cov>=4.0.0", + "pytest-asyncio>=0.25.0", + "respx>=0.22.0", "mypy>=1.8.0", "ruff>=0.2.0", "bandit>=1.7.0", @@ -95,6 +97,10 @@ testpaths = ["tests"] python_files = ["test_*.py", "*_test.py"] python_classes = ["Test*"] python_functions = ["test_*"] +asyncio_mode = "auto" +markers = [ + "asyncio: mark test as asyncio", +] # Coverage configuration [tool.coverage.run] @@ -125,4 +131,6 @@ include = [ dev = [ "pytest>=8.4.1", "pytest-cov>=6.2.1", + "pytest-asyncio>=0.25.0", + "respx>=0.22.0", ] diff --git a/tests/test_auth_basic.py b/tests/test_auth_basic.py new file mode 100644 index 0000000..1b25d60 --- /dev/null +++ b/tests/test_auth_basic.py @@ -0,0 +1,77 @@ +"""Tests for basic authentication module.""" + +import base64 + +import httpx +import pytest + +from gavaconnect.auth.basic import BasicAuthPolicy, BasicCredentials + + +class TestBasicCredentials: + """Test BasicCredentials dataclass.""" + + def test_creation(self): + """Test creating BasicCredentials.""" + creds = BasicCredentials(client_id="test_id", client_secret="test_secret") + assert creds.client_id == "test_id" + assert creds.client_secret == "test_secret" + + def test_immutable(self): + """Test that BasicCredentials is immutable.""" + creds = BasicCredentials(client_id="test_id", client_secret="test_secret") + with pytest.raises(AttributeError): + creds.client_id = "new_id" + + +class TestBasicAuthPolicy: + """Test BasicAuthPolicy class.""" + + def test_init(self): + """Test BasicAuthPolicy initialization.""" + creds = BasicCredentials(client_id="test_id", client_secret="test_secret") + policy = BasicAuthPolicy(creds) + + # Verify the header is created correctly + expected_token = base64.b64encode(b"test_id:test_secret").decode() + assert policy._header == f"Basic {expected_token}" + + @pytest.mark.asyncio + async def test_authorize(self): + """Test authorization of a request.""" + creds = BasicCredentials(client_id="test_id", client_secret="test_secret") + policy = BasicAuthPolicy(creds) + + request = httpx.Request("GET", "https://example.com") + await policy.authorize(request) + + expected_token = base64.b64encode(b"test_id:test_secret").decode() + assert request.headers["authorization"] == f"Basic {expected_token}" + + @pytest.mark.asyncio + async def test_on_unauthorized(self): + """Test unauthorized response handling.""" + creds = BasicCredentials(client_id="test_id", client_secret="test_secret") + policy = BasicAuthPolicy(creds) + + # Basic auth cannot refresh, so should always return False + result = await policy.on_unauthorized() + assert result is False + + def test_different_credentials(self): + """Test with different credentials produce different headers.""" + creds1 = BasicCredentials(client_id="id1", client_secret="secret1") + creds2 = BasicCredentials(client_id="id2", client_secret="secret2") + + policy1 = BasicAuthPolicy(creds1) + policy2 = BasicAuthPolicy(creds2) + + assert policy1._header != policy2._header + + def test_special_characters_in_credentials(self): + """Test credentials with special characters.""" + creds = BasicCredentials(client_id="test:id", client_secret="test@secret!") + policy = BasicAuthPolicy(creds) + + expected_token = base64.b64encode(b"test:id:test@secret!").decode() + assert policy._header == f"Basic {expected_token}" diff --git a/tests/test_auth_bearer.py b/tests/test_auth_bearer.py new file mode 100644 index 0000000..a3db503 --- /dev/null +++ b/tests/test_auth_bearer.py @@ -0,0 +1,331 @@ +"""Tests for bearer authentication module.""" + +from unittest.mock import AsyncMock, Mock + +import httpx +import pytest +import respx + +from gavaconnect.auth.bearer import BearerAuthPolicy +from gavaconnect.auth.providers import ClientCredentialsProvider + + +class MockTokenProvider: + """Mock token provider for testing.""" + + def __init__( + self, token: str = "test_token", refresh_token: str = "new_token" + ) -> None: + """Initialize the mock token provider. + + Args: + token: The initial token to return. + refresh_token: The token to return after refresh. + + """ + self.token = token + self.refresh_token = refresh_token + self.get_token_calls = 0 + self.refresh_calls = 0 + + async def get_token(self) -> str: + """Mock get_token method.""" + self.get_token_calls += 1 + return self.token + + async def refresh(self) -> str: + """Mock refresh method.""" + self.refresh_calls += 1 + self.token = self.refresh_token + return self.refresh_token + + +class TestBearerAuthPolicy: + """Test BearerAuthPolicy class.""" + + def test_init(self): + """Test BearerAuthPolicy initialization.""" + provider = MockTokenProvider() + policy = BearerAuthPolicy(provider) + + assert policy._p is provider + assert policy._last == "" + + @pytest.mark.asyncio + async def test_authorize(self): + """Test authorization of a request.""" + provider = MockTokenProvider(token="test_access_token") + policy = BearerAuthPolicy(provider) + + request = httpx.Request("GET", "https://example.com") + await policy.authorize(request) + + assert request.headers["authorization"] == "Bearer test_access_token" + assert policy._last == "test_access_token" + assert provider.get_token_calls == 1 + + @pytest.mark.asyncio + async def test_authorize_multiple_calls(self): + """Test multiple authorization calls.""" + provider = MockTokenProvider(token="token123") + policy = BearerAuthPolicy(provider) + + request1 = httpx.Request("GET", "https://example.com/1") + request2 = httpx.Request("GET", "https://example.com/2") + + await policy.authorize(request1) + await policy.authorize(request2) + + assert request1.headers["authorization"] == "Bearer token123" + assert request2.headers["authorization"] == "Bearer token123" + assert provider.get_token_calls == 2 + + @pytest.mark.asyncio + async def test_on_unauthorized_token_changed(self): + """Test unauthorized handling when token changes.""" + provider = MockTokenProvider(token="old_token", refresh_token="new_token") + policy = BearerAuthPolicy(provider) + + # Set initial token + policy._last = "old_token" + + result = await policy.on_unauthorized() + + assert result is True # Token changed + assert policy._last == "new_token" + assert provider.refresh_calls == 1 + + @pytest.mark.asyncio + async def test_on_unauthorized_token_unchanged(self): + """Test unauthorized handling when token doesn't change.""" + provider = MockTokenProvider(token="same_token", refresh_token="same_token") + policy = BearerAuthPolicy(provider) + + # Set initial token to same as refresh token + policy._last = "same_token" + + result = await policy.on_unauthorized() + + assert result is False # Token didn't change + assert policy._last == "same_token" + assert provider.refresh_calls == 1 + + @pytest.mark.asyncio + async def test_on_unauthorized_empty_last_token(self): + """Test unauthorized handling with empty last token.""" + provider = MockTokenProvider(refresh_token="new_token") + policy = BearerAuthPolicy(provider) + + # _last starts as empty string + result = await policy.on_unauthorized() + + assert result is True # Empty string != "new_token" + assert policy._last == "new_token" + + @pytest.mark.asyncio + async def test_full_flow(self): + """Test complete authorization and refresh flow.""" + provider = MockTokenProvider( + token="initial_token", refresh_token="refreshed_token" + ) + policy = BearerAuthPolicy(provider) + + # Initial authorization + request = httpx.Request("GET", "https://example.com") + await policy.authorize(request) + assert request.headers["authorization"] == "Bearer initial_token" + + # Unauthorized response triggers refresh + changed = await policy.on_unauthorized() + assert changed is True + assert policy._last == "refreshed_token" + + # New authorization uses refreshed token + request2 = httpx.Request("GET", "https://example.com/2") + await policy.authorize(request2) + assert request2.headers["authorization"] == "Bearer refreshed_token" + + +class TestTokenProviderProtocol: + """Test TokenProvider protocol compliance.""" + + @pytest.mark.asyncio + async def test_mock_provider_compliance(self): + """Test that mock provider implements the protocol correctly.""" + provider = MockTokenProvider() + + # Should have async get_token and refresh methods + token = await provider.get_token() + assert isinstance(token, str) + + refresh_token = await provider.refresh() + assert isinstance(refresh_token, str) + + @pytest.mark.asyncio + async def test_provider_with_async_mock(self): + """Test using AsyncMock for token provider.""" + provider = Mock() + provider.get_token = AsyncMock(return_value="mocked_token") + provider.refresh = AsyncMock(return_value="mocked_refresh") + + policy = BearerAuthPolicy(provider) + + request = httpx.Request("GET", "https://example.com") + await policy.authorize(request) + + assert request.headers["authorization"] == "Bearer mocked_token" + provider.get_token.assert_called_once() + + result = await policy.on_unauthorized() + assert result is True # "" != "mocked_refresh" + provider.refresh.assert_called_once() + + +class TestBearerAuthPolicyIntegration: + """Integration tests for BearerAuthPolicy with real token providers.""" + + @respx.mock + @pytest.mark.asyncio + async def test_integration_with_client_credentials_provider(self): + """Test BearerAuthPolicy with ClientCredentialsProvider using real HTTP mocking.""" + # Mock the OAuth2 token endpoint + token_route = respx.post("https://auth.example.com/oauth/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "real_integration_token", + "expires_in": 3600, + "token_type": "Bearer", + }, + ) + ) + + # Create a real ClientCredentialsProvider + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/oauth/token", + client_id="integration_client", + client_secret="integration_secret", + scope="api:read api:write", + ) + + # Create BearerAuthPolicy with the real provider + auth_policy = BearerAuthPolicy(provider) + + # Test authorization + request = httpx.Request("GET", "https://api.example.com/data") + await auth_policy.authorize(request) + + # Verify the request was authorized correctly + assert request.headers["authorization"] == "Bearer real_integration_token" + assert token_route.called + + # Verify the OAuth request was made correctly + oauth_request = token_route.calls[0].request + assert oauth_request.method == "POST" + form_data = dict(httpx.QueryParams(oauth_request.content.decode())) + assert form_data["grant_type"] == "client_credentials" + assert form_data["scope"] == "api:read api:write" + + @respx.mock + @pytest.mark.asyncio + async def test_integration_refresh_flow(self): + """Test complete refresh flow with real HTTP mocking.""" + call_count = 0 + + def token_response(request: httpx.Request) -> httpx.Response: + nonlocal call_count + call_count += 1 + return httpx.Response( + 200, json={"access_token": f"token_v{call_count}", "expires_in": 3600} + ) + + # Mock endpoint that returns different tokens + token_route = respx.post("https://auth.example.com/oauth/token").mock( + side_effect=token_response + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/oauth/token", + client_id="refresh_client", + client_secret="refresh_secret", + ) + + auth_policy = BearerAuthPolicy(provider) + + # First authorization + request1 = httpx.Request("GET", "https://api.example.com/resource1") + await auth_policy.authorize(request1) + assert request1.headers["authorization"] == "Bearer token_v1" + assert token_route.call_count == 1 + + # Simulate unauthorized response and refresh + changed = await auth_policy.on_unauthorized() + assert changed is True # Token should have changed + assert token_route.call_count == 2 + + # New authorization should use refreshed token (cached) + request2 = httpx.Request("GET", "https://api.example.com/resource2") + await auth_policy.authorize(request2) + assert ( + request2.headers["authorization"] == "Bearer token_v2" + ) # Uses cached refreshed token + # Should still be 2 calls since token is cached + assert token_route.call_count == 2 + + @respx.mock + @pytest.mark.asyncio + async def test_integration_error_handling(self): + """Test error handling in integration scenario.""" + # Mock OAuth endpoint that returns an error + respx.post("https://auth.example.com/oauth/token").mock( + return_value=httpx.Response( + 401, + json={ + "error": "invalid_client", + "error_description": "Client authentication failed", + }, + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/oauth/token", + client_id="invalid_client", + client_secret="invalid_secret", + ) + + auth_policy = BearerAuthPolicy(provider) + + # Authorization should fail with HTTP error + request = httpx.Request("GET", "https://api.example.com/protected") + + with pytest.raises(httpx.HTTPStatusError) as exc_info: + await auth_policy.authorize(request) + + assert exc_info.value.response.status_code == 401 + + @respx.mock + @pytest.mark.asyncio + async def test_integration_caching_behavior(self): + """Test that token caching works correctly in integration.""" + token_route = respx.post("https://auth.example.com/oauth/token").mock( + return_value=httpx.Response( + 200, json={"access_token": "cached_token", "expires_in": 3600} + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/oauth/token", + client_id="cache_client", + client_secret="cache_secret", + ) + + auth_policy = BearerAuthPolicy(provider) + + # Multiple authorizations should use cached token + for i in range(3): + request = httpx.Request("GET", f"https://api.example.com/endpoint{i}") + await auth_policy.authorize(request) + assert request.headers["authorization"] == "Bearer cached_token" + + # Only first call should hit the token endpoint (due to caching) + assert token_route.call_count == 1 diff --git a/tests/test_auth_module.py b/tests/test_auth_module.py new file mode 100644 index 0000000..540e1d4 --- /dev/null +++ b/tests/test_auth_module.py @@ -0,0 +1,67 @@ +"""Tests for auth module imports and exports.""" + +from gavaconnect import auth +from gavaconnect.auth import ( + BasicAuthPolicy, + BasicCredentials, + BearerAuthPolicy, + ClientCredentialsProvider, + TokenProvider, +) + + +class TestAuthModuleImports: + """Test that auth module exports work correctly.""" + + def test_all_exports_available(self): + """Test that all expected exports are available.""" + # Test direct imports + assert BasicAuthPolicy is not None + assert BasicCredentials is not None + assert BearerAuthPolicy is not None + assert TokenProvider is not None + assert ClientCredentialsProvider is not None + + def test_module_has_all_attribute(self): + """Test that __all__ is properly defined.""" + assert hasattr(auth, "__all__") + assert isinstance(auth.__all__, list) + + expected_exports = { + "BasicAuthPolicy", + "BasicCredentials", + "BearerAuthPolicy", + "TokenProvider", + "ClientCredentialsProvider", + } + + assert set(auth.__all__) == expected_exports + + def test_module_docstring(self): + """Test that module has proper docstring.""" + assert auth.__doc__ is not None + assert "Authentication module for GavaConnect SDK" in auth.__doc__ + + def test_classes_importable_from_module(self): + """Test that classes can be imported from the module.""" + assert hasattr(auth, "BasicAuthPolicy") + assert hasattr(auth, "BasicCredentials") + assert hasattr(auth, "BearerAuthPolicy") + assert hasattr(auth, "TokenProvider") + assert hasattr(auth, "ClientCredentialsProvider") + + def test_class_types(self): + """Test that imported classes are the correct types.""" + from gavaconnect.auth.basic import BasicAuthPolicy as BasicAuthPolicyOrig + from gavaconnect.auth.basic import BasicCredentials as BasicCredentialsOrig + from gavaconnect.auth.bearer import BearerAuthPolicy as BearerAuthPolicyOrig + from gavaconnect.auth.bearer import TokenProvider as TokenProviderOrig + from gavaconnect.auth.providers import ( + ClientCredentialsProvider as ClientCredentialsProviderOrig, + ) + + assert BasicAuthPolicy is BasicAuthPolicyOrig + assert BasicCredentials is BasicCredentialsOrig + assert BearerAuthPolicy is BearerAuthPolicyOrig + assert TokenProvider is TokenProviderOrig + assert ClientCredentialsProvider is ClientCredentialsProviderOrig diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py new file mode 100644 index 0000000..d0ac81d --- /dev/null +++ b/tests/test_auth_providers.py @@ -0,0 +1,463 @@ +"""Tests for token provider implementations.""" + +import asyncio +from unittest.mock import patch + +import httpx +import pytest +import respx + +from gavaconnect.auth.providers import ClientCredentialsProvider + + +class TestClientCredentialsProvider: + """Test ClientCredentialsProvider class.""" + + def test_init_minimal(self): + """Test initialization with minimal parameters.""" + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + ) + + assert provider._url == "https://auth.example.com/token" + assert provider._cid == "test_client" + assert provider._sec == "test_secret" + assert provider._scope is None + assert provider._early == 60 + assert isinstance(provider._client, httpx.AsyncClient) + assert isinstance(provider._lock, asyncio.Lock) + assert provider._token == "" + assert provider._exp == 0.0 + + def test_init_full_parameters(self): + """Test initialization with all parameters.""" + custom_client = httpx.AsyncClient() + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + scope="read write", + early_refresh_s=120, + client=custom_client, + ) + + assert provider._scope == "read write" + assert provider._early == 120 + assert provider._client is custom_client + + @respx.mock + @pytest.mark.asyncio + async def test_fetch_success_without_scope(self): + """Test successful token fetch without scope.""" + # Mock the token endpoint + token_route = respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, json={"access_token": "test_access_token", "expires_in": 3600} + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + ) + + with patch("time.time", return_value=1000.0): + token, exp_time = await provider._fetch() + + assert token == "test_access_token" + assert exp_time == 1000.0 + max(30.0, 3600 - 60) # 4540.0 + + # Verify the request was made correctly + assert token_route.called + request = token_route.calls[0].request + assert request.method == "POST" + assert request.url == "https://auth.example.com/token" + + # Check the form data + form_data = dict(httpx.QueryParams(request.content.decode())) + assert form_data["grant_type"] == "client_credentials" + assert "scope" not in form_data + + @respx.mock + @pytest.mark.asyncio + async def test_fetch_success_with_scope(self): + """Test successful token fetch with scope.""" + token_route = respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, json={"access_token": "scoped_token", "expires_in": 7200} + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + scope="read write admin", + ) + + await provider._fetch() + + # Verify scope was included in request + assert token_route.called + request = token_route.calls[0].request + form_data = dict(httpx.QueryParams(request.content.decode())) + assert form_data["grant_type"] == "client_credentials" + assert form_data["scope"] == "read write admin" + + @respx.mock + @pytest.mark.asyncio + async def test_fetch_with_custom_expires_in(self): + """Test fetch with custom expires_in value.""" + respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "short_lived_token", + "expires_in": 300, # 5 minutes + }, + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + early_refresh_s=60, + ) + + with patch("time.time", return_value=2000.0): + token, exp_time = await provider._fetch() + + # Should use max(30.0, 300 - 60) = 240 + assert exp_time == 2000.0 + 240.0 + + @respx.mock + @pytest.mark.asyncio + async def test_fetch_without_expires_in(self): + """Test fetch when response doesn't include expires_in.""" + respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "default_ttl_token" + # No expires_in field + }, + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + ) + + with patch("time.time", return_value=3000.0): + token, exp_time = await provider._fetch() + + # Should use default 3600 seconds: max(30.0, 3600 - 60) = 3540 + assert exp_time == 3000.0 + 3540.0 + + @respx.mock + @pytest.mark.asyncio + async def test_fetch_http_error(self): + """Test fetch when HTTP request fails.""" + respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response(401, json={"error": "invalid_client"}) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + ) + + with pytest.raises(httpx.HTTPStatusError): + await provider._fetch() + + @pytest.mark.asyncio + async def test_get_token_first_call(self): + """Test get_token on first call (no cached token).""" + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + ) + + # Mock the _fetch method directly for this test + async def mock_fetch() -> tuple[str, float]: + return "fresh_token", 5000.0 + + provider._fetch = mock_fetch + + with patch("time.time", return_value=1000.0): + token = await provider.get_token() + + assert token == "fresh_token" + assert provider._token == "fresh_token" + assert provider._exp == 5000.0 + + @pytest.mark.asyncio + async def test_get_token_cached_valid(self): + """Test get_token with valid cached token.""" + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + ) + + # Set up cached token + provider._token = "cached_token" + provider._exp = 5000.0 + + # Mock _fetch to track if it's called + fetch_called = False + + async def mock_fetch() -> tuple[str, float]: + nonlocal fetch_called + fetch_called = True + return "new_token", 8000.0 + + provider._fetch = mock_fetch + + with patch("time.time", return_value=4000.0): # Before expiry + token = await provider.get_token() + + assert token == "cached_token" + assert not fetch_called + + @pytest.mark.asyncio + async def test_get_token_cached_expired(self): + """Test get_token with expired cached token.""" + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + ) + + # Set up expired cached token + provider._token = "expired_token" + provider._exp = 4000.0 + + fetch_called = False + + async def mock_fetch() -> tuple[str, float]: + nonlocal fetch_called + fetch_called = True + return "new_token", 8000.0 + + provider._fetch = mock_fetch + + with patch("time.time", return_value=5000.0): # After expiry + token = await provider.get_token() + + assert token == "new_token" + assert provider._token == "new_token" + assert provider._exp == 8000.0 + assert fetch_called + + @pytest.mark.asyncio + async def test_refresh(self): + """Test refresh method.""" + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + ) + + # Set up existing token + provider._token = "old_token" + provider._exp = 5000.0 + + fetch_called = False + + async def mock_fetch() -> tuple[str, float]: + nonlocal fetch_called + fetch_called = True + return "refreshed_token", 8000.0 + + provider._fetch = mock_fetch + + token = await provider.refresh() + + assert token == "refreshed_token" + assert provider._token == "refreshed_token" + assert provider._exp == 8000.0 + assert fetch_called + + @pytest.mark.asyncio + async def test_concurrent_get_token_calls(self): + """Test that concurrent get_token calls are properly synchronized.""" + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + ) + + fetch_call_count = 0 + + async def mock_fetch() -> tuple[str, float]: + nonlocal fetch_call_count + fetch_call_count += 1 + await asyncio.sleep(0.1) # Simulate network delay + return f"token_{fetch_call_count}", 8000.0 + + provider._fetch = mock_fetch + + with patch("time.time", return_value=1000.0): + # Make multiple concurrent calls + tasks = [provider.get_token() for _ in range(5)] + tokens = await asyncio.gather(*tasks) + + # All should get the same token + assert all(token == "token_1" for token in tokens) + # _fetch should only be called once due to the lock + assert fetch_call_count == 1 + + @respx.mock + @pytest.mark.asyncio + async def test_early_refresh_parameter(self): + """Test that early_refresh_s parameter affects token expiry calculation.""" + # Mock responses for both providers + respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, json={"access_token": "test_token", "expires_in": 3600} + ) + ) + + # Test with different early refresh values + provider1 = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + early_refresh_s=60, + ) + + provider2 = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + early_refresh_s=300, + ) + + with patch("time.time", return_value=1000.0): + _, exp1 = await provider1._fetch() + _, exp2 = await provider2._fetch() + + # Provider1: 1000 + max(30, 3600-60) = 1000 + 3540 = 4540 + # Provider2: 1000 + max(30, 3600-300) = 1000 + 3300 = 4300 + assert exp1 == 4540.0 + assert exp2 == 4300.0 + + @respx.mock + @pytest.mark.asyncio + async def test_minimum_ttl_enforcement(self): + """Test that minimum TTL of 30 seconds is enforced.""" + respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "short_token", + "expires_in": 10, # Very short expiry + }, + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + early_refresh_s=60, + ) + + with patch("time.time", return_value=2000.0): + _, exp_time = await provider._fetch() + + # Should use minimum of 30 seconds: 2000 + max(30, 10-60) = 2000 + 30 = 2030 + assert exp_time == 2030.0 + + @respx.mock + @pytest.mark.asyncio + async def test_authentication_headers(self): + """Test that authentication headers are sent correctly.""" + token_route = respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, json={"access_token": "test_token", "expires_in": 3600} + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + ) + + await provider._fetch() + + # Verify authentication was sent + assert token_route.called + request = token_route.calls[0].request + assert "authorization" in request.headers + + # Basic auth should be base64 encoded client_id:client_secret + import base64 + + expected_auth = base64.b64encode(b"test_client:test_secret").decode() + assert request.headers["authorization"] == f"Basic {expected_auth}" + + @respx.mock + @pytest.mark.asyncio + async def test_content_type_header(self): + """Test that correct content-type header is sent.""" + token_route = respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, json={"access_token": "test_token", "expires_in": 3600} + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + ) + + await provider._fetch() + + assert token_route.called + request = token_route.calls[0].request + assert request.headers["content-type"] == "application/x-www-form-urlencoded" + + @respx.mock + @pytest.mark.asyncio + async def test_full_integration_flow(self): + """Test complete token lifecycle with real HTTP mocking.""" + token_route = respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, json={"access_token": "integration_token", "expires_in": 3600} + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="integration_client", + client_secret="integration_secret", + scope="read write", + ) + + with patch("time.time", return_value=1000.0): + # First call should fetch token + token1 = await provider.get_token() + assert token1 == "integration_token" + assert token_route.call_count == 1 + + # Second call should use cached token + token2 = await provider.get_token() + assert token2 == "integration_token" + assert token_route.call_count == 1 # No additional calls + + # Refresh should force new fetch + token3 = await provider.refresh() + assert token3 == "integration_token" + assert token_route.call_count == 2 # One additional call diff --git a/uv.lock b/uv.lock index 1cdd770..23d6ad7 100644 --- a/uv.lock +++ b/uv.lock @@ -114,14 +114,18 @@ dev = [ { name = "bandit" }, { name = "mypy" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-cov" }, + { name = "respx" }, { name = "ruff" }, ] [package.dev-dependencies] dev = [ { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-cov" }, + { name = "respx" }, ] [package.metadata] @@ -130,7 +134,9 @@ requires-dist = [ { name = "httpx", specifier = ">=0.28.1" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, + { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.25.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0.0" }, + { name = "respx", marker = "extra == 'dev'", specifier = ">=0.22.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.2.0" }, ] provides-extras = ["dev"] @@ -138,7 +144,9 @@ provides-extras = ["dev"] [package.metadata.requires-dev] dev = [ { name = "pytest", specifier = ">=8.4.1" }, + { name = "pytest-asyncio", specifier = ">=0.25.0" }, { name = "pytest-cov", specifier = ">=6.2.1" }, + { name = "respx", specifier = ">=0.22.0" }, ] [[package]] @@ -316,6 +324,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4e/51/f8794af39eeb870e87a8c8068642fc07bce0c854d6865d7dd0f2a9d338c2/pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea", size = 46652, upload-time = "2025-07-16T04:29:26.393Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/9d/bf86eddabf8c6c9cb1ea9a869d6873b46f105a5d292d3a6f7071f5b07935/pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf", size = 15157, upload-time = "2025-07-16T04:29:24.929Z" }, +] + [[package]] name = "pytest-cov" version = "6.2.1" @@ -347,6 +367,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446, upload-time = "2024-08-06T20:33:04.33Z" }, ] +[[package]] +name = "respx" +version = "0.22.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/7c/96bd0bc759cf009675ad1ee1f96535edcb11e9666b985717eb8c87192a95/respx-0.22.0.tar.gz", hash = "sha256:3c8924caa2a50bd71aefc07aa812f2466ff489f1848c96e954a5362d17095d91", size = 28439, upload-time = "2024-12-19T22:33:59.374Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/67/afbb0978d5399bc9ea200f1d4489a23c9a1dad4eee6376242b8182389c79/respx-0.22.0-py2.py3-none-any.whl", hash = "sha256:631128d4c9aba15e56903fb5f66fb1eff412ce28dd387ca3a81339e52dbd3ad0", size = 25127, upload-time = "2024-12-19T22:33:57.837Z" }, +] + [[package]] name = "rich" version = "14.1.0"