Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions src/quant_platform_kit/longbridge/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import hmac
import json
import time
from datetime import datetime, timezone
from typing import Any, Callable


Expand Down Expand Up @@ -42,6 +43,26 @@ def _longport_sign(method: str, uri: str, headers: dict[str, str], params: str,
return f"HMAC-SHA256 SignedHeaders=authorization;x-api-key;x-timestamp, Signature={signature}"


def _decode_token_expiry(token: str) -> float | None:
try:
parts = token.split(".")
if len(parts) <= 1:
return None
payload_b64 = parts[1]
padded_payload = payload_b64 + "=" * (-len(payload_b64) % 4)
payload = json.loads(base64.urlsafe_b64decode(padded_payload).decode("utf-8"))
expiry = payload.get("exp")
if expiry is None:
return None
return float(expiry)
except Exception:
return None


def _format_expiry(expiry_timestamp: float) -> str:
return datetime.fromtimestamp(expiry_timestamp, timezone.utc).isoformat()


def refresh_token_if_needed(
current_token: str,
*,
Expand All @@ -53,19 +74,20 @@ def refresh_token_if_needed(
requests_module: Any | None = None,
secret_client_factory: Callable[[], Any] | None = None,
) -> str:
expiry_timestamp = _decode_token_expiry(current_token)
now = time.time()

if not app_key or not app_secret:
if expiry_timestamp is not None and expiry_timestamp <= now:
raise RuntimeError(
"LongPort token in secret "
f"'{secret_name}' expired at {_format_expiry(expiry_timestamp)} "
"and cannot be refreshed because LONGPORT_APP_KEY/LONGPORT_APP_SECRET is missing."
)
return current_token

try:
parts = current_token.split(".")
if len(parts) > 1:
payload_b64 = parts[1]
padded_payload = payload_b64 + "=" * (-len(payload_b64) % 4)
payload = json.loads(base64.urlsafe_b64decode(padded_payload).decode("utf-8"))
if (payload.get("exp", 0) - time.time()) / 86400 > refresh_threshold_days:
return current_token
except Exception:
pass
if expiry_timestamp is not None and (expiry_timestamp - now) / 86400 > refresh_threshold_days:
return current_token

if requests_module is None:
import requests as requests_module
Expand All @@ -83,6 +105,13 @@ def refresh_token_if_needed(
timeout=15,
).json()
if response.get("code") != 0:
if expiry_timestamp is not None and expiry_timestamp <= now:
code = response.get("code")
message = response.get("message") or "unknown error"
raise RuntimeError(
f"LongPort token in secret '{secret_name}' expired at {_format_expiry(expiry_timestamp)}; "
f"refresh failed with code {code}: {message}"
)
return current_token

new_token = response["data"]["token"]
Expand Down
63 changes: 63 additions & 0 deletions tests/test_longbridge_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import base64
import json
import sys
import time
import types
import unittest
from unittest.mock import patch
Expand Down Expand Up @@ -62,6 +63,17 @@ def json():
return Response()


class FakeFailedRequests:
@staticmethod
def get(url, headers, timeout):
class Response:
@staticmethod
def json():
return {"code": 401003, "message": "token expired", "data": None}

return Response()


class LongBridgeAuthTests(unittest.TestCase):
def test_fetch_token_from_secret_reads_latest_version(self) -> None:
client = FakeSecretClient("token-abc")
Expand Down Expand Up @@ -109,6 +121,57 @@ def test_refresh_token_if_needed_persists_new_token(self) -> None:
self.assertEqual(client.created_parent, "projects/demo/secrets/token")
self.assertEqual(client.destroyed, ["projects/demo/secrets/token/versions/1"])

def test_refresh_token_if_needed_raises_clear_error_when_expired_and_refresh_fails(self) -> None:
payload = {"exp": 1}
encoded = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8")).decode("utf-8").rstrip("=")
token = f"aaa.{encoded}.bbb"

with self.assertRaises(RuntimeError) as context:
refresh_token_if_needed(
token,
project_id="demo",
secret_name="longport_token_sg",
app_key="key",
app_secret="secret",
requests_module=FakeFailedRequests,
)

self.assertIn("longport_token_sg", str(context.exception))
self.assertIn("refresh failed with code 401003", str(context.exception))

def test_refresh_token_if_needed_returns_same_token_when_refresh_fails_but_token_not_expired(self) -> None:
payload = {"exp": int(time.time()) + 86400}
encoded = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8")).decode("utf-8").rstrip("=")
token = f"aaa.{encoded}.bbb"

refreshed = refresh_token_if_needed(
token,
project_id="demo",
secret_name="token",
app_key="key",
app_secret="secret",
refresh_threshold_days=30,
requests_module=FakeFailedRequests,
)

self.assertEqual(refreshed, token)

def test_refresh_token_if_needed_raises_clear_error_when_expired_and_app_credentials_missing(self) -> None:
payload = {"exp": 1}
encoded = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8")).decode("utf-8").rstrip("=")
token = f"aaa.{encoded}.bbb"

with self.assertRaises(RuntimeError) as context:
refresh_token_if_needed(
token,
project_id="demo",
secret_name="longport_token_sg",
app_key="",
app_secret="",
)

self.assertIn("LONGPORT_APP_KEY/LONGPORT_APP_SECRET is missing", str(context.exception))

def test_build_contexts_uses_longport_openapi(self) -> None:
longport_module = types.ModuleType("longport")
openapi_module = types.ModuleType("longport.openapi")
Expand Down