diff --git a/app/products/_account_selection.py b/app/products/_account_selection.py new file mode 100644 index 00000000..5e9c7d65 --- /dev/null +++ b/app/products/_account_selection.py @@ -0,0 +1,63 @@ +"""Shared account selection helpers for products-layer request handlers.""" + +from app.control.model.enums import ModeId +from app.control.model.spec import ModelSpec +from app.control.account.runtime import get_refresh_service +from app.platform.config.snapshot import get_config + + +def mode_candidates(spec: ModelSpec) -> tuple[int, ...]: + """Return mode IDs to try for *spec* in priority order. + + Chat models using ``AUTO`` can optionally fall back to ``FAST`` and then + ``EXPERT`` when the upstream ``auto`` quota window is exhausted but the + account still has usable quota in the other chat windows. + """ + primary = int(spec.mode_id) + if ( + spec.is_chat() + and spec.mode_id == ModeId.AUTO + and get_config("features.auto_chat_mode_fallback", True) + ): + return (primary, int(ModeId.FAST), int(ModeId.EXPERT)) + return (primary,) + + +async def reserve_account( + directory, + spec: ModelSpec, + *, + exclude_tokens: list[str] | None = None, + now_s_override: int | None = None, +): + """Reserve an account and return ``(lease, selected_mode_id)``. + + Returns ``(None, original_mode_id)`` when no account is available. + """ + original_mode_id = int(spec.mode_id) + + async def _try_reserve(): + for candidate_mode_id in mode_candidates(spec): + lease = await directory.reserve( + pool_candidates=spec.pool_candidates(), + mode_id=candidate_mode_id, + now_s_override=now_s_override, + exclude_tokens=exclude_tokens, + ) + if lease is not None: + return lease, candidate_mode_id + return None, original_mode_id + + lease, selected_mode_id = await _try_reserve() + if lease is not None: + return lease, selected_mode_id + + if get_config("account.refresh.on_empty_retry_enabled", True): + refresh_svc = get_refresh_service() + if refresh_svc is not None: + await refresh_svc.refresh_on_demand() + lease, selected_mode_id = await _try_reserve() + if lease is not None: + return lease, selected_mode_id + + return None, original_mode_id diff --git a/app/products/anthropic/messages.py b/app/products/anthropic/messages.py index cacbe97b..3e7064dc 100644 --- a/app/products/anthropic/messages.py +++ b/app/products/anthropic/messages.py @@ -19,6 +19,7 @@ from app.platform.errors import RateLimitError, UpstreamError from app.platform.runtime.clock import now_s from app.platform.tokens import estimate_prompt_tokens, estimate_tokens, estimate_tool_call_tokens +from app.control.model.enums import ModeId from app.control.model.registry import resolve as resolve_model from app.control.account.enums import FeedbackKind from app.dataplane.reverse.protocol.xai_chat import classify_line, StreamAdapter @@ -32,6 +33,7 @@ _quota_sync, _fail_sync, _parse_retry_codes, _feedback_kind, _log_task_exception, _configured_retry_codes, _should_retry_upstream, ) +from app.products._account_selection import reserve_account from app.products.openai._tool_sieve import ToolSieve @@ -319,11 +321,11 @@ async def create( async def _run_stream() -> AsyncGenerator[str, None]: excluded: list[str] = [] for attempt in range(max_retries + 1): - acct = await directory.reserve( - pool_candidates = spec.pool_candidates(), - mode_id = mode_id, - now_s_override = now_s(), - exclude_tokens = excluded or None, + acct, selected_mode_id = await reserve_account( + directory, + spec, + now_s_override=now_s(), + exclude_tokens=excluded or None, ) if acct is None: raise RateLimitError("No available accounts for this model tier") @@ -363,7 +365,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: ended = False async for line in _stream_chat( token = token, - mode_id = spec.mode_id, + mode_id = ModeId(selected_mode_id), message = internal_message, files = files, timeout_s = timeout_s, @@ -589,11 +591,11 @@ async def _run_stream() -> AsyncGenerator[str, None]: else _feedback_kind(fail_exc) if fail_exc else FeedbackKind.SERVER_ERROR ) - await directory.feedback(token, kind, mode_id, now_s_val=now_s()) + await directory.feedback(token, kind, selected_mode_id, now_s_val=now_s()) if success: - asyncio.create_task(_quota_sync(token, mode_id)).add_done_callback(_log_task_exception) + asyncio.create_task(_quota_sync(token, selected_mode_id)).add_done_callback(_log_task_exception) else: - asyncio.create_task(_fail_sync(token, mode_id, fail_exc)).add_done_callback(_log_task_exception) + asyncio.create_task(_fail_sync(token, selected_mode_id, fail_exc)).add_done_callback(_log_task_exception) if success or not _retry: return @@ -610,11 +612,11 @@ async def _run_stream() -> AsyncGenerator[str, None]: adapter = StreamAdapter() for attempt in range(max_retries + 1): - acct = await directory.reserve( - pool_candidates = spec.pool_candidates(), - mode_id = mode_id, - now_s_override = now_s(), - exclude_tokens = excluded or None, + acct, selected_mode_id = await reserve_account( + directory, + spec, + now_s_override=now_s(), + exclude_tokens=excluded or None, ) if acct is None: raise RateLimitError("No available accounts for this model tier") @@ -630,7 +632,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: ended = False async for line in _stream_chat( token = token, - mode_id = spec.mode_id, + mode_id = ModeId(selected_mode_id), message = internal_message, files = files, timeout_s = timeout_s, @@ -666,11 +668,11 @@ async def _run_stream() -> AsyncGenerator[str, None]: else _feedback_kind(fail_exc) if fail_exc else FeedbackKind.SERVER_ERROR ) - await directory.feedback(token, kind, mode_id, now_s_val=now_s()) + await directory.feedback(token, kind, selected_mode_id, now_s_val=now_s()) if success: - asyncio.create_task(_quota_sync(token, mode_id)).add_done_callback(_log_task_exception) + asyncio.create_task(_quota_sync(token, selected_mode_id)).add_done_callback(_log_task_exception) else: - asyncio.create_task(_fail_sync(token, mode_id, fail_exc)).add_done_callback(_log_task_exception) + asyncio.create_task(_fail_sync(token, selected_mode_id, fail_exc)).add_done_callback(_log_task_exception) if success or not _retry: break diff --git a/app/products/openai/chat.py b/app/products/openai/chat.py index f704040a..72021c76 100644 --- a/app/products/openai/chat.py +++ b/app/products/openai/chat.py @@ -50,6 +50,7 @@ build_usage, ) from ._tool_sieve import ToolSieve +from app.products._account_selection import reserve_account def _log_task_exception(task: "asyncio.Task") -> None: @@ -400,9 +401,9 @@ async def completions( async def _run_stream() -> AsyncGenerator[str, None]: excluded: list[str] = [] for attempt in range(max_retries + 1): - acct = await directory.reserve( - pool_candidates=spec.pool_candidates(), - mode_id=mode_id, + acct, selected_mode_id = await reserve_account( + directory, + spec, now_s_override=now_s(), exclude_tokens=excluded or None, ) @@ -422,7 +423,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: tool_calls_emitted = False async for line in _stream_chat( token=token, - mode_id=spec.mode_id, + mode_id=ModeId(selected_mode_id), message=message, files=files, tool_overrides=tool_overrides, @@ -570,14 +571,14 @@ async def _run_stream() -> AsyncGenerator[str, None]: if fail_exc else FeedbackKind.SERVER_ERROR ) - await directory.feedback(token, kind, mode_id, now_s_val=now_s()) + await directory.feedback(token, kind, selected_mode_id, now_s_val=now_s()) if success: asyncio.create_task( - _quota_sync(token, mode_id) + _quota_sync(token, selected_mode_id) ).add_done_callback(_log_task_exception) else: asyncio.create_task( - _fail_sync(token, mode_id, fail_exc) + _fail_sync(token, selected_mode_id, fail_exc) ).add_done_callback(_log_task_exception) if success or not _retry: @@ -591,9 +592,9 @@ async def _run_stream() -> AsyncGenerator[str, None]: token = "" adapter = StreamAdapter() for attempt in range(max_retries + 1): - acct = await directory.reserve( - pool_candidates=spec.pool_candidates(), - mode_id=mode_id, + acct, selected_mode_id = await reserve_account( + directory, + spec, now_s_override=now_s(), exclude_tokens=excluded or None, ) @@ -610,7 +611,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: try: async for line in _stream_chat( token=token, - mode_id=spec.mode_id, + mode_id=ModeId(selected_mode_id), message=message, files=files, tool_overrides=tool_overrides, @@ -654,14 +655,14 @@ async def _run_stream() -> AsyncGenerator[str, None]: if fail_exc else FeedbackKind.SERVER_ERROR ) - await directory.feedback(token, kind, mode_id, now_s_val=now_s()) + await directory.feedback(token, kind, selected_mode_id, now_s_val=now_s()) if success: - asyncio.create_task(_quota_sync(token, mode_id)).add_done_callback( + asyncio.create_task(_quota_sync(token, selected_mode_id)).add_done_callback( _log_task_exception ) else: asyncio.create_task( - _fail_sync(token, mode_id, fail_exc) + _fail_sync(token, selected_mode_id, fail_exc) ).add_done_callback(_log_task_exception) if success or not _retry: diff --git a/app/products/openai/responses.py b/app/products/openai/responses.py index c8c3d7bc..d586a553 100644 --- a/app/products/openai/responses.py +++ b/app/products/openai/responses.py @@ -14,9 +14,11 @@ from app.platform.errors import RateLimitError, UpstreamError from app.platform.runtime.clock import now_s from app.platform.tokens import estimate_prompt_tokens, estimate_tokens, estimate_tool_call_tokens +from app.control.model.enums import ModeId from app.control.model.registry import resolve as resolve_model from app.control.account.enums import FeedbackKind from app.dataplane.reverse.protocol.xai_chat import classify_line, StreamAdapter +from app.products._account_selection import reserve_account from .chat import _stream_chat, _extract_message, _resolve_image, _quota_sync, _fail_sync, _parse_retry_codes, _feedback_kind, _log_task_exception from .chat import _configured_retry_codes, _should_retry_upstream @@ -258,11 +260,11 @@ async def create( async def _run_stream() -> AsyncGenerator[str, None]: excluded: list[str] = [] for attempt in range(max_retries + 1): - acct = await directory.reserve( - pool_candidates = spec.pool_candidates(), - mode_id = mode_id, - now_s_override = now_s(), - exclude_tokens = excluded or None, + acct, selected_mode_id = await reserve_account( + directory, + spec, + now_s_override=now_s(), + exclude_tokens=excluded or None, ) if acct is None: raise RateLimitError("No available accounts for this model tier") @@ -291,7 +293,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: ended = False async for line in _stream_chat( token = token, - mode_id = spec.mode_id, + mode_id = ModeId(selected_mode_id), message = message, files = files, timeout_s = timeout_s, @@ -553,11 +555,11 @@ async def _run_stream() -> AsyncGenerator[str, None]: finally: await directory.release(acct) kind = FeedbackKind.SUCCESS if success else _feedback_kind(fail_exc) if fail_exc else FeedbackKind.SERVER_ERROR - await directory.feedback(token, kind, mode_id, now_s_val=now_s()) + await directory.feedback(token, kind, selected_mode_id, now_s_val=now_s()) if success: - asyncio.create_task(_quota_sync(token, mode_id)).add_done_callback(_log_task_exception) + asyncio.create_task(_quota_sync(token, selected_mode_id)).add_done_callback(_log_task_exception) else: - asyncio.create_task(_fail_sync(token, mode_id, fail_exc)).add_done_callback(_log_task_exception) + asyncio.create_task(_fail_sync(token, selected_mode_id, fail_exc)).add_done_callback(_log_task_exception) if success or not _retry: return @@ -573,11 +575,11 @@ async def _run_stream() -> AsyncGenerator[str, None]: token = "" adapter = StreamAdapter() for attempt in range(max_retries + 1): - acct = await directory.reserve( - pool_candidates = spec.pool_candidates(), - mode_id = mode_id, - now_s_override = now_s(), - exclude_tokens = excluded or None, + acct, selected_mode_id = await reserve_account( + directory, + spec, + now_s_override=now_s(), + exclude_tokens=excluded or None, ) if acct is None: raise RateLimitError("No available accounts for this model tier") @@ -592,7 +594,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: try: async for line in _stream_chat( token = token, - mode_id = spec.mode_id, + mode_id = ModeId(selected_mode_id), message = message, files = files, timeout_s = timeout_s, @@ -623,11 +625,11 @@ async def _run_stream() -> AsyncGenerator[str, None]: finally: await directory.release(acct) kind = FeedbackKind.SUCCESS if success else _feedback_kind(fail_exc) if fail_exc else FeedbackKind.SERVER_ERROR - await directory.feedback(token, kind, mode_id) + await directory.feedback(token, kind, selected_mode_id) if success: - asyncio.create_task(_quota_sync(token, mode_id)).add_done_callback(_log_task_exception) + asyncio.create_task(_quota_sync(token, selected_mode_id)).add_done_callback(_log_task_exception) else: - asyncio.create_task(_fail_sync(token, mode_id, fail_exc)).add_done_callback(_log_task_exception) + asyncio.create_task(_fail_sync(token, selected_mode_id, fail_exc)).add_done_callback(_log_task_exception) if success or not _retry: break diff --git a/config.defaults.toml b/config.defaults.toml index 639cf26d..44071988 100644 --- a/config.defaults.toml +++ b/config.defaults.toml @@ -30,6 +30,10 @@ memory = false stream = true # 是否输出思考过程 thinking = true +# 聊天类 AUTO 模型在 auto 额度耗尽时,是否自动降级到 fast/expert +auto_chat_mode_fallback = true +# 当本地额度缓存判定无可用账号时,是否先触发一次按需刷新再重试 +on_empty_retry_enabled = true # 思考精简输出(false=完整原始推理过程,true=提炼结构化摘要) thinking_summary = false # 是否动态生成 Statsig 指纹 diff --git a/tests/test_account_selection.py b/tests/test_account_selection.py new file mode 100644 index 00000000..cbc073ed --- /dev/null +++ b/tests/test_account_selection.py @@ -0,0 +1,141 @@ +import unittest +from dataclasses import dataclass +from unittest.mock import patch + +from app.control.model.enums import Capability, ModeId, Tier +from app.control.model.spec import ModelSpec +from app.products._account_selection import mode_candidates, reserve_account + + +@dataclass +class _Lease: + token: str + + +class _FakeDirectory: + def __init__(self, available_by_mode: dict[int, str | None]) -> None: + self.available_by_mode = available_by_mode + self.calls: list[int] = [] + + async def reserve( + self, + *, + pool_candidates, + mode_id, + now_s_override=None, + exclude_tokens=None, + ): + self.calls.append(mode_id) + token = self.available_by_mode.get(mode_id) + return _Lease(token) if token else None + + +class _RefreshingDirectory: + def __init__(self) -> None: + self.calls: list[int] = [] + self.after_refresh = False + + async def reserve( + self, + *, + pool_candidates, + mode_id, + now_s_override=None, + exclude_tokens=None, + ): + self.calls.append(mode_id) + if self.after_refresh and mode_id == int(ModeId.AUTO): + return _Lease("auto-token") + return None + + +class _RefreshService: + def __init__(self, directory: _RefreshingDirectory) -> None: + self.directory = directory + self.calls = 0 + + async def refresh_on_demand(self): + self.calls += 1 + self.directory.after_refresh = True + + +class AccountSelectionTests(unittest.IsolatedAsyncioTestCase): + def setUp(self) -> None: + self.auto_chat_spec = ModelSpec( + "grok-4.20-0309", + ModeId.AUTO, + Tier.BASIC, + Capability.CHAT, + True, + "Grok 4.20 0309", + ) + self.fast_chat_spec = ModelSpec( + "grok-4.20-fast", + ModeId.FAST, + Tier.BASIC, + Capability.CHAT, + True, + "Grok 4.20 Fast", + ) + + @patch("app.products._account_selection.get_config", return_value=True) + def test_auto_chat_mode_candidates_include_fallbacks(self, _mock_get_config) -> None: + self.assertEqual( + mode_candidates(self.auto_chat_spec), + (int(ModeId.AUTO), int(ModeId.FAST), int(ModeId.EXPERT)), + ) + + @patch("app.products._account_selection.get_config", return_value=False) + def test_auto_chat_mode_candidates_can_disable_fallback(self, _mock_get_config) -> None: + self.assertEqual(mode_candidates(self.auto_chat_spec), (int(ModeId.AUTO),)) + + @patch("app.products._account_selection.get_config", return_value=True) + def test_non_auto_models_do_not_change_mode_order(self, _mock_get_config) -> None: + self.assertEqual(mode_candidates(self.fast_chat_spec), (int(ModeId.FAST),)) + + @patch("app.products._account_selection.get_config", return_value=True) + async def test_reserve_account_falls_back_to_fast(self, _mock_get_config) -> None: + directory = _FakeDirectory( + { + int(ModeId.AUTO): None, + int(ModeId.FAST): "fast-token", + int(ModeId.EXPERT): "expert-token", + } + ) + + lease, selected_mode_id = await reserve_account(directory, self.auto_chat_spec) + + self.assertIsNotNone(lease) + self.assertEqual(lease.token, "fast-token") + self.assertEqual(selected_mode_id, int(ModeId.FAST)) + self.assertEqual( + directory.calls, + [int(ModeId.AUTO), int(ModeId.FAST)], + ) + + @patch("app.products._account_selection.get_refresh_service") + @patch("app.products._account_selection.get_config") + async def test_reserve_account_retries_after_on_demand_refresh( + self, + mock_get_config, + mock_get_refresh_service, + ) -> None: + directory = _RefreshingDirectory() + refresh_service = _RefreshService(directory) + mock_get_refresh_service.return_value = refresh_service + mock_get_config.side_effect = lambda key, default=None: { + "features.auto_chat_mode_fallback": False, + "account.refresh.on_empty_retry_enabled": True, + }.get(key, default) + + lease, selected_mode_id = await reserve_account(directory, self.auto_chat_spec) + + self.assertIsNotNone(lease) + self.assertEqual(lease.token, "auto-token") + self.assertEqual(selected_mode_id, int(ModeId.AUTO)) + self.assertEqual(refresh_service.calls, 1) + self.assertEqual(directory.calls, [int(ModeId.AUTO), int(ModeId.AUTO)]) + + +if __name__ == "__main__": + unittest.main()