diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index c049c5179..b3fe502f3 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -12,7 +12,6 @@ on: env: REGISTRY: ghcr.io - IMAGE_NAME: ${{ github.repository }} jobs: build-and-push: @@ -40,11 +39,15 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} + - name: Normalize image name + id: image + run: echo "name=${GITHUB_REPOSITORY,,}" >> "$GITHUB_OUTPUT" + - name: Extract metadata id: meta uses: docker/metadata-action@v5 with: - images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + images: ${{ env.REGISTRY }}/${{ steps.image.outputs.name }} tags: | type=raw,value=latest,enable=${{ github.ref_type == 'branch' }} type=ref,event=tag @@ -60,4 +63,4 @@ jobs: labels: ${{ steps.meta.outputs.labels }} cache-from: type=gha cache-to: type=gha,mode=max - pull: true \ No newline at end of file + pull: true diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml index 25c6b6964..f6d72dafc 100644 --- a/.github/workflows/security.yml +++ b/.github/workflows/security.yml @@ -19,7 +19,7 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 - + # 针对 PR:使用原生逻辑 - name: Setup PR commits & Run Gitleaks if: github.event_name == 'pull_request' @@ -29,7 +29,7 @@ jobs: git remote add pr-head "https://github.com/${{ github.event.pull_request.head.repo.full_name }}.git" git fetch --no-tags --prune --depth=1 pr-head \ "+${{ github.event.pull_request.head.sha }}:refs/remotes/pr-head/current" - + docker run --rm \ -v "$PWD:/repo" \ -w /repo \ @@ -54,4 +54,4 @@ jobs: - name: Export requirements run: uv export --frozen --no-dev --format requirements-txt -o /tmp/req.txt - name: Run pip-audit - run: uvx pip-audit -r /tmp/req.txt --progress-spinner off \ No newline at end of file + run: uvx pip-audit -r /tmp/req.txt --progress-spinner off diff --git a/.gitignore b/.gitignore index e94a37e41..652d6743b 100644 --- a/.gitignore +++ b/.gitignore @@ -53,3 +53,4 @@ htmlcov/ # Project specific .ace-tool/ +.tmp_live/ diff --git a/README.md b/README.md index 0363662d8..29a6131ed 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,7 @@ docker compose up -d ### Vercel -[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https://github.com/chenyme/grok2api&env=LOG_LEVEL,LOG_FILE_ENABLED,DATA_DIR,LOG_DIR,ACCOUNT_STORAGE,ACCOUNT_REDIS_URL,ACCOUNT_MYSQL_URL,ACCOUNT_POSTGRESQL_URL) +[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https://github.com/chenyme/grok2api&env=LOG_LEVEL,LOG_FILE_ENABLED,DATA_DIR,LOG_DIR,ACCOUNT_STORAGE,ACCOUNT_REDIS_URL,ACCOUNT_MYSQL_URL,ACCOUNT_POSTGRESQL_URL,CONFIG_STORAGE,CONFIG_LOCAL_PATH) ### Render @@ -187,7 +187,8 @@ docker compose up -d | `ACCOUNT_SQL_MAX_OVERFLOW` | SQL 连接池最大溢出连接数 | `10` | | `ACCOUNT_SQL_POOL_TIMEOUT` | 等待连接池空闲连接的超时时间(秒) | `30` | | `ACCOUNT_SQL_POOL_RECYCLE` | 连接最大复用时间(秒),超时后自动重连 | `1800` | -| `CONFIG_LOCAL_PATH` | `local` 模式运行时配置文件路径 | `${DATA_DIR}/config.toml` | +| `CONFIG_STORAGE` | 运行时配置存储后端;默认本地 TOML,不跟随 `ACCOUNT_STORAGE` | `local` | +| `CONFIG_LOCAL_PATH` | 本地运行时配置文件路径 | `${DATA_DIR}/config.toml` | 运行时配置也支持 `GROK_` 前缀环境变量覆盖,例如 `GROK_APP_API_KEY` 会覆盖 `app.api_key`,`GROK_FEATURES_STREAM` 会覆盖 `features.stream`。 @@ -362,15 +363,31 @@ curl http://localhost:8000/v1/chat/completions \ "model": "grok-imagine-video", "stream": true, "messages": [ - {"role":"user","content":"霓虹雨夜街头,电影感慢镜头追拍"} + {"role":"user","content":"霓虹雨夜街头,电影感慢镜头追拍"}, + {"role":"user","content":"镜头穿过霓虹招牌,人物回头看向远处车灯"} ], - "video_config": { + "metadata": { + "video_config": { + "seconds": 16, + "size": "1792x1024", + "resolution_name": "720p", + "preset": "normal" + } + } + }' +``` + +`image_config` / `video_config` 也可以放在请求顶层;顶层字段优先于 `metadata` 内的同名配置。 + +```json +{ + "video_config": { "seconds": 10, "size": "1792x1024", "resolution_name": "720p", "preset": "normal" - } - }' + } +} ```
@@ -390,10 +407,11 @@ curl http://localhost:8000/v1/chat/completions \ | \|_ `size` | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | | \|_ `response_format` | `url`, `b64_json` | | `video_config` | 视频模型参数 | -| \|_ `seconds` | `6`, `10`, `12`, `16`, `20` | +| \|_ `seconds` | `6`, `10`, `12`, `16`, `20`, `22`, `26`, `30`, `32`, `36`, `40`;长视频需要按分段数量提供同数量 user 提示词 | | \|_ `size` | `720x1280`, `1280x720`, `1024x1024`, `1024x1792`, `1792x1024` | | \|_ `resolution_name` | `480p`, `720p` | | \|_ `preset` | `fun`, `normal`, `spicy`, `custom` | +| `metadata.image_config` / `metadata.video_config` | `image_config` / `video_config` 的兼容位置;顶层配置优先 |
@@ -589,7 +607,7 @@ curl -L http://localhost:8000/v1/videos//content \ | :-- | :-- | | `model` | 视频模型,目前为 `grok-imagine-video` | | `prompt` | 视频生成提示词 | -| `seconds` | 视频长度:`6`, `10`, `12`, `16`, `20` | +| `seconds` | 视频长度:`6`, `10`;更长视频请使用 `/v1/chat/completions` | | `size` | 支持 `720x1280`, `1280x720`, `1024x1024`, `1024x1792`, `1792x1024` | | `resolution_name` | `480p` 或 `720p` | | `preset` | `fun`, `normal`, `spicy`, `custom` | diff --git a/app/control/account/backends/local.py b/app/control/account/backends/local.py index 85eb3971f..ffd7f6274 100644 --- a/app/control/account/backends/local.py +++ b/app/control/account/backends/local.py @@ -55,6 +55,7 @@ def _init_sync(self) -> None: CREATE TABLE IF NOT EXISTS {_TBL} ( token TEXT NOT NULL PRIMARY KEY, + account_id TEXT, pool TEXT NOT NULL DEFAULT 'basic', status TEXT NOT NULL DEFAULT 'active', created_at INTEGER NOT NULL, @@ -85,7 +86,13 @@ def _init_sync(self) -> None: CREATE INDEX IF NOT EXISTS idx_acc_deleted ON {_TBL} (deleted_at) WHERE deleted_at IS NOT NULL; """) + # Migration: add account_id column if missing. + self._ensure_column_sync(conn, "account_id", "TEXT") self._ensure_column_sync(conn, "quota_grok_4_3", "TEXT NOT NULL DEFAULT '{}'") + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_acc_account_id " + f"ON {_TBL} (account_id) WHERE account_id IS NOT NULL" + ) conn.commit() @staticmethod @@ -112,6 +119,7 @@ def _get_revision_sync(self, conn: sqlite3.Connection) -> int: @staticmethod def _row_to_record(row: sqlite3.Row) -> AccountRecord: d = dict(row) + d["account_id"] = d.get("account_id") or None d["tags"] = json.loads(d.get("tags") or "[]") heavy_raw = d.pop("quota_heavy", "{}") or "{}" grok_4_3_raw = d.pop("quota_grok_4_3", "{}") or "{}" @@ -132,6 +140,7 @@ def _record_to_row(record: AccountRecord, revision: int) -> dict[str, Any]: qs = record.quota_set() return { "token": record.token, + "account_id": record.account_id, "pool": record.pool, "status": record.status.value, "created_at": record.created_at, @@ -171,19 +180,36 @@ def _upsert_sync( continue pool = item.pool if item.pool in ("basic", "super", "heavy") else "basic" qs = default_quota_set(pool) + account_id = item.account_id or None + + # Dedup by account_id: soft-delete old token records with same account_id. + if account_id: + dups = conn.execute( + f"SELECT token FROM {_TBL} " + f"WHERE account_id = ? AND token != ? AND deleted_at IS NULL", + (account_id, token), + ).fetchall() + for dup in dups: + conn.execute( + f"UPDATE {_TBL} SET deleted_at = ?, updated_at = ?, revision = ? " + f"WHERE token = ?", + (ts, ts, revision, dup[0]), + ) + conn.execute( f""" INSERT INTO {_TBL} ( - token, pool, status, created_at, updated_at, + token, account_id, pool, status, created_at, updated_at, tags, quota_auto, quota_fast, quota_expert, quota_heavy, quota_grok_4_3, usage_use_count, usage_fail_count, usage_sync_count, ext, revision ) VALUES ( - :token, :pool, 'active', :ts, :ts, + :token, :account_id, :pool, 'active', :ts, :ts, :tags, :qa, :qf, :qe, :qh, :qg, 0, 0, 0, :ext, :rev ) ON CONFLICT(token) DO UPDATE SET + account_id = COALESCE(excluded.account_id, account_id), pool = excluded.pool, status = 'active', deleted_at = NULL, @@ -193,17 +219,18 @@ def _upsert_sync( revision = excluded.revision """, { - "token": token, - "pool": pool, - "ts": ts, - "tags": json.dumps(item.tags), - "qa": json.dumps(qs.auto.to_dict()), - "qf": json.dumps(qs.fast.to_dict()), - "qe": json.dumps(qs.expert.to_dict()), - "qh": json.dumps(qs.heavy.to_dict()) if qs.heavy else "{}", - "qg": json.dumps(qs.grok_4_3.to_dict()) if qs.grok_4_3 else "{}", - "ext": json.dumps(item.ext), - "rev": revision, + "token": token, + "account_id": account_id, + "pool": pool, + "ts": ts, + "tags": json.dumps(item.tags), + "qa": json.dumps(qs.auto.to_dict()), + "qf": json.dumps(qs.fast.to_dict()), + "qe": json.dumps(qs.expert.to_dict()), + "qh": json.dumps(qs.heavy.to_dict()) if qs.heavy else "{}", + "qg": json.dumps(qs.grok_4_3.to_dict()) if qs.grok_4_3 else "{}", + "ext": json.dumps(item.ext), + "rev": revision, }, ) count += conn.execute("SELECT changes()").fetchone()[0] @@ -229,6 +256,8 @@ def _patch_sync( sets: dict[str, Any] = {"updated_at": ts, "revision": revision} + if patch.account_id is not None: + sets["account_id"] = patch.account_id if patch.pool is not None: sets["pool"] = patch.pool if patch.status is not None: diff --git a/app/control/account/backends/redis.py b/app/control/account/backends/redis.py index 68c4d4b49..3c3aece09 100644 --- a/app/control/account/backends/redis.py +++ b/app/control/account/backends/redis.py @@ -55,6 +55,7 @@ def __init__(self, redis: "Redis") -> None: def _to_hash(record: AccountRecord, revision: int) -> dict[str, str]: qs = record.quota_set() return { + "account_id": record.account_id or "", "pool": record.pool, "status": record.status.value, "created_at": str(record.created_at), @@ -64,6 +65,7 @@ def _to_hash(record: AccountRecord, revision: int) -> dict[str, str]: "quota_fast": json.dumps(qs.fast.to_dict()), "quota_expert": json.dumps(qs.expert.to_dict()), "quota_heavy": json.dumps(qs.heavy.to_dict()) if qs.heavy else "{}", + "quota_grok_4_3": json.dumps(qs.grok_4_3.to_dict()) if qs.grok_4_3 else "{}", "usage_use_count": str(record.usage_use_count), "usage_fail_count": str(record.usage_fail_count), "usage_sync_count": str(record.usage_sync_count), @@ -88,8 +90,10 @@ def _i(k: str) -> int | None: v = _s(k) return int(v) if v else None + aid = _s("account_id") return AccountRecord.model_validate({ "token": token, + "account_id": aid if aid else None, "pool": _s("pool") or "basic", "status": _s("status") or "active", "created_at": _i("created_at") or now_ms(), @@ -102,6 +106,9 @@ def _i(k: str) -> int | None: **({ "heavy": json.loads(_s("quota_heavy")) } if _s("quota_heavy") and _s("quota_heavy") != "{}" else {}), + **({ + "grok_4_3": json.loads(_s("quota_grok_4_3")) + } if _s("quota_grok_4_3") and _s("quota_grok_4_3") != "{}" else {}), }, "usage_use_count": int(_s("usage_use_count") or 0), "usage_fail_count": int(_s("usage_fail_count") or 0), @@ -207,14 +214,40 @@ async def upsert_accounts( pool = item.pool if item.pool in ("basic", "super", "heavy") else "basic" qs = default_quota_set(pool) ts = now_ms() + account_id = item.account_id or None + + # Dedup by account_id: scan for existing records with same account_id. + if account_id: + async for key in self._r.scan_iter("accounts:record:*"): + h = await self._r.hgetall(key) + if not h: + continue + existing_aid = (h.get(b"account_id") or h.get("account_id") or b"") + if isinstance(existing_aid, bytes): + existing_aid = existing_aid.decode() + if existing_aid == account_id: + dup_token = (key.decode() if isinstance(key, bytes) else key).split(":", 2)[-1] + if dup_token != token: + await self._r.hset(key, mapping={ + "deleted_at": str(ts), + "updated_at": str(ts), + "revision": str(rev), + }) + # Remove from pool set and add to deleted pool tracking. + dup_pool = (h.get(b"pool") or h.get("pool") or b"basic") + if isinstance(dup_pool, bytes): + dup_pool = dup_pool.decode() + await self._r.srem(_pool_key(dup_pool), dup_token) + record = AccountRecord( - token = token, - pool = pool, - tags = item.tags, - ext = item.ext, - quota = qs.to_dict(), - created_at = ts, - updated_at = ts, + token = token, + account_id = account_id, + pool = pool, + tags = item.tags, + ext = item.ext, + quota = qs.to_dict(), + created_at = ts, + updated_at = ts, ) key = _record_key(token) await self._r.hset(key, mapping=self._to_hash(record, rev)) @@ -258,6 +291,8 @@ async def patch_accounts( updates["last_sync_at"] = str(patch.last_sync_at) if patch.last_clear_at is not None: updates["last_clear_at"] = str(patch.last_clear_at) + if patch.account_id is not None: + updates["account_id"] = patch.account_id if patch.pool is not None: updates["pool"] = patch.pool if patch.quota_auto is not None: @@ -268,6 +303,8 @@ async def patch_accounts( updates["quota_expert"] = json.dumps(patch.quota_expert) if patch.quota_heavy is not None: updates["quota_heavy"] = json.dumps(patch.quota_heavy) + if patch.quota_grok_4_3 is not None: + updates["quota_grok_4_3"] = json.dumps(patch.quota_grok_4_3) # Usage counters. if patch.usage_use_delta is not None: diff --git a/app/control/account/backends/sql.py b/app/control/account/backends/sql.py index 66b605bb6..aca46705a 100644 --- a/app/control/account/backends/sql.py +++ b/app/control/account/backends/sql.py @@ -37,6 +37,7 @@ _TBL_ACCOUNTS, metadata, sa.Column("token", sa.String(512), primary_key=True), + sa.Column("account_id", sa.String(64), nullable=True, index=True), # xaiUserId UUID for dedup sa.Column("pool", sa.Text, nullable=False, default="basic"), sa.Column("status", sa.Text, nullable=False, default="active"), sa.Column("created_at", sa.BigInteger, nullable=False), @@ -404,6 +405,8 @@ def _evict_cached_engine(engine: AsyncEngine) -> None: def _row_to_record(row: Any) -> AccountRecord: d = dict(row._mapping) + # Normalise account_id — stored as string, can be empty. + d["account_id"] = d.get("account_id") or None d["tags"] = json.loads(d.get("tags") or "[]") heavy_raw = d.pop("quota_heavy", "{}") or "{}" grok_4_3_raw = d.pop("quota_grok_4_3", "{}") or "{}" @@ -522,6 +525,23 @@ async def _do_initialize(self) -> None: async def _ensure_columns(self, conn: Any) -> None: """Idempotent ALTER TABLE migrations for columns added after the initial schema.""" existing = await self._table_columns(conn, _TBL_ACCOUNTS) + if "account_id" not in existing: + if self._dialect == "mysql": + await conn.exec_driver_sql( + f"ALTER TABLE {_TBL_ACCOUNTS} " + f"ADD COLUMN account_id VARCHAR(64) NULL" + ) + await conn.exec_driver_sql( + f"CREATE INDEX idx_accounts_account_id ON {_TBL_ACCOUNTS} (account_id)" + ) + else: + await conn.exec_driver_sql( + f"ALTER TABLE {_TBL_ACCOUNTS} " + f"ADD COLUMN account_id VARCHAR(64)" + ) + await conn.exec_driver_sql( + f"CREATE INDEX idx_accounts_account_id ON {_TBL_ACCOUNTS} (account_id)" + ) if "quota_grok_4_3" not in existing: if self._dialect == "mysql": # MySQL forbids DEFAULT values on TEXT/BLOB columns; @@ -629,8 +649,10 @@ async def upsert_accounts( continue pool = item.pool if item.pool in ("basic", "super", "heavy") else "basic" qs = default_quota_set(pool) + account_id = item.account_id or None row = { "token": token, + "account_id": account_id, "pool": pool, "status": "active", "created_at": ts, @@ -648,6 +670,29 @@ async def upsert_accounts( "ext": json.dumps(item.ext), "revision": rev, } + # Dedup by account_id: if this account_id already exists under a + # *different* token, soft-delete the old token record so we + # maintain one canonical token per xaiUserId. + if account_id: + existing = (await conn.execute( + sa.select(accounts_table.c.token).where( + accounts_table.c.account_id == account_id, + accounts_table.c.token != token, + accounts_table.c.deleted_at.is_(None), + ) + )).fetchall() + for dup in existing: + dup_token = dup[0] + await conn.execute( + accounts_table.update() + .where(accounts_table.c.token == dup_token) + .values(deleted_at=ts, updated_at=ts, revision=rev) + ) + logger.info( + "account deduplicated by account_id: old_token={}... new_token={}... account_id={}", + dup_token[:10], token[:10], account_id, + ) + await conn.execute(self._build_upsert(row)) count += 1 return AccountMutationResult(upserted=count, revision=rev) @@ -672,6 +717,8 @@ async def patch_accounts( record = _row_to_record(row) updates: dict[str, Any] = {"updated_at": ts, "revision": rev} + if patch.account_id is not None: + updates["account_id"] = patch.account_id if patch.pool is not None: updates["pool"] = patch.pool if patch.status is not None: diff --git a/app/control/account/commands.py b/app/control/account/commands.py index ba5011432..c62d103dc 100644 --- a/app/control/account/commands.py +++ b/app/control/account/commands.py @@ -14,6 +14,7 @@ class AccountUpsert(BaseModel): """ token: str + account_id: str | None = None # Grok xaiUserId for deduplication pool: str = "basic" tags: list[str] = Field(default_factory=list) ext: dict[str, Any] = Field(default_factory=dict) @@ -26,6 +27,7 @@ class AccountPatch(BaseModel): """ token: str + account_id: str | None = None # update account ID (from subscription API) pool: str | None = None # update pool type when inferred status: AccountStatus | None = None tags: list[str] | None = None diff --git a/app/control/account/models.py b/app/control/account/models.py index cb8e8e17f..dc6288d50 100644 --- a/app/control/account/models.py +++ b/app/control/account/models.py @@ -176,6 +176,7 @@ class AccountRecord(BaseModel): """ token: str + account_id: str | None = None # Grok's xaiUserId — UUID, used for deduplication pool: str = "basic" status: AccountStatus = AccountStatus.ACTIVE created_at: int = Field(default_factory=now_ms) diff --git a/app/control/account/quota_defaults.py b/app/control/account/quota_defaults.py index eb3250c8d..a8e913316 100644 --- a/app/control/account/quota_defaults.py +++ b/app/control/account/quota_defaults.py @@ -75,6 +75,7 @@ def _w(remaining: int, total: int, window_seconds: int) -> QuotaWindow: "basic": frozenset((1,)), "super": frozenset((0, 1, 2, 4)), "heavy": frozenset((0, 1, 2, 3, 4)), + "auto": frozenset((0, 1, 2, 3, 4)), # 探测用全集,确保 infer_pool 能拿到 auto.total } # --------------------------------------------------------------------------- diff --git a/app/control/account/refresh.py b/app/control/account/refresh.py index f95a1baa6..71313984b 100644 --- a/app/control/account/refresh.py +++ b/app/control/account/refresh.py @@ -2,7 +2,7 @@ import asyncio from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Any, TYPE_CHECKING from app.platform.errors import UpstreamError from app.platform.config.snapshot import get_config @@ -126,17 +126,91 @@ async def _fetch_mode_quota( ) return None + # ------------------------------------------------------------------ + # Subscription API fetch (account type + account_id) + # ------------------------------------------------------------------ + + async def _fetch_subscription(self, token: str): + """Fetch subscription info (account type + xaiUserId) for *token*.""" + try: + from app.dataplane.reverse.protocol.xai_subscription import ( + fetch_subscription, + ) + return await fetch_subscription(token) + except UpstreamError: + raise + except Exception as exc: + logger.debug( + "subscription fetch failed: token={}... error={}", + token[:10], exc, + ) + return None + + async def _refresh_subscription(self, record: AccountRecord) -> None: + """Fetch subscription info, persist account_id, and store tier hint for Phase 2. + + Does NOT set pool — that's Phase 2's job (multi-factor: rate-limits + subscription). + Stores ``sub_tier`` and ``sub_active`` in ext so _refresh_one can cross-validate. + """ + if record.is_deleted(): + return + try: + info = await self._fetch_subscription(record.token) + except UpstreamError as exc: + if await self._expire_invalid_credentials(record, exc): + return + raise + except Exception: + return + + if info is None: + return + + from .commands import AccountPatch + + patch_fields: dict[str, Any] = {} + ext_merge: dict[str, Any] = { + "sub_tier": info.tier, + "sub_active": info.is_active, + } + + if info.xai_user_id and not record.account_id: + patch_fields["account_id"] = info.xai_user_id + + if patch_fields or ext_merge: + patch_fields["ext_merge"] = ext_merge + await self._repo.patch_accounts([ + AccountPatch(token=record.token, **patch_fields) + ]) + logger.info( + "subscription info updated: token={}... account_id={} tier={} active={}", + record.token[:10], + info.xai_user_id or record.account_id, + info.tier, info.is_active, + ) + # ------------------------------------------------------------------ # Core refresh logic # ------------------------------------------------------------------ async def refresh_on_import(self, tokens: list[str]) -> RefreshResult: - """Called after bulk import — sync real quotas for all accounts.""" + """Called after bulk import — sync real quotas and subscription info for all accounts.""" records = await self._repo.get_accounts(tokens) active = [r for r in records if is_manageable(r)] if not active: return RefreshResult(checked=len(records)) + # Phase 1: Fetch subscription info (account_id + tier) for accounts missing it. + sub_results = await run_batch( + [r for r in active if not r.account_id], + lambda r: self._refresh_subscription(r), + concurrency=get_config("account.refresh.usage_concurrency", 50), + ) + # Refresh the records after subscription updates. + records = await self._repo.get_accounts(tokens) + active = [r for r in records if is_manageable(r)] + + # Phase 2: Fetch quota windows. concurrency = get_config("account.refresh.usage_concurrency", 50) results = await run_batch( active, @@ -145,7 +219,8 @@ async def refresh_on_import(self, tokens: list[str]) -> RefreshResult: ) agg = RefreshResult(checked=len(records)) for r in results: - agg.merge(r) + if r: + agg.merge(r) return agg async def refresh_call_async(self, token: str, mode_id: int) -> None: @@ -236,8 +311,11 @@ async def _refresh_one( if record.is_deleted(): return RefreshResult() + # For basic accounts, fetch all modes (like "auto" pool) so infer_pool + # can see auto.total and correctly upgrade to super/heavy when warranted. + fetch_pool = "auto" if record.pool == "basic" else record.pool try: - windows = await self._fetch_all_quotas(record.token, record.pool) + windows = await self._fetch_all_quotas(record.token, fetch_pool) except UpstreamError as exc: if await self._expire_invalid_credentials(record, exc): return RefreshResult(checked=1, expired=1, failed=0) @@ -256,10 +334,17 @@ async def _refresh_one( patches: dict[str, dict] = {} refreshed = False + # Use the inferred pool for quota normalization when it represents + # an upgrade (e.g. basic→super). Without this, auto/expert/grok_4_3 + # quota data fetched for basic accounts gets discarded by + # normalize_quota_window (basic pool doesn't support those modes). + inferred = infer_pool(windows) # type: ignore[arg-type] + effective_pool = inferred if inferred != "basic" else record.pool + for mode in ALL_MODES_FULL: mode_id = int(mode) if mode_id in windows: - window = normalize_quota_window(record.pool, mode_id, windows[mode_id]) + window = normalize_quota_window(effective_pool, mode_id, windows[mode_id]) if window is None: continue patches[_MODE_KEYS[mode_id]] = window.to_dict() @@ -293,12 +378,38 @@ async def _refresh_one( if not patches: return RefreshResult(checked=1, failed=0 if refreshed else 1) - # Infer pool type from live quota data and patch if it changed. - inferred = infer_pool(windows) # type: ignore[arg-type] + # ── Multi-factor pool inference ────────────────────────────── + # 'inferred' was set above from rate-limits auto.total (primary). + # Now cross-validate with subscription data (fallback). + # + # Scenarios handled: + # auto=50, any sub status → super (rate-limits wins) + # auto=7, INACTIVE sub → basic (genuinely downgraded) + # auto=?, ACTIVE sub → super (subscription wins, API fluke) + # auto=7, ACTIVE sub (anomaly) → super (active sub > anomalous quota) + if inferred == "basic": + sub_tier = record.ext.get("sub_tier", "") + sub_active = record.ext.get("sub_active", False) + if sub_active and sub_tier in ( + "SUBSCRIPTION_TIER_GROK_PRO", + "SUBSCRIPTION_TIER_SUPER_GROK", + "SUBSCRIPTION_TIER_SUPER_GROK_PRO", + "SUBSCRIPTION_TIER_GROK_PRO_HEAVY", + "SUBSCRIPTION_TIER_SUPER_GROK_LITE", + "SUBSCRIPTION_TIER_GROK_TEAMS", + ): + from app.dataplane.reverse.protocol.xai_subscription import tier_to_pool as _t2p + inferred = _t2p(sub_tier) + logger.info( + "account pool upgraded by subscription: token={}... " + "rate_limits_pool=basic subscription_tier={} -> pool={}", + record.token[:10], sub_tier, inferred, + ) + pool_patch = inferred if inferred != record.pool else None if pool_patch: logger.info( - "account pool updated from live quota: token={}... previous_pool={} current_pool={}", + "account pool updated: token={}... previous_pool={} current_pool={}", record.token[:10], record.pool, inferred, diff --git a/app/control/proxy/__init__.py b/app/control/proxy/__init__.py index 9d7581d60..1af2dcd55 100644 --- a/app/control/proxy/__init__.py +++ b/app/control/proxy/__init__.py @@ -228,12 +228,8 @@ async def _pick_proxy_url(self, resource: bool = False) -> str | None: if self._egress_mode == EgressMode.DIRECT: return None async with self._lock: - # Prefer resource-specific nodes when available; fall back to base nodes. - nodes = ( - self._resource_nodes - if resource and self._resource_nodes - else self._nodes - ) + # Resource downloads are direct unless a resource proxy is explicit. + nodes = self._resource_nodes if resource else self._nodes if not nodes: return None if self._egress_mode == EgressMode.SINGLE_PROXY: diff --git a/app/dataplane/reverse/protocol/xai_subscription.py b/app/dataplane/reverse/protocol/xai_subscription.py new file mode 100644 index 000000000..e3753c641 --- /dev/null +++ b/app/dataplane/reverse/protocol/xai_subscription.py @@ -0,0 +1,259 @@ +"""XAI subscription / account-type protocol — fetch subscription tier and account ID. + +Provides the official account type query by calling ``GET /rest/subscriptions``. +Returns the user's subscription tier (basic / super / heavy) and unique account +ID (``xaiUserId``) for deduplication purposes. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from app.platform.errors import UpstreamError +from app.platform.logging.logger import logger + + +# --------------------------------------------------------------------------- +# Tier mapping — Grok subscription tiers → Grok2API pool names +# --------------------------------------------------------------------------- + +# Known subscription tier strings returned by grok.com/rest/subscriptions. +_TIER_TO_POOL: dict[str, str] = { + "SUBSCRIPTION_TIER_UNKNOWN": "basic", + "SUBSCRIPTION_TIER_FREE": "basic", + "SUBSCRIPTION_TIER_GROK_PRO": "super", + "SUBSCRIPTION_TIER_SUPER_GROK": "super", + "SUBSCRIPTION_TIER_SUPER_GROK_PRO": "heavy", + "SUBSCRIPTION_TIER_GROK_PRO_HEAVY": "heavy", + "SUBSCRIPTION_TIER_SUPER_GROK_LITE": "super", + "SUBSCRIPTION_TIER_GROK_TEAMS": "super", +} + +# Active subscription statuses. +_ACTIVE_STATUSES: frozenset[str] = frozenset({ + "SUBSCRIPTION_STATUS_ACTIVE", + "SUBSCRIPTION_STATUS_TRIAL", + "SUBSCRIPTION_STATUS_TRIALING", +}) + + +# --------------------------------------------------------------------------- +# Result type +# --------------------------------------------------------------------------- + + +@dataclass(slots=True) +class SubscriptionInfo: + """Parsed account-type information from the subscription API. + + ``xai_user_id`` — unique account identifier (UUID); ``None`` if unavailable. + ``pool`` — Grok2API pool name (``"basic"`` / ``"super"`` / ``"heavy"``). + ``tier`` — raw upstream tier string (e.g. ``"SUBSCRIPTION_TIER_GROK_PRO"``). + ``is_active`` — whether at least one subscription is currently active. + ``raw`` — the full parsed JSON dict for debugging / extension use. + """ + + xai_user_id: str | None + pool: str + tier: str + is_active: bool + raw: dict | None + + +def tier_to_pool(tier: str) -> str: + """Map an upstream subscription tier string to a Grok2API pool name.""" + return _TIER_TO_POOL.get(tier, "basic") + + +def is_active_status(status: str) -> bool: + """Return True if *status* represents an active subscription.""" + return status in _ACTIVE_STATUSES + + +# --------------------------------------------------------------------------- +# Response parser +# --------------------------------------------------------------------------- + + +def parse_subscription(body: dict) -> SubscriptionInfo: + """Parse the ``/rest/subscriptions`` response body. + + Expected format:: + + { + "subscriptions": [ + { + "xaiUserId": "uuid-here", + "tier": "SUBSCRIPTION_TIER_GROK_PRO", + "status": "SUBSCRIPTION_STATUS_ACTIVE", + ... + }, + ... + ] + } + + Returns a ``SubscriptionInfo`` populated from the most recent (last in + list) subscription entry. If the subscriptions array is empty, returns + a ``"basic"`` pool with no account ID. + """ + subs: list[dict] = body.get("subscriptions", []) if body else [] + + if not subs: + return SubscriptionInfo( + xai_user_id=None, + pool="basic", + tier="", + is_active=False, + raw=body, + ) + + # Use the *last* subscription (typically most recent). + last = subs[-1] + + xai_user_id = last.get("xaiUserId") or None + tier = last.get("tier", "") + status = last.get("status", "") + + pool = tier_to_pool(tier) + + # If the most recent subscription is not active but a later entry is, use + # the best active tier we can find. Also collect account ID from any entry. + if not xai_user_id: + for sub in subs: + xai_user_id = sub.get("xaiUserId") or xai_user_id + if xai_user_id: + break + + is_active = False + best_active_pool = pool + for sub in subs: + if is_active_status(sub.get("status", "")): + is_active = True + active_tier = sub.get("tier", "") + active_pool = tier_to_pool(active_tier) + if _pool_rank(active_pool) > _pool_rank(best_active_pool): + best_active_pool = active_pool + + if is_active: + pool = best_active_pool + + # If all subscriptions are inactive (expired/cancelled), the account + # reverts to basic tier on the Grok side, but we keep the historical tier + # for information and let the rate-limits API refine it. + if not is_active and pool != "basic": + logger.debug( + "subscription inactive — may have reverted to basic: tier={} status={}", + tier, status, + ) + + return SubscriptionInfo( + xai_user_id=xai_user_id, + pool=pool, + tier=tier, + is_active=is_active, + raw=body, + ) + + +def _pool_rank(pool: str) -> int: + """Return an ordinal rank for a pool name (higher = better).""" + return {"basic": 0, "super": 1, "heavy": 2}.get(pool, 0) + + +# --------------------------------------------------------------------------- +# HTTP fetch +# --------------------------------------------------------------------------- + + +async def _do_fetch(token: str) -> dict: + """GET the subscriptions endpoint and return parsed JSON body.""" + from app.dataplane.reverse.transport.http import get_json + from app.dataplane.proxy import get_proxy_runtime + from app.control.proxy.models import ProxyFeedback, ProxyFeedbackKind + from app.dataplane.reverse.runtime.endpoint_table import SUBSCRIPTIONS + + proxy = await get_proxy_runtime() + lease = await proxy.acquire() + try: + body = await get_json( + SUBSCRIPTIONS, + token, + lease=lease, + timeout_s=20.0, + ) + await proxy.feedback( + lease, ProxyFeedback(kind=ProxyFeedbackKind.SUCCESS, status_code=200) + ) + return body + except Exception as exc: + status = getattr(exc, "status", None) or getattr(exc, "status_code", None) + from app.dataplane.reverse.protocol.xai_usage import _proxy_feedback_kind_for_error + kind = _proxy_feedback_kind_for_error(exc, status=status) + await proxy.feedback(lease, ProxyFeedback(kind=kind, status_code=status)) + raise + + +async def fetch_subscription(token: str) -> SubscriptionInfo | None: + """Fetch account-type / subscription information for *token*. + + Returns a ``SubscriptionInfo``, or ``None`` if the endpoint is unreachable + (network timeout, 5xx, etc.). + """ + import asyncio + + try: + body = await asyncio.wait_for(_do_fetch(token), timeout=25.0) + except asyncio.TimeoutError: + logger.debug( + "subscription fetch timed out: token={}...", token[:10] + ) + return None + except UpstreamError as exc: + if getattr(exc, "status", None) in (401, 403): + # Invalid token — let caller handle. + raise + logger.debug( + "subscription fetch failed: token={}... status={}", + token[:10], exc.status if hasattr(exc, "status") else "?", + ) + return None + except Exception as exc: + logger.debug( + "subscription fetch error: token={}... error={}", + token[:10], exc, + ) + return None + + return parse_subscription(body) + + +async def fetch_subscription_for_import(token: str) -> SubscriptionInfo: + """Fetch subscription info during account import. + + Always returns a ``SubscriptionInfo`` — falls back to ``"basic"`` with + no account ID when the API is unreachable. + + Raises ``UpstreamError`` for 401/403 (invalid credentials). + """ + result = await fetch_subscription(token) + if result is not None: + return result + + # API unreachable — return a safe default. + return SubscriptionInfo( + xai_user_id=None, + pool="basic", + tier="", + is_active=False, + raw=None, + ) + + +__all__ = [ + "SubscriptionInfo", + "parse_subscription", + "fetch_subscription", + "fetch_subscription_for_import", + "tier_to_pool", + "is_active_status", +] diff --git a/app/dataplane/reverse/runtime/endpoint_table.py b/app/dataplane/reverse/runtime/endpoint_table.py index bb390e503..13da5fc41 100644 --- a/app/dataplane/reverse/runtime/endpoint_table.py +++ b/app/dataplane/reverse/runtime/endpoint_table.py @@ -23,6 +23,11 @@ # ── Rate limits (usage / quota sync) ───────────────────────────────────── RATE_LIMITS = f"{BASE}/rest/rate-limits" # POST +# ── Subscription / account type ───────────────────────────────────────── +SUBSCRIPTIONS = f"{BASE}/rest/subscriptions" # GET +PRODUCTS = f"{BASE}/rest/products" # GET ?provider=SUBSCRIPTION_PROVIDER_STRIPE +MODES = f"{BASE}/rest/modes" # POST + # ── gRPC-Web endpoints ────────────────────────────────────────────────── ACCEPT_TOS = "https://accounts.x.ai/auth_mgmt.AuthManagement/SetTosAcceptedVersion" NSFW_MGMT = f"{BASE}/auth_mgmt.AuthManagement/UpdateUserFeatureControls" @@ -47,7 +52,7 @@ "BASE", "ASSETS_CDN", "CHAT", "ASSETS_UPLOAD", "ASSETS_LIST", "ASSETS_DELETE", "ASSETS_DOWNLOAD", - "RATE_LIMITS", + "RATE_LIMITS", "SUBSCRIPTIONS", "PRODUCTS", "MODES", "ACCEPT_TOS", "NSFW_MGMT", "SET_BIRTH", "MEDIA_POST", "MEDIA_POST_LINK", "VIDEO_UPSCALE", "WS_IMAGINE", "WS_LIVEKIT", "LIVEKIT_TOKENS", diff --git a/app/dataplane/reverse/transport/asset_upload.py b/app/dataplane/reverse/transport/asset_upload.py index 891594424..dc3ce32ab 100644 --- a/app/dataplane/reverse/transport/asset_upload.py +++ b/app/dataplane/reverse/transport/asset_upload.py @@ -184,7 +184,7 @@ async def upload_from_input(token: str, file_input: str) -> tuple[str, str]: if _is_url(file_input): # Fetch the remote URL and re-upload as base64. proxy = await get_proxy_runtime() - lease = await proxy.acquire() + lease = await proxy.acquire(resource=True) try: headers = build_http_headers(token, lease=lease) kwargs = build_session_kwargs(lease=lease) diff --git a/app/dataplane/reverse/transport/assets.py b/app/dataplane/reverse/transport/assets.py index b786447ff..faa7135a0 100644 --- a/app/dataplane/reverse/transport/assets.py +++ b/app/dataplane/reverse/transport/assets.py @@ -194,7 +194,11 @@ async def download_asset( } proxy = await get_proxy_runtime() - lease = await proxy.acquire(scope=ProxyScope.ASSET, kind=RequestKind.HTTP) + lease = await proxy.acquire( + scope=ProxyScope.ASSET, + kind=RequestKind.HTTP, + resource=True, + ) try: stream = await get_bytes_stream( diff --git a/app/platform/config/backends/factory.py b/app/platform/config/backends/factory.py index 258b385eb..eda5e53b5 100644 --- a/app/platform/config/backends/factory.py +++ b/app/platform/config/backends/factory.py @@ -1,4 +1,4 @@ -"""Config backend factory — follows ACCOUNT_STORAGE automatically.""" +"""Config backend factory.""" import os from pathlib import Path @@ -8,19 +8,24 @@ def get_config_backend_name() -> str: - """Return the active config backend name (mirrors ACCOUNT_STORAGE).""" - return os.getenv("ACCOUNT_STORAGE", "local").strip().lower() + """Return the active config backend name. + + Runtime configuration is local by default even when account data is stored + in Redis or SQL. Set CONFIG_STORAGE explicitly to opt into a remote config + backend. + """ + return os.getenv("CONFIG_STORAGE", "local").strip().lower() def create_config_backend() -> ConfigBackend: - """Instantiate the config backend that matches the account storage backend. + """Instantiate the configured runtime config backend. - ``ACCOUNT_STORAGE=local`` → TOML file (``${DATA_DIR}/config.toml``) - ``ACCOUNT_STORAGE=redis`` → Redis (ACCOUNT_REDIS_URL) - ``ACCOUNT_STORAGE=mysql`` → MySQL (ACCOUNT_MYSQL_URL) - ``ACCOUNT_STORAGE=postgresql`` → PostgreSQL (ACCOUNT_POSTGRESQL_URL) + ``CONFIG_STORAGE=local`` → TOML file (``${DATA_DIR}/config.toml``) + ``CONFIG_STORAGE=redis`` → Redis (ACCOUNT_REDIS_URL) + ``CONFIG_STORAGE=mysql`` → MySQL (ACCOUNT_MYSQL_URL) + ``CONFIG_STORAGE=postgresql`` → PostgreSQL (ACCOUNT_POSTGRESQL_URL) - No extra env vars needed — reuses the same connection settings as accounts. + Remote config backends reuse the same connection settings as accounts. """ backend = get_config_backend_name() @@ -31,7 +36,7 @@ def create_config_backend() -> ConfigBackend: if backend in ("mysql", "postgresql"): return _make_sql(backend) - raise ValueError(f"Unknown account storage backend: {backend!r}") + raise ValueError(f"Unknown config storage backend: {backend!r}") def _make_toml() -> ConfigBackend: diff --git a/app/platform/config/snapshot.py b/app/platform/config/snapshot.py index f489f01ad..8b14b4575 100644 --- a/app/platform/config/snapshot.py +++ b/app/platform/config/snapshot.py @@ -22,6 +22,25 @@ def _mtime(path: Path) -> float: return 0.0 +def _apply_legacy_cache_overrides(user_overrides: dict[str, Any]) -> dict[str, Any]: + """Map legacy [storage] cache limits into [cache.local] at load time.""" + storage = user_overrides.get("storage") + if not isinstance(storage, dict): + return user_overrides + + cache = user_overrides.setdefault("cache", {}) + if not isinstance(cache, dict): + return user_overrides + local = cache.setdefault("local", {}) + if not isinstance(local, dict): + return user_overrides + + for key in ("image_max_mb", "video_max_mb"): + if key not in local and key in storage: + local[key] = storage[key] + return user_overrides + + class ConfigSnapshot: """Immutable view over the loaded configuration dict. @@ -74,6 +93,7 @@ async def load(self, defaults_path: Path | None = None) -> None: defaults = await asyncio.to_thread(load_toml, dp) user_overrides = await backend.load() + user_overrides = _apply_legacy_cache_overrides(user_overrides) self._data = _deep_merge(defaults, user_overrides) self._data = _apply_env(self._data) diff --git a/app/platform/startup/migration.py b/app/platform/startup/migration.py index e8cb7fc79..063b2be31 100644 --- a/app/platform/startup/migration.py +++ b/app/platform/startup/migration.py @@ -2,9 +2,10 @@ Config migration ---------------- -local : seeds ``${DATA_DIR}/config.toml`` from ``config.defaults.toml`` if +local : default; seeds ``${DATA_DIR}/config.toml`` from ``config.defaults.toml`` if the file does not exist yet — gives users an editable copy on first run. -redis / sql : if the backend is empty (version == 0) AND +redis / sql : only when CONFIG_STORAGE is explicitly set to a remote backend. + If the backend is empty (version == 0) AND ``${DATA_DIR}/config.toml`` exists, migrates the user overrides into the DB backend. If it does not exist either, nothing is written (defaults are always loaded from ``config.defaults.toml`` at runtime). diff --git a/app/platform/storage/__init__.py b/app/platform/storage/__init__.py index 895f9a58e..c57f78cae 100644 --- a/app/platform/storage/__init__.py +++ b/app/platform/storage/__init__.py @@ -3,6 +3,8 @@ from .media_cache import ( clear_local_media_files, delete_local_media_file, + list_local_media_files, + local_media_stats, reconcile_local_media_cache_async, save_local_image, save_local_video, @@ -13,6 +15,8 @@ "clear_local_media_files", "delete_local_media_file", "image_files_dir", + "list_local_media_files", + "local_media_stats", "reconcile_local_media_cache_async", "save_local_image", "save_local_video", diff --git a/app/platform/storage/media_cache.py b/app/platform/storage/media_cache.py index f91b06d3f..171cf2466 100644 --- a/app/platform/storage/media_cache.py +++ b/app/platform/storage/media_cache.py @@ -88,6 +88,57 @@ def reconcile(self, media_type: MediaType) -> None: self._enforce_limit_locked(conn, media_type) conn.commit() + def limit_mb(self, media_type: MediaType) -> int: + """Return the configured per-type cache limit in MB.""" + cfg = self._config_provider() + limit_mb = max(0, int(cfg.get_int(f"cache.local.{media_type}_max_mb", 0))) + return limit_mb + + def stats(self, media_type: MediaType) -> dict[str, Any]: + """Return count, size, and configured limit for one media type.""" + files = list(self._iter_files(media_type)) + total_size = sum(path.stat().st_size for path in files) + limit_mb = self.limit_mb(media_type) + limit_bytes = limit_mb * 1024 * 1024 + usage_ratio = (total_size / limit_bytes) if limit_bytes > 0 else None + usage_percent = round(usage_ratio * 100, 1) if usage_ratio is not None else None + return { + "count": len(files), + "size_mb": round(total_size / 1024 / 1024, 2), + "size_bytes": total_size, + "limit_mb": limit_mb, + "limit_bytes": limit_bytes, + "limited": limit_bytes > 0, + "usage_ratio": round(usage_ratio, 4) if usage_ratio is not None else None, + "usage_percent": usage_percent, + } + + def list_files( + self, + media_type: MediaType, + *, + page: int = 1, + page_size: int = 1000, + ) -> dict[str, Any]: + """Return paginated local media files newest first.""" + files = sorted( + self._iter_files(media_type), + key=lambda path: path.stat().st_mtime, + reverse=True, + ) + total = len(files) + start = max(0, page - 1) * page_size + chunk = files[start : start + page_size] + items = [] + for path in chunk: + stat = path.stat() + items.append({ + "name": path.name, + "size_bytes": stat.st_size, + "modified_at": stat.st_mtime, + }) + return {"total": total, "page": page, "page_size": page_size, "items": items} + def delete(self, media_type: MediaType, name: str) -> bool: """Delete a single local media file and keep the index consistent.""" safe_name = self._validate_name(media_type, name) @@ -138,9 +189,7 @@ def _save( return path def _limit_bytes(self, media_type: MediaType) -> int: - cfg = self._config_provider() - limit_mb = max(0, int(cfg.get_int(f"cache.local.{media_type}_max_mb", 0))) - return limit_mb * 1024 * 1024 + return self.limit_mb(media_type) * 1024 * 1024 def _target_bytes(self, max_bytes: int) -> int: return max(0, int(max_bytes * _LOW_WATERMARK_RATIO)) @@ -479,6 +528,21 @@ def delete_local_media_file(media_type: MediaType, name: str) -> bool: return local_media_cache.delete(media_type, name) +def local_media_stats(media_type: MediaType) -> dict[str, Any]: + """Return count, size, and configured limit for local media files.""" + return local_media_cache.stats(media_type) + + +def list_local_media_files( + media_type: MediaType, + *, + page: int = 1, + page_size: int = 1000, +) -> dict[str, Any]: + """Return paginated local media files newest first.""" + return local_media_cache.list_files(media_type, page=page, page_size=page_size) + + async def reconcile_local_media_cache_async( media_type: MediaType | None = None, ) -> None: @@ -494,7 +558,9 @@ async def reconcile_local_media_cache_async( "LocalMediaCacheStore", "clear_local_media_files", "delete_local_media_file", + "list_local_media_files", "local_media_cache", + "local_media_stats", "reconcile_local_media_cache_async", "save_local_image", "save_local_video", diff --git a/app/products/anthropic/messages.py b/app/products/anthropic/messages.py index 42ba98a17..be01b35e1 100644 --- a/app/products/anthropic/messages.py +++ b/app/products/anthropic/messages.py @@ -31,7 +31,8 @@ from app.products.openai.chat import ( _stream_chat, _extract_message, _resolve_image, _quota_sync, _fail_sync, _parse_retry_codes, _feedback_kind, _log_task_exception, - _configured_retry_codes, _should_retry_upstream, + _adapter_has_visible_output, _configured_retry_codes, + _empty_upstream_response_error, _should_retry_upstream, _StreamStartGate, ) from app.products._account_selection import reserve_account, selection_max_retries from app.products.openai._tool_sieve import ToolSieve @@ -345,11 +346,12 @@ async def _run_stream() -> AsyncGenerator[str, None]: tool_output_tokens = 0 block_index = 0 # tracks next content_block index collected_annotations: list[dict] = [] + gate = _StreamStartGate() try: try: # message_start - yield _sse("message_start", { + for out in gate.emit(_sse("message_start", { "type": "message_start", "message": { "id": msg_id, @@ -360,8 +362,10 @@ async def _run_stream() -> AsyncGenerator[str, None]: "stop_reason": None, "usage": {"input_tokens": estimate_prompt_tokens(internal_message), "output_tokens": 0}, }, - }) - yield _sse("ping", {"type": "ping"}) + })): + yield out + for out in gate.emit(_sse("ping", {"type": "ping"})): + yield out ended = False async for line in _stream_chat( @@ -385,35 +389,38 @@ async def _run_stream() -> AsyncGenerator[str, None]: if ev.kind == "thinking" and emit_think and not think_closed: if not think_started: think_started = True - yield _sse("content_block_start", { + for out in gate.emit(_sse("content_block_start", { "type": "content_block_start", "index": block_index, "content_block": {"type": "thinking", "thinking": ""}, - }) + })): + yield out think_buf.append(ev.content) - yield _sse("content_block_delta", { + for out in gate.emit(_sse("content_block_delta", { "type": "content_block_delta", "index": block_index, "delta": {"type": "thinking_delta", "thinking": ev.content}, - }) + })): + yield out elif ev.kind == "text": # Close thinking block if open if think_started and not think_closed: think_closed = True - yield _sse("content_block_stop", { + for out in gate.emit(_sse("content_block_stop", { "type": "content_block_stop", "index": block_index, - }) + })): + yield out block_index += 1 # Feed through ToolSieve if tools active if sieve is not None: safe_text, calls = sieve.feed(ev.content) - if calls is not None: + if calls: # Emit tool_use blocks for call in calls: - yield _sse("content_block_start", { + for out in gate.emit(_sse("content_block_start", { "type": "content_block_start", "index": block_index, "content_block": { @@ -422,19 +429,22 @@ async def _run_stream() -> AsyncGenerator[str, None]: "name": call.name, "input": {}, }, - }) - yield _sse("content_block_delta", { + }), visible=True): + yield out + for out in gate.emit(_sse("content_block_delta", { "type": "content_block_delta", "index": block_index, "delta": { "type": "input_json_delta", "partial_json": call.arguments, }, - }) - yield _sse("content_block_stop", { + })): + yield out + for out in gate.emit(_sse("content_block_stop", { "type": "content_block_stop", "index": block_index, - }) + })): + yield out block_index += 1 tool_output_tokens = estimate_tool_call_tokens(calls) tool_calls_emitted = True @@ -447,17 +457,19 @@ async def _run_stream() -> AsyncGenerator[str, None]: if text_chunk: if not text_started: text_started = True - yield _sse("content_block_start", { + for out in gate.emit(_sse("content_block_start", { "type": "content_block_start", "index": block_index, "content_block": {"type": "text", "text": ""}, - }) + })): + yield out text_buf.append(text_chunk) - yield _sse("content_block_delta", { + for out in gate.emit(_sse("content_block_delta", { "type": "content_block_delta", "index": block_index, "delta": {"type": "text_delta", "text": text_chunk}, - }) + }), visible=bool(text_chunk.strip())): + yield out elif ev.kind == "annotation" and ev.annotation_data: collected_annotations.append(ev.annotation_data) @@ -475,14 +487,15 @@ async def _run_stream() -> AsyncGenerator[str, None]: if calls: # Close text block if open if text_started: - yield _sse("content_block_stop", { + for out in gate.emit(_sse("content_block_stop", { "type": "content_block_stop", "index": block_index, - }) + })): + yield out block_index += 1 text_started = False for call in calls: - yield _sse("content_block_start", { + for out in gate.emit(_sse("content_block_start", { "type": "content_block_start", "index": block_index, "content_block": { @@ -491,19 +504,22 @@ async def _run_stream() -> AsyncGenerator[str, None]: "name": call.name, "input": {}, }, - }) - yield _sse("content_block_delta", { + }), visible=True): + yield out + for out in gate.emit(_sse("content_block_delta", { "type": "content_block_delta", "index": block_index, "delta": { "type": "input_json_delta", "partial_json": call.arguments, }, - }) - yield _sse("content_block_stop", { + })): + yield out + for out in gate.emit(_sse("content_block_stop", { "type": "content_block_stop", "index": block_index, - }) + })): + yield out block_index += 1 tool_output_tokens = estimate_tool_call_tokens(calls) tool_calls_emitted = True @@ -514,13 +530,16 @@ async def _run_stream() -> AsyncGenerator[str, None]: sources = adapter.search_sources_list() if sources: tool_delta["search_sources"] = sources - yield _sse("message_delta", { + for out in gate.emit(_sse("message_delta", { "type": "message_delta", "delta": tool_delta, "usage": {"output_tokens": tool_output_tokens}, - }) - yield _sse("message_stop", {"type": "message_stop"}) - yield "data: [DONE]\n\n" + }), visible=True): + yield out + for out in gate.emit(_sse("message_stop", {"type": "message_stop"})): + yield out + for out in gate.emit("data: [DONE]\n\n"): + yield out success = True logger.info("messages stream tool_calls: attempt={}/{} model={}", attempt + 1, max_retries + 1, model) @@ -532,37 +551,43 @@ async def _run_stream() -> AsyncGenerator[str, None]: chunk = img_text + "\n" text_buf.append(chunk) if text_started: - yield _sse("content_block_delta", { + for out in gate.emit(_sse("content_block_delta", { "type": "content_block_delta", "index": block_index, "delta": {"type": "text_delta", "text": chunk}, - }) + }), visible=bool(img_text.strip())): + yield out references = adapter.references_suffix() if references: text_buf.append(references) if text_started: - yield _sse("content_block_delta", { + for out in gate.emit(_sse("content_block_delta", { "type": "content_block_delta", "index": block_index, "delta": {"type": "text_delta", "text": references}, - }) + }), visible=True): + yield out # Close open blocks if think_started and not think_closed: - yield _sse("content_block_stop", { + for out in gate.emit(_sse("content_block_stop", { "type": "content_block_stop", "index": block_index, - }) + })): + yield out block_index += 1 if text_started: - yield _sse("content_block_stop", { + for out in gate.emit(_sse("content_block_stop", { "type": "content_block_stop", "index": block_index, - }) + })): + yield out full_text = "".join(text_buf) + if not _adapter_has_visible_output(adapter, extra_text=full_text): + raise _empty_upstream_response_error() full_think = "".join(think_buf) out_tokens = estimate_tokens(full_text) if full_think: @@ -575,13 +600,16 @@ async def _run_stream() -> AsyncGenerator[str, None]: msg_delta["search_sources"] = sources if collected_annotations: msg_delta["annotations"] = collected_annotations - yield _sse("message_delta", { + for out in gate.emit(_sse("message_delta", { "type": "message_delta", "delta": msg_delta, "usage": {"output_tokens": out_tokens}, - }) - yield _sse("message_stop", {"type": "message_stop"}) - yield "data: [DONE]\n\n" + }), visible=True): + yield out + for out in gate.emit(_sse("message_stop", {"type": "message_stop"})): + yield out + for out in gate.emit("data: [DONE]\n\n"): + yield out success = True logger.info( "messages stream completed: attempt={}/{} model={} text_len={} think_len={} images={}", @@ -664,6 +692,8 @@ async def _run_stream() -> AsyncGenerator[str, None]: break if ended: break + if not _adapter_has_visible_output(adapter): + raise _empty_upstream_response_error() success = True except UpstreamError as exc: diff --git a/app/products/anthropic/router.py b/app/products/anthropic/router.py index 34b1f2126..cacfc3575 100644 --- a/app/products/anthropic/router.py +++ b/app/products/anthropic/router.py @@ -1,6 +1,6 @@ """Anthropic Messages API router (/v1/messages).""" -from typing import Any +from typing import Any, AsyncGenerator, AsyncIterable import orjson from fastapi import APIRouter, Depends, Request @@ -8,7 +8,7 @@ from pydantic import BaseModel, Field from app.platform.auth.middleware import verify_api_key -from app.platform.errors import AppError, ValidationError +from app.platform.errors import AppError, UpstreamError, ValidationError from app.platform.logging.logger import logger from app.control.model import registry as model_registry @@ -72,6 +72,26 @@ async def _safe_sse_anthropic(stream): yield "data: [DONE]\n\n" +async def _prime_sse(stream: AsyncIterable[str]) -> AsyncGenerator[str, None]: + """Read the first chunk before StreamingResponse sends HTTP 200.""" + iterator = stream.__aiter__() + try: + first = await anext(iterator) + except StopAsyncIteration as exc: + raise UpstreamError( + "Upstream returned empty response", + status=429, + body="empty_response", + ) from exc + + async def _primed() -> AsyncGenerator[str, None]: + yield first + async for chunk in iterator: + yield chunk + + return _primed() + + # --------------------------------------------------------------------------- # /v1/messages # --------------------------------------------------------------------------- @@ -118,6 +138,7 @@ async def messages_endpoint(req: MessagesRequest): if isinstance(result, dict): return JSONResponse(result) + result = await _prime_sse(result) return StreamingResponse( _safe_sse_anthropic(result), media_type = "text/event-stream", diff --git a/app/products/openai/__init__.py b/app/products/openai/__init__.py index 5bc0c2e6d..c3904ab36 100644 --- a/app/products/openai/__init__.py +++ b/app/products/openai/__init__.py @@ -1,3 +1,10 @@ -from .router import router +"""OpenAI product package exports.""" __all__ = ["router"] + + +def __getattr__(name: str): + if name == "router": + from .router import router + return router + raise AttributeError(name) diff --git a/app/products/openai/chat.py b/app/products/openai/chat.py index 7a551b13e..b87f35390 100644 --- a/app/products/openai/chat.py +++ b/app/products/openai/chat.py @@ -192,6 +192,64 @@ def _feedback_kind(exc: BaseException) -> "FeedbackKind": return feedback_kind_for_error(exc) +EMPTY_UPSTREAM_BODY = "empty_response" + + +def _empty_upstream_response_error() -> UpstreamError: + return UpstreamError( + "Upstream returned empty response", + status=429, + body=EMPTY_UPSTREAM_BODY, + ) + + +def _adapter_has_visible_output( + adapter: StreamAdapter, + *, + extra_text: str = "", + has_tool_calls: bool = False, + include_thinking: bool = False, +) -> bool: + """Return True when the adapter has user-visible response content. + + Thinking/reasoning buffers intentionally do not count here. + """ + if has_tool_calls: + return True + if include_thinking and "".join(adapter.thinking_buf).strip(): + return True + if (extra_text or "".join(adapter.text_buf)).strip(): + return True + if adapter.image_urls: + return True + if adapter.references_suffix().strip(): + return True + if adapter.search_sources_list(): + return True + return False + + +class _StreamStartGate: + """Buffer stream events until a visible payload is available.""" + + __slots__ = ("_pending", "visible") + + def __init__(self) -> None: + self._pending: list[str] = [] + self.visible = False + + def emit(self, chunk: str, *, visible: bool = False) -> list[str]: + if visible and not self.visible: + self.visible = True + out = self._pending + [chunk] + self._pending = [] + return out + if self.visible: + return [chunk] + self._pending.append(chunk) + return [] + + async def _download_image_bytes(token: str, url: str) -> tuple[bytes, str]: """Download image bytes via the shared asset transport used by /v1/images.""" from app.dataplane.reverse.protocol.xai_assets import infer_content_type @@ -527,6 +585,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: ended = False sieve = ToolSieve(tool_names) tool_calls_emitted = False + gate = _StreamStartGate() async for line in _stream_chat( token=token, mode_id=ModeId(selected_mode_id), @@ -552,8 +611,15 @@ async def _run_stream() -> AsyncGenerator[str, None]: chunk = make_stream_chunk( response_id, model, safe_text ) - yield f"data: {orjson.dumps(chunk).decode()}\n\n" - if parsed_calls is not None: + payload = f"data: {orjson.dumps(chunk).decode()}\n\n" + for out in gate.emit( + payload, visible=bool(safe_text.strip()) + ): + yield out + if parsed_calls: + for out in gate.emit("", visible=True): + if out: + yield out for i, tc in enumerate(parsed_calls): chunk = make_tool_call_chunk( response_id, @@ -585,12 +651,18 @@ async def _run_stream() -> AsyncGenerator[str, None]: chunk = make_stream_chunk( response_id, model, ev.content ) - yield f"data: {orjson.dumps(chunk).decode()}\n\n" + payload = f"data: {orjson.dumps(chunk).decode()}\n\n" + for out in gate.emit( + payload, visible=bool(ev.content.strip()) + ): + yield out elif ev.kind == "thinking" and emit_think: chunk = make_thinking_chunk( response_id, model, ev.content ) - yield f"data: {orjson.dumps(chunk).decode()}\n\n" + payload = f"data: {orjson.dumps(chunk).decode()}\n\n" + for out in gate.emit(payload, visible=True): + yield out elif ev.kind == "annotation" and ev.annotation_data: collected_annotations.append(ev.annotation_data) elif ev.kind == "soft_stop": @@ -603,6 +675,9 @@ async def _run_stream() -> AsyncGenerator[str, None]: # Stream ended — flush sieve for any buffered XML flushed_calls = sieve.flush() if flushed_calls: + for out in gate.emit("", visible=True): + if out: + yield out for i, tc in enumerate(flushed_calls): chunk = make_tool_call_chunk( response_id, @@ -637,14 +712,27 @@ async def _run_stream() -> AsyncGenerator[str, None]: chunk = make_stream_chunk( response_id, model, img_text + "\n" ) - yield f"data: {orjson.dumps(chunk).decode()}\n\n" + payload = f"data: {orjson.dumps(chunk).decode()}\n\n" + for out in gate.emit( + payload, visible=bool(img_text.strip()) + ): + yield out references = adapter.references_suffix() if references: chunk = make_stream_chunk( response_id, model, references ) - yield f"data: {orjson.dumps(chunk).decode()}\n\n" + payload = f"data: {orjson.dumps(chunk).decode()}\n\n" + for out in gate.emit(payload, visible=True): + yield out + + if not _adapter_has_visible_output( + adapter, + has_tool_calls=tool_calls_emitted, + include_thinking=emit_think, + ): + raise _empty_upstream_response_error() chat_anns = _to_chat_annotations(collected_annotations) final = make_stream_chunk( @@ -658,8 +746,11 @@ async def _run_stream() -> AsyncGenerator[str, None]: sources = adapter.search_sources_list() if sources: final["search_sources"] = sources - yield f"data: {orjson.dumps(final).decode()}\n\n" - yield "data: [DONE]\n\n" + final_payload = f"data: {orjson.dumps(final).decode()}\n\n" + for out in gate.emit(final_payload, visible=True): + yield out + for out in gate.emit("data: [DONE]\n\n"): + yield out success = True logger.info( "chat stream completed: attempt={}/{} model={} image_count={}", @@ -764,6 +855,8 @@ async def _run_stream() -> AsyncGenerator[str, None]: break if ended: break + if not _adapter_has_visible_output(adapter): + raise _empty_upstream_response_error() success = True except UpstreamError as exc: @@ -884,6 +977,10 @@ async def _run_stream() -> AsyncGenerator[str, None]: __all__ = [ "completions", + "EMPTY_UPSTREAM_BODY", + "_adapter_has_visible_output", "_configured_retry_codes", + "_empty_upstream_response_error", "_should_retry_upstream", + "_StreamStartGate", ] diff --git a/app/products/openai/responses.py b/app/products/openai/responses.py index d816c7a92..9b94b881e 100644 --- a/app/products/openai/responses.py +++ b/app/products/openai/responses.py @@ -7,8 +7,6 @@ import asyncio from typing import Any, AsyncGenerator -import orjson - from app.platform.logging.logger import logger from app.platform.config.snapshot import get_config from app.platform.errors import RateLimitError, UpstreamError @@ -20,13 +18,14 @@ from app.dataplane.reverse.protocol.xai_chat import classify_line, StreamAdapter from app.products._account_selection import reserve_account, selection_max_retries -from .chat import _stream_chat, _extract_message, _resolve_image, _quota_sync, _fail_sync, _parse_retry_codes, _feedback_kind, _log_task_exception, _upstream_body_excerpt +from .chat import _stream_chat, _extract_message, _resolve_image, _quota_sync, _fail_sync, _feedback_kind, _log_task_exception, _upstream_body_excerpt +from .chat import _adapter_has_visible_output, _empty_upstream_response_error, _StreamStartGate from .chat import _configured_retry_codes, _should_retry_upstream from ._format import ( make_resp_id, build_resp_usage, make_resp_object, format_sse, ) from app.dataplane.reverse.protocol.tool_prompt import ( - build_tool_system_prompt, extract_tool_names, inject_into_message, tool_calls_to_xml, + build_tool_system_prompt, extract_tool_names, inject_into_message, ) from app.dataplane.reverse.protocol.tool_parser import parse_tool_calls from ._tool_sieve import ToolSieve @@ -221,7 +220,6 @@ async def create( cfg = get_config() spec = resolve_model(model) - mode_id = int(spec.mode_id) # cast once, reuse everywhere messages: list[dict] = [] if instructions: @@ -283,13 +281,15 @@ async def _run_stream() -> AsyncGenerator[str, None]: tool_calls_emitted = False detected_fc_items: list[dict] = [] collected_annotations: list[dict] = [] + gate = _StreamStartGate() try: try: - yield format_sse("response.created", { + for out in gate.emit(format_sse("response.created", { "type": "response.created", "response": make_resp_object(response_id, model, "in_progress", []), - }) + })): + yield out ended = False async for line in _stream_chat( @@ -313,49 +313,54 @@ async def _run_stream() -> AsyncGenerator[str, None]: if ev.kind == "thinking" and emit_think and not reasoning_closed: if not reasoning_started: reasoning_started = True - yield format_sse("response.output_item.added", { + for out in gate.emit(format_sse("response.output_item.added", { "type": "response.output_item.added", "output_index": 0, "item": { "id": reasoning_id, "type": "reasoning", "summary": [], "status": "in_progress", }, - }) - yield format_sse("response.reasoning_summary_part.added", { + }), visible=True): + yield out + for out in gate.emit(format_sse("response.reasoning_summary_part.added", { "type": "response.reasoning_summary_part.added", "item_id": reasoning_id, "output_index": 0, "summary_index": 0, "part": {"type": "summary_text", "text": ""}, - }) + }), visible=True): + yield out think_buf.append(ev.content) - yield format_sse("response.reasoning_summary_text.delta", { + for out in gate.emit(format_sse("response.reasoning_summary_text.delta", { "type": "response.reasoning_summary_text.delta", "item_id": reasoning_id, "output_index": 0, "summary_index": 0, "delta": ev.content, - }) + }), visible=True): + yield out elif ev.kind == "text": if reasoning_started and not reasoning_closed: reasoning_closed = True full_think = "".join(think_buf) - yield format_sse("response.reasoning_summary_text.done", { + for out in gate.emit(format_sse("response.reasoning_summary_text.done", { "type": "response.reasoning_summary_text.done", "item_id": reasoning_id, "output_index": 0, "summary_index": 0, "text": full_think, - }) - yield format_sse("response.reasoning_summary_part.done", { + })): + yield out + for out in gate.emit(format_sse("response.reasoning_summary_part.done", { "type": "response.reasoning_summary_part.done", "item_id": reasoning_id, "output_index": 0, "summary_index": 0, "part": {"type": "summary_text", "text": full_think}, - }) - yield format_sse("response.output_item.done", { + })): + yield out + for out in gate.emit(format_sse("response.output_item.done", { "type": "response.output_item.done", "output_index": 0, "item": { @@ -364,17 +369,19 @@ async def _run_stream() -> AsyncGenerator[str, None]: "summary": [{"type": "summary_text", "text": full_think}], "status": "completed", }, - }) + })): + yield out # Feed through ToolSieve if tools are active if sieve is not None: safe_text, calls = sieve.feed(ev.content) - if calls is not None: + if calls: fc_items = _build_fc_items(calls) detected_fc_items = fc_items base_idx = 1 if reasoning_started else 0 async for evt in _emit_fc_events(fc_items, base_idx): - yield evt + for out in gate.emit(evt, visible=True): + yield out tool_calls_emitted = True ended = True break @@ -386,43 +393,47 @@ async def _run_stream() -> AsyncGenerator[str, None]: msg_idx = 1 if reasoning_started else 0 if not message_started: message_started = True - yield format_sse("response.output_item.added", { + for out in gate.emit(format_sse("response.output_item.added", { "type": "response.output_item.added", "output_index": msg_idx, "item": { "id": message_id, "type": "message", "role": "assistant", "content": [], "status": "in_progress", }, - }) - yield format_sse("response.content_part.added", { + })): + yield out + for out in gate.emit(format_sse("response.content_part.added", { "type": "response.content_part.added", "item_id": message_id, "output_index": msg_idx, "content_index": 0, "part": {"type": "output_text", "text": "", "annotations": []}, - }) + })): + yield out text_buf.append(text_chunk) - yield format_sse("response.output_text.delta", { + for out in gate.emit(format_sse("response.output_text.delta", { "type": "response.output_text.delta", "item_id": message_id, "output_index": msg_idx, "content_index": 0, "delta": text_chunk, - }) + }), visible=bool(text_chunk.strip())): + yield out elif ev.kind == "annotation" and ev.annotation_data: if message_started: collected_annotations.append(ev.annotation_data) msg_idx = 1 if reasoning_started else 0 - yield format_sse("response.output_text.annotation.added", { + for out in gate.emit(format_sse("response.output_text.annotation.added", { "type": "response.output_text.annotation.added", "item_id": message_id, "output_index": msg_idx, "content_index": 0, "annotation_index": len(collected_annotations) - 1, "annotation": ev.annotation_data, - }) + })): + yield out elif ev.kind == "soft_stop": ended = True @@ -439,7 +450,8 @@ async def _run_stream() -> AsyncGenerator[str, None]: detected_fc_items = fc_items base_idx = 1 if reasoning_started else 0 async for evt in _emit_fc_events(fc_items, base_idx): - yield evt + for out in gate.emit(evt, visible=True): + yield out tool_calls_emitted = True if tool_calls_emitted: @@ -457,14 +469,16 @@ async def _run_stream() -> AsyncGenerator[str, None]: pt = estimate_prompt_tokens(message) ct = estimate_tool_call_tokens(detected_fc_items) rt = estimate_tokens(full_think) if full_think else 0 - yield format_sse("response.completed", { + for out in gate.emit(format_sse("response.completed", { "type": "response.completed", "response": make_resp_object( response_id, model, "completed", output, build_resp_usage(pt, ct + rt, rt), ), - }) - yield "data: [DONE]\n\n" + }), visible=True): + yield out + for out in gate.emit("data: [DONE]\n\n"): + yield out success = True logger.info("responses stream tool_calls: attempt={}/{} model={}", attempt + 1, max_retries + 1, model) @@ -476,42 +490,52 @@ async def _run_stream() -> AsyncGenerator[str, None]: img_md = img_text + "\n" text_buf.append(img_md) if message_started: - yield format_sse("response.output_text.delta", { + for out in gate.emit(format_sse("response.output_text.delta", { "type": "response.output_text.delta", "item_id": message_id, "output_index": msg_idx, "content_index": 0, "delta": img_md, - }) + }), visible=bool(img_text.strip())): + yield out references = adapter.references_suffix() if references: text_buf.append(references) if message_started: - yield format_sse("response.output_text.delta", { + for out in gate.emit(format_sse("response.output_text.delta", { "type": "response.output_text.delta", "item_id": message_id, "output_index": msg_idx, "content_index": 0, "delta": references, - }) + }), visible=True): + yield out full_text = "".join(text_buf) + if not _adapter_has_visible_output( + adapter, + extra_text=full_text, + include_thinking=emit_think, + ): + raise _empty_upstream_response_error() if message_started: - yield format_sse("response.output_text.done", { + for out in gate.emit(format_sse("response.output_text.done", { "type": "response.output_text.done", "item_id": message_id, "output_index": msg_idx, "content_index": 0, "text": full_text, - }) - yield format_sse("response.content_part.done", { + })): + yield out + for out in gate.emit(format_sse("response.content_part.done", { "type": "response.content_part.done", "item_id": message_id, "output_index": msg_idx, "content_index": 0, "part": {"type": "output_text", "text": full_text, "annotations": collected_annotations}, - }) + })): + yield out # 构建 message item(流式 output_item.done + response.completed 共用) sources = adapter.search_sources_list() msg_item: dict = { @@ -523,11 +547,12 @@ async def _run_stream() -> AsyncGenerator[str, None]: } if sources: msg_item["search_sources"] = sources - yield format_sse("response.output_item.done", { + for out in gate.emit(format_sse("response.output_item.done", { "type": "response.output_item.done", "output_index": msg_idx, "item": msg_item, - }) + })): + yield out full_think = "".join(think_buf) output = [] @@ -550,19 +575,22 @@ async def _run_stream() -> AsyncGenerator[str, None]: sources = adapter.search_sources_list() if sources: msg_item["search_sources"] = sources - output.append(msg_item) + if message_started or full_text: + output.append(msg_item) pt = estimate_prompt_tokens(message) ct = estimate_tokens(full_text) rt = estimate_tokens(full_think) if full_think else 0 - yield format_sse("response.completed", { + for out in gate.emit(format_sse("response.completed", { "type": "response.completed", "response": make_resp_object( response_id, model, "completed", output, build_resp_usage(pt, ct + rt, rt), ), - }) - yield "data: [DONE]\n\n" + }), visible=True): + yield out + for out in gate.emit("data: [DONE]\n\n"): + yield out success = True logger.info("responses stream completed: attempt={}/{} model={} text_len={} reasoning_len={} image_count={}", attempt + 1, max_retries + 1, model, @@ -644,6 +672,8 @@ async def _run_stream() -> AsyncGenerator[str, None]: break if ended: break + if not _adapter_has_visible_output(adapter): + raise _empty_upstream_response_error() success = True except UpstreamError as exc: diff --git a/app/products/openai/router.py b/app/products/openai/router.py index 01a27504a..731654dff 100644 --- a/app/products/openai/router.py +++ b/app/products/openai/router.py @@ -3,15 +3,16 @@ import base64 import binascii import mimetypes -from typing import Annotated, AsyncGenerator, AsyncIterable, Literal +from typing import Annotated, AsyncGenerator, AsyncIterable, Literal, TypeVar import orjson from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile from fastapi.responses import JSONResponse, StreamingResponse, FileResponse +from pydantic import ValidationError as PydanticValidationError from app.control.account.state_machine import is_manageable from app.platform.auth.middleware import verify_api_key -from app.platform.errors import AppError, ValidationError +from app.platform.errors import AppError, UpstreamError, ValidationError from app.platform.logging.logger import logger from app.platform.storage import image_files_dir, video_files_dir from app.control.model import registry as model_registry @@ -131,6 +132,26 @@ async def _safe_sse(stream: AsyncIterable[str]) -> AsyncGenerator[str, None]: yield "data: [DONE]\n\n" +async def _prime_sse(stream: AsyncIterable[str]) -> AsyncGenerator[str, None]: + """Read the first chunk before StreamingResponse sends HTTP 200.""" + iterator = stream.__aiter__() + try: + first = await anext(iterator) + except StopAsyncIteration as exc: + raise UpstreamError( + "Upstream returned empty response", + status=429, + body="empty_response", + ) from exc + + async def _primed() -> AsyncGenerator[str, None]: + yield first + async for chunk in iterator: + yield chunk + + return _primed() + + _SSE_HEADERS = {"Cache-Control": "no-cache", "Connection": "keep-alive"} @@ -143,6 +164,48 @@ async def _safe_sse(stream: AsyncIterable[str]) -> AsyncGenerator[str, None]: _ALLOWED_SIZES = {"1280x720", "720x1280", "1792x1024", "1024x1792", "1024x1024"} _EFFORT_VALUES = {"none", "minimal", "low", "medium", "high", "xhigh"} _LITE_IMAGE_MODELS = {"grok-imagine-image-lite"} +_ConfigT = TypeVar("_ConfigT", ImageConfig, VideoConfig) + + +def _metadata_config( + req: ChatCompletionRequest, + key: str, + model_cls: type[_ConfigT], +) -> _ConfigT | None: + metadata = req.metadata + if not isinstance(metadata, dict) or key not in metadata: + return None + value = metadata.get(key) + if value is None: + return None + if not isinstance(value, dict): + raise ValidationError( + f"metadata.{key} must be an object", + param=f"metadata.{key}", + ) + try: + return model_cls.model_validate(value) + except PydanticValidationError as exc: + raise ValidationError( + f"metadata.{key} is invalid: {exc.errors()[0].get('msg', 'invalid value')}", + param=f"metadata.{key}", + ) from exc + + +def _resolve_image_config(req: ChatCompletionRequest) -> ImageConfig: + return ( + req.image_config + or _metadata_config(req, "image_config", ImageConfig) + or ImageConfig() + ) + + +def _resolve_video_config(req: ChatCompletionRequest) -> VideoConfig: + return ( + req.video_config + or _metadata_config(req, "video_config", VideoConfig) + or VideoConfig() + ) def _validate_chat(req: ChatCompletionRequest) -> None: @@ -236,7 +299,7 @@ async def chat_completions_endpoint(req: ChatCompletionRequest): if spec.is_image_edit(): from .images import edit as img_edit - cfg = req.image_config or ImageConfig() + cfg = _resolve_image_config(req) _validate_image_edit_n(cfg.n or 1, param="image_config.n") result = await img_edit( model=req.model, @@ -251,7 +314,7 @@ async def chat_completions_endpoint(req: ChatCompletionRequest): elif spec.is_image(): from .images import generate as img_gen - cfg = req.image_config or ImageConfig() + cfg = _resolve_image_config(req) size = cfg.size or "1024x1024" fmt = cfg.response_format or "url" n = cfg.n or 1 @@ -280,7 +343,7 @@ async def chat_completions_endpoint(req: ChatCompletionRequest): elif spec.is_video(): from .video import completions as vid_comp - vcfg = req.video_config or VideoConfig() + vcfg = _resolve_video_config(req) from .video import validate_video_length as _validate_video_length _validate_video_length(vcfg.seconds or 6) @@ -339,6 +402,8 @@ async def _err_stream(): if isinstance(result, dict): return JSONResponse(result) + if not (spec.is_image_edit() or spec.is_image()): + result = await _prime_sse(result) return StreamingResponse( _safe_sse(result), media_type="text/event-stream", headers=_SSE_HEADERS ) @@ -418,6 +483,7 @@ async def responses_endpoint(req: ResponsesCreateRequest): if isinstance(result, dict): return JSONResponse(result) + result = await _prime_sse(result) return StreamingResponse( _safe_sse_responses(result), media_type = "text/event-stream", @@ -480,9 +546,14 @@ async def videos_create( references_payload = None if input_reference: + if len(input_reference) > 7: + raise ValidationError( + "Video generation supports at most 7 reference images", + param="input_reference", + ) references_payload = [ {"image_url": await _upload_to_data_uri(f, param="input_reference")} - for f in input_reference[:7] + for f in input_reference ] result = await create_video( diff --git a/app/products/openai/schemas.py b/app/products/openai/schemas.py index f4caa0fee..a42ff1c2e 100644 --- a/app/products/openai/schemas.py +++ b/app/products/openai/schemas.py @@ -39,6 +39,7 @@ class ChatCompletionRequest(BaseModel): tool_choice: str | dict[str, Any] | None = None parallel_tool_calls: bool | None = True max_tokens: int | None = None + metadata: dict[str, Any] | None = None class ImageGenerationRequest(BaseModel): diff --git a/app/products/openai/video.py b/app/products/openai/video.py index f4e75baca..1240d0d7b 100644 --- a/app/products/openai/video.py +++ b/app/products/openai/video.py @@ -29,8 +29,10 @@ from app.platform.runtime.clock import now_s from app.platform.storage import save_local_video from app.control.account.enums import FeedbackKind +from app.control.account.runtime import get_refresh_service from app.control.model import registry as model_registry from app.control.model.registry import resolve as resolve_model +from app.dataplane.account.selector import current_strategy from app.dataplane.proxy import get_proxy_runtime from app.dataplane.proxy.adapters.headers import build_http_headers from app.dataplane.proxy.adapters.session import ResettableSession, build_session_kwargs @@ -52,7 +54,16 @@ make_stream_chunk, make_thinking_chunk, ) -from .chat import _fail_sync, _quota_sync, _feedback_kind +from .chat import ( + EMPTY_UPSTREAM_BODY, + _configured_retry_codes, + _fail_sync, + _feedback_kind, + _log_task_exception, + _quota_sync, + _should_retry_upstream, +) +from app.products._account_selection import selection_max_retries _IMAGE_MEDIA_TYPE = "MEDIA_POST_TYPE_IMAGE" _VIDEO_MEDIA_TYPE = "MEDIA_POST_TYPE_VIDEO" @@ -61,7 +72,23 @@ _VIDEO_OBJECT = "video" _VIDEO_JOB_TTL_S = 3600 _VIDEO_EXTENSION_REF_TYPE = "ORIGINAL_REF_TYPE_VIDEO_EXTENSION" -_SUPPORTED_VIDEO_LENGTHS = frozenset({6, 10, 12, 16, 20}) +_VIDEO_MAX_REFERENCES = 7 +_VIDEO_ALLOWED_POOL_IDS = frozenset((1, 2)) # super / heavy only +_SUPPORTED_VIDEO_LENGTHS = frozenset({6, 10, 12, 16, 20, 22, 26, 30, 32, 36, 40}) +_ASYNC_VIDEO_LENGTHS = frozenset({6, 10}) +_VIDEO_SEGMENT_LENGTHS: dict[int, list[int]] = { + 6: [6], + 10: [10], + 12: [6, 6], + 16: [10, 6], + 20: [10, 10], + 22: [10, 6, 6], + 26: [10, 10, 6], + 30: [10, 10, 10], + 32: [10, 10, 6, 6], + 36: [10, 10, 10, 6], + 40: [10, 10, 10, 10], +} _VIDEO_SIZE_MAP: dict[str, tuple[str, str]] = { "720x1280": ("9:16", "720p"), "1280x720": ("16:9", "720p"), @@ -108,6 +135,7 @@ class _VideoJob: remixed_from_video_id: str | None = None video_url: str = "" content_path: str = "" + content_file_id: str = "" def to_dict(self) -> dict[str, Any]: payload: dict[str, Any] = { @@ -128,6 +156,11 @@ def to_dict(self) -> dict[str, Any]: payload["error"] = self.error if self.remixed_from_video_id: payload["remixed_from_video_id"] = self.remixed_from_video_id + if self.status == "completed": + if self.content_file_id: + payload["metadata"] = {"url": _local_video_url(self.content_file_id)} + elif self.content_path: + payload["metadata"] = {"url": _video_content_url(self.id)} return payload @@ -139,6 +172,12 @@ def _build_message(prompt: str, preset: str) -> str: return f"{prompt} {_PRESET_FLAGS.get(preset, '--mode=custom')}".strip() +def _video_content_url(video_id: str) -> str: + app_url = get_config().get_str("app.app_url", "").rstrip("/") + path = f"/v1/videos/{video_id}/content" + return f"{app_url}{path}" if app_url else path + + def _progress_reason(progress: int) -> str: return f"视频正在生成 {max(0, min(100, int(progress)))}%" @@ -169,6 +208,16 @@ def validate_video_length(seconds: int) -> None: raise ValidationError(f"seconds must be one of [{allowed}]", param="seconds") +def validate_async_video_length(seconds: int) -> None: + if seconds not in _ASYNC_VIDEO_LENGTHS: + allowed = ", ".join(str(item) for item in sorted(_ASYNC_VIDEO_LENGTHS)) + raise ValidationError( + f"seconds must be one of [{allowed}] for /v1/videos; " + "use /v1/chat/completions for longer videos", + param="seconds", + ) + + def _resolve_video_size(size: str) -> tuple[str, str]: normalized = (size or "720x1280").strip() config = _VIDEO_SIZE_MAP.get(normalized) @@ -196,20 +245,62 @@ def _resolve_video_preset(value: str | None, *, default: str = "custom") -> str: def _build_segment_lengths(seconds: int) -> list[int]: - if seconds == 6: - return [6] - if seconds == 10: - return [10] - if seconds == 12: - return [6, 6] - if seconds == 16: - return [10, 6] - if seconds == 20: - return [10, 10] + segments = _VIDEO_SEGMENT_LENGTHS.get(seconds) + if segments is not None: + return list(segments) validate_video_length(seconds) raise AssertionError("unreachable") +def _normalize_segment_prompts( + prompt: str, + segment_lengths: list[int], + segment_prompts: list[str] | None = None, + *, + strict: bool = False, +) -> list[str]: + prompts = [p.strip() for p in (segment_prompts or [prompt]) if p and p.strip()] + if not prompts: + raise ValidationError("Video prompt cannot be empty", param="messages") + if strict and len(prompts) != len(segment_lengths): + raise ValidationError( + "Video generation requires exactly " + f"{len(segment_lengths)} prompt segment(s) for this duration", + param="messages", + ) + if len(prompts) > len(segment_lengths): + raise ValidationError( + f"Video generation uses {len(segment_lengths)} prompt segment(s) for this duration", + param="messages", + ) + while len(prompts) < len(segment_lengths): + prompts.append(prompts[-1]) + return prompts + + +def _video_pool_candidates(spec) -> tuple[int, ...]: + return tuple( + pool_id + for pool_id in spec.pool_candidates() + if pool_id in _VIDEO_ALLOWED_POOL_IDS + ) + + +def _empty_video_response_error() -> UpstreamError: + return UpstreamError( + "Video upstream returned empty response", + status=429, + body=EMPTY_UPSTREAM_BODY, + ) + + +def _transport_upstream_error(exc: BaseException, *, context: str) -> UpstreamError: + if isinstance(exc, UpstreamError): + return exc + body = str(exc).replace("\n", "\\n")[:400] + return UpstreamError(f"{context}: {exc}", status=502, body=body) + + def _video_create_payload( *, prompt: str, @@ -337,13 +428,18 @@ async def _stream_video_request( kwargs = build_session_kwargs(lease=lease) async with ResettableSession(**kwargs) as session: - response = await session.post( - CHAT, - headers=headers, - data=orjson.dumps(payload), - timeout=timeout_s, - stream=True, - ) + try: + response = await session.post( + CHAT, + headers=headers, + data=orjson.dumps(payload), + timeout=timeout_s, + stream=True, + ) + except Exception as exc: + raise _transport_upstream_error( + exc, context="Video transport failed" + ) from exc if response.status_code != 200: body = response.content.decode("utf-8", "replace")[:300] raise UpstreamError( @@ -351,8 +447,13 @@ async def _stream_video_request( status=response.status_code, body=body, ) - async for line in response.aiter_lines(): - yield line + try: + async for line in response.aiter_lines(): + yield line + except Exception as exc: + raise _transport_upstream_error( + exc, context="Video stream read failed" + ) from exc def _absolutize_video_url(url: str) -> str: @@ -436,6 +537,12 @@ async def _prepare_video_references( input_references: list[dict[str, Any]], ) -> list[_VideoReference]: """Upload multiple video references concurrently and preserve order.""" + if len(input_references) > _VIDEO_MAX_REFERENCES: + raise ValidationError( + f"Video generation supports at most {_VIDEO_MAX_REFERENCES} reference images", + param="input_reference", + ) + tasks = [ _prepare_video_reference(token, ref) for ref in input_references @@ -490,6 +597,7 @@ async def _collect_video_segment( final_thumbnail = "" video_post_id = "" stream_data_items: list[str] = [] + saw_video_signal = False async for line in _stream_video_request( token, @@ -511,6 +619,7 @@ async def _collect_video_segment( stream = _extract_streaming_video_response(obj) if stream: + saw_video_signal = True try: progress = int(stream.get("progress") or 0) except (TypeError, ValueError): @@ -537,6 +646,8 @@ async def _collect_video_segment( final_thumbnail = _absolutize_video_url(thumbnail) attachments = _extract_model_response_file_attachments(obj) + if attachments: + saw_video_signal = True if attachments and not final_asset_id: final_asset_id = attachments[0] @@ -549,6 +660,8 @@ async def _collect_video_segment( body="\n".join(stream_data_items), ) if not final_url: + if not stream_data_items or not saw_video_signal: + raise _empty_video_response_error() raise UpstreamError( "Video generation returned no final video URL", body="\n".join(stream_data_items), @@ -638,8 +751,11 @@ async def _generate_video_with_token( preset: str, timeout_s: float, input_references: list[dict[str, Any]] | None = None, + segment_prompts: list[str] | None = None, progress_cb: Callable[[int], Awaitable[None]] | None = None, ) -> _VideoArtifact: + segments = _build_segment_lengths(seconds) + prompts = _normalize_segment_prompts(prompt, segments, segment_prompts) references: list[_VideoReference] = [] if input_references: references = await _prepare_video_references(token, input_references) @@ -648,7 +764,7 @@ async def _generate_video_with_token( post = await create_media_post( token, media_type=_VIDEO_MEDIA_TYPE, - prompt=prompt, + prompt=prompts[0], referer="https://grok.com/imagine", ) post_data = post.get("post") @@ -658,16 +774,16 @@ async def _generate_video_with_token( if not parent_post_id: raise UpstreamError("Video create-post returned no post id") - segments = _build_segment_lengths(seconds) total_segments = len(segments) artifact: _VideoArtifact | None = None extend_post_id = parent_post_id elapsed_seconds = 0 for index, segment_length in enumerate(segments): + segment_prompt = prompts[index] if index == 0: payload = _video_create_payload( - prompt=prompt, + prompt=segment_prompt, parent_post_id=parent_post_id, aspect_ratio=aspect_ratio, resolution_name=resolution_name, @@ -680,7 +796,7 @@ async def _generate_video_with_token( referer = "https://grok.com/imagine" else: payload = _video_extend_payload( - prompt=prompt, + prompt=segment_prompt, parent_post_id=parent_post_id, extend_post_id=extend_post_id, aspect_ratio=aspect_ratio, @@ -725,6 +841,7 @@ async def _run_video_generation( seconds: int, preset: str = "custom", input_references: list[dict[str, Any]] | None = None, + segment_prompts: list[str] | None = None, progress_cb: Callable[[int], Awaitable[None]] | None = None, ) -> _VideoArtifact: async def _runner(token: str, timeout_s: float) -> _VideoArtifact: @@ -737,6 +854,7 @@ async def _runner(token: str, timeout_s: float) -> _VideoArtifact: preset=preset, timeout_s=timeout_s, input_references=input_references, + segment_prompts=segment_prompts, progress_cb=progress_cb, ) @@ -753,44 +871,94 @@ async def _run_video_with_account( spec = resolve_model(model) if not spec.is_video(): raise ValidationError(f"Model {model!r} is not a video model", param="model") + pool_candidates = _video_pool_candidates(spec) + if not pool_candidates: + raise RateLimitError("No super/heavy accounts available for video generation") from app.dataplane.account import _directory as _acct_dir if _acct_dir is None: raise RateLimitError("Account directory not initialised") - acct = await _acct_dir.reserve( - pool_candidates=spec.pool_candidates(), - mode_id=int(spec.mode_id), - now_s_override=now_s(), - ) - if acct is None: - raise RateLimitError("No available accounts for video generation") + max_retries = selection_max_retries() + retry_codes = _configured_retry_codes(cfg) + excluded: list[str] = [] + mode_id = int(spec.mode_id) - token = acct.token - success = False - fail_exc: BaseException | None = None - try: - artifact = await runner(token, timeout_s) - success = True - return artifact - except BaseException as exc: - fail_exc = exc - raise - finally: - await _acct_dir.release(acct) - kind = ( - FeedbackKind.SUCCESS - if success - else _feedback_kind(fail_exc) - if fail_exc - else FeedbackKind.SERVER_ERROR + async def _reserve(): + acct = await _acct_dir.reserve( + pool_candidates=pool_candidates, + mode_id=mode_id, + now_s_override=now_s(), + exclude_tokens=excluded or None, ) - await _acct_dir.feedback(token, kind, int(spec.mode_id)) - if success: - asyncio.create_task(_quota_sync(token, int(spec.mode_id))) - else: - asyncio.create_task(_fail_sync(token, int(spec.mode_id), fail_exc)) + if acct is not None or current_strategy() == "random": + return acct + refresh_svc = get_refresh_service() + if refresh_svc is not None: + await refresh_svc.refresh_on_demand() + acct = await _acct_dir.reserve( + pool_candidates=pool_candidates, + mode_id=mode_id, + now_s_override=now_s(), + exclude_tokens=excluded or None, + ) + return acct + + for attempt in range(max_retries + 1): + acct = await _reserve() + if acct is None: + raise RateLimitError("No available super/heavy accounts for video generation") + + token = acct.token + success = False + retry = False + fail_exc: BaseException | None = None + try: + artifact = await runner(token, timeout_s) + success = True + return artifact + except UpstreamError as exc: + fail_exc = exc + if _should_retry_upstream(exc, retry_codes) and attempt < max_retries: + retry = True + logger.warning( + "video retry scheduled: attempt={}/{} status={} token={}...", + attempt + 1, + max_retries, + exc.status, + token[:8], + ) + else: + raise + except BaseException as exc: + fail_exc = exc + raise + finally: + await _acct_dir.release(acct) + kind = ( + FeedbackKind.SUCCESS + if success + else _feedback_kind(fail_exc) + if fail_exc + else FeedbackKind.SERVER_ERROR + ) + await _acct_dir.feedback(token, kind, mode_id) + if success: + asyncio.create_task(_quota_sync(token, mode_id)).add_done_callback( + _log_task_exception + ) + else: + asyncio.create_task(_fail_sync(token, mode_id, fail_exc)).add_done_callback( + _log_task_exception + ) + + if retry: + excluded.append(token) + continue + break + + raise RateLimitError("No available super/heavy accounts for video generation") async def _put_video_job(job: _VideoJob) -> None: @@ -840,33 +1008,11 @@ async def _run_video_job( default=default_resolution_name, ) resolved_preset = _resolve_video_preset(preset) - spec = resolve_model(job.model) - - from app.dataplane.account import _directory as _acct_dir - - if _acct_dir is None: - raise RateLimitError("Account directory not initialised") - - acct = await _acct_dir.reserve( - pool_candidates=spec.pool_candidates(), - mode_id=int(spec.mode_id), - now_s_override=now_s(), - ) - if acct is None: - raise RateLimitError("No available accounts for video generation") - token = acct.token - success = False - fail_exc: BaseException | None = None - try: - cfg = get_config() - timeout_s = cfg.get_float("video.timeout", 180.0) - - async def _progress(progress: int) -> None: - await _set_job_status( - job, status="in_progress", progress=max(1, progress) - ) + async def _progress(progress: int) -> None: + await _set_job_status(job, status="in_progress", progress=max(1, progress)) + async def _runner(token: str, timeout_s: float) -> tuple[_VideoArtifact, bytes]: artifact = await _generate_video_with_token( token=token, prompt=prompt, @@ -879,32 +1025,19 @@ async def _progress(progress: int) -> None: progress_cb=_progress, ) raw, _mime = await _download_video_bytes(token, artifact.video_url) - success = True - except BaseException as exc: - fail_exc = exc - raise - finally: - await _acct_dir.release(acct) - kind = ( - FeedbackKind.SUCCESS - if success - else _feedback_kind(fail_exc) - if fail_exc - else FeedbackKind.SERVER_ERROR - ) - await _acct_dir.feedback(token, kind, int(spec.mode_id)) - if success: - asyncio.create_task(_quota_sync(token, int(spec.mode_id))) - else: - asyncio.create_task(_fail_sync(token, int(spec.mode_id), fail_exc)) + return artifact, raw - path = _save_video_bytes(raw, job.id) + artifact, raw = await _run_video_with_account(model=job.model, runner=_runner) + + file_id = hashlib.sha1(artifact.video_url.encode("utf-8")).hexdigest()[:32] + path = _save_video_bytes(raw, file_id) async with _VIDEO_JOBS_LOCK: job.status = "completed" job.progress = 100 job.completed_at = int(time.time()) job.video_url = artifact.video_url job.content_path = str(path) + job.content_file_id = file_id job.remixed_from_video_id = artifact.remixed_from_video_id except Exception as exc: logger.exception("video job failed: job_id={} error={}", job.id, exc) @@ -932,7 +1065,7 @@ async def create_video( raise ValidationError("prompt cannot be empty", param="prompt") normalized_seconds = _coerce_seconds(seconds) - validate_video_length(normalized_seconds) + validate_async_video_length(normalized_seconds) normalized_size = (size or "720x1280").strip() _aspect_ratio, default_resolution_name = _resolve_video_size(normalized_size) _resolve_video_resolution_name(resolution_name, default=default_resolution_name) @@ -992,52 +1125,68 @@ async def content_path(video_id: str) -> Path: def _extract_video_prompt_and_reference( messages: list[dict], ) -> tuple[str, list[dict[str, Any]] | None]: - prompt = "" - reference_urls: list[str] = [] + prompts, input_references = _extract_video_segment_prompts_and_references(messages) + return prompts[-1], input_references - for msg in reversed(messages): - content = msg.get("content", "") - if isinstance(content, str) and content.strip(): - prompt = content.strip() - if prompt: - break - continue - if not isinstance(content, list): + +def _text_and_references_from_content( + content: str | list[dict[str, Any]] | None, +) -> tuple[str, list[str]]: + if isinstance(content, str): + return content.strip(), [] + if not isinstance(content, list): + return "", [] + + text_parts: list[str] = [] + reference_urls: list[str] = [] + for item in content: + if not isinstance(item, dict): continue + item_type = item.get("type") + if item_type == "text": + text = str(item.get("text") or "").strip() + if text: + text_parts.append(text) + elif item_type == "image_url": + image_url = item.get("image_url") + if isinstance(image_url, dict): + url = str(image_url.get("url") or "").strip() + if url: + reference_urls.append(url) + elif isinstance(image_url, str) and image_url.strip(): + reference_urls.append(image_url.strip()) + return " ".join(text_parts).strip(), reference_urls + + +def _extract_video_segment_prompts_and_references( + messages: list[dict], +) -> tuple[list[str], list[dict[str, Any]] | None]: + prompts: list[str] = [] + reference_urls: list[str] = [] - text_parts: list[str] = [] - block_references: list[str] = [] - for item in content: - if not isinstance(item, dict): - continue - item_type = item.get("type") - if item_type == "text": - text = str(item.get("text") or "").strip() - if text: - text_parts.append(text) - elif item_type == "image_url": - image_url = item.get("image_url") - if isinstance(image_url, dict): - url = str(image_url.get("url") or "").strip() - if url: - block_references.append(url) - elif isinstance(image_url, str) and image_url.strip(): - block_references.append(image_url.strip()) - - if text_parts: - prompt = " ".join(text_parts) + for msg in messages: + if msg.get("role", "user") != "user": + continue + prompt, block_references = _text_and_references_from_content( + msg.get("content", "") + ) + if prompt: + prompts.append(prompt) if block_references and not reference_urls: reference_urls = block_references - if prompt: - break - if not prompt: + if not prompts: raise ValidationError("Video prompt cannot be empty", param="messages") input_references: list[dict[str, Any]] | None = None if reference_urls: - input_references = [{"image_url": url} for url in reference_urls[:7]] - return prompt, input_references + if len(reference_urls) > _VIDEO_MAX_REFERENCES: + raise ValidationError( + f"Video generation supports at most {_VIDEO_MAX_REFERENCES} reference images", + param="messages", + ) + input_references = [{"image_url": url} for url in reference_urls] + return prompts, input_references async def completions( @@ -1058,7 +1207,17 @@ async def completions( default=default_resolution_name, ) resolved_preset = _resolve_video_preset(preset) - prompt, input_references = _extract_video_prompt_and_reference(messages) + segments = _build_segment_lengths(seconds) + segment_prompts, input_references = _extract_video_segment_prompts_and_references( + messages + ) + normalized_segment_prompts = _normalize_segment_prompts( + segment_prompts[0], + segments, + segment_prompts, + strict=True, + ) + prompt = normalized_segment_prompts[0] cfg = get_config() is_stream = stream if stream is not None else cfg.get_bool("features.stream", False) @@ -1075,6 +1234,7 @@ async def _runner(token: str, timeout_s: float) -> str: preset=resolved_preset, timeout_s=timeout_s, input_references=input_references, + segment_prompts=normalized_segment_prompts, progress_cb=progress_cb, ) file_id = hashlib.sha1(artifact.video_url.encode("utf-8")).hexdigest()[:32] @@ -1140,7 +1300,12 @@ async def _progress(progress: int) -> None: "retrieve", "content_path", "validate_video_length", + "validate_async_video_length", "completions", "_build_segment_lengths", + "_empty_video_response_error", + "_extract_video_segment_prompts_and_references", + "_normalize_segment_prompts", "_resolve_video_size", + "_video_pool_candidates", ] diff --git a/app/products/web/admin/__init__.py b/app/products/web/admin/__init__.py index 88093af78..313789485 100644 --- a/app/products/web/admin/__init__.py +++ b/app/products/web/admin/__init__.py @@ -196,7 +196,10 @@ async def update_config(req: ConfigPatchRequest): patch = _sanitize_proxy_config(req.root) _ensure_runtime_patch_allowed(patch) - cache_local_changed = _patch_touches_prefix(patch, "cache.local") + cache_local_changed = ( + _patch_touches_prefix(patch, "cache.local") + or _patch_touches_prefix(patch, "storage") + ) await config.update(patch) # config.update() only writes to the backend and invalidates the in-memory # snapshot (_version = None); it does not refresh the data. load() is diff --git a/app/products/web/admin/cache.py b/app/products/web/admin/cache.py index fc49ab597..c6abb504a 100644 --- a/app/products/web/admin/cache.py +++ b/app/products/web/admin/cache.py @@ -1,29 +1,21 @@ """Local media cache management — stats, list, clear, delete.""" import asyncio -from pathlib import Path -from typing import Any, Literal +from typing import Literal from fastapi import APIRouter, Query from pydantic import BaseModel -from app.platform.config.snapshot import get_config from app.platform.errors import AppError, ErrorKind from app.platform.storage import ( clear_local_media_files, delete_local_media_file, - image_files_dir, - video_files_dir, + list_local_media_files, + local_media_stats, ) router = APIRouter(prefix="/cache", tags=["Admin - Cache"]) -# --------------------------------------------------------------------------- -# Lightweight local media cache service. -# --------------------------------------------------------------------------- -_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"} -_VIDEO_EXTS = {".mp4", ".mov", ".m4v", ".webm", ".avi", ".mkv"} - class ClearCacheRequest(BaseModel): type: Literal["image", "video"] = "image" @@ -39,67 +31,6 @@ class DeleteCacheItemsRequest(BaseModel): names: list[str] -def _dir(media_type: str) -> Path: - return image_files_dir() if media_type == "image" else video_files_dir() - - -def _exts(media_type: str): - return _IMAGE_EXTS if media_type == "image" else _VIDEO_EXTS - - -def _limit_mb(media_type: str) -> int: - cfg = get_config() - return max(0, int(cfg.get_int(f"cache.local.{media_type}_max_mb", 0))) - - -def _stats(media_type: str) -> dict[str, Any]: - d = _dir(media_type) - files = [] - if d.exists(): - allowed = _exts(media_type) - files = [f for f in d.glob("*") if f.is_file() and f.suffix.lower() in allowed] - - total_size = sum(f.stat().st_size for f in files) - limit_mb = _limit_mb(media_type) - limit_bytes = limit_mb * 1024 * 1024 - usage_ratio = (total_size / limit_bytes) if limit_bytes > 0 else None - usage_percent = round(usage_ratio * 100, 1) if usage_ratio is not None else None - return { - "count": len(files), - "size_mb": round(total_size / 1024 / 1024, 2), - "size_bytes": total_size, - "limit_mb": limit_mb, - "limit_bytes": limit_bytes, - "limited": limit_bytes > 0, - "usage_ratio": round(usage_ratio, 4) if usage_ratio is not None else None, - "usage_percent": usage_percent, - } - - -def _list_files(media_type: str, page: int, page_size: int) -> dict[str, Any]: - d = _dir(media_type) - if not d.exists(): - return {"total": 0, "page": page, "page_size": page_size, "items": []} - allowed = _exts(media_type) - files = sorted( - (f for f in d.glob("*") if f.is_file() and f.suffix.lower() in allowed), - key=lambda f: f.stat().st_mtime, - reverse=True, - ) - total = len(files) - start = (page - 1) * page_size - chunk = files[start : start + page_size] - items = [] - for f in chunk: - st = f.stat() - items.append({ - "name": f.name, - "size_bytes": st.st_size, - "modified_at": st.st_mtime, - }) - return {"total": total, "page": page, "page_size": page_size, "items": items} - - # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @@ -107,8 +38,8 @@ def _list_files(media_type: str, page: int, page_size: int) -> dict[str, Any]: @router.get("") async def cache_stats(): return { - "local_image": _stats("image"), - "local_video": _stats("video"), + "local_image": local_media_stats("image"), + "local_video": local_media_stats("video"), } @@ -120,7 +51,10 @@ async def list_local( page_size: int = 1000, ): media_type = type_ or cache_type - return {"status": "success", **_list_files(media_type, page, page_size)} + return { + "status": "success", + **list_local_media_files(media_type, page=page, page_size=page_size), + } @router.post("/clear") diff --git a/app/products/web/admin/tokens.py b/app/products/web/admin/tokens.py index bf5bf10f6..70e535b5e 100644 --- a/app/products/web/admin/tokens.py +++ b/app/products/web/admin/tokens.py @@ -121,6 +121,7 @@ def _quota_brief(q: dict) -> dict: def _serialize_record(r) -> dict: return { "token": r.token, + "account_id": r.account_id, "pool": r.pool or "basic", "status": r.status, "quota": _quota_brief(r.quota) if isinstance(r.quota, dict) else {}, diff --git a/config.defaults.toml b/config.defaults.toml index f7a9b9437..8c132b639 100644 --- a/config.defaults.toml +++ b/config.defaults.toml @@ -78,9 +78,9 @@ mode = "direct" proxy_url = "" # 基础代理池(API 流量,proxy_pool 模式必填) proxy_pool = [] -# 资源代理 URL(图片/视频下载,未配置则回落到 proxy_url) +# 资源代理 URL(图片/视频下载,未配置则直连) resource_proxy_url = "" -# 资源代理池(图片/视频下载,未配置则回落到 proxy_pool) +# 资源代理池(图片/视频下载,未配置则直连) resource_proxy_pool = [] # 跳过代理 SSL 证书验证(代理使用自签名证书时启用) skip_ssl_verify = false diff --git a/docs/README.en.md b/docs/README.en.md index 1cf8a7a6d..2968baf61 100644 --- a/docs/README.en.md +++ b/docs/README.en.md @@ -111,7 +111,7 @@ docker compose up -d ### Vercel -[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https://github.com/chenyme/grok2api&env=LOG_LEVEL,LOG_FILE_ENABLED,DATA_DIR,LOG_DIR,ACCOUNT_STORAGE,ACCOUNT_REDIS_URL,ACCOUNT_MYSQL_URL,ACCOUNT_POSTGRESQL_URL) +[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https://github.com/chenyme/grok2api&env=LOG_LEVEL,LOG_FILE_ENABLED,DATA_DIR,LOG_DIR,ACCOUNT_STORAGE,ACCOUNT_REDIS_URL,ACCOUNT_MYSQL_URL,ACCOUNT_POSTGRESQL_URL,CONFIG_STORAGE,CONFIG_LOCAL_PATH) ### Render @@ -186,7 +186,8 @@ docker compose up -d | `ACCOUNT_SQL_MAX_OVERFLOW` | Maximum overflow connections above pool size | `10` | | `ACCOUNT_SQL_POOL_TIMEOUT` | Seconds to wait for a free connection from the pool | `30` | | `ACCOUNT_SQL_POOL_RECYCLE` | Max connection lifetime in seconds before reconnect | `1800` | -| `CONFIG_LOCAL_PATH` | Runtime config file path for `local` config storage | `${DATA_DIR}/config.toml` | +| `CONFIG_STORAGE` | Runtime config storage backend; defaults to local TOML and does not follow `ACCOUNT_STORAGE` | `local` | +| `CONFIG_LOCAL_PATH` | Local runtime config file path | `${DATA_DIR}/config.toml` | Runtime config can also be overridden with `GROK_`-prefixed environment variables. For example, `GROK_APP_API_KEY` overrides `app.api_key`, and `GROK_FEATURES_STREAM` overrides `features.stream`. @@ -361,17 +362,33 @@ curl http://localhost:8000/v1/chat/completions \ "model": "grok-imagine-video", "stream": true, "messages": [ - {"role":"user","content":"A neon rainy street at night, cinematic slow tracking shot"} + {"role":"user","content":"A neon rainy street at night, cinematic slow tracking shot"}, + {"role":"user","content":"The camera passes under glowing signs as the character turns toward distant headlights"} ], - "video_config": { - "seconds": 10, - "size": "1792x1024", - "resolution_name": "720p", - "preset": "normal" + "metadata": { + "video_config": { + "seconds": 16, + "size": "1792x1024", + "resolution_name": "720p", + "preset": "normal" + } } }' ``` +`image_config` / `video_config` may also be sent at the request top level; top-level fields take precedence over matching `metadata` fields. + +```json +{ + "video_config": { + "seconds": 10, + "size": "1792x1024", + "resolution_name": "720p", + "preset": "normal" + } +} +``` +
Field Notes
@@ -389,10 +406,11 @@ curl http://localhost:8000/v1/chat/completions \ | \|_ `size` | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | | \|_ `response_format` | `url`, `b64_json` | | `video_config` | Video model parameters | -| \|_ `seconds` | `6`, `10`, `12`, `16`, `20` | +| \|_ `seconds` | `6`, `10`, `12`, `16`, `20`, `22`, `26`, `30`, `32`, `36`, `40`; longer videos require one user prompt per segment | | \|_ `size` | `720x1280`, `1280x720`, `1024x1024`, `1024x1792`, `1792x1024` | | \|_ `resolution_name` | `480p`, `720p` | | \|_ `preset` | `fun`, `normal`, `spicy`, `custom` | +| `metadata.image_config` / `metadata.video_config` | Compatibility location for `image_config` / `video_config`; top-level config wins |
@@ -588,7 +606,7 @@ curl -L http://localhost:8000/v1/videos//content \ | :-- | :-- | | `model` | Video model, currently `grok-imagine-video` | | `prompt` | Video generation prompt | -| `seconds` | Video length: `6`, `10`, `12`, `16`, `20` | +| `seconds` | Video length: `6`, `10`; use `/v1/chat/completions` for longer videos | | `size` | Supports `720x1280`, `1280x720`, `1024x1024`, `1024x1792`, `1792x1024` | | `resolution_name` | `480p` or `720p` | | `preset` | `fun`, `normal`, `spicy`, `custom` | diff --git a/scripts/inject_cookie.py b/scripts/inject_cookie.py new file mode 100644 index 000000000..a78776559 --- /dev/null +++ b/scripts/inject_cookie.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +"""Grok cookie injector — set the SSO cookie in a browser and open Grok. + +Usage: + python inject_cookie.py + python inject_cookie.py --url https://grok.com + python inject_cookie.py --browser firefox + +The cookie value can be: + - A raw JWT: eyJ0eXAiOiJKV1Qi... + - With sso= prefix: sso=eyJ0eXAiOiJKV1Qi... + - A full cookie header: sso=eyJ...; Domain=.grok.com; Path=/ + +Requirements: + pip install browser-cookie3 (optional, for reading existing cookies) + +The script uses webbrowser to open the URL and prints instructions for +manual cookie injection into Chrome DevTools. +""" + +from __future__ import annotations + +import argparse +import json +import re +import sys +import webbrowser + +# ── JWT decode (no signature verification) ────────────────────────────── + + +def decode_jwt_payload(raw: str) -> dict: + import base64 + + parts = raw.strip().split(".") + if len(parts) != 3: + raise ValueError("Not a valid JWT (expected 3 segments)") + + payload_b64 = parts[1] + payload_b64 += "=" * (4 - len(payload_b64) % 4) + try: + payload = json.loads(base64.urlsafe_b64decode(payload_b64)) + except Exception as exc: + raise ValueError(f"Cannot decode JWT payload: {exc}") from exc + return payload + + +def parse_cookie(raw: str) -> tuple[str, str, dict]: + """Parse a Grok SSO cookie value. + + Returns (cookie_name, cookie_value, jwt_payload). + """ + text = raw.strip() + + # Strip browser export format: "Cookie: sso=..." + if text.lower().startswith("cookie:"): + text = text[len("cookie:"):].strip() + + # Find the JWT by looking for sso= prefix + cookie_value = text + if text.startswith("sso="): + cookie_value = text[4:] + if ";" in cookie_value: + cookie_value = cookie_value.split(";")[0].strip() + + # If no sso= prefix, assume raw JWT + cookie_value = cookie_value.strip() + payload = decode_jwt_payload(cookie_value) + return ("sso", cookie_value, payload) + + +# ── Browser injection helpers ─────────────────────────────────────────── + + +def inject_via_playwright(cookie_value: str, url: str) -> int: + """Use Playwright to inject the cookie and open the browser.""" + try: + from playwright.sync_api import sync_playwright + except ImportError: + print("[!] playwright not installed. Install with: pip install playwright && playwright install chromium") + return 1 + + with sync_playwright() as p: + browser = p.chromium.launch(headless=False) + context = browser.new_context() + context.add_cookies([ + { + "name": "sso", + "value": cookie_value, + "domain": ".grok.com", + "path": "/", + "httpOnly": False, + "secure": True, + "sameSite": "Lax", + } + ]) + page = context.new_page() + page.goto(url) + print(f"[+] Cookie injected — browser opened at {url}") + print("[+] Press Ctrl+C to close the browser...") + try: + page.wait_for_timeout(60_000) # 60s before auto-close + except KeyboardInterrupt: + pass + browser.close() + return 0 + + +def print_manual_instructions(cookie_value: str, url: str) -> None: + """Print manual DevTools injection instructions.""" + js = ( + "document.cookie = 'sso=" + cookie_value + "; " + "domain=.grok.com; path=/; SameSite=Lax; Secure';" + ) + + print(f""" +╔══════════════════════════════════════════════════════════════╗ +║ Grok Cookie Injection — Manual Instructions ║ +╠══════════════════════════════════════════════════════════════╣ +║ ║ +║ 1. Open {url} in your browser ║ +║ 2. Press F12 to open DevTools ║ +║ 3. Go to the Console tab ║ +║ 4. Paste this command: ║ +║ ║ +║ {js} +║ ║ +║ 5. Press Enter (the page will reload) ║ +║ 6. You should be logged in ║ +║ ║ +╠══════════════════════════════════════════════════════════════╣ +║ Alternative — Application tab: ║ +║ 1. F12 → Application → Cookies → grok.com ║ +║ 2. Add cookie: name=sso, value=, path=/ ║ +║ 3. Refresh the page ║ +║ ║ +╠══════════════════════════════════════════════════════════════╣ +║ For grok2api import: ║ +║ Use the token value directly: ║ +║ curl -X POST .../admin/api/tokens/add ║ +║ -d '{{"tokens": ["{cookie_value}"], "pool": "auto"}}' ║ +╚══════════════════════════════════════════════════════════════╝ +""") + + +# ── Main ───────────────────────────────────────────────────────────────── + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Inject a Grok SSO cookie into a browser and open Grok", + ) + parser.add_argument( + "cookie", + help="Grok SSO cookie value (JWT, optionally with 'sso=' prefix)", + ) + parser.add_argument( + "--url", default="https://grok.com", + help="Target URL (default: https://grok.com)", + ) + parser.add_argument( + "--manual", action="store_true", + help="Print manual DevTools injection instructions (no automation)", + ) + parser.add_argument( + "--playwright", action="store_true", + help="Use Playwright to automate cookie injection", + ) + args = parser.parse_args() + + try: + name, value, payload = parse_cookie(args.cookie) + except ValueError as exc: + print(f"[!] Cookie parse error: {exc}", file=sys.stderr) + return 1 + + session_id = payload.get("session_id", "unknown") + print(f"[+] Parsed cookie: name={name}, session_id={session_id}") + print(f"[+] Full token: {value[:40]}...{value[-20:]}") + + if args.playwright: + return inject_via_playwright(value, args.url) + + # Default: print instructions and offer to open browser. + print_manual_instructions(value, args.url) + + try: + choice = input("\n[?] Open browser now? (y/N): ").strip().lower() + if choice in ("y", "yes"): + webbrowser.open(args.url) + except (KeyboardInterrupt, EOFError): + pass + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/test_chat_metadata_config.py b/tests/test_chat_metadata_config.py new file mode 100644 index 000000000..9ec196727 --- /dev/null +++ b/tests/test_chat_metadata_config.py @@ -0,0 +1,84 @@ +import unittest + +from app.platform.errors import ValidationError +from app.products.openai import router +from app.products.openai.schemas import ChatCompletionRequest, ImageConfig, VideoConfig + + +def _request(**kwargs) -> ChatCompletionRequest: + payload = { + "model": "grok-imagine-video", + "messages": [{"role": "user", "content": "hello"}], + } + payload.update(kwargs) + return ChatCompletionRequest.model_validate(payload) + + +class ChatMetadataConfigTests(unittest.TestCase): + def test_metadata_video_config_is_used_as_fallback(self) -> None: + req = _request( + metadata={ + "video_config": { + "seconds": 16, + "size": "1792x1024", + "resolution_name": "720p", + "preset": "normal", + } + } + ) + + cfg = router._resolve_video_config(req) + + self.assertEqual(cfg.seconds, 16) + self.assertEqual(cfg.size, "1792x1024") + self.assertEqual(cfg.resolution_name, "720p") + self.assertEqual(cfg.preset, "normal") + + def test_top_level_video_config_wins_over_metadata(self) -> None: + req = _request( + video_config=VideoConfig(seconds=6, size="720x1280"), + metadata={"video_config": {"seconds": 16, "size": "1792x1024"}}, + ) + + cfg = router._resolve_video_config(req) + + self.assertEqual(cfg.seconds, 6) + self.assertEqual(cfg.size, "720x1280") + + def test_metadata_image_config_is_used_as_fallback(self) -> None: + req = _request( + metadata={ + "image_config": { + "n": 3, + "size": "1792x1024", + "response_format": "url", + } + } + ) + + cfg = router._resolve_image_config(req) + + self.assertEqual(cfg.n, 3) + self.assertEqual(cfg.size, "1792x1024") + self.assertEqual(cfg.response_format, "url") + + def test_top_level_image_config_wins_over_metadata(self) -> None: + req = _request( + image_config=ImageConfig(n=1, size="1024x1024"), + metadata={"image_config": {"n": 3, "size": "1792x1024"}}, + ) + + cfg = router._resolve_image_config(req) + + self.assertEqual(cfg.n, 1) + self.assertEqual(cfg.size, "1024x1024") + + def test_metadata_config_must_be_object(self) -> None: + req = _request(metadata={"video_config": "bad"}) + + with self.assertRaises(ValidationError): + router._resolve_video_config(req) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_config_backend_factory.py b/tests/test_config_backend_factory.py new file mode 100644 index 000000000..35ef254cf --- /dev/null +++ b/tests/test_config_backend_factory.py @@ -0,0 +1,63 @@ +import os +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +from app.platform.config.backends.factory import ( + create_config_backend, + get_config_backend_name, +) +from app.platform.config.backends.sql import SqlConfigBackend +from app.platform.config.backends.toml import TomlConfigBackend + + +class ConfigBackendFactoryTests(unittest.TestCase): + def test_config_backend_defaults_to_local_when_account_storage_is_mysql(self) -> None: + with patch.dict(os.environ, {"ACCOUNT_STORAGE": "mysql"}, clear=True): + self.assertEqual(get_config_backend_name(), "local") + + def test_config_backend_defaults_to_local_when_account_storage_is_redis(self) -> None: + with patch.dict(os.environ, {"ACCOUNT_STORAGE": "redis"}, clear=True): + self.assertEqual(get_config_backend_name(), "local") + + def test_create_config_backend_uses_toml_for_explicit_local(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + config_path = Path(tmpdir) / "config.toml" + with patch.dict( + os.environ, + { + "ACCOUNT_STORAGE": "mysql", + "CONFIG_STORAGE": "local", + "CONFIG_LOCAL_PATH": str(config_path), + }, + clear=True, + ): + backend = create_config_backend() + + self.assertIsInstance(backend, TomlConfigBackend) + + def test_create_config_backend_uses_sql_only_when_config_storage_is_mysql(self) -> None: + sentinel_engine = object() + with patch.dict( + os.environ, + { + "ACCOUNT_STORAGE": "local", + "CONFIG_STORAGE": "mysql", + "ACCOUNT_MYSQL_URL": "mysql://user:pass@example.com/db", + }, + clear=True, + ): + with patch( + "app.control.account.backends.sql.create_mysql_engine", + return_value=sentinel_engine, + ): + backend = create_config_backend() + + self.assertIsInstance(backend, SqlConfigBackend) + self.assertIs(backend._engine, sentinel_engine) + self.assertEqual(backend._dialect, "mysql") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_empty_response_handling.py b/tests/test_empty_response_handling.py new file mode 100644 index 000000000..769e8673a --- /dev/null +++ b/tests/test_empty_response_handling.py @@ -0,0 +1,430 @@ +import asyncio +import json +import unittest +from types import SimpleNamespace +from unittest.mock import patch + +from app.control.account.enums import FeedbackKind +from app.platform.errors import UpstreamError +from app.products.anthropic import messages as anthropic_messages +from app.products.anthropic import router as anthropic_router +from app.products.openai import chat, responses, video +from app.products.openai import router as openai_router + + +class _FakeConfig: + def get(self, key: str, default=None): + return default + + def get_bool(self, key: str, default: bool = False) -> bool: + return default + + def get_float(self, key: str, default: float = 0.0) -> float: + return default + + def get_str(self, key: str, default: str = "") -> str: + return default + + +class _FakeDirectory: + def __init__(self) -> None: + self.feedbacks: list[tuple[str, FeedbackKind, int]] = [] + + async def release(self, acct) -> None: + return None + + async def feedback( + self, + token: str, + kind: FeedbackKind, + selected_mode_id: int, + now_s_val=None, + ) -> None: + self.feedbacks.append((token, kind, selected_mode_id)) + + +class _FakeAdapter: + def __init__( + self, + *, + text: str = "", + images: list[tuple[str, str]] | None = None, + references: str = "", + sources: list[dict] | None = None, + thinking: str = "", + ) -> None: + self.text_buf = [text] if text else [] + self.thinking_buf = [thinking] if thinking else [] + self.image_urls = images or [] + self._references = references + self._sources = sources + + def references_suffix(self) -> str: + return self._references + + def search_sources_list(self): + return self._sources + + +async def _empty_stream(**kwargs): + yield "data: [DONE]" + + +def _chat_frame(response: dict) -> str: + return "data: " + json.dumps({"result": {"response": response}}) + "\n\n" + + +async def _thinking_stream(**kwargs): + yield _chat_frame({"token": "thinking one", "isThinking": True}) + yield "data: [DONE]" + + +async def _thinking_then_text_stream(**kwargs): + yield _chat_frame({"token": "thinking one", "isThinking": True}) + yield _chat_frame({"token": "hello", "messageTag": "final"}) + yield "data: [DONE]" + + +async def _empty_video_stream(*args, **kwargs): + yield "data: [DONE]" + + +async def _noop_async(*args, **kwargs): + return None + + +async def _empty_generator(): + if False: + yield "" + + +class EmptyResponseHandlingTests(unittest.TestCase): + def test_visible_output_helper_ignores_thinking_only(self) -> None: + self.assertFalse( + chat._adapter_has_visible_output(_FakeAdapter(thinking="reasoning only")) + ) + + def test_visible_output_helper_accepts_text_tool_images_and_sources(self) -> None: + self.assertTrue(chat._adapter_has_visible_output(_FakeAdapter(text="hello"))) + self.assertTrue( + chat._adapter_has_visible_output( + _FakeAdapter(images=[("https://example.com/a.png", "img-1")]) + ) + ) + self.assertTrue( + chat._adapter_has_visible_output( + _FakeAdapter(sources=[{"url": "https://example.com", "title": "x"}]) + ) + ) + self.assertTrue( + chat._adapter_has_visible_output(_FakeAdapter(), has_tool_calls=True) + ) + + def test_empty_response_error_is_retryable_429(self) -> None: + exc = chat._empty_upstream_response_error() + + self.assertEqual(exc.status, 429) + self.assertEqual(exc.details["body"], chat.EMPTY_UPSTREAM_BODY) + + def test_video_empty_response_error_is_retryable_429(self) -> None: + exc = video._empty_video_response_error() + + self.assertEqual(exc.status, 429) + self.assertEqual(exc.details["body"], chat.EMPTY_UPSTREAM_BODY) + + def test_video_segment_empty_stream_raises_429(self) -> None: + async def _run() -> None: + await video._collect_video_segment( + token="tok-test", + payload={}, + referer="https://grok.com/imagine", + timeout_s=1, + ) + + with self.assertRaises(UpstreamError) as ctx: + with patch.object( + video, "_stream_video_request", side_effect=_empty_video_stream + ): + asyncio.run(_run()) + + self.assertEqual(ctx.exception.status, 429) + self.assertEqual(ctx.exception.details["body"], chat.EMPTY_UPSTREAM_BODY) + + def test_stream_prime_turns_no_first_chunk_into_429(self) -> None: + async def _run() -> None: + await openai_router._prime_sse(_empty_generator()) + + with self.assertRaises(UpstreamError) as ctx: + asyncio.run(_run()) + + self.assertEqual(ctx.exception.status, 429) + + def test_anthropic_stream_prime_turns_no_first_chunk_into_429(self) -> None: + async def _run() -> None: + await anthropic_router._prime_sse(_empty_generator()) + + with self.assertRaises(UpstreamError) as ctx: + asyncio.run(_run()) + + self.assertEqual(ctx.exception.status, 429) + + def test_chat_non_stream_empty_response_raises_429_feedback(self) -> None: + directory = _FakeDirectory() + exc = self._run_chat_non_stream_empty(directory) + + self.assertEqual(exc.status, 429) + self.assertEqual(exc.details["body"], chat.EMPTY_UPSTREAM_BODY) + self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED) + + def test_chat_stream_empty_response_raises_429_before_first_chunk(self) -> None: + directory = _FakeDirectory() + + async def _run() -> None: + stream = await chat.completions( + model="grok-test", + messages=[{"role": "user", "content": "hello"}], + stream=True, + ) + await openai_router._prime_sse(stream) + + with self.assertRaises(UpstreamError) as ctx: + with self._patch_common(chat, directory): + asyncio.run(_run()) + + self.assertEqual(ctx.exception.status, 429) + self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED) + + def test_chat_stream_thinking_flushes_before_text(self) -> None: + directory = _FakeDirectory() + + async def _run() -> str: + stream = await chat.completions( + model="grok-test", + messages=[{"role": "user", "content": "hello"}], + stream=True, + emit_think=True, + ) + return await anext(stream) + + with self._patch_common(chat, directory, stream_func=_thinking_then_text_stream): + first = asyncio.run(_run()) + + self.assertIn("reasoning_content", first) + self.assertIn("thinking one", first) + + def test_chat_stream_thinking_only_is_not_empty_when_enabled(self) -> None: + directory = _FakeDirectory() + + async def _run() -> list[str]: + stream = await chat.completions( + model="grok-test", + messages=[{"role": "user", "content": "hello"}], + stream=True, + emit_think=True, + ) + chunks = [] + async for chunk in stream: + chunks.append(chunk) + return chunks + + with self._patch_common(chat, directory, stream_func=_thinking_stream): + chunks = asyncio.run(_run()) + + joined = "".join(chunks) + self.assertIn("reasoning_content", joined) + self.assertIn("thinking one", joined) + self.assertIn("data: [DONE]", joined) + self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.SUCCESS) + + def test_chat_stream_thinking_only_raises_429_when_disabled(self) -> None: + directory = _FakeDirectory() + + async def _run() -> None: + stream = await chat.completions( + model="grok-test", + messages=[{"role": "user", "content": "hello"}], + stream=True, + emit_think=False, + ) + async for _chunk in stream: + pass + + with self.assertRaises(UpstreamError) as ctx: + with self._patch_common(chat, directory, stream_func=_thinking_stream): + asyncio.run(_run()) + + self.assertEqual(ctx.exception.status, 429) + self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED) + + def test_responses_stream_reasoning_flushes_before_text(self) -> None: + directory = _FakeDirectory() + + async def _run() -> list[str]: + stream = await responses.create( + model="grok-test", + input_val="hello", + instructions=None, + stream=True, + emit_think=True, + temperature=0.8, + top_p=0.95, + ) + return [await anext(stream), await anext(stream), await anext(stream)] + + with self._patch_common(responses, directory, stream_func=_thinking_then_text_stream): + chunks = asyncio.run(_run()) + + self.assertIn("response.created", chunks[0]) + self.assertIn("response.output_item.added", chunks[1]) + self.assertIn("response.reasoning_summary_part.added", chunks[2]) + + def test_responses_stream_reasoning_only_is_not_empty_when_enabled(self) -> None: + directory = _FakeDirectory() + + async def _run() -> list[str]: + stream = await responses.create( + model="grok-test", + input_val="hello", + instructions=None, + stream=True, + emit_think=True, + temperature=0.8, + top_p=0.95, + ) + chunks = [] + async for chunk in stream: + chunks.append(chunk) + return chunks + + with self._patch_common(responses, directory, stream_func=_thinking_stream): + chunks = asyncio.run(_run()) + + joined = "".join(chunks) + self.assertIn("response.reasoning_summary_text.delta", joined) + self.assertIn("response.completed", joined) + self.assertIn("data: [DONE]", joined) + self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.SUCCESS) + + def test_responses_stream_reasoning_only_raises_429_when_disabled(self) -> None: + directory = _FakeDirectory() + + async def _run() -> None: + stream = await responses.create( + model="grok-test", + input_val="hello", + instructions=None, + stream=True, + emit_think=False, + temperature=0.8, + top_p=0.95, + ) + async for _chunk in stream: + pass + + with self.assertRaises(UpstreamError) as ctx: + with self._patch_common(responses, directory, stream_func=_thinking_stream): + asyncio.run(_run()) + + self.assertEqual(ctx.exception.status, 429) + self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED) + + def test_responses_non_stream_empty_response_raises_429_feedback(self) -> None: + directory = _FakeDirectory() + exc = self._run_responses_non_stream_empty(directory) + + self.assertEqual(exc.status, 429) + self.assertEqual(exc.details["body"], chat.EMPTY_UPSTREAM_BODY) + self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED) + + def test_anthropic_non_stream_empty_response_raises_429_feedback(self) -> None: + directory = _FakeDirectory() + exc = self._run_anthropic_non_stream_empty(directory) + + self.assertEqual(exc.status, 429) + self.assertEqual(exc.details["body"], chat.EMPTY_UPSTREAM_BODY) + self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED) + + def _run_chat_non_stream_empty(self, directory: _FakeDirectory) -> UpstreamError: + async def _run() -> None: + await chat.completions( + model="grok-test", + messages=[{"role": "user", "content": "hello"}], + stream=False, + ) + + with self.assertRaises(UpstreamError) as ctx: + with self._patch_common(chat, directory): + asyncio.run(_run()) + return ctx.exception + + def _run_responses_non_stream_empty(self, directory: _FakeDirectory) -> UpstreamError: + async def _run() -> None: + await responses.create( + model="grok-test", + input_val="hello", + instructions=None, + stream=False, + emit_think=True, + temperature=0.8, + top_p=0.95, + ) + + with self.assertRaises(UpstreamError) as ctx: + with self._patch_common(responses, directory): + asyncio.run(_run()) + return ctx.exception + + def _run_anthropic_non_stream_empty( + self, directory: _FakeDirectory + ) -> UpstreamError: + async def _run() -> None: + await anthropic_messages.create( + model="grok-test", + messages=[{"role": "user", "content": "hello"}], + stream=False, + emit_think=True, + temperature=0.8, + top_p=0.95, + ) + + with self.assertRaises(UpstreamError) as ctx: + with self._patch_common(anthropic_messages, directory): + asyncio.run(_run()) + return ctx.exception + + def _patch_common( + self, + module, + directory: _FakeDirectory, + *, + stream_func=_empty_stream, + ): + import contextlib + + @contextlib.contextmanager + def _ctx(): + with patch("app.dataplane.account._directory", directory): + with patch.object(module, "get_config", return_value=_FakeConfig()): + with patch.object( + module, + "resolve_model", + return_value=SimpleNamespace(mode_id=0), + ): + with patch.object(module, "selection_max_retries", return_value=0): + with patch.object( + module, + "reserve_account", + return_value=(SimpleNamespace(token="tok-test"), 0), + ): + with patch.object( + module, "_stream_chat", side_effect=stream_func + ): + with patch.object(module, "_fail_sync", _noop_async): + with patch.object(module, "_quota_sync", _noop_async): + yield + + return _ctx() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_image_output_format.py b/tests/test_image_output_format.py new file mode 100644 index 000000000..dffc082df --- /dev/null +++ b/tests/test_image_output_format.py @@ -0,0 +1,87 @@ +import asyncio +import base64 +import tempfile +import unittest +from pathlib import Path +from urllib.parse import parse_qs, urlparse +from unittest.mock import patch + +from app.products.openai import images + + +class _StubConfig: + def __init__(self, *, image_format: str, app_url: str) -> None: + self.image_format = image_format + self.app_url = app_url + + def get_str(self, key: str, default: str = "") -> str: + if key == "features.image_format": + return self.image_format + if key == "app.app_url": + return self.app_url + return default + + +class ImageOutputFormatTests(unittest.TestCase): + def test_resolve_image_output_uses_local_proxy_when_configured(self) -> None: + blob_b64 = base64.b64encode(b"fake-image-bytes").decode("ascii") + asset_id = "12345678-1234-1234-1234-123456789abc" + with tempfile.TemporaryDirectory() as tmpdir: + config = _StubConfig( + image_format="local_url", + app_url="https://app.example.com", + ) + + def _save_local_image(raw: bytes, mime: str, file_id: str) -> str: + suffix = ".png" if "png" in mime else ".jpg" + (Path(tmpdir) / f"{file_id}{suffix}").write_bytes(raw) + return file_id + + with patch.object(images, "get_config", return_value=config): + with patch.object(images, "save_local_image", side_effect=_save_local_image): + result = asyncio.run( + images._resolve_image_output( + token="unused", + url=f"https://assets.grok.com/users/user-1/{asset_id}.png", + response_format="url", + blob_b64=blob_b64, + ) + ) + file_id = parse_qs(urlparse(result.api_value).query)["id"][0] + self.assertTrue((Path(tmpdir) / f"{file_id}.png").exists()) + + self.assertEqual( + result.api_value, + f"https://app.example.com/v1/files/image?id={asset_id}", + ) + self.assertEqual( + result.markdown_value, + f"![image](https://app.example.com/v1/files/image?id={asset_id})", + ) + + def test_resolve_image_output_keeps_upstream_url_when_configured(self) -> None: + config = _StubConfig( + image_format="grok_url", + app_url="", + ) + with patch.object(images, "get_config", return_value=config): + result = asyncio.run( + images._resolve_image_output( + token="unused", + url="https://assets.grok.com/users/user-1/file-abc123/content.png", + response_format="url", + ) + ) + + self.assertEqual( + result.api_value, + "https://assets.grok.com/users/user-1/file-abc123/content.png", + ) + self.assertEqual( + result.markdown_value, + "![image](https://assets.grok.com/users/user-1/file-abc123/content.png)", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_legacy_cache_config.py b/tests/test_legacy_cache_config.py new file mode 100644 index 000000000..8b3bf186c --- /dev/null +++ b/tests/test_legacy_cache_config.py @@ -0,0 +1,66 @@ +import asyncio +import tempfile +import unittest +from pathlib import Path +from typing import Any + +from app.platform.config.snapshot import ConfigSnapshot + + +class _Backend: + def __init__(self, data: dict[str, Any]) -> None: + self.data = data + + async def load(self) -> dict[str, Any]: + return self.data + + async def apply_patch(self, patch: dict[str, Any]) -> None: + self.data.update(patch) + + async def version(self) -> object: + return 1 + + +class LegacyCacheConfigTests(unittest.TestCase): + def test_legacy_storage_cache_limits_map_to_cache_local(self) -> None: + cfg = asyncio.run(self._load({ + "storage": { + "image_max_mb": 12, + "video_max_mb": 34, + }, + })) + + self.assertEqual(cfg.get_int("cache.local.image_max_mb"), 12) + self.assertEqual(cfg.get_int("cache.local.video_max_mb"), 34) + + def test_cache_local_limits_win_over_legacy_storage(self) -> None: + cfg = asyncio.run(self._load({ + "storage": { + "image_max_mb": 12, + "video_max_mb": 34, + }, + "cache": { + "local": { + "image_max_mb": 56, + "video_max_mb": 78, + }, + }, + })) + + self.assertEqual(cfg.get_int("cache.local.image_max_mb"), 56) + self.assertEqual(cfg.get_int("cache.local.video_max_mb"), 78) + + async def _load(self, overrides: dict[str, Any]) -> ConfigSnapshot: + with tempfile.TemporaryDirectory() as tmpdir: + defaults = Path(tmpdir) / "config.defaults.toml" + defaults.write_text( + "[cache.local]\nimage_max_mb = 0\nvideo_max_mb = 0\n", + encoding="utf-8", + ) + cfg = ConfigSnapshot(backend=_Backend(overrides)) + await cfg.load(defaults) + return cfg + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_media_cache_limits.py b/tests/test_media_cache_limits.py new file mode 100644 index 000000000..a8477517d --- /dev/null +++ b/tests/test_media_cache_limits.py @@ -0,0 +1,113 @@ +import tempfile +import unittest +import os +from pathlib import Path +from unittest.mock import patch + +from app.platform.storage.media_cache import LocalMediaCacheStore + + +class _StubConfig: + def __init__(self, *, image_max_mb: int = 0, video_max_mb: int = 0) -> None: + self._ints = { + "cache.local.image_max_mb": image_max_mb, + "cache.local.video_max_mb": video_max_mb, + } + + def get_int(self, key: str, default: int = 0) -> int: + return self._ints.get(key, default) + + +class MediaCacheLimitTests(unittest.TestCase): + def test_save_image_prunes_oldest_file_when_type_limit_exceeded(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + image_dir = root / "images" + video_dir = root / "videos" + image_dir.mkdir() + video_dir.mkdir() + + store = LocalMediaCacheStore( + config_provider=lambda: _StubConfig(image_max_mb=1) + ) + with patch("app.platform.storage.media_cache.image_files_dir", return_value=image_dir): + with patch("app.platform.storage.media_cache.video_files_dir", return_value=video_dir): + with patch( + "app.platform.storage.media_cache.local_media_cache_db_path", + return_value=root / "cache.db", + ): + with patch( + "app.platform.storage.media_cache.local_media_lock_path", + return_value=root / "cache.lock", + ): + store.save_image(b"a" * 800_000, "image/png", "old") + file_id = store.save_image(b"b" * 400_000, "image/png", "new") + + self.assertEqual(file_id, "new") + self.assertTrue((image_dir / "new.png").exists()) + self.assertFalse((image_dir / "old.png").exists()) + + def test_save_video_prunes_oldest_video_only(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + image_dir = root / "images" + video_dir = root / "videos" + image_dir.mkdir() + video_dir.mkdir() + + image_path = image_dir / "keep.png" + image_path.write_bytes(b"i" * 800_000) + + store = LocalMediaCacheStore( + config_provider=lambda: _StubConfig(video_max_mb=1) + ) + with patch("app.platform.storage.media_cache.image_files_dir", return_value=image_dir): + with patch("app.platform.storage.media_cache.video_files_dir", return_value=video_dir): + with patch( + "app.platform.storage.media_cache.local_media_cache_db_path", + return_value=root / "cache.db", + ): + with patch( + "app.platform.storage.media_cache.local_media_lock_path", + return_value=root / "cache.lock", + ): + store.save_video(b"a" * 800_000, "old") + new_video = store.save_video(b"b" * 400_000, "new") + + self.assertTrue(new_video.exists()) + self.assertFalse((video_dir / "old.mp4").exists()) + self.assertTrue(image_path.exists()) + + def test_stats_and_list_files_use_shared_cache_rules(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + image_dir = root / "images" + video_dir = root / "videos" + image_dir.mkdir() + video_dir.mkdir() + + older = image_dir / "older.png" + newer = image_dir / "newer.jpg" + ignored = image_dir / "ignored.txt" + older.write_bytes(b"a" * 128) + newer.write_bytes(b"b" * 256) + ignored.write_text("skip") + os.utime(older, (100, 100)) + os.utime(newer, (200, 200)) + + store = LocalMediaCacheStore( + config_provider=lambda: _StubConfig(image_max_mb=1) + ) + with patch("app.platform.storage.media_cache.image_files_dir", return_value=image_dir): + with patch("app.platform.storage.media_cache.video_files_dir", return_value=video_dir): + stats = store.stats("image") + listing = store.list_files("image", page=1, page_size=10) + + self.assertEqual(stats["count"], 2) + self.assertEqual(stats["size_bytes"], 384) + self.assertEqual(stats["limit_mb"], 1) + self.assertEqual([item["name"] for item in listing["items"]], ["newer.jpg", "older.png"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_startup_config_migration.py b/tests/test_startup_config_migration.py new file mode 100644 index 000000000..39415ec55 --- /dev/null +++ b/tests/test_startup_config_migration.py @@ -0,0 +1,91 @@ +import asyncio +import os +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +from app.platform.startup import migration + + +class _Backend: + def __init__(self, version: int = 0) -> None: + self.version_calls = 0 + self.applied: list[dict] = [] + self._version = version + + async def version(self) -> int: + self.version_calls += 1 + return self._version + + async def apply_patch(self, patch: dict) -> None: + self.applied.append(patch) + + +class StartupConfigMigrationTests(unittest.TestCase): + def test_account_storage_mysql_does_not_migrate_config_when_config_storage_unset(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + defaults = root / "config.defaults.toml" + user_config = root / "config.toml" + defaults.write_text("[app]\napi_key = \"default\"\n", encoding="utf-8") + user_config.write_text("[app]\napi_key = \"local\"\n", encoding="utf-8") + backend = _Backend() + + with patch.dict(os.environ, {"ACCOUNT_STORAGE": "mysql"}, clear=True): + with patch.object(migration, "_DEFAULTS_PATH", defaults): + with patch.object(migration, "_USER_CFG_PATH", user_config): + asyncio.run(migration._migrate_config(backend)) + + self.assertEqual(backend.version_calls, 0) + self.assertEqual(backend.applied, []) + self.assertEqual( + user_config.read_text(encoding="utf-8"), + "[app]\napi_key = \"local\"\n", + ) + + def test_local_config_is_seeded_from_defaults_when_missing(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + defaults = root / "config.defaults.toml" + user_config = root / "config.toml" + defaults.write_text("[app]\napi_key = \"default\"\n", encoding="utf-8") + backend = _Backend() + + with patch.dict(os.environ, {"ACCOUNT_STORAGE": "postgresql"}, clear=True): + with patch.object(migration, "_DEFAULTS_PATH", defaults): + with patch.object(migration, "_USER_CFG_PATH", user_config): + asyncio.run(migration._migrate_config(backend)) + + self.assertTrue(user_config.exists()) + self.assertEqual( + user_config.read_text(encoding="utf-8"), + defaults.read_text(encoding="utf-8"), + ) + self.assertEqual(backend.version_calls, 0) + self.assertEqual(backend.applied, []) + + def test_explicit_remote_config_storage_still_migrates_local_config(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + defaults = root / "config.defaults.toml" + user_config = root / "config.toml" + defaults.write_text("[app]\napi_key = \"default\"\n", encoding="utf-8") + user_config.write_text("[app]\napi_key = \"local\"\n", encoding="utf-8") + backend = _Backend() + + with patch.dict( + os.environ, + {"ACCOUNT_STORAGE": "local", "CONFIG_STORAGE": "mysql"}, + clear=True, + ): + with patch.object(migration, "_DEFAULTS_PATH", defaults): + with patch.object(migration, "_USER_CFG_PATH", user_config): + asyncio.run(migration._migrate_config(backend)) + + self.assertEqual(backend.version_calls, 1) + self.assertEqual(backend.applied, [{"app": {"api_key": "local"}}]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_video_proxy_and_newapi.py b/tests/test_video_proxy_and_newapi.py new file mode 100644 index 000000000..fffd1d201 --- /dev/null +++ b/tests/test_video_proxy_and_newapi.py @@ -0,0 +1,264 @@ +import asyncio +import hashlib +import tempfile +import time +import unittest +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import patch + +from app.control.proxy import ProxyDirectory +from app.dataplane.reverse.transport import assets +from app.products.openai import video + + +class _ProxyConfig: + def __init__(self, *, resource_proxy_url: str = "") -> None: + self.resource_proxy_url = resource_proxy_url + + def get_str(self, key: str, default: str = "") -> str: + values = { + "proxy.egress.mode": "single_proxy", + "proxy.clearance.mode": "none", + "proxy.egress.proxy_url": "http://proxy.example:8080", + "proxy.egress.resource_proxy_url": self.resource_proxy_url, + } + return values.get(key, default) + + def get_list(self, key: str, default=None): + return default or [] + + def get_int(self, key: str, default: int = 0) -> int: + return default + + +class _AppConfig: + def __init__(self, *, app_url: str = "") -> None: + self.app_url = app_url + + def get_str(self, key: str, default: str = "") -> str: + if key == "app.app_url": + return self.app_url + return default + + +class _AssetConfig: + def get_float(self, key: str, default: float = 0.0) -> float: + return default + + +class _VideoConfig: + def get_float(self, key: str, default: float = 0.0) -> float: + return default + + def get(self, key: str, default=None): + return default + + +class _FakeProxyRuntime: + def __init__(self) -> None: + self.acquire_kwargs = [] + + async def acquire(self, **kwargs): + self.acquire_kwargs.append(kwargs) + return SimpleNamespace(proxy_url=None) + + async def feedback(self, lease, result) -> None: + return None + + +class _FakeAccountDirectory: + def __init__(self) -> None: + self.reserve_calls = [] + self.feedbacks = [] + + async def reserve(self, **kwargs): + self.reserve_calls.append(kwargs) + return SimpleNamespace(token="tok-video") + + async def release(self, acct) -> None: + return None + + async def feedback(self, token, kind, mode_id) -> None: + self.feedbacks.append((token, kind, mode_id)) + + +async def _empty_bytes_stream(*args, **kwargs): + if False: + yield b"" + + +async def _noop_async(*args, **kwargs): + return None + + +class VideoProxyAndNewApiTests(unittest.TestCase): + def test_resource_download_is_direct_without_resource_proxy(self) -> None: + async def _run(): + directory = ProxyDirectory() + with patch("app.control.proxy.get_config", return_value=_ProxyConfig()): + await directory.load() + return await directory.acquire(resource=True) + + lease = asyncio.run(_run()) + + self.assertIsNone(lease.proxy_url) + + def test_resource_download_uses_explicit_resource_proxy(self) -> None: + async def _run(): + directory = ProxyDirectory() + cfg = _ProxyConfig(resource_proxy_url="http://res-proxy.example:8080") + with patch("app.control.proxy.get_config", return_value=cfg): + await directory.load() + return await directory.acquire(resource=True) + + lease = asyncio.run(_run()) + + self.assertEqual(lease.proxy_url, "http://res-proxy.example:8080") + + def test_download_asset_acquires_resource_lease(self) -> None: + async def _run(): + proxy = _FakeProxyRuntime() + with patch.object(assets, "get_proxy_runtime", return_value=proxy): + with patch.object(assets, "get_config", return_value=_AssetConfig()): + with patch.object( + assets, "get_bytes_stream", return_value=_empty_bytes_stream() + ): + await assets.download_asset( + "tok-test", + "https://assets.grok.com/users/u/file.mp4", + ) + return proxy.acquire_kwargs + + calls = asyncio.run(_run()) + + self.assertEqual(calls[-1]["resource"], True) + + def test_completed_video_job_returns_newapi_metadata_url(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "video.mp4" + path.write_bytes(b"fake-video") + job = video._VideoJob( + id="video_test", + model="grok-imagine-video", + prompt="hello", + seconds="6", + size="720x1280", + quality="standard", + created_at=int(time.time()), + status="completed", + progress=100, + completed_at=int(time.time()), + content_path=str(path), + content_file_id="08b8238c88e9bbe8423c692cbd04ec52", + ) + + with patch.object(video, "get_config", return_value=_AppConfig(app_url="https://api.example.com")): + payload = job.to_dict() + + self.assertEqual( + payload["metadata"]["url"], + "https://api.example.com/v1/files/video?id=08b8238c88e9bbe8423c692cbd04ec52", + ) + + def test_video_job_saves_with_sync_file_id_and_metadata_url(self) -> None: + async def _run(tmpdir: str): + artifact = video._VideoArtifact( + video_url="https://assets.grok.com/users/u/generated.mp4", + video_post_id="post_1", + asset_id="asset_1", + thumbnail_url="https://assets.grok.com/thumb.jpg", + ) + raw = b"fake-video" + saved_ids: list[str] = [] + job = video._VideoJob( + id="video_async", + model="grok-imagine-video", + prompt="hello", + seconds="6", + size="720x1280", + quality="standard", + created_at=int(time.time()), + ) + + async def _fake_run_video_with_account(*, model, runner): + return artifact, raw + + def _fake_save_video_bytes(raw_bytes: bytes, file_id: str) -> Path: + saved_ids.append(file_id) + path = Path(tmpdir) / f"{file_id}.mp4" + path.write_bytes(raw_bytes) + return path + + with patch.object( + video, "_run_video_with_account", _fake_run_video_with_account + ): + with patch.object(video, "_save_video_bytes", _fake_save_video_bytes): + await video._run_video_job( + job, + size="720x1280", + resolution_name=None, + prompt="hello", + seconds=6, + preset=None, + ) + + with patch.object( + video, + "get_config", + return_value=_AppConfig(app_url="https://api.example.com"), + ): + payload = job.to_dict() + + return job, payload, saved_ids + + with tempfile.TemporaryDirectory() as tmpdir: + job, payload, saved_ids = asyncio.run(_run(tmpdir)) + + expected_file_id = hashlib.sha1( + "https://assets.grok.com/users/u/generated.mp4".encode("utf-8") + ).hexdigest()[:32] + self.assertEqual(saved_ids, [expected_file_id]) + self.assertEqual(job.content_file_id, expected_file_id) + self.assertTrue(job.content_path.endswith(f"{expected_file_id}.mp4")) + self.assertEqual( + payload["metadata"]["url"], + f"https://api.example.com/v1/files/video?id={expected_file_id}", + ) + + def test_video_pool_candidates_exclude_basic(self) -> None: + spec = SimpleNamespace(pool_candidates=lambda: (0, 1, 2)) + + self.assertEqual(video._video_pool_candidates(spec), (1, 2)) + + def test_video_account_runner_reserves_only_super_or_heavy(self) -> None: + async def _run(): + directory = _FakeAccountDirectory() + spec = SimpleNamespace( + is_video=lambda: True, + mode_id=0, + pool_candidates=lambda: (0, 1, 2), + ) + + async def _runner(token: str, timeout_s: float) -> str: + return token + + with patch("app.dataplane.account._directory", directory): + with patch.object(video, "get_config", return_value=_VideoConfig()): + with patch.object(video, "resolve_model", return_value=spec): + with patch.object(video, "selection_max_retries", return_value=0): + with patch.object(video, "_quota_sync", _noop_async): + result = await video._run_video_with_account( + model="grok-imagine-video", + runner=_runner, + ) + return result, directory.reserve_calls + + result, calls = asyncio.run(_run()) + + self.assertEqual(result, "tok-video") + self.assertEqual(calls[0]["pool_candidates"], (1, 2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_video_reference_helpers.py b/tests/test_video_reference_helpers.py new file mode 100644 index 000000000..e47c1ad33 --- /dev/null +++ b/tests/test_video_reference_helpers.py @@ -0,0 +1,151 @@ +import asyncio +import unittest + +from app.platform.errors import ValidationError +from app.products.openai import router, video + + +def _message_with_references(count: int) -> list[dict]: + return [ + { + "role": "user", + "content": [ + {"type": "text", "text": "生成一个参考图视频"}, + *[ + { + "type": "image_url", + "image_url": {"url": f"https://example.com/ref-{idx}.png"}, + } + for idx in range(count) + ], + ], + } + ] + + +class VideoReferenceHelperTests(unittest.TestCase): + def test_chat_video_prompt_allows_seven_references(self) -> None: + prompt, refs = video._extract_video_prompt_and_reference( + _message_with_references(7) + ) + + self.assertEqual(prompt, "生成一个参考图视频") + self.assertEqual( + refs, + [ + {"image_url": f"https://example.com/ref-{idx}.png"} + for idx in range(7) + ], + ) + + def test_chat_video_prompt_rejects_more_than_seven_references(self) -> None: + with self.assertRaises(ValidationError): + video._extract_video_prompt_and_reference(_message_with_references(8)) + + def test_prepare_video_references_rejects_more_than_seven_references(self) -> None: + refs = [{"image_url": f"https://example.com/ref-{idx}.png"} for idx in range(8)] + + async def _run() -> None: + await video._prepare_video_references("token", refs) + + with self.assertRaises(ValidationError): + asyncio.run(_run()) + + def test_videos_create_rejects_more_than_seven_multipart_references(self) -> None: + async def _run() -> None: + await router.videos_create( + model="grok-imagine-video", + prompt="生成一个参考图视频", + input_reference=[object() for _ in range(8)], # type: ignore[list-item] + ) + + with self.assertRaises(ValidationError): + asyncio.run(_run()) + + def test_chat_video_segment_prompts_follow_user_order(self) -> None: + prompts, refs = video._extract_video_segment_prompts_and_references( + [ + {"role": "system", "content": "ignore"}, + {"role": "user", "content": "第一段"}, + {"role": "assistant", "content": "ignore"}, + {"role": "user", "content": "第二段"}, + ] + ) + + self.assertEqual(prompts, ["第一段", "第二段"]) + self.assertIsNone(refs) + + def test_segment_prompts_reuse_last_prompt_when_short(self) -> None: + self.assertEqual( + video._normalize_segment_prompts("第一段", [10, 6], ["第一段"]), + ["第一段", "第一段"], + ) + + def test_segment_prompts_strict_rejects_missing_prompts(self) -> None: + with self.assertRaises(ValidationError): + video._normalize_segment_prompts( + "第一段", + [10, 6], + ["第一段"], + strict=True, + ) + + def test_segment_prompts_reject_extra_prompts(self) -> None: + with self.assertRaises(ValidationError): + video._normalize_segment_prompts("一", [6], ["一", "二"]) + + def test_video_segment_lengths_prefer_ten_second_segments(self) -> None: + cases = { + 6: [6], + 10: [10], + 12: [6, 6], + 16: [10, 6], + 20: [10, 10], + 22: [10, 6, 6], + 26: [10, 10, 6], + 30: [10, 10, 10], + 32: [10, 10, 6, 6], + 36: [10, 10, 10, 6], + 40: [10, 10, 10, 10], + } + + for seconds, expected in cases.items(): + with self.subTest(seconds=seconds): + self.assertEqual(video._build_segment_lengths(seconds), expected) + + def test_async_video_length_rejects_long_videos(self) -> None: + video.validate_async_video_length(6) + video.validate_async_video_length(10) + with self.assertRaises(ValidationError): + video.validate_async_video_length(12) + + def test_async_video_create_rejects_long_videos(self) -> None: + async def _run() -> None: + await video.create_video( + model="grok-imagine-video", + prompt="长视频", + seconds=12, + ) + + with self.assertRaises(ValidationError): + asyncio.run(_run()) + + def test_chat_video_length_rejects_non_exact_duration(self) -> None: + with self.assertRaises(ValidationError): + video.validate_video_length(28) + + def test_chat_video_completions_requires_prompt_for_each_segment(self) -> None: + async def _run() -> None: + await video.completions( + model="grok-imagine-video", + messages=[{"role": "user", "content": "第一段"}], + stream=False, + seconds=16, + ) + + with self.assertRaises(ValidationError): + asyncio.run(_run()) + + +if __name__ == "__main__": + unittest.main()