From 415b9de7485580bf9518fec80b22f94fa2408ccd Mon Sep 17 00:00:00 2001 From: Musale Martin Date: Thu, 21 Aug 2025 16:07:53 +0300 Subject: [PATCH 1/4] Implement authentication module with Basic and Bearer policies in GavaConnect SDK --- gavaconnect/auth/README.md | 48 +++++++++++++++++++ gavaconnect/auth/__init__.py | 13 ++++++ gavaconnect/auth/basic.py | 48 +++++++++++++++++++ gavaconnect/auth/bearer.py | 87 +++++++++++++++++++++++++++++++++++ gavaconnect/auth/providers.py | 82 +++++++++++++++++++++++++++++++++ 5 files changed, 278 insertions(+) create mode 100644 gavaconnect/auth/README.md create mode 100644 gavaconnect/auth/__init__.py create mode 100644 gavaconnect/auth/basic.py create mode 100644 gavaconnect/auth/bearer.py create mode 100644 gavaconnect/auth/providers.py diff --git a/gavaconnect/auth/README.md b/gavaconnect/auth/README.md new file mode 100644 index 0000000..d204d08 --- /dev/null +++ b/gavaconnect/auth/README.md @@ -0,0 +1,48 @@ +# Authentication Design — `gavaconnect` Python SDK + +## Introduction + +The SDK implements authentication as a **pluggable policy** so each endpoint family (`checkers`, `tax`, `payments`, `authorization`) can use the scheme it requires while sharing a common transport layer. The SDK supports: + +* **Basic** (static header from `client_id:client_secret`) +* **Bearer** (OAuth2 Client Credentials) with **concurrency-safe caching**, **early refresh**, and **401-triggered single retry** + +Design goals: **credential isolation per resource**, **safe token lifecycle**, **consistent retries/timeouts**, and **extensibility** (e.g., API-Key, HMAC, mTLS) without changing call sites. + +--- + +## High-Level Architecture + +* Each resource client is constructed with an **`AuthPolicy`**: `BasicAuthPolicy` or `BearerAuthPolicy(TokenProvider)`. +* The shared **AsyncTransport**: + + * Calls `authorize(request)` before send. + * On **401**, calls `on_unauthorized()` (Bearer refresh), then **retries once**. + * Applies **timeouts** and **retry/backoff** for **429/5xx** (honors `Retry-After`). +* Hooks provide **logging** (with redaction) and **OpenTelemetry** spans. + +```mermaid +flowchart LR + A[Your code] -->|calls| R[Resource Client (e.g., payments)] + R -->|build request| T[AsyncTransport] + T -->|authorize(request)| AP[AuthPolicy
Basic or Bearer] + AP -->|add Authorization header| T + T -->|HTTP send| API[(Service API)] + API -- 200/2xx --> T + T -- return --> R --> A + + API -- 401 Unauthorized --> T + T -->|on_unauthorized()| AP + AP -->|refresh token (Bearer only)| T + T -->|retry once| API +``` + +--- + +## Why Per-Resource Auth? + +* **Safety by construction:** Credentials for `payments` cannot be sent to `tax` endpoints (and vice versa). This prevents cross-tenant or scope leakage. +* **Clarity & DX:** The chosen auth scheme is explicit at the resource constructor—no hidden URL regex routing or magic defaults. +* **Heterogeneous schemes:** Some families can remain on **Basic** while others adopt **Bearer** with scopes/rotation, without affecting call sites. +* **Testability:** You can unit-test each resource with its auth policy, mock token refresh, and assert no credential cross-talk. +* **Compliance & least privilege:** Bind the **minimal** credentials/scopes to only the endpoints that require them, simplifying audits and rotation. diff --git a/gavaconnect/auth/__init__.py b/gavaconnect/auth/__init__.py new file mode 100644 index 0000000..5093349 --- /dev/null +++ b/gavaconnect/auth/__init__.py @@ -0,0 +1,13 @@ +"""Authentication module for GavaConnect SDK.""" + +from .basic import BasicAuthPolicy, BasicCredentials +from .bearer import BearerAuthPolicy, TokenProvider +from .providers import ClientCredentialsProvider + +__all__ = [ + "BasicAuthPolicy", + "BasicCredentials", + "BearerAuthPolicy", + "TokenProvider", + "ClientCredentialsProvider", +] diff --git a/gavaconnect/auth/basic.py b/gavaconnect/auth/basic.py new file mode 100644 index 0000000..741705f --- /dev/null +++ b/gavaconnect/auth/basic.py @@ -0,0 +1,48 @@ +"""Basic authentication implementation for GavaConnect SDK.""" + +import base64 +from dataclasses import dataclass + +import httpx + + +@dataclass(frozen=True, slots=True) +class BasicCredentials: + """Basic authentication credentials.""" + + client_id: str + client_secret: str + + +class BasicAuthPolicy: + """HTTP Basic authentication policy.""" + + def __init__(self, creds: BasicCredentials) -> None: + """Initialize the basic auth policy. + + Args: + creds: Basic authentication credentials. + + """ + token = base64.b64encode( + f"{creds.client_id}:{creds.client_secret}".encode() + ).decode() + self._header = f"Basic {token}" + + async def authorize(self, request: httpx.Request) -> None: + """Add basic authentication header to the request. + + Args: + request: The HTTP request to authorize. + + """ + request.headers["authorization"] = self._header + + async def on_unauthorized(self) -> bool: + """Handle unauthorized response. + + Returns: + False, as basic auth cannot refresh credentials. + + """ + return False diff --git a/gavaconnect/auth/bearer.py b/gavaconnect/auth/bearer.py new file mode 100644 index 0000000..612b7d9 --- /dev/null +++ b/gavaconnect/auth/bearer.py @@ -0,0 +1,87 @@ +"""Bearer token authentication implementation for GavaConnect SDK.""" + +from __future__ import annotations + +from typing import Protocol + +import httpx + + +class AuthPolicy(Protocol): + """Protocol for authentication policies.""" + + async def authorize(self, request: httpx.Request) -> None: + """Add authentication to the request. + + Args: + request: The HTTP request to authorize. + + """ + ... + + async def on_unauthorized(self) -> bool: + """Handle unauthorized response. + + Returns: + True if authentication was refreshed, False otherwise. + + """ + return False + + +class TokenProvider(Protocol): + """Protocol for token providers.""" + + async def get_token(self) -> str: + """Get the current access token. + + Returns: + The access token. + + """ + ... + + async def refresh(self) -> str: + """Refresh and return a new access token. + + Returns: + The new access token. + + """ + ... + + +class BearerAuthPolicy: + """Bearer token authentication policy.""" + + def __init__(self, provider: TokenProvider) -> None: + """Initialize the bearer auth policy. + + Args: + provider: Token provider for obtaining access tokens. + + """ + self._p, self._last = provider, "" + + async def authorize(self, request: httpx.Request) -> None: + """Add bearer token to the request. + + Args: + request: The HTTP request to authorize. + + """ + token = await self._p.get_token() + self._last = token + request.headers["authorization"] = f"Bearer {token}" + + async def on_unauthorized(self) -> bool: + """Handle unauthorized response by refreshing the token. + + Returns: + True if the token was refreshed, False otherwise. + + """ + new_token = await self._p.refresh() + changed = new_token != self._last + self._last = new_token + return changed diff --git a/gavaconnect/auth/providers.py b/gavaconnect/auth/providers.py new file mode 100644 index 0000000..9117fa5 --- /dev/null +++ b/gavaconnect/auth/providers.py @@ -0,0 +1,82 @@ +"""Token provider implementations for GavaConnect SDK.""" + +import asyncio +import time + +import httpx + + +class ClientCredentialsProvider: + """OAuth2 client credentials token provider.""" + + def __init__( + self, + token_url: str, + client_id: str, + client_secret: str, + scope: str | None = None, + early_refresh_s: int = 60, + client: httpx.AsyncClient | None = None, + ) -> None: + """Initialize the client credentials provider. + + Args: + token_url: OAuth2 token endpoint URL. + client_id: OAuth2 client ID. + client_secret: OAuth2 client secret. + scope: Optional scope for the token. + early_refresh_s: Seconds before expiry to refresh token. + client: Optional HTTP client to use. + + """ + self._url, self._cid, self._sec, self._scope = ( + token_url, + client_id, + client_secret, + scope, + ) + self._early, self._client = ( + early_refresh_s, + (client or httpx.AsyncClient(timeout=10)), + ) + self._lock = asyncio.Lock() + self._token, self._exp = "", 0.0 + + async def _fetch(self) -> tuple[str, float]: + data = {"grant_type": "client_credentials"} | ( + {"scope": self._scope} if self._scope else {} + ) + r = await self._client.post( + self._url, + auth=(self._cid, self._sec), + data=data, + headers={"content-type": "application/x-www-form-urlencoded"}, + ) + r.raise_for_status() + p = r.json() + ttl = float(p.get("expires_in", 3600)) + return p["access_token"], time.time() + max(30.0, ttl - self._early) + + async def get_token(self) -> str: + """Get the current access token, refreshing if necessary. + + Returns: + The access token. + + """ + async with self._lock: + if self._token and time.time() < self._exp: + return self._token + self._token, self._exp = await self._fetch() + return self._token + + async def refresh(self) -> str: + """Force refresh the access token. + + Returns: + The new access token. + + """ + async with self._lock: + self._token, self._exp = await self._fetch() + return self._token From 86b35243c9734cfe9239fda57580ebb3aab29b13 Mon Sep 17 00:00:00 2001 From: Musale Martin Date: Thu, 21 Aug 2025 16:28:57 +0300 Subject: [PATCH 2/4] Add tests for authentication modules and update dependencies --- .coverage | Bin 53248 -> 53248 bytes coverage.xml | 155 ++++++++++- pyproject.toml | 8 + tests/test_auth_basic.py | 77 ++++++ tests/test_auth_bearer.py | 323 +++++++++++++++++++++++ tests/test_auth_module.py | 67 +++++ tests/test_auth_providers.py | 484 +++++++++++++++++++++++++++++++++++ uv.lock | 32 +++ 8 files changed, 1145 insertions(+), 1 deletion(-) create mode 100644 tests/test_auth_basic.py create mode 100644 tests/test_auth_bearer.py create mode 100644 tests/test_auth_module.py create mode 100644 tests/test_auth_providers.py diff --git a/.coverage b/.coverage index c6d8598b1d8d9b881c82f982fe0edb7ff3da6151..ce2d00532dbb40028610acec03a77a78a88687b7 100644 GIT binary patch literal 53248 zcmeI53w&HvoyX5)=AQGv@4TjIn?7z|Nt5Q$^p&=hQkp(U3#Bc6Kq*Zp$)uf zv{echP@t~J0|* zMo;Abku&&QsI!LMrMA07{x*Qyqz1s&|u@+nl7@ec5DZaHErD@>UmAY$&u3s4dkgTji3QWEWdA zsh&h;hj(>yM_r*pc(q(|dkza^C&|>N-k^pzF>GGrWs+UVOtQBlnZ;2|Om)^6H8)*v zh>EhZ$SujjoD!K1`C*g!Ei)0gaD{5o%-WnEz!z3DnHn1qag2_~+J}(N;a57M|xW_lh~AO2+yJ|{JjH7hE9{tD;ADESLpPIKwDEk8vZaY zDk>`@H|m8j3Rf%qPn;D(w32UOqjah#naK8K$VmR2Nm7dem!v%so{`h37rf%a%5zO8 ztk>2Q0{9{(D&p~ok2@mprQlOLYZ5`pz!vzE8^K>b{H=-HX79~o4~pf$t2JJ6Ui@4G z4~Edz6dd_YR~8k9e9N$c`aRP0gU{$$Lk>y?hP*VnqkdbWTe`{CL@Fb{`b};aqw3oGN3f1NIUY^coQ}VVN+_ti_sno09oX7?) zl&$vK)9LPHq8INCvPElFUNe>T3S)TlUFm{>KC(N4J~nNuE{ut8E`iOjlADh9mxP|4 zL~mz!zY1={*sJvZU~>l7K@*V;y=u8BvvQ06NuShdcs6h>>c3&D)FkM&C(()9Jy?`8 zdc1aX@=e|0-Dp7~m&~Pll3q}yFqlGMU7N^y9hsy&Det8<)zLo}2El5zH3i2YONQV} zIe&CSR7{!_@%zHfh(6Pnl{fXC1PwcUT69*B{K?6Ic)3&>uKa4}%s2_OJ424(DPAEk zH(bGXRDo)RU+Wk(L4qG~-(d1XJ(L1U0i}RaKq;UUPzopolmbctrGQdEDWDX1`za8O z6h%ak{};K3n0r(n)I%wt6i^B%1(X6x0i}RaKq;UUPzopolmbeDe@F$aqUcDMzQd!b zI1;TK7Jd&vb3@bQhG~YJRgrs`xrg2V{fC^O8njYCDWDWk3Md7X0!jg;fKosypcGIF zCfx&Z!TbM))y>?Gxyzg*&h1XKeaODiuC~4{3)Mp@pcGIF zCX0ZskgMr8Wgx#0FCf`9t;{R|m1yCIe0OzA-9$UpQD=4D>#r_SzzOg{gyeg|;w9 zmqTA08PtvkMjX8VZ&*FdeaOAkzR8|wJ!SPczjfZ{%(Q>@S8j9FOevrgPzopolmbct zrGQdEDWDWk3NW@I=3fvQtnBDb`M+^x%%3&5-S_AJ`ns6kJji{A`9EA0^UojLuFuH- z#hRGkIB321=l^_AyY0{awL#TlAVbIeKbDC33kOY$f!+tQ>FDJ#f9jxJ;f(yhXszs6 zAFQo&MtrE>|EG*uDWDWk3Md7X0!jg;fKosypcGIFCiDJ|NU|(PzopolmbctrGQdEDWDWk3Md7X0!jg; zz}r%R;Qjwscq?;%@4n*x%6-B8q5G8kb@vPIXWR$f``kO-+uR%6ZEla-;jVL+yNld8 z?i6>TJIXD04L9Qa-Z|>L*1GVe19!2iAV;N$YXzVe0|wUh8)2W^1Rl)#|d=TPv-Lthv^7tIis04YeFA zX1-w_HxHXHm_IQ0n@^gLn-7~0nD?5un>U*~&8=pax!zo9US!TSrZA2?U57uw@8B>z56{3p_!4{$9)u6V z9q=yL32EqrD_|M4z)Yxzu}}dPM8pa4vN$AuDE5o5h{wdo#r@(gajUpaWW*-1UaSy{ z#Cc+}m>`A=SH$@1{3!o9f0loTe~mxEALbw7_we`d8+aez!rS;Nxxnh76i^B%1^!kF zfWK9X=#hxA{0Z>Wh%1+a-;20nCHOsIg%#j;BQ9zI|7ygA61N~eA)BQTTTg<&8F9fv z@HZhYmUX)j7cB-qiP*9i{LZlMBJexH(gonRBkq;W5{R#j0e>T+RS$k!h~?mK2(b+O zt3oUVe?8)yR`9PxoHZN#b%--(fqw;J^GxtBM{H^ae=XwlCh*rFPMZ$?YQ(A2z+Z(p zWh(fWAvR3`et($VQC}ymm(gJrOOcYJouL&*3SX|V#G~_|1r}tUCkI zkvKi9;DSFbL{7b+ejW}-~_@fZ_%>#cV;tV;+5s0(qgI|f*EJJp9SV6|=FvR^c zz#ocOcL@9nL{F9uK^#{HejIU(41w~ng4|4Hh=*jjlp+o}0e%UhAxn!9?IGa1h^7s` z6IL+6w-HasMX?ZtL=#aH;2VfhP58P*xd3t{ASl#>=of|4F!><(9P!{}@HNDmW8lXS ztL4n1h=*jr79oy00)7P1s{x-OR*r(}kw#X+uCQz*T!%EG65bh>jevI`9jSz!NF{mL zfmH6ncBImBxHddn3fqv5l|Ubo49#3fI%JWYlaN8O9k>Qb3Wr_`x(R5tAOjek7GzYT z(SnR*^jVOBj5Z50jsus~BM=n@kJTX$=&(ZM0S#72xzS%mOaru6AzKFSsz*dr6uecJ z_0d`R%O0Sy^7v6eU*+*pfVRrxBLQ8N$43B~Dle@8^i*Dx2eedPdl1l3c}*>#q4I+@ zfPTuWUIw&NKA{TGP5Jl63|2W@Hn7_^7wE-2j#=!fCkFr!vOu04~+xbC$ESDx+jkh1vF0{uK@H; zUKR(mPA+{7os%Ce1vF0XlmhxDcaHJP+uTT5%rGCbeQ0&?U8@uL7E+ zRxuRNBefwF5Ljd+qQ!?m!6GACc^uFpwX$+Rhtx{S01Z-eO91^*d({QBN6jq;bVp6N zfaa(P4(N@VbqdfLHPZrgMorEejZp&w&=)lhfVQZ~Iif3Sa^7f)nw)LmiF!nfih?8R zS{~33W5(-%b{Lbrp&QCv56}!_paXhg>~#RN!kBg%&5*s@(EXxvL}-3d>39Y2)1!F^oG%h7l2mZMNKrm2 zgeDtCATM|UI0&r(a_X@H#555tKs3gSB*@F~NKl~XgZKYm(bfg||1tL$?oZtByZhWP zyN}4f2lyfPI+_39;*NCZx>>i?-QaF^FO#|dRJYVQ?F#2r=L%<^^DF0h=NXy%f6m$I zq@4$y4?1@^?{YewW%kSV38%%G>C`)8oeIZtBGx?fdHZ}jWxwBk-2Rz;(0{>JRBR>-Xt*=-c%<`bL@SU!;%Mhr&;F2CvC1|9kLxco;qmcWV#u zBK}M6$n5?$*aaErf_3~j%!e6J3!}s-v4_v)Q+O304yC|luK$QQDjpOsiXVv&h#SQ= zu|;eUE5&TlAjXOz!r*`4&&dq`Y5q9h$v?sG=XdgVi`#iGZ|7^oBECd>K|7&+QTvT{ zkG5NTM*D{LX>FahOUr0of2%M1RliCBrNCP$VDf@f%Y2xB|A#Id`*_Kv;XPmGWJ40m$K)H=d`k)63?2=enLER7JH7kc_#ZYaZ@vU zmUwy-J4ifjI{Ojv)M@O8#8alS9}qW9VF!pCn%FbMjScMk#0`z?d&CDC*msHbJbRkB zeh&K%@ub(-Q^d8C*tdx%R{{#zV#i{S6I(9(&wh5;W5h-Ydz4r=*dxRzb@n-8uCvb)$2j{8vHW99pEl$k z(IfKj5wlMP7}1N2JbO5-A(LwMN#ZqY*(ZosuVEi2UbUJ%MEvL~_A%m1_Ol0xTlcX4 zAYQtM{X6lJrR<}`txMPg#CuxVM~LTN!ahtqZy)_fz}=Ck{Vn`f~P z63=L6A0Xa8gZ(RU-63`_u{VSL3-P!*b`SBGaqRuX-WYZ_@ga}hMLgsLyOY@P*gq57 zL)aa}rp?|*Y?$nJ;u8jYFR?J#KM`w!y@xoevEAWn=}~rDxLR@#-5S=Ai8s52_~2yr zZsM9_>|MmwhuF=;hpO34#G{U|8;QLd<`Y+rVmA~nCBM8KS$RES<;Y!xBPy>WtQ_%9 z!XuUMAS}u6BrNxK5SEs2CoCzwmhf1~HbSGMuOIasp>r}zXge7~)4qmKH@6a=)YF8z z(M!nn9>N&!CXB|e)(cxfrsZ1#lK&u|3Qx<#d^4dYHW9|OF2ZOm8MYVE&al16cZ8>9 zdfrYLKbq*rQ5*YlWLrOu*g#lXa}{Atem!CB!7B-CYS$4Sths`)>gCG`CseH^96wD+ou%mlKYNUrJaRUq%?Oyo7Lg{9?lR@TG*q;!6nQ z!&(W4#xEkQh%Y9L4_(xc6$=T=;tL2%%UTGJmd+=1N-rdIkIy4?+zSYgJ97yQ=X^r* z#2iAyJdf~%F`JMZvk1khnS`f!Ga(l<2%~&@7!5o+EsO@9pBkQ)>HL&_bejl=zB-w( zVrV1bkctMv_>g+S@^~F#S@|TwlCoMtx1@&fRd*txTU<>j+^T-$69}zS;|Wb`9HC*3 zB?My(AqTG?#c0A)d=z1nj|^4Nq9Z~Tw0vcFTBiBK3BedfsJ}jx5cCSd*I@{ub~;Wd z^5ul5wX*P9v3w~N#zYAf78O%rQOpg;63aW`Y58;b;c1!lTZB=~3~NU7Mp!8Gew{EM z1)T?pe~|fy#N}&#Sf0+3=_ETm{GD$fT^qD3M&DPK1Y-v2k;rNPSL>w&Gr9&459YoTC zpdb!XbkLy%5o{5v76%n|=+H&M4iZEV6htr(OVA)isn_S8+dh1s;a>R7FNfxrLyyDb zT98%(w6fmYvqaidFdSH@EpDASwm3SW8%7`y2-0eRR^9V5MW(_FHQm2Hoeq!BjNxQ} z{?Imkr_c0(-qHrWpeOW@mgo-Mq^oq9&eJSS(oWmJBgOl#Ff5hrA7SZ)!uvvf_L$zSkLmoyvUT&(>O&Zam3xkAJy1K;5WKF2kDfcM;0T*I??0cYJkyo}X#ZSRfTJ&cJ|tL7^|>MJ+m zD?2PQogtC4TNPCsj^C<-qJ}ddsx&L2a-}S)m&zrfU9v?*yC|}XmZ)GAM0vGOlv8<8 zR^>z)l@+N>MwC`*zok>6luGWY?tZ&UB_t+QzbK)UNG1DxCE}vEGDT**SJX0lL}-$T z+UPfmiQ3pLGBB!7c)vf|G5i?SMIEC{ - + @@ -11,9 +11,162 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/pyproject.toml b/pyproject.toml index d1c7ac3..1bc3681 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,8 @@ classifiers = [ dev = [ "pytest>=8.0.0", "pytest-cov>=4.0.0", + "pytest-asyncio>=0.25.0", + "respx>=0.22.0", "mypy>=1.8.0", "ruff>=0.2.0", "bandit>=1.7.0", @@ -95,6 +97,10 @@ testpaths = ["tests"] python_files = ["test_*.py", "*_test.py"] python_classes = ["Test*"] python_functions = ["test_*"] +asyncio_mode = "auto" +markers = [ + "asyncio: mark test as asyncio", +] # Coverage configuration [tool.coverage.run] @@ -125,4 +131,6 @@ include = [ dev = [ "pytest>=8.4.1", "pytest-cov>=6.2.1", + "pytest-asyncio>=0.25.0", + "respx>=0.22.0", ] diff --git a/tests/test_auth_basic.py b/tests/test_auth_basic.py new file mode 100644 index 0000000..a11c364 --- /dev/null +++ b/tests/test_auth_basic.py @@ -0,0 +1,77 @@ +"""Tests for basic authentication module.""" + +import base64 + +import httpx +import pytest + +from gavaconnect.auth.basic import BasicAuthPolicy, BasicCredentials + + +class TestBasicCredentials: + """Test BasicCredentials dataclass.""" + + def test_creation(self): + """Test creating BasicCredentials.""" + creds = BasicCredentials(client_id="test_id", client_secret="test_secret") + assert creds.client_id == "test_id" + assert creds.client_secret == "test_secret" + + def test_immutable(self): + """Test that BasicCredentials is immutable.""" + creds = BasicCredentials(client_id="test_id", client_secret="test_secret") + with pytest.raises(AttributeError): + creds.client_id = "new_id" + + +class TestBasicAuthPolicy: + """Test BasicAuthPolicy class.""" + + def test_init(self): + """Test BasicAuthPolicy initialization.""" + creds = BasicCredentials(client_id="test_id", client_secret="test_secret") + policy = BasicAuthPolicy(creds) + + # Verify the header is created correctly + expected_token = base64.b64encode(b"test_id:test_secret").decode() + assert policy._header == f"Basic {expected_token}" + + @pytest.mark.asyncio + async def test_authorize(self): + """Test authorization of a request.""" + creds = BasicCredentials(client_id="test_id", client_secret="test_secret") + policy = BasicAuthPolicy(creds) + + request = httpx.Request("GET", "https://example.com") + await policy.authorize(request) + + expected_token = base64.b64encode(b"test_id:test_secret").decode() + assert request.headers["authorization"] == f"Basic {expected_token}" + + @pytest.mark.asyncio + async def test_on_unauthorized(self): + """Test unauthorized response handling.""" + creds = BasicCredentials(client_id="test_id", client_secret="test_secret") + policy = BasicAuthPolicy(creds) + + # Basic auth cannot refresh, so should always return False + result = await policy.on_unauthorized() + assert result is False + + def test_different_credentials(self): + """Test with different credentials produce different headers.""" + creds1 = BasicCredentials(client_id="id1", client_secret="secret1") + creds2 = BasicCredentials(client_id="id2", client_secret="secret2") + + policy1 = BasicAuthPolicy(creds1) + policy2 = BasicAuthPolicy(creds2) + + assert policy1._header != policy2._header + + def test_special_characters_in_credentials(self): + """Test credentials with special characters.""" + creds = BasicCredentials(client_id="test:id", client_secret="test@secret!") + policy = BasicAuthPolicy(creds) + + expected_token = base64.b64encode(b"test:id:test@secret!").decode() + assert policy._header == f"Basic {expected_token}" diff --git a/tests/test_auth_bearer.py b/tests/test_auth_bearer.py new file mode 100644 index 0000000..8f5d9ce --- /dev/null +++ b/tests/test_auth_bearer.py @@ -0,0 +1,323 @@ +"""Tests for bearer authentication module.""" + +from unittest.mock import AsyncMock, Mock + +import httpx +import pytest +import respx + +from gavaconnect.auth.bearer import BearerAuthPolicy +from gavaconnect.auth.providers import ClientCredentialsProvider + + +class MockTokenProvider: + """Mock token provider for testing.""" + + def __init__(self, token: str = "test_token", refresh_token: str = "new_token"): + self.token = token + self.refresh_token = refresh_token + self.get_token_calls = 0 + self.refresh_calls = 0 + + async def get_token(self) -> str: + """Mock get_token method.""" + self.get_token_calls += 1 + return self.token + + async def refresh(self) -> str: + """Mock refresh method.""" + self.refresh_calls += 1 + self.token = self.refresh_token + return self.refresh_token + + +class TestBearerAuthPolicy: + """Test BearerAuthPolicy class.""" + + def test_init(self): + """Test BearerAuthPolicy initialization.""" + provider = MockTokenProvider() + policy = BearerAuthPolicy(provider) + + assert policy._p is provider + assert policy._last == "" + + @pytest.mark.asyncio + async def test_authorize(self): + """Test authorization of a request.""" + provider = MockTokenProvider(token="test_access_token") + policy = BearerAuthPolicy(provider) + + request = httpx.Request("GET", "https://example.com") + await policy.authorize(request) + + assert request.headers["authorization"] == "Bearer test_access_token" + assert policy._last == "test_access_token" + assert provider.get_token_calls == 1 + + @pytest.mark.asyncio + async def test_authorize_multiple_calls(self): + """Test multiple authorization calls.""" + provider = MockTokenProvider(token="token123") + policy = BearerAuthPolicy(provider) + + request1 = httpx.Request("GET", "https://example.com/1") + request2 = httpx.Request("GET", "https://example.com/2") + + await policy.authorize(request1) + await policy.authorize(request2) + + assert request1.headers["authorization"] == "Bearer token123" + assert request2.headers["authorization"] == "Bearer token123" + assert provider.get_token_calls == 2 + + @pytest.mark.asyncio + async def test_on_unauthorized_token_changed(self): + """Test unauthorized handling when token changes.""" + provider = MockTokenProvider(token="old_token", refresh_token="new_token") + policy = BearerAuthPolicy(provider) + + # Set initial token + policy._last = "old_token" + + result = await policy.on_unauthorized() + + assert result is True # Token changed + assert policy._last == "new_token" + assert provider.refresh_calls == 1 + + @pytest.mark.asyncio + async def test_on_unauthorized_token_unchanged(self): + """Test unauthorized handling when token doesn't change.""" + provider = MockTokenProvider(token="same_token", refresh_token="same_token") + policy = BearerAuthPolicy(provider) + + # Set initial token to same as refresh token + policy._last = "same_token" + + result = await policy.on_unauthorized() + + assert result is False # Token didn't change + assert policy._last == "same_token" + assert provider.refresh_calls == 1 + + @pytest.mark.asyncio + async def test_on_unauthorized_empty_last_token(self): + """Test unauthorized handling with empty last token.""" + provider = MockTokenProvider(refresh_token="new_token") + policy = BearerAuthPolicy(provider) + + # _last starts as empty string + result = await policy.on_unauthorized() + + assert result is True # Empty string != "new_token" + assert policy._last == "new_token" + + @pytest.mark.asyncio + async def test_full_flow(self): + """Test complete authorization and refresh flow.""" + provider = MockTokenProvider(token="initial_token", refresh_token="refreshed_token") + policy = BearerAuthPolicy(provider) + + # Initial authorization + request = httpx.Request("GET", "https://example.com") + await policy.authorize(request) + assert request.headers["authorization"] == "Bearer initial_token" + + # Unauthorized response triggers refresh + changed = await policy.on_unauthorized() + assert changed is True + assert policy._last == "refreshed_token" + + # New authorization uses refreshed token + request2 = httpx.Request("GET", "https://example.com/2") + await policy.authorize(request2) + assert request2.headers["authorization"] == "Bearer refreshed_token" + + +class TestTokenProviderProtocol: + """Test TokenProvider protocol compliance.""" + + @pytest.mark.asyncio + async def test_mock_provider_compliance(self): + """Test that mock provider implements the protocol correctly.""" + provider = MockTokenProvider() + + # Should have async get_token and refresh methods + token = await provider.get_token() + assert isinstance(token, str) + + refresh_token = await provider.refresh() + assert isinstance(refresh_token, str) + + @pytest.mark.asyncio + async def test_provider_with_async_mock(self): + """Test using AsyncMock for token provider.""" + provider = Mock() + provider.get_token = AsyncMock(return_value="mocked_token") + provider.refresh = AsyncMock(return_value="mocked_refresh") + + policy = BearerAuthPolicy(provider) + + request = httpx.Request("GET", "https://example.com") + await policy.authorize(request) + + assert request.headers["authorization"] == "Bearer mocked_token" + provider.get_token.assert_called_once() + + result = await policy.on_unauthorized() + assert result is True # "" != "mocked_refresh" + provider.refresh.assert_called_once() + + +class TestBearerAuthPolicyIntegration: + """Integration tests for BearerAuthPolicy with real token providers.""" + + @respx.mock + @pytest.mark.asyncio + async def test_integration_with_client_credentials_provider(self): + """Test BearerAuthPolicy with ClientCredentialsProvider using real HTTP mocking.""" + # Mock the OAuth2 token endpoint + token_route = respx.post("https://auth.example.com/oauth/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "real_integration_token", + "expires_in": 3600, + "token_type": "Bearer" + } + ) + ) + + # Create a real ClientCredentialsProvider + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/oauth/token", + client_id="integration_client", + client_secret="integration_secret", + scope="api:read api:write" + ) + + # Create BearerAuthPolicy with the real provider + auth_policy = BearerAuthPolicy(provider) + + # Test authorization + request = httpx.Request("GET", "https://api.example.com/data") + await auth_policy.authorize(request) + + # Verify the request was authorized correctly + assert request.headers["authorization"] == "Bearer real_integration_token" + assert token_route.called + + # Verify the OAuth request was made correctly + oauth_request = token_route.calls[0].request + assert oauth_request.method == "POST" + form_data = dict(httpx.QueryParams(oauth_request.content.decode())) + assert form_data["grant_type"] == "client_credentials" + assert form_data["scope"] == "api:read api:write" + + @respx.mock + @pytest.mark.asyncio + async def test_integration_refresh_flow(self): + """Test complete refresh flow with real HTTP mocking.""" + call_count = 0 + + def token_response(request): + nonlocal call_count + call_count += 1 + return httpx.Response( + 200, + json={ + "access_token": f"token_v{call_count}", + "expires_in": 3600 + } + ) + + # Mock endpoint that returns different tokens + token_route = respx.post("https://auth.example.com/oauth/token").mock( + side_effect=token_response + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/oauth/token", + client_id="refresh_client", + client_secret="refresh_secret" + ) + + auth_policy = BearerAuthPolicy(provider) + + # First authorization + request1 = httpx.Request("GET", "https://api.example.com/resource1") + await auth_policy.authorize(request1) + assert request1.headers["authorization"] == "Bearer token_v1" + assert token_route.call_count == 1 + + # Simulate unauthorized response and refresh + changed = await auth_policy.on_unauthorized() + assert changed is True # Token should have changed + assert token_route.call_count == 2 + + # New authorization should use refreshed token (cached) + request2 = httpx.Request("GET", "https://api.example.com/resource2") + await auth_policy.authorize(request2) + assert request2.headers["authorization"] == "Bearer token_v2" # Uses cached refreshed token + # Should still be 2 calls since token is cached + assert token_route.call_count == 2 + + @respx.mock + @pytest.mark.asyncio + async def test_integration_error_handling(self): + """Test error handling in integration scenario.""" + # Mock OAuth endpoint that returns an error + respx.post("https://auth.example.com/oauth/token").mock( + return_value=httpx.Response( + 401, + json={"error": "invalid_client", "error_description": "Client authentication failed"} + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/oauth/token", + client_id="invalid_client", + client_secret="invalid_secret" + ) + + auth_policy = BearerAuthPolicy(provider) + + # Authorization should fail with HTTP error + request = httpx.Request("GET", "https://api.example.com/protected") + + with pytest.raises(httpx.HTTPStatusError) as exc_info: + await auth_policy.authorize(request) + + assert exc_info.value.response.status_code == 401 + + @respx.mock + @pytest.mark.asyncio + async def test_integration_caching_behavior(self): + """Test that token caching works correctly in integration.""" + token_route = respx.post("https://auth.example.com/oauth/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "cached_token", + "expires_in": 3600 + } + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/oauth/token", + client_id="cache_client", + client_secret="cache_secret" + ) + + auth_policy = BearerAuthPolicy(provider) + + # Multiple authorizations should use cached token + for i in range(3): + request = httpx.Request("GET", f"https://api.example.com/endpoint{i}") + await auth_policy.authorize(request) + assert request.headers["authorization"] == "Bearer cached_token" + + # Only first call should hit the token endpoint (due to caching) + assert token_route.call_count == 1 diff --git a/tests/test_auth_module.py b/tests/test_auth_module.py new file mode 100644 index 0000000..6dd89ed --- /dev/null +++ b/tests/test_auth_module.py @@ -0,0 +1,67 @@ +"""Tests for auth module imports and exports.""" + +import pytest + +from gavaconnect import auth +from gavaconnect.auth import ( + BasicAuthPolicy, + BasicCredentials, + BearerAuthPolicy, + ClientCredentialsProvider, + TokenProvider, +) + + +class TestAuthModuleImports: + """Test that auth module exports work correctly.""" + + def test_all_exports_available(self): + """Test that all expected exports are available.""" + # Test direct imports + assert BasicAuthPolicy is not None + assert BasicCredentials is not None + assert BearerAuthPolicy is not None + assert TokenProvider is not None + assert ClientCredentialsProvider is not None + + def test_module_has_all_attribute(self): + """Test that __all__ is properly defined.""" + assert hasattr(auth, '__all__') + assert isinstance(auth.__all__, list) + + expected_exports = { + "BasicAuthPolicy", + "BasicCredentials", + "BearerAuthPolicy", + "TokenProvider", + "ClientCredentialsProvider", + } + + assert set(auth.__all__) == expected_exports + + def test_module_docstring(self): + """Test that module has proper docstring.""" + assert auth.__doc__ is not None + assert "Authentication module for GavaConnect SDK" in auth.__doc__ + + def test_classes_importable_from_module(self): + """Test that classes can be imported from the module.""" + assert hasattr(auth, 'BasicAuthPolicy') + assert hasattr(auth, 'BasicCredentials') + assert hasattr(auth, 'BearerAuthPolicy') + assert hasattr(auth, 'TokenProvider') + assert hasattr(auth, 'ClientCredentialsProvider') + + def test_class_types(self): + """Test that imported classes are the correct types.""" + from gavaconnect.auth.basic import BasicAuthPolicy as BasicAuthPolicyOrig + from gavaconnect.auth.basic import BasicCredentials as BasicCredentialsOrig + from gavaconnect.auth.bearer import BearerAuthPolicy as BearerAuthPolicyOrig + from gavaconnect.auth.bearer import TokenProvider as TokenProviderOrig + from gavaconnect.auth.providers import ClientCredentialsProvider as ClientCredentialsProviderOrig + + assert BasicAuthPolicy is BasicAuthPolicyOrig + assert BasicCredentials is BasicCredentialsOrig + assert BearerAuthPolicy is BearerAuthPolicyOrig + assert TokenProvider is TokenProviderOrig + assert ClientCredentialsProvider is ClientCredentialsProviderOrig diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py new file mode 100644 index 0000000..c2cb51a --- /dev/null +++ b/tests/test_auth_providers.py @@ -0,0 +1,484 @@ +"""Tests for token provider implementations.""" + +import asyncio +import time +from unittest.mock import patch + +import httpx +import pytest +import respx + +from gavaconnect.auth.providers import ClientCredentialsProvider + + +class TestClientCredentialsProvider: + """Test ClientCredentialsProvider class.""" + + def test_init_minimal(self): + """Test initialization with minimal parameters.""" + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret" + ) + + assert provider._url == "https://auth.example.com/token" + assert provider._cid == "test_client" + assert provider._sec == "test_secret" + assert provider._scope is None + assert provider._early == 60 + assert isinstance(provider._client, httpx.AsyncClient) + assert isinstance(provider._lock, asyncio.Lock) + assert provider._token == "" + assert provider._exp == 0.0 + + def test_init_full_parameters(self): + """Test initialization with all parameters.""" + custom_client = httpx.AsyncClient() + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + scope="read write", + early_refresh_s=120, + client=custom_client + ) + + assert provider._scope == "read write" + assert provider._early == 120 + assert provider._client is custom_client + + @respx.mock + @pytest.mark.asyncio + async def test_fetch_success_without_scope(self): + """Test successful token fetch without scope.""" + # Mock the token endpoint + token_route = respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "test_access_token", + "expires_in": 3600 + } + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret" + ) + + with patch('time.time', return_value=1000.0): + token, exp_time = await provider._fetch() + + assert token == "test_access_token" + assert exp_time == 1000.0 + max(30.0, 3600 - 60) # 4540.0 + + # Verify the request was made correctly + assert token_route.called + request = token_route.calls[0].request + assert request.method == "POST" + assert request.url == "https://auth.example.com/token" + + # Check the form data + form_data = dict(httpx.QueryParams(request.content.decode())) + assert form_data["grant_type"] == "client_credentials" + assert "scope" not in form_data + + @respx.mock + @pytest.mark.asyncio + async def test_fetch_success_with_scope(self): + """Test successful token fetch with scope.""" + token_route = respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "scoped_token", + "expires_in": 7200 + } + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + scope="read write admin" + ) + + await provider._fetch() + + # Verify scope was included in request + assert token_route.called + request = token_route.calls[0].request + form_data = dict(httpx.QueryParams(request.content.decode())) + assert form_data["grant_type"] == "client_credentials" + assert form_data["scope"] == "read write admin" + + @respx.mock + @pytest.mark.asyncio + async def test_fetch_with_custom_expires_in(self): + """Test fetch with custom expires_in value.""" + respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "short_lived_token", + "expires_in": 300 # 5 minutes + } + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + early_refresh_s=60 + ) + + with patch('time.time', return_value=2000.0): + token, exp_time = await provider._fetch() + + # Should use max(30.0, 300 - 60) = 240 + assert exp_time == 2000.0 + 240.0 + + @respx.mock + @pytest.mark.asyncio + async def test_fetch_without_expires_in(self): + """Test fetch when response doesn't include expires_in.""" + respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "default_ttl_token" + # No expires_in field + } + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret" + ) + + with patch('time.time', return_value=3000.0): + token, exp_time = await provider._fetch() + + # Should use default 3600 seconds: max(30.0, 3600 - 60) = 3540 + assert exp_time == 3000.0 + 3540.0 + + @respx.mock + @pytest.mark.asyncio + async def test_fetch_http_error(self): + """Test fetch when HTTP request fails.""" + respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response(401, json={"error": "invalid_client"}) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret" + ) + + with pytest.raises(httpx.HTTPStatusError): + await provider._fetch() + + @pytest.mark.asyncio + async def test_get_token_first_call(self): + """Test get_token on first call (no cached token).""" + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret" + ) + + # Mock the _fetch method directly for this test + async def mock_fetch(): + return "fresh_token", 5000.0 + + provider._fetch = mock_fetch + + with patch('time.time', return_value=1000.0): + token = await provider.get_token() + + assert token == "fresh_token" + assert provider._token == "fresh_token" + assert provider._exp == 5000.0 + + @pytest.mark.asyncio + async def test_get_token_cached_valid(self): + """Test get_token with valid cached token.""" + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret" + ) + + # Set up cached token + provider._token = "cached_token" + provider._exp = 5000.0 + + # Mock _fetch to track if it's called + fetch_called = False + async def mock_fetch(): + nonlocal fetch_called + fetch_called = True + return "new_token", 8000.0 + + provider._fetch = mock_fetch + + with patch('time.time', return_value=4000.0): # Before expiry + token = await provider.get_token() + + assert token == "cached_token" + assert not fetch_called + + @pytest.mark.asyncio + async def test_get_token_cached_expired(self): + """Test get_token with expired cached token.""" + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret" + ) + + # Set up expired cached token + provider._token = "expired_token" + provider._exp = 4000.0 + + fetch_called = False + async def mock_fetch(): + nonlocal fetch_called + fetch_called = True + return "new_token", 8000.0 + + provider._fetch = mock_fetch + + with patch('time.time', return_value=5000.0): # After expiry + token = await provider.get_token() + + assert token == "new_token" + assert provider._token == "new_token" + assert provider._exp == 8000.0 + assert fetch_called + + @pytest.mark.asyncio + async def test_refresh(self): + """Test refresh method.""" + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret" + ) + + # Set up existing token + provider._token = "old_token" + provider._exp = 5000.0 + + fetch_called = False + async def mock_fetch(): + nonlocal fetch_called + fetch_called = True + return "refreshed_token", 8000.0 + + provider._fetch = mock_fetch + + token = await provider.refresh() + + assert token == "refreshed_token" + assert provider._token == "refreshed_token" + assert provider._exp == 8000.0 + assert fetch_called + + @pytest.mark.asyncio + async def test_concurrent_get_token_calls(self): + """Test that concurrent get_token calls are properly synchronized.""" + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret" + ) + + fetch_call_count = 0 + + async def mock_fetch(): + nonlocal fetch_call_count + fetch_call_count += 1 + await asyncio.sleep(0.1) # Simulate network delay + return f"token_{fetch_call_count}", 8000.0 + + provider._fetch = mock_fetch + + with patch('time.time', return_value=1000.0): + # Make multiple concurrent calls + tasks = [provider.get_token() for _ in range(5)] + tokens = await asyncio.gather(*tasks) + + # All should get the same token + assert all(token == "token_1" for token in tokens) + # _fetch should only be called once due to the lock + assert fetch_call_count == 1 + + @respx.mock + @pytest.mark.asyncio + async def test_early_refresh_parameter(self): + """Test that early_refresh_s parameter affects token expiry calculation.""" + # Mock responses for both providers + respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "test_token", + "expires_in": 3600 + } + ) + ) + + # Test with different early refresh values + provider1 = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + early_refresh_s=60 + ) + + provider2 = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + early_refresh_s=300 + ) + + with patch('time.time', return_value=1000.0): + _, exp1 = await provider1._fetch() + _, exp2 = await provider2._fetch() + + # Provider1: 1000 + max(30, 3600-60) = 1000 + 3540 = 4540 + # Provider2: 1000 + max(30, 3600-300) = 1000 + 3300 = 4300 + assert exp1 == 4540.0 + assert exp2 == 4300.0 + + @respx.mock + @pytest.mark.asyncio + async def test_minimum_ttl_enforcement(self): + """Test that minimum TTL of 30 seconds is enforced.""" + respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "short_token", + "expires_in": 10 # Very short expiry + } + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret", + early_refresh_s=60 + ) + + with patch('time.time', return_value=2000.0): + _, exp_time = await provider._fetch() + + # Should use minimum of 30 seconds: 2000 + max(30, 10-60) = 2000 + 30 = 2030 + assert exp_time == 2030.0 + + @respx.mock + @pytest.mark.asyncio + async def test_authentication_headers(self): + """Test that authentication headers are sent correctly.""" + token_route = respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "test_token", + "expires_in": 3600 + } + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret" + ) + + await provider._fetch() + + # Verify authentication was sent + assert token_route.called + request = token_route.calls[0].request + assert "authorization" in request.headers + + # Basic auth should be base64 encoded client_id:client_secret + import base64 + expected_auth = base64.b64encode(b"test_client:test_secret").decode() + assert request.headers["authorization"] == f"Basic {expected_auth}" + + @respx.mock + @pytest.mark.asyncio + async def test_content_type_header(self): + """Test that correct content-type header is sent.""" + token_route = respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "test_token", + "expires_in": 3600 + } + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="test_client", + client_secret="test_secret" + ) + + await provider._fetch() + + assert token_route.called + request = token_route.calls[0].request + assert request.headers["content-type"] == "application/x-www-form-urlencoded" + + @respx.mock + @pytest.mark.asyncio + async def test_full_integration_flow(self): + """Test complete token lifecycle with real HTTP mocking.""" + token_route = respx.post("https://auth.example.com/token").mock( + return_value=httpx.Response( + 200, + json={ + "access_token": "integration_token", + "expires_in": 3600 + } + ) + ) + + provider = ClientCredentialsProvider( + token_url="https://auth.example.com/token", + client_id="integration_client", + client_secret="integration_secret", + scope="read write" + ) + + with patch('time.time', return_value=1000.0): + # First call should fetch token + token1 = await provider.get_token() + assert token1 == "integration_token" + assert token_route.call_count == 1 + + # Second call should use cached token + token2 = await provider.get_token() + assert token2 == "integration_token" + assert token_route.call_count == 1 # No additional calls + + # Refresh should force new fetch + token3 = await provider.refresh() + assert token3 == "integration_token" + assert token_route.call_count == 2 # One additional call diff --git a/uv.lock b/uv.lock index 1cdd770..23d6ad7 100644 --- a/uv.lock +++ b/uv.lock @@ -114,14 +114,18 @@ dev = [ { name = "bandit" }, { name = "mypy" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-cov" }, + { name = "respx" }, { name = "ruff" }, ] [package.dev-dependencies] dev = [ { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-cov" }, + { name = "respx" }, ] [package.metadata] @@ -130,7 +134,9 @@ requires-dist = [ { name = "httpx", specifier = ">=0.28.1" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, + { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.25.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0.0" }, + { name = "respx", marker = "extra == 'dev'", specifier = ">=0.22.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.2.0" }, ] provides-extras = ["dev"] @@ -138,7 +144,9 @@ provides-extras = ["dev"] [package.metadata.requires-dev] dev = [ { name = "pytest", specifier = ">=8.4.1" }, + { name = "pytest-asyncio", specifier = ">=0.25.0" }, { name = "pytest-cov", specifier = ">=6.2.1" }, + { name = "respx", specifier = ">=0.22.0" }, ] [[package]] @@ -316,6 +324,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4e/51/f8794af39eeb870e87a8c8068642fc07bce0c854d6865d7dd0f2a9d338c2/pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea", size = 46652, upload-time = "2025-07-16T04:29:26.393Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/9d/bf86eddabf8c6c9cb1ea9a869d6873b46f105a5d292d3a6f7071f5b07935/pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf", size = 15157, upload-time = "2025-07-16T04:29:24.929Z" }, +] + [[package]] name = "pytest-cov" version = "6.2.1" @@ -347,6 +367,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446, upload-time = "2024-08-06T20:33:04.33Z" }, ] +[[package]] +name = "respx" +version = "0.22.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/7c/96bd0bc759cf009675ad1ee1f96535edcb11e9666b985717eb8c87192a95/respx-0.22.0.tar.gz", hash = "sha256:3c8924caa2a50bd71aefc07aa812f2466ff489f1848c96e954a5362d17095d91", size = 28439, upload-time = "2024-12-19T22:33:59.374Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/67/afbb0978d5399bc9ea200f1d4489a23c9a1dad4eee6376242b8182389c79/respx-0.22.0-py2.py3-none-any.whl", hash = "sha256:631128d4c9aba15e56903fb5f66fb1eff412ce28dd387ca3a81339e52dbd3ad0", size = 25127, upload-time = "2024-12-19T22:33:57.837Z" }, +] + [[package]] name = "rich" version = "14.1.0" From c8e48ad637cae7075a61491c533c5f04ce3f2be7 Mon Sep 17 00:00:00 2001 From: Musale Martin Date: Thu, 21 Aug 2025 16:29:54 +0300 Subject: [PATCH 3/4] Update coverage.xml timestamp for accurate reporting --- coverage.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coverage.xml b/coverage.xml index 7ada2f9..a2a85fa 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,5 +1,5 @@ - + From bb3b30488b8c971efe37c417b591bd9fd8949afb Mon Sep 17 00:00:00 2001 From: Musale Martin Date: Thu, 21 Aug 2025 16:40:45 +0300 Subject: [PATCH 4/4] Update coverage.xml timestamp and refactor test files for improved readability --- coverage.xml | 2 +- tests/test_auth_basic.py | 14 +-- tests/test_auth_bearer.py | 146 +++++++++++---------- tests/test_auth_module.py | 26 ++-- tests/test_auth_providers.py | 237 ++++++++++++++++------------------- 5 files changed, 206 insertions(+), 219 deletions(-) diff --git a/coverage.xml b/coverage.xml index a2a85fa..4785cc5 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,5 +1,5 @@ - + diff --git a/tests/test_auth_basic.py b/tests/test_auth_basic.py index a11c364..1b25d60 100644 --- a/tests/test_auth_basic.py +++ b/tests/test_auth_basic.py @@ -31,7 +31,7 @@ def test_init(self): """Test BasicAuthPolicy initialization.""" creds = BasicCredentials(client_id="test_id", client_secret="test_secret") policy = BasicAuthPolicy(creds) - + # Verify the header is created correctly expected_token = base64.b64encode(b"test_id:test_secret").decode() assert policy._header == f"Basic {expected_token}" @@ -41,10 +41,10 @@ async def test_authorize(self): """Test authorization of a request.""" creds = BasicCredentials(client_id="test_id", client_secret="test_secret") policy = BasicAuthPolicy(creds) - + request = httpx.Request("GET", "https://example.com") await policy.authorize(request) - + expected_token = base64.b64encode(b"test_id:test_secret").decode() assert request.headers["authorization"] == f"Basic {expected_token}" @@ -53,7 +53,7 @@ async def test_on_unauthorized(self): """Test unauthorized response handling.""" creds = BasicCredentials(client_id="test_id", client_secret="test_secret") policy = BasicAuthPolicy(creds) - + # Basic auth cannot refresh, so should always return False result = await policy.on_unauthorized() assert result is False @@ -62,16 +62,16 @@ def test_different_credentials(self): """Test with different credentials produce different headers.""" creds1 = BasicCredentials(client_id="id1", client_secret="secret1") creds2 = BasicCredentials(client_id="id2", client_secret="secret2") - + policy1 = BasicAuthPolicy(creds1) policy2 = BasicAuthPolicy(creds2) - + assert policy1._header != policy2._header def test_special_characters_in_credentials(self): """Test credentials with special characters.""" creds = BasicCredentials(client_id="test:id", client_secret="test@secret!") policy = BasicAuthPolicy(creds) - + expected_token = base64.b64encode(b"test:id:test@secret!").decode() assert policy._header == f"Basic {expected_token}" diff --git a/tests/test_auth_bearer.py b/tests/test_auth_bearer.py index 8f5d9ce..a3db503 100644 --- a/tests/test_auth_bearer.py +++ b/tests/test_auth_bearer.py @@ -12,18 +12,27 @@ class MockTokenProvider: """Mock token provider for testing.""" - - def __init__(self, token: str = "test_token", refresh_token: str = "new_token"): + + def __init__( + self, token: str = "test_token", refresh_token: str = "new_token" + ) -> None: + """Initialize the mock token provider. + + Args: + token: The initial token to return. + refresh_token: The token to return after refresh. + + """ self.token = token self.refresh_token = refresh_token self.get_token_calls = 0 self.refresh_calls = 0 - + async def get_token(self) -> str: """Mock get_token method.""" self.get_token_calls += 1 return self.token - + async def refresh(self) -> str: """Mock refresh method.""" self.refresh_calls += 1 @@ -38,7 +47,7 @@ def test_init(self): """Test BearerAuthPolicy initialization.""" provider = MockTokenProvider() policy = BearerAuthPolicy(provider) - + assert policy._p is provider assert policy._last == "" @@ -47,10 +56,10 @@ async def test_authorize(self): """Test authorization of a request.""" provider = MockTokenProvider(token="test_access_token") policy = BearerAuthPolicy(provider) - + request = httpx.Request("GET", "https://example.com") await policy.authorize(request) - + assert request.headers["authorization"] == "Bearer test_access_token" assert policy._last == "test_access_token" assert provider.get_token_calls == 1 @@ -60,13 +69,13 @@ async def test_authorize_multiple_calls(self): """Test multiple authorization calls.""" provider = MockTokenProvider(token="token123") policy = BearerAuthPolicy(provider) - + request1 = httpx.Request("GET", "https://example.com/1") request2 = httpx.Request("GET", "https://example.com/2") - + await policy.authorize(request1) await policy.authorize(request2) - + assert request1.headers["authorization"] == "Bearer token123" assert request2.headers["authorization"] == "Bearer token123" assert provider.get_token_calls == 2 @@ -76,12 +85,12 @@ async def test_on_unauthorized_token_changed(self): """Test unauthorized handling when token changes.""" provider = MockTokenProvider(token="old_token", refresh_token="new_token") policy = BearerAuthPolicy(provider) - + # Set initial token policy._last = "old_token" - + result = await policy.on_unauthorized() - + assert result is True # Token changed assert policy._last == "new_token" assert provider.refresh_calls == 1 @@ -91,12 +100,12 @@ async def test_on_unauthorized_token_unchanged(self): """Test unauthorized handling when token doesn't change.""" provider = MockTokenProvider(token="same_token", refresh_token="same_token") policy = BearerAuthPolicy(provider) - + # Set initial token to same as refresh token policy._last = "same_token" - + result = await policy.on_unauthorized() - + assert result is False # Token didn't change assert policy._last == "same_token" assert provider.refresh_calls == 1 @@ -106,29 +115,31 @@ async def test_on_unauthorized_empty_last_token(self): """Test unauthorized handling with empty last token.""" provider = MockTokenProvider(refresh_token="new_token") policy = BearerAuthPolicy(provider) - + # _last starts as empty string result = await policy.on_unauthorized() - + assert result is True # Empty string != "new_token" assert policy._last == "new_token" @pytest.mark.asyncio async def test_full_flow(self): """Test complete authorization and refresh flow.""" - provider = MockTokenProvider(token="initial_token", refresh_token="refreshed_token") + provider = MockTokenProvider( + token="initial_token", refresh_token="refreshed_token" + ) policy = BearerAuthPolicy(provider) - + # Initial authorization request = httpx.Request("GET", "https://example.com") await policy.authorize(request) assert request.headers["authorization"] == "Bearer initial_token" - + # Unauthorized response triggers refresh changed = await policy.on_unauthorized() assert changed is True assert policy._last == "refreshed_token" - + # New authorization uses refreshed token request2 = httpx.Request("GET", "https://example.com/2") await policy.authorize(request2) @@ -142,29 +153,29 @@ class TestTokenProviderProtocol: async def test_mock_provider_compliance(self): """Test that mock provider implements the protocol correctly.""" provider = MockTokenProvider() - + # Should have async get_token and refresh methods token = await provider.get_token() assert isinstance(token, str) - + refresh_token = await provider.refresh() assert isinstance(refresh_token, str) - @pytest.mark.asyncio + @pytest.mark.asyncio async def test_provider_with_async_mock(self): """Test using AsyncMock for token provider.""" provider = Mock() provider.get_token = AsyncMock(return_value="mocked_token") provider.refresh = AsyncMock(return_value="mocked_refresh") - + policy = BearerAuthPolicy(provider) - + request = httpx.Request("GET", "https://example.com") await policy.authorize(request) - + assert request.headers["authorization"] == "Bearer mocked_token" provider.get_token.assert_called_once() - + result = await policy.on_unauthorized() assert result is True # "" != "mocked_refresh" provider.refresh.assert_called_once() @@ -184,30 +195,30 @@ async def test_integration_with_client_credentials_provider(self): json={ "access_token": "real_integration_token", "expires_in": 3600, - "token_type": "Bearer" - } + "token_type": "Bearer", + }, ) ) - + # Create a real ClientCredentialsProvider provider = ClientCredentialsProvider( token_url="https://auth.example.com/oauth/token", client_id="integration_client", client_secret="integration_secret", - scope="api:read api:write" + scope="api:read api:write", ) - + # Create BearerAuthPolicy with the real provider auth_policy = BearerAuthPolicy(provider) - + # Test authorization request = httpx.Request("GET", "https://api.example.com/data") await auth_policy.authorize(request) - + # Verify the request was authorized correctly assert request.headers["authorization"] == "Bearer real_integration_token" assert token_route.called - + # Verify the OAuth request was made correctly oauth_request = token_route.calls[0].request assert oauth_request.method == "POST" @@ -220,46 +231,44 @@ async def test_integration_with_client_credentials_provider(self): async def test_integration_refresh_flow(self): """Test complete refresh flow with real HTTP mocking.""" call_count = 0 - - def token_response(request): + + def token_response(request: httpx.Request) -> httpx.Response: nonlocal call_count call_count += 1 return httpx.Response( - 200, - json={ - "access_token": f"token_v{call_count}", - "expires_in": 3600 - } + 200, json={"access_token": f"token_v{call_count}", "expires_in": 3600} ) - + # Mock endpoint that returns different tokens token_route = respx.post("https://auth.example.com/oauth/token").mock( side_effect=token_response ) - + provider = ClientCredentialsProvider( token_url="https://auth.example.com/oauth/token", client_id="refresh_client", - client_secret="refresh_secret" + client_secret="refresh_secret", ) - + auth_policy = BearerAuthPolicy(provider) - + # First authorization request1 = httpx.Request("GET", "https://api.example.com/resource1") await auth_policy.authorize(request1) assert request1.headers["authorization"] == "Bearer token_v1" assert token_route.call_count == 1 - + # Simulate unauthorized response and refresh changed = await auth_policy.on_unauthorized() assert changed is True # Token should have changed assert token_route.call_count == 2 - + # New authorization should use refreshed token (cached) request2 = httpx.Request("GET", "https://api.example.com/resource2") await auth_policy.authorize(request2) - assert request2.headers["authorization"] == "Bearer token_v2" # Uses cached refreshed token + assert ( + request2.headers["authorization"] == "Bearer token_v2" + ) # Uses cached refreshed token # Should still be 2 calls since token is cached assert token_route.call_count == 2 @@ -271,24 +280,27 @@ async def test_integration_error_handling(self): respx.post("https://auth.example.com/oauth/token").mock( return_value=httpx.Response( 401, - json={"error": "invalid_client", "error_description": "Client authentication failed"} + json={ + "error": "invalid_client", + "error_description": "Client authentication failed", + }, ) ) - + provider = ClientCredentialsProvider( token_url="https://auth.example.com/oauth/token", client_id="invalid_client", - client_secret="invalid_secret" + client_secret="invalid_secret", ) - + auth_policy = BearerAuthPolicy(provider) - + # Authorization should fail with HTTP error request = httpx.Request("GET", "https://api.example.com/protected") - + with pytest.raises(httpx.HTTPStatusError) as exc_info: await auth_policy.authorize(request) - + assert exc_info.value.response.status_code == 401 @respx.mock @@ -297,27 +309,23 @@ async def test_integration_caching_behavior(self): """Test that token caching works correctly in integration.""" token_route = respx.post("https://auth.example.com/oauth/token").mock( return_value=httpx.Response( - 200, - json={ - "access_token": "cached_token", - "expires_in": 3600 - } + 200, json={"access_token": "cached_token", "expires_in": 3600} ) ) - + provider = ClientCredentialsProvider( token_url="https://auth.example.com/oauth/token", client_id="cache_client", - client_secret="cache_secret" + client_secret="cache_secret", ) - + auth_policy = BearerAuthPolicy(provider) - + # Multiple authorizations should use cached token for i in range(3): request = httpx.Request("GET", f"https://api.example.com/endpoint{i}") await auth_policy.authorize(request) assert request.headers["authorization"] == "Bearer cached_token" - + # Only first call should hit the token endpoint (due to caching) assert token_route.call_count == 1 diff --git a/tests/test_auth_module.py b/tests/test_auth_module.py index 6dd89ed..540e1d4 100644 --- a/tests/test_auth_module.py +++ b/tests/test_auth_module.py @@ -1,7 +1,5 @@ """Tests for auth module imports and exports.""" -import pytest - from gavaconnect import auth from gavaconnect.auth import ( BasicAuthPolicy, @@ -26,17 +24,17 @@ def test_all_exports_available(self): def test_module_has_all_attribute(self): """Test that __all__ is properly defined.""" - assert hasattr(auth, '__all__') + assert hasattr(auth, "__all__") assert isinstance(auth.__all__, list) - + expected_exports = { "BasicAuthPolicy", - "BasicCredentials", + "BasicCredentials", "BearerAuthPolicy", "TokenProvider", "ClientCredentialsProvider", } - + assert set(auth.__all__) == expected_exports def test_module_docstring(self): @@ -46,11 +44,11 @@ def test_module_docstring(self): def test_classes_importable_from_module(self): """Test that classes can be imported from the module.""" - assert hasattr(auth, 'BasicAuthPolicy') - assert hasattr(auth, 'BasicCredentials') - assert hasattr(auth, 'BearerAuthPolicy') - assert hasattr(auth, 'TokenProvider') - assert hasattr(auth, 'ClientCredentialsProvider') + assert hasattr(auth, "BasicAuthPolicy") + assert hasattr(auth, "BasicCredentials") + assert hasattr(auth, "BearerAuthPolicy") + assert hasattr(auth, "TokenProvider") + assert hasattr(auth, "ClientCredentialsProvider") def test_class_types(self): """Test that imported classes are the correct types.""" @@ -58,8 +56,10 @@ def test_class_types(self): from gavaconnect.auth.basic import BasicCredentials as BasicCredentialsOrig from gavaconnect.auth.bearer import BearerAuthPolicy as BearerAuthPolicyOrig from gavaconnect.auth.bearer import TokenProvider as TokenProviderOrig - from gavaconnect.auth.providers import ClientCredentialsProvider as ClientCredentialsProviderOrig - + from gavaconnect.auth.providers import ( + ClientCredentialsProvider as ClientCredentialsProviderOrig, + ) + assert BasicAuthPolicy is BasicAuthPolicyOrig assert BasicCredentials is BasicCredentialsOrig assert BearerAuthPolicy is BearerAuthPolicyOrig diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py index c2cb51a..d0ac81d 100644 --- a/tests/test_auth_providers.py +++ b/tests/test_auth_providers.py @@ -1,7 +1,6 @@ """Tests for token provider implementations.""" import asyncio -import time from unittest.mock import patch import httpx @@ -19,9 +18,9 @@ def test_init_minimal(self): provider = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="test_client", - client_secret="test_secret" + client_secret="test_secret", ) - + assert provider._url == "https://auth.example.com/token" assert provider._cid == "test_client" assert provider._sec == "test_secret" @@ -41,9 +40,9 @@ def test_init_full_parameters(self): client_secret="test_secret", scope="read write", early_refresh_s=120, - client=custom_client + client=custom_client, ) - + assert provider._scope == "read write" assert provider._early == 120 assert provider._client is custom_client @@ -55,32 +54,28 @@ async def test_fetch_success_without_scope(self): # Mock the token endpoint token_route = respx.post("https://auth.example.com/token").mock( return_value=httpx.Response( - 200, - json={ - "access_token": "test_access_token", - "expires_in": 3600 - } + 200, json={"access_token": "test_access_token", "expires_in": 3600} ) ) - + provider = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="test_client", - client_secret="test_secret" + client_secret="test_secret", ) - - with patch('time.time', return_value=1000.0): + + with patch("time.time", return_value=1000.0): token, exp_time = await provider._fetch() - + assert token == "test_access_token" assert exp_time == 1000.0 + max(30.0, 3600 - 60) # 4540.0 - + # Verify the request was made correctly assert token_route.called request = token_route.calls[0].request assert request.method == "POST" assert request.url == "https://auth.example.com/token" - + # Check the form data form_data = dict(httpx.QueryParams(request.content.decode())) assert form_data["grant_type"] == "client_credentials" @@ -92,23 +87,19 @@ async def test_fetch_success_with_scope(self): """Test successful token fetch with scope.""" token_route = respx.post("https://auth.example.com/token").mock( return_value=httpx.Response( - 200, - json={ - "access_token": "scoped_token", - "expires_in": 7200 - } + 200, json={"access_token": "scoped_token", "expires_in": 7200} ) ) - + provider = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="test_client", client_secret="test_secret", - scope="read write admin" + scope="read write admin", ) - + await provider._fetch() - + # Verify scope was included in request assert token_route.called request = token_route.calls[0].request @@ -125,21 +116,21 @@ async def test_fetch_with_custom_expires_in(self): 200, json={ "access_token": "short_lived_token", - "expires_in": 300 # 5 minutes - } + "expires_in": 300, # 5 minutes + }, ) ) - + provider = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="test_client", client_secret="test_secret", - early_refresh_s=60 + early_refresh_s=60, ) - - with patch('time.time', return_value=2000.0): + + with patch("time.time", return_value=2000.0): token, exp_time = await provider._fetch() - + # Should use max(30.0, 300 - 60) = 240 assert exp_time == 2000.0 + 240.0 @@ -153,19 +144,19 @@ async def test_fetch_without_expires_in(self): json={ "access_token": "default_ttl_token" # No expires_in field - } + }, ) ) - + provider = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="test_client", - client_secret="test_secret" + client_secret="test_secret", ) - - with patch('time.time', return_value=3000.0): + + with patch("time.time", return_value=3000.0): token, exp_time = await provider._fetch() - + # Should use default 3600 seconds: max(30.0, 3600 - 60) = 3540 assert exp_time == 3000.0 + 3540.0 @@ -176,13 +167,13 @@ async def test_fetch_http_error(self): respx.post("https://auth.example.com/token").mock( return_value=httpx.Response(401, json={"error": "invalid_client"}) ) - + provider = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="test_client", - client_secret="test_secret" + client_secret="test_secret", ) - + with pytest.raises(httpx.HTTPStatusError): await provider._fetch() @@ -192,18 +183,18 @@ async def test_get_token_first_call(self): provider = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="test_client", - client_secret="test_secret" + client_secret="test_secret", ) - + # Mock the _fetch method directly for this test - async def mock_fetch(): + async def mock_fetch() -> tuple[str, float]: return "fresh_token", 5000.0 - + provider._fetch = mock_fetch - - with patch('time.time', return_value=1000.0): + + with patch("time.time", return_value=1000.0): token = await provider.get_token() - + assert token == "fresh_token" assert provider._token == "fresh_token" assert provider._exp == 5000.0 @@ -214,25 +205,26 @@ async def test_get_token_cached_valid(self): provider = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="test_client", - client_secret="test_secret" + client_secret="test_secret", ) - + # Set up cached token provider._token = "cached_token" provider._exp = 5000.0 - + # Mock _fetch to track if it's called fetch_called = False - async def mock_fetch(): + + async def mock_fetch() -> tuple[str, float]: nonlocal fetch_called fetch_called = True return "new_token", 8000.0 - + provider._fetch = mock_fetch - - with patch('time.time', return_value=4000.0): # Before expiry + + with patch("time.time", return_value=4000.0): # Before expiry token = await provider.get_token() - + assert token == "cached_token" assert not fetch_called @@ -242,24 +234,25 @@ async def test_get_token_cached_expired(self): provider = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="test_client", - client_secret="test_secret" + client_secret="test_secret", ) - + # Set up expired cached token provider._token = "expired_token" provider._exp = 4000.0 - + fetch_called = False - async def mock_fetch(): + + async def mock_fetch() -> tuple[str, float]: nonlocal fetch_called fetch_called = True return "new_token", 8000.0 - + provider._fetch = mock_fetch - - with patch('time.time', return_value=5000.0): # After expiry + + with patch("time.time", return_value=5000.0): # After expiry token = await provider.get_token() - + assert token == "new_token" assert provider._token == "new_token" assert provider._exp == 8000.0 @@ -271,23 +264,24 @@ async def test_refresh(self): provider = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="test_client", - client_secret="test_secret" + client_secret="test_secret", ) - + # Set up existing token provider._token = "old_token" provider._exp = 5000.0 - + fetch_called = False - async def mock_fetch(): + + async def mock_fetch() -> tuple[str, float]: nonlocal fetch_called fetch_called = True return "refreshed_token", 8000.0 - + provider._fetch = mock_fetch - + token = await provider.refresh() - + assert token == "refreshed_token" assert provider._token == "refreshed_token" assert provider._exp == 8000.0 @@ -299,24 +293,24 @@ async def test_concurrent_get_token_calls(self): provider = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="test_client", - client_secret="test_secret" + client_secret="test_secret", ) - + fetch_call_count = 0 - - async def mock_fetch(): + + async def mock_fetch() -> tuple[str, float]: nonlocal fetch_call_count fetch_call_count += 1 await asyncio.sleep(0.1) # Simulate network delay return f"token_{fetch_call_count}", 8000.0 - + provider._fetch = mock_fetch - - with patch('time.time', return_value=1000.0): + + with patch("time.time", return_value=1000.0): # Make multiple concurrent calls tasks = [provider.get_token() for _ in range(5)] tokens = await asyncio.gather(*tasks) - + # All should get the same token assert all(token == "token_1" for token in tokens) # _fetch should only be called once due to the lock @@ -329,33 +323,29 @@ async def test_early_refresh_parameter(self): # Mock responses for both providers respx.post("https://auth.example.com/token").mock( return_value=httpx.Response( - 200, - json={ - "access_token": "test_token", - "expires_in": 3600 - } + 200, json={"access_token": "test_token", "expires_in": 3600} ) ) - + # Test with different early refresh values provider1 = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="test_client", client_secret="test_secret", - early_refresh_s=60 + early_refresh_s=60, ) - + provider2 = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="test_client", client_secret="test_secret", - early_refresh_s=300 + early_refresh_s=300, ) - - with patch('time.time', return_value=1000.0): + + with patch("time.time", return_value=1000.0): _, exp1 = await provider1._fetch() _, exp2 = await provider2._fetch() - + # Provider1: 1000 + max(30, 3600-60) = 1000 + 3540 = 4540 # Provider2: 1000 + max(30, 3600-300) = 1000 + 3300 = 4300 assert exp1 == 4540.0 @@ -370,21 +360,21 @@ async def test_minimum_ttl_enforcement(self): 200, json={ "access_token": "short_token", - "expires_in": 10 # Very short expiry - } + "expires_in": 10, # Very short expiry + }, ) ) - + provider = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="test_client", client_secret="test_secret", - early_refresh_s=60 + early_refresh_s=60, ) - - with patch('time.time', return_value=2000.0): + + with patch("time.time", return_value=2000.0): _, exp_time = await provider._fetch() - + # Should use minimum of 30 seconds: 2000 + max(30, 10-60) = 2000 + 30 = 2030 assert exp_time == 2030.0 @@ -394,29 +384,26 @@ async def test_authentication_headers(self): """Test that authentication headers are sent correctly.""" token_route = respx.post("https://auth.example.com/token").mock( return_value=httpx.Response( - 200, - json={ - "access_token": "test_token", - "expires_in": 3600 - } + 200, json={"access_token": "test_token", "expires_in": 3600} ) ) - + provider = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="test_client", - client_secret="test_secret" + client_secret="test_secret", ) - + await provider._fetch() - + # Verify authentication was sent assert token_route.called request = token_route.calls[0].request assert "authorization" in request.headers - + # Basic auth should be base64 encoded client_id:client_secret import base64 + expected_auth = base64.b64encode(b"test_client:test_secret").decode() assert request.headers["authorization"] == f"Basic {expected_auth}" @@ -426,22 +413,18 @@ async def test_content_type_header(self): """Test that correct content-type header is sent.""" token_route = respx.post("https://auth.example.com/token").mock( return_value=httpx.Response( - 200, - json={ - "access_token": "test_token", - "expires_in": 3600 - } + 200, json={"access_token": "test_token", "expires_in": 3600} ) ) - + provider = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="test_client", - client_secret="test_secret" + client_secret="test_secret", ) - + await provider._fetch() - + assert token_route.called request = token_route.calls[0].request assert request.headers["content-type"] == "application/x-www-form-urlencoded" @@ -452,32 +435,28 @@ async def test_full_integration_flow(self): """Test complete token lifecycle with real HTTP mocking.""" token_route = respx.post("https://auth.example.com/token").mock( return_value=httpx.Response( - 200, - json={ - "access_token": "integration_token", - "expires_in": 3600 - } + 200, json={"access_token": "integration_token", "expires_in": 3600} ) ) - + provider = ClientCredentialsProvider( token_url="https://auth.example.com/token", client_id="integration_client", client_secret="integration_secret", - scope="read write" + scope="read write", ) - - with patch('time.time', return_value=1000.0): + + with patch("time.time", return_value=1000.0): # First call should fetch token token1 = await provider.get_token() assert token1 == "integration_token" assert token_route.call_count == 1 - + # Second call should use cached token token2 = await provider.get_token() assert token2 == "integration_token" assert token_route.call_count == 1 # No additional calls - + # Refresh should force new fetch token3 = await provider.refresh() assert token3 == "integration_token"