Skip to content

Commit 05597ca

Browse files
author
anna-singleton-resolver
committed
refactor: TokenData as frozen dataclass and param validity checks at init
1 parent f27f6a4 commit 05597ca

2 files changed

Lines changed: 25 additions & 17 deletions

File tree

src/resolver_athena_client/client/channel.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import logging
55
import threading
66
import time
7-
from typing import NamedTuple, override
7+
from dataclasses import dataclass
8+
from typing import override
89

910
import grpc
1011
import httpx
@@ -19,19 +20,30 @@
1920
logger = logging.getLogger(__name__)
2021

2122

22-
class TokenData(NamedTuple):
23+
@dataclass(frozen=True)
24+
class TokenData:
2325
"""Immutable snapshot of token state."""
2426

2527
access_token: str
2628
expires_at: float
2729
scheme: str
2830
issued_at: float
31+
proactive_refresh_threshold: float = 0.25
32+
33+
def __post_init__(self) -> None:
34+
"""Validate that proactive_refresh_threshold is between 0 and 1."""
35+
if (
36+
self.proactive_refresh_threshold <= 0
37+
or self.proactive_refresh_threshold >= 1
38+
):
39+
msg = "proactive_refresh_threshold must be between 0 and 1"
40+
raise ValueError(msg)
2941

3042
def is_valid(self) -> bool:
3143
"""Check if this token is still valid (with a 30-second buffer)."""
3244
return time.time() < (self.expires_at - 30)
3345

34-
def is_old(self, proactive_refresh_threshold: float) -> bool:
46+
def is_old(self) -> bool:
3547
"""Check if this token should be proactively refreshed.
3648
3749
A token is considered "old" if less than the
@@ -45,13 +57,12 @@ def is_old(self, proactive_refresh_threshold: float) -> bool:
4557
to trigger proactive refresh (e.g. 0.25 for 25%)
4658
4759
"""
48-
if proactive_refresh_threshold <= 0 or proactive_refresh_threshold >= 1:
49-
msg = "proactive_refresh_threshold must be between 0 and 1"
50-
raise ValueError(msg)
5160
current_time = time.time()
5261
total_lifetime = self.expires_at - self.issued_at
5362
time_remaining = self.expires_at - current_time
54-
return time_remaining < (total_lifetime * proactive_refresh_threshold)
63+
return time_remaining < (
64+
total_lifetime * self.proactive_refresh_threshold
65+
)
5566

5667

5768
class CredentialHelper:
@@ -116,7 +127,7 @@ def get_token(self) -> TokenData:
116127
# Fast path: token is valid and fresh
117128
if token_data is not None and token_data.is_valid():
118129
# If token is old, trigger background refresh
119-
if token_data.is_old(self._proactive_refresh_threshold):
130+
if token_data.is_old():
120131
self._start_background_refresh()
121132
return token_data
122133

@@ -157,10 +168,7 @@ def _start_background_refresh(self) -> None:
157168
or not self._refresh_thread.is_alive()
158169
)
159170
token_needs_refresh = (
160-
self._token_data is None
161-
or self._token_data.is_old(
162-
self._proactive_refresh_threshold
163-
)
171+
self._token_data is None or self._token_data.is_old()
164172
)
165173
refresh_needed = refresh_not_active and token_needs_refresh
166174
if refresh_needed:
@@ -182,9 +190,7 @@ def _background_refresh(self) -> None:
182190
with self._lock:
183191
# Check if token still needs refresh (prevent stampede)
184192
token_data = self._token_data
185-
if token_data is not None and not token_data.is_old(
186-
self._proactive_refresh_threshold
187-
):
193+
if token_data is not None and not token_data.is_old():
188194
# Token was already refreshed by another thread
189195
return
190196

@@ -240,6 +246,7 @@ def _refresh_token(self) -> None:
240246
expires_at=current_time + expires_in,
241247
scheme=scheme,
242248
issued_at=current_time,
249+
proactive_refresh_threshold=self._proactive_refresh_threshold,
243250
)
244251

245252
except httpx.HTTPStatusError as e:

tests/client/test_channel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def test_is_token_valid_with_expired_token(self) -> None:
135135
expires_at=time.time() - 100,
136136
scheme="Bearer",
137137
issued_at=time.time() - 3700,
138+
proactive_refresh_threshold=0.25,
138139
)
139140

140141
assert not helper._token_data.is_valid()
@@ -505,7 +506,7 @@ def test_token_is_old_when_past_halfway_lifetime(self) -> None:
505506
issued_at=current_time - 3_000, # 50 minutes ago
506507
)
507508
# Total lifetime = 3600s, remaining = 600s (1/6th), so it's old
508-
assert token.is_old(0.25)
509+
assert token.is_old()
509510

510511
def test_token_is_not_old_when_fresh(self) -> None:
511512
"""Test that a token is not old when more than 25% lifetime remains."""
@@ -518,7 +519,7 @@ def test_token_is_not_old_when_fresh(self) -> None:
518519
issued_at=current_time - 1200, # 20 minutes ago
519520
)
520521
# Total lifetime = 3600s, remaining = 2400s (67%), so it's fresh
521-
assert not token.is_old(0.25)
522+
assert not token.is_old()
522523

523524
def test_get_token_triggers_background_refresh_for_old_token(self) -> None:
524525
"""Test that get_token triggers background refresh for old tokens."""

0 commit comments

Comments
 (0)