diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
index c049c5179..b3fe502f3 100644
--- a/.github/workflows/docker.yml
+++ b/.github/workflows/docker.yml
@@ -12,7 +12,6 @@ on:
env:
REGISTRY: ghcr.io
- IMAGE_NAME: ${{ github.repository }}
jobs:
build-and-push:
@@ -40,11 +39,15 @@ jobs:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
+ - name: Normalize image name
+ id: image
+ run: echo "name=${GITHUB_REPOSITORY,,}" >> "$GITHUB_OUTPUT"
+
- name: Extract metadata
id: meta
uses: docker/metadata-action@v5
with:
- images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
+ images: ${{ env.REGISTRY }}/${{ steps.image.outputs.name }}
tags: |
type=raw,value=latest,enable=${{ github.ref_type == 'branch' }}
type=ref,event=tag
@@ -60,4 +63,4 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
- pull: true
\ No newline at end of file
+ pull: true
diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml
index 25c6b6964..f6d72dafc 100644
--- a/.github/workflows/security.yml
+++ b/.github/workflows/security.yml
@@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 0
-
+
# 针对 PR:使用原生逻辑
- name: Setup PR commits & Run Gitleaks
if: github.event_name == 'pull_request'
@@ -29,7 +29,7 @@ jobs:
git remote add pr-head "https://github.com/${{ github.event.pull_request.head.repo.full_name }}.git"
git fetch --no-tags --prune --depth=1 pr-head \
"+${{ github.event.pull_request.head.sha }}:refs/remotes/pr-head/current"
-
+
docker run --rm \
-v "$PWD:/repo" \
-w /repo \
@@ -54,4 +54,4 @@ jobs:
- name: Export requirements
run: uv export --frozen --no-dev --format requirements-txt -o /tmp/req.txt
- name: Run pip-audit
- run: uvx pip-audit -r /tmp/req.txt --progress-spinner off
\ No newline at end of file
+ run: uvx pip-audit -r /tmp/req.txt --progress-spinner off
diff --git a/.gitignore b/.gitignore
index e94a37e41..652d6743b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -53,3 +53,4 @@ htmlcov/
# Project specific
.ace-tool/
+.tmp_live/
diff --git a/README.md b/README.md
index 0363662d8..29a6131ed 100644
--- a/README.md
+++ b/README.md
@@ -112,7 +112,7 @@ docker compose up -d
### Vercel
-[](https://vercel.com/new/clone?repository-url=https://github.com/chenyme/grok2api&env=LOG_LEVEL,LOG_FILE_ENABLED,DATA_DIR,LOG_DIR,ACCOUNT_STORAGE,ACCOUNT_REDIS_URL,ACCOUNT_MYSQL_URL,ACCOUNT_POSTGRESQL_URL)
+[](https://vercel.com/new/clone?repository-url=https://github.com/chenyme/grok2api&env=LOG_LEVEL,LOG_FILE_ENABLED,DATA_DIR,LOG_DIR,ACCOUNT_STORAGE,ACCOUNT_REDIS_URL,ACCOUNT_MYSQL_URL,ACCOUNT_POSTGRESQL_URL,CONFIG_STORAGE,CONFIG_LOCAL_PATH)
### Render
@@ -187,7 +187,8 @@ docker compose up -d
| `ACCOUNT_SQL_MAX_OVERFLOW` | SQL 连接池最大溢出连接数 | `10` |
| `ACCOUNT_SQL_POOL_TIMEOUT` | 等待连接池空闲连接的超时时间(秒) | `30` |
| `ACCOUNT_SQL_POOL_RECYCLE` | 连接最大复用时间(秒),超时后自动重连 | `1800` |
-| `CONFIG_LOCAL_PATH` | `local` 模式运行时配置文件路径 | `${DATA_DIR}/config.toml` |
+| `CONFIG_STORAGE` | 运行时配置存储后端;默认本地 TOML,不跟随 `ACCOUNT_STORAGE` | `local` |
+| `CONFIG_LOCAL_PATH` | 本地运行时配置文件路径 | `${DATA_DIR}/config.toml` |
运行时配置也支持 `GROK_` 前缀环境变量覆盖,例如 `GROK_APP_API_KEY` 会覆盖 `app.api_key`,`GROK_FEATURES_STREAM` 会覆盖 `features.stream`。
@@ -362,15 +363,31 @@ curl http://localhost:8000/v1/chat/completions \
"model": "grok-imagine-video",
"stream": true,
"messages": [
- {"role":"user","content":"霓虹雨夜街头,电影感慢镜头追拍"}
+ {"role":"user","content":"霓虹雨夜街头,电影感慢镜头追拍"},
+ {"role":"user","content":"镜头穿过霓虹招牌,人物回头看向远处车灯"}
],
- "video_config": {
+ "metadata": {
+ "video_config": {
+ "seconds": 16,
+ "size": "1792x1024",
+ "resolution_name": "720p",
+ "preset": "normal"
+ }
+ }
+ }'
+```
+
+`image_config` / `video_config` 也可以放在请求顶层;顶层字段优先于 `metadata` 内的同名配置。
+
+```json
+{
+ "video_config": {
"seconds": 10,
"size": "1792x1024",
"resolution_name": "720p",
"preset": "normal"
- }
- }'
+ }
+}
```
@@ -390,10 +407,11 @@ curl http://localhost:8000/v1/chat/completions \
| \|_ `size` | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` |
| \|_ `response_format` | `url`, `b64_json` |
| `video_config` | 视频模型参数 |
-| \|_ `seconds` | `6`, `10`, `12`, `16`, `20` |
+| \|_ `seconds` | `6`, `10`, `12`, `16`, `20`, `22`, `26`, `30`, `32`, `36`, `40`;长视频需要按分段数量提供同数量 user 提示词 |
| \|_ `size` | `720x1280`, `1280x720`, `1024x1024`, `1024x1792`, `1792x1024` |
| \|_ `resolution_name` | `480p`, `720p` |
| \|_ `preset` | `fun`, `normal`, `spicy`, `custom` |
+| `metadata.image_config` / `metadata.video_config` | `image_config` / `video_config` 的兼容位置;顶层配置优先 |
@@ -589,7 +607,7 @@ curl -L http://localhost:8000/v1/videos//content \
| :-- | :-- |
| `model` | 视频模型,目前为 `grok-imagine-video` |
| `prompt` | 视频生成提示词 |
-| `seconds` | 视频长度:`6`, `10`, `12`, `16`, `20` |
+| `seconds` | 视频长度:`6`, `10`;更长视频请使用 `/v1/chat/completions` |
| `size` | 支持 `720x1280`, `1280x720`, `1024x1024`, `1024x1792`, `1792x1024` |
| `resolution_name` | `480p` 或 `720p` |
| `preset` | `fun`, `normal`, `spicy`, `custom` |
diff --git a/app/control/account/backends/local.py b/app/control/account/backends/local.py
index 85eb3971f..ffd7f6274 100644
--- a/app/control/account/backends/local.py
+++ b/app/control/account/backends/local.py
@@ -55,6 +55,7 @@ def _init_sync(self) -> None:
CREATE TABLE IF NOT EXISTS {_TBL} (
token TEXT NOT NULL PRIMARY KEY,
+ account_id TEXT,
pool TEXT NOT NULL DEFAULT 'basic',
status TEXT NOT NULL DEFAULT 'active',
created_at INTEGER NOT NULL,
@@ -85,7 +86,13 @@ def _init_sync(self) -> None:
CREATE INDEX IF NOT EXISTS idx_acc_deleted
ON {_TBL} (deleted_at) WHERE deleted_at IS NOT NULL;
""")
+ # Migration: add account_id column if missing.
+ self._ensure_column_sync(conn, "account_id", "TEXT")
self._ensure_column_sync(conn, "quota_grok_4_3", "TEXT NOT NULL DEFAULT '{}'")
+ conn.execute(
+ "CREATE INDEX IF NOT EXISTS idx_acc_account_id "
+ f"ON {_TBL} (account_id) WHERE account_id IS NOT NULL"
+ )
conn.commit()
@staticmethod
@@ -112,6 +119,7 @@ def _get_revision_sync(self, conn: sqlite3.Connection) -> int:
@staticmethod
def _row_to_record(row: sqlite3.Row) -> AccountRecord:
d = dict(row)
+ d["account_id"] = d.get("account_id") or None
d["tags"] = json.loads(d.get("tags") or "[]")
heavy_raw = d.pop("quota_heavy", "{}") or "{}"
grok_4_3_raw = d.pop("quota_grok_4_3", "{}") or "{}"
@@ -132,6 +140,7 @@ def _record_to_row(record: AccountRecord, revision: int) -> dict[str, Any]:
qs = record.quota_set()
return {
"token": record.token,
+ "account_id": record.account_id,
"pool": record.pool,
"status": record.status.value,
"created_at": record.created_at,
@@ -171,19 +180,36 @@ def _upsert_sync(
continue
pool = item.pool if item.pool in ("basic", "super", "heavy") else "basic"
qs = default_quota_set(pool)
+ account_id = item.account_id or None
+
+ # Dedup by account_id: soft-delete old token records with same account_id.
+ if account_id:
+ dups = conn.execute(
+ f"SELECT token FROM {_TBL} "
+ f"WHERE account_id = ? AND token != ? AND deleted_at IS NULL",
+ (account_id, token),
+ ).fetchall()
+ for dup in dups:
+ conn.execute(
+ f"UPDATE {_TBL} SET deleted_at = ?, updated_at = ?, revision = ? "
+ f"WHERE token = ?",
+ (ts, ts, revision, dup[0]),
+ )
+
conn.execute(
f"""
INSERT INTO {_TBL} (
- token, pool, status, created_at, updated_at,
+ token, account_id, pool, status, created_at, updated_at,
tags, quota_auto, quota_fast, quota_expert, quota_heavy, quota_grok_4_3,
usage_use_count, usage_fail_count, usage_sync_count,
ext, revision
) VALUES (
- :token, :pool, 'active', :ts, :ts,
+ :token, :account_id, :pool, 'active', :ts, :ts,
:tags, :qa, :qf, :qe, :qh, :qg,
0, 0, 0, :ext, :rev
)
ON CONFLICT(token) DO UPDATE SET
+ account_id = COALESCE(excluded.account_id, account_id),
pool = excluded.pool,
status = 'active',
deleted_at = NULL,
@@ -193,17 +219,18 @@ def _upsert_sync(
revision = excluded.revision
""",
{
- "token": token,
- "pool": pool,
- "ts": ts,
- "tags": json.dumps(item.tags),
- "qa": json.dumps(qs.auto.to_dict()),
- "qf": json.dumps(qs.fast.to_dict()),
- "qe": json.dumps(qs.expert.to_dict()),
- "qh": json.dumps(qs.heavy.to_dict()) if qs.heavy else "{}",
- "qg": json.dumps(qs.grok_4_3.to_dict()) if qs.grok_4_3 else "{}",
- "ext": json.dumps(item.ext),
- "rev": revision,
+ "token": token,
+ "account_id": account_id,
+ "pool": pool,
+ "ts": ts,
+ "tags": json.dumps(item.tags),
+ "qa": json.dumps(qs.auto.to_dict()),
+ "qf": json.dumps(qs.fast.to_dict()),
+ "qe": json.dumps(qs.expert.to_dict()),
+ "qh": json.dumps(qs.heavy.to_dict()) if qs.heavy else "{}",
+ "qg": json.dumps(qs.grok_4_3.to_dict()) if qs.grok_4_3 else "{}",
+ "ext": json.dumps(item.ext),
+ "rev": revision,
},
)
count += conn.execute("SELECT changes()").fetchone()[0]
@@ -229,6 +256,8 @@ def _patch_sync(
sets: dict[str, Any] = {"updated_at": ts, "revision": revision}
+ if patch.account_id is not None:
+ sets["account_id"] = patch.account_id
if patch.pool is not None:
sets["pool"] = patch.pool
if patch.status is not None:
diff --git a/app/control/account/backends/redis.py b/app/control/account/backends/redis.py
index 68c4d4b49..3c3aece09 100644
--- a/app/control/account/backends/redis.py
+++ b/app/control/account/backends/redis.py
@@ -55,6 +55,7 @@ def __init__(self, redis: "Redis") -> None:
def _to_hash(record: AccountRecord, revision: int) -> dict[str, str]:
qs = record.quota_set()
return {
+ "account_id": record.account_id or "",
"pool": record.pool,
"status": record.status.value,
"created_at": str(record.created_at),
@@ -64,6 +65,7 @@ def _to_hash(record: AccountRecord, revision: int) -> dict[str, str]:
"quota_fast": json.dumps(qs.fast.to_dict()),
"quota_expert": json.dumps(qs.expert.to_dict()),
"quota_heavy": json.dumps(qs.heavy.to_dict()) if qs.heavy else "{}",
+ "quota_grok_4_3": json.dumps(qs.grok_4_3.to_dict()) if qs.grok_4_3 else "{}",
"usage_use_count": str(record.usage_use_count),
"usage_fail_count": str(record.usage_fail_count),
"usage_sync_count": str(record.usage_sync_count),
@@ -88,8 +90,10 @@ def _i(k: str) -> int | None:
v = _s(k)
return int(v) if v else None
+ aid = _s("account_id")
return AccountRecord.model_validate({
"token": token,
+ "account_id": aid if aid else None,
"pool": _s("pool") or "basic",
"status": _s("status") or "active",
"created_at": _i("created_at") or now_ms(),
@@ -102,6 +106,9 @@ def _i(k: str) -> int | None:
**({
"heavy": json.loads(_s("quota_heavy"))
} if _s("quota_heavy") and _s("quota_heavy") != "{}" else {}),
+ **({
+ "grok_4_3": json.loads(_s("quota_grok_4_3"))
+ } if _s("quota_grok_4_3") and _s("quota_grok_4_3") != "{}" else {}),
},
"usage_use_count": int(_s("usage_use_count") or 0),
"usage_fail_count": int(_s("usage_fail_count") or 0),
@@ -207,14 +214,40 @@ async def upsert_accounts(
pool = item.pool if item.pool in ("basic", "super", "heavy") else "basic"
qs = default_quota_set(pool)
ts = now_ms()
+ account_id = item.account_id or None
+
+ # Dedup by account_id: scan for existing records with same account_id.
+ if account_id:
+ async for key in self._r.scan_iter("accounts:record:*"):
+ h = await self._r.hgetall(key)
+ if not h:
+ continue
+ existing_aid = (h.get(b"account_id") or h.get("account_id") or b"")
+ if isinstance(existing_aid, bytes):
+ existing_aid = existing_aid.decode()
+ if existing_aid == account_id:
+ dup_token = (key.decode() if isinstance(key, bytes) else key).split(":", 2)[-1]
+ if dup_token != token:
+ await self._r.hset(key, mapping={
+ "deleted_at": str(ts),
+ "updated_at": str(ts),
+ "revision": str(rev),
+ })
+ # Remove from pool set and add to deleted pool tracking.
+ dup_pool = (h.get(b"pool") or h.get("pool") or b"basic")
+ if isinstance(dup_pool, bytes):
+ dup_pool = dup_pool.decode()
+ await self._r.srem(_pool_key(dup_pool), dup_token)
+
record = AccountRecord(
- token = token,
- pool = pool,
- tags = item.tags,
- ext = item.ext,
- quota = qs.to_dict(),
- created_at = ts,
- updated_at = ts,
+ token = token,
+ account_id = account_id,
+ pool = pool,
+ tags = item.tags,
+ ext = item.ext,
+ quota = qs.to_dict(),
+ created_at = ts,
+ updated_at = ts,
)
key = _record_key(token)
await self._r.hset(key, mapping=self._to_hash(record, rev))
@@ -258,6 +291,8 @@ async def patch_accounts(
updates["last_sync_at"] = str(patch.last_sync_at)
if patch.last_clear_at is not None:
updates["last_clear_at"] = str(patch.last_clear_at)
+ if patch.account_id is not None:
+ updates["account_id"] = patch.account_id
if patch.pool is not None:
updates["pool"] = patch.pool
if patch.quota_auto is not None:
@@ -268,6 +303,8 @@ async def patch_accounts(
updates["quota_expert"] = json.dumps(patch.quota_expert)
if patch.quota_heavy is not None:
updates["quota_heavy"] = json.dumps(patch.quota_heavy)
+ if patch.quota_grok_4_3 is not None:
+ updates["quota_grok_4_3"] = json.dumps(patch.quota_grok_4_3)
# Usage counters.
if patch.usage_use_delta is not None:
diff --git a/app/control/account/backends/sql.py b/app/control/account/backends/sql.py
index 66b605bb6..aca46705a 100644
--- a/app/control/account/backends/sql.py
+++ b/app/control/account/backends/sql.py
@@ -37,6 +37,7 @@
_TBL_ACCOUNTS,
metadata,
sa.Column("token", sa.String(512), primary_key=True),
+ sa.Column("account_id", sa.String(64), nullable=True, index=True), # xaiUserId UUID for dedup
sa.Column("pool", sa.Text, nullable=False, default="basic"),
sa.Column("status", sa.Text, nullable=False, default="active"),
sa.Column("created_at", sa.BigInteger, nullable=False),
@@ -404,6 +405,8 @@ def _evict_cached_engine(engine: AsyncEngine) -> None:
def _row_to_record(row: Any) -> AccountRecord:
d = dict(row._mapping)
+ # Normalise account_id — stored as string, can be empty.
+ d["account_id"] = d.get("account_id") or None
d["tags"] = json.loads(d.get("tags") or "[]")
heavy_raw = d.pop("quota_heavy", "{}") or "{}"
grok_4_3_raw = d.pop("quota_grok_4_3", "{}") or "{}"
@@ -522,6 +525,23 @@ async def _do_initialize(self) -> None:
async def _ensure_columns(self, conn: Any) -> None:
"""Idempotent ALTER TABLE migrations for columns added after the initial schema."""
existing = await self._table_columns(conn, _TBL_ACCOUNTS)
+ if "account_id" not in existing:
+ if self._dialect == "mysql":
+ await conn.exec_driver_sql(
+ f"ALTER TABLE {_TBL_ACCOUNTS} "
+ f"ADD COLUMN account_id VARCHAR(64) NULL"
+ )
+ await conn.exec_driver_sql(
+ f"CREATE INDEX idx_accounts_account_id ON {_TBL_ACCOUNTS} (account_id)"
+ )
+ else:
+ await conn.exec_driver_sql(
+ f"ALTER TABLE {_TBL_ACCOUNTS} "
+ f"ADD COLUMN account_id VARCHAR(64)"
+ )
+ await conn.exec_driver_sql(
+ f"CREATE INDEX idx_accounts_account_id ON {_TBL_ACCOUNTS} (account_id)"
+ )
if "quota_grok_4_3" not in existing:
if self._dialect == "mysql":
# MySQL forbids DEFAULT values on TEXT/BLOB columns;
@@ -629,8 +649,10 @@ async def upsert_accounts(
continue
pool = item.pool if item.pool in ("basic", "super", "heavy") else "basic"
qs = default_quota_set(pool)
+ account_id = item.account_id or None
row = {
"token": token,
+ "account_id": account_id,
"pool": pool,
"status": "active",
"created_at": ts,
@@ -648,6 +670,29 @@ async def upsert_accounts(
"ext": json.dumps(item.ext),
"revision": rev,
}
+ # Dedup by account_id: if this account_id already exists under a
+ # *different* token, soft-delete the old token record so we
+ # maintain one canonical token per xaiUserId.
+ if account_id:
+ existing = (await conn.execute(
+ sa.select(accounts_table.c.token).where(
+ accounts_table.c.account_id == account_id,
+ accounts_table.c.token != token,
+ accounts_table.c.deleted_at.is_(None),
+ )
+ )).fetchall()
+ for dup in existing:
+ dup_token = dup[0]
+ await conn.execute(
+ accounts_table.update()
+ .where(accounts_table.c.token == dup_token)
+ .values(deleted_at=ts, updated_at=ts, revision=rev)
+ )
+ logger.info(
+ "account deduplicated by account_id: old_token={}... new_token={}... account_id={}",
+ dup_token[:10], token[:10], account_id,
+ )
+
await conn.execute(self._build_upsert(row))
count += 1
return AccountMutationResult(upserted=count, revision=rev)
@@ -672,6 +717,8 @@ async def patch_accounts(
record = _row_to_record(row)
updates: dict[str, Any] = {"updated_at": ts, "revision": rev}
+ if patch.account_id is not None:
+ updates["account_id"] = patch.account_id
if patch.pool is not None:
updates["pool"] = patch.pool
if patch.status is not None:
diff --git a/app/control/account/commands.py b/app/control/account/commands.py
index ba5011432..c62d103dc 100644
--- a/app/control/account/commands.py
+++ b/app/control/account/commands.py
@@ -14,6 +14,7 @@ class AccountUpsert(BaseModel):
"""
token: str
+ account_id: str | None = None # Grok xaiUserId for deduplication
pool: str = "basic"
tags: list[str] = Field(default_factory=list)
ext: dict[str, Any] = Field(default_factory=dict)
@@ -26,6 +27,7 @@ class AccountPatch(BaseModel):
"""
token: str
+ account_id: str | None = None # update account ID (from subscription API)
pool: str | None = None # update pool type when inferred
status: AccountStatus | None = None
tags: list[str] | None = None
diff --git a/app/control/account/models.py b/app/control/account/models.py
index cb8e8e17f..dc6288d50 100644
--- a/app/control/account/models.py
+++ b/app/control/account/models.py
@@ -176,6 +176,7 @@ class AccountRecord(BaseModel):
"""
token: str
+ account_id: str | None = None # Grok's xaiUserId — UUID, used for deduplication
pool: str = "basic"
status: AccountStatus = AccountStatus.ACTIVE
created_at: int = Field(default_factory=now_ms)
diff --git a/app/control/account/quota_defaults.py b/app/control/account/quota_defaults.py
index eb3250c8d..a8e913316 100644
--- a/app/control/account/quota_defaults.py
+++ b/app/control/account/quota_defaults.py
@@ -75,6 +75,7 @@ def _w(remaining: int, total: int, window_seconds: int) -> QuotaWindow:
"basic": frozenset((1,)),
"super": frozenset((0, 1, 2, 4)),
"heavy": frozenset((0, 1, 2, 3, 4)),
+ "auto": frozenset((0, 1, 2, 3, 4)), # 探测用全集,确保 infer_pool 能拿到 auto.total
}
# ---------------------------------------------------------------------------
diff --git a/app/control/account/refresh.py b/app/control/account/refresh.py
index f95a1baa6..71313984b 100644
--- a/app/control/account/refresh.py
+++ b/app/control/account/refresh.py
@@ -2,7 +2,7 @@
import asyncio
from dataclasses import dataclass
-from typing import TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from app.platform.errors import UpstreamError
from app.platform.config.snapshot import get_config
@@ -126,17 +126,91 @@ async def _fetch_mode_quota(
)
return None
+ # ------------------------------------------------------------------
+ # Subscription API fetch (account type + account_id)
+ # ------------------------------------------------------------------
+
+ async def _fetch_subscription(self, token: str):
+ """Fetch subscription info (account type + xaiUserId) for *token*."""
+ try:
+ from app.dataplane.reverse.protocol.xai_subscription import (
+ fetch_subscription,
+ )
+ return await fetch_subscription(token)
+ except UpstreamError:
+ raise
+ except Exception as exc:
+ logger.debug(
+ "subscription fetch failed: token={}... error={}",
+ token[:10], exc,
+ )
+ return None
+
+ async def _refresh_subscription(self, record: AccountRecord) -> None:
+ """Fetch subscription info, persist account_id, and store tier hint for Phase 2.
+
+ Does NOT set pool — that's Phase 2's job (multi-factor: rate-limits + subscription).
+ Stores ``sub_tier`` and ``sub_active`` in ext so _refresh_one can cross-validate.
+ """
+ if record.is_deleted():
+ return
+ try:
+ info = await self._fetch_subscription(record.token)
+ except UpstreamError as exc:
+ if await self._expire_invalid_credentials(record, exc):
+ return
+ raise
+ except Exception:
+ return
+
+ if info is None:
+ return
+
+ from .commands import AccountPatch
+
+ patch_fields: dict[str, Any] = {}
+ ext_merge: dict[str, Any] = {
+ "sub_tier": info.tier,
+ "sub_active": info.is_active,
+ }
+
+ if info.xai_user_id and not record.account_id:
+ patch_fields["account_id"] = info.xai_user_id
+
+ if patch_fields or ext_merge:
+ patch_fields["ext_merge"] = ext_merge
+ await self._repo.patch_accounts([
+ AccountPatch(token=record.token, **patch_fields)
+ ])
+ logger.info(
+ "subscription info updated: token={}... account_id={} tier={} active={}",
+ record.token[:10],
+ info.xai_user_id or record.account_id,
+ info.tier, info.is_active,
+ )
+
# ------------------------------------------------------------------
# Core refresh logic
# ------------------------------------------------------------------
async def refresh_on_import(self, tokens: list[str]) -> RefreshResult:
- """Called after bulk import — sync real quotas for all accounts."""
+ """Called after bulk import — sync real quotas and subscription info for all accounts."""
records = await self._repo.get_accounts(tokens)
active = [r for r in records if is_manageable(r)]
if not active:
return RefreshResult(checked=len(records))
+ # Phase 1: Fetch subscription info (account_id + tier) for accounts missing it.
+ sub_results = await run_batch(
+ [r for r in active if not r.account_id],
+ lambda r: self._refresh_subscription(r),
+ concurrency=get_config("account.refresh.usage_concurrency", 50),
+ )
+ # Refresh the records after subscription updates.
+ records = await self._repo.get_accounts(tokens)
+ active = [r for r in records if is_manageable(r)]
+
+ # Phase 2: Fetch quota windows.
concurrency = get_config("account.refresh.usage_concurrency", 50)
results = await run_batch(
active,
@@ -145,7 +219,8 @@ async def refresh_on_import(self, tokens: list[str]) -> RefreshResult:
)
agg = RefreshResult(checked=len(records))
for r in results:
- agg.merge(r)
+ if r:
+ agg.merge(r)
return agg
async def refresh_call_async(self, token: str, mode_id: int) -> None:
@@ -236,8 +311,11 @@ async def _refresh_one(
if record.is_deleted():
return RefreshResult()
+ # For basic accounts, fetch all modes (like "auto" pool) so infer_pool
+ # can see auto.total and correctly upgrade to super/heavy when warranted.
+ fetch_pool = "auto" if record.pool == "basic" else record.pool
try:
- windows = await self._fetch_all_quotas(record.token, record.pool)
+ windows = await self._fetch_all_quotas(record.token, fetch_pool)
except UpstreamError as exc:
if await self._expire_invalid_credentials(record, exc):
return RefreshResult(checked=1, expired=1, failed=0)
@@ -256,10 +334,17 @@ async def _refresh_one(
patches: dict[str, dict] = {}
refreshed = False
+ # Use the inferred pool for quota normalization when it represents
+ # an upgrade (e.g. basic→super). Without this, auto/expert/grok_4_3
+ # quota data fetched for basic accounts gets discarded by
+ # normalize_quota_window (basic pool doesn't support those modes).
+ inferred = infer_pool(windows) # type: ignore[arg-type]
+ effective_pool = inferred if inferred != "basic" else record.pool
+
for mode in ALL_MODES_FULL:
mode_id = int(mode)
if mode_id in windows:
- window = normalize_quota_window(record.pool, mode_id, windows[mode_id])
+ window = normalize_quota_window(effective_pool, mode_id, windows[mode_id])
if window is None:
continue
patches[_MODE_KEYS[mode_id]] = window.to_dict()
@@ -293,12 +378,38 @@ async def _refresh_one(
if not patches:
return RefreshResult(checked=1, failed=0 if refreshed else 1)
- # Infer pool type from live quota data and patch if it changed.
- inferred = infer_pool(windows) # type: ignore[arg-type]
+ # ── Multi-factor pool inference ──────────────────────────────
+ # 'inferred' was set above from rate-limits auto.total (primary).
+ # Now cross-validate with subscription data (fallback).
+ #
+ # Scenarios handled:
+ # auto=50, any sub status → super (rate-limits wins)
+ # auto=7, INACTIVE sub → basic (genuinely downgraded)
+ # auto=?, ACTIVE sub → super (subscription wins, API fluke)
+ # auto=7, ACTIVE sub (anomaly) → super (active sub > anomalous quota)
+ if inferred == "basic":
+ sub_tier = record.ext.get("sub_tier", "")
+ sub_active = record.ext.get("sub_active", False)
+ if sub_active and sub_tier in (
+ "SUBSCRIPTION_TIER_GROK_PRO",
+ "SUBSCRIPTION_TIER_SUPER_GROK",
+ "SUBSCRIPTION_TIER_SUPER_GROK_PRO",
+ "SUBSCRIPTION_TIER_GROK_PRO_HEAVY",
+ "SUBSCRIPTION_TIER_SUPER_GROK_LITE",
+ "SUBSCRIPTION_TIER_GROK_TEAMS",
+ ):
+ from app.dataplane.reverse.protocol.xai_subscription import tier_to_pool as _t2p
+ inferred = _t2p(sub_tier)
+ logger.info(
+ "account pool upgraded by subscription: token={}... "
+ "rate_limits_pool=basic subscription_tier={} -> pool={}",
+ record.token[:10], sub_tier, inferred,
+ )
+
pool_patch = inferred if inferred != record.pool else None
if pool_patch:
logger.info(
- "account pool updated from live quota: token={}... previous_pool={} current_pool={}",
+ "account pool updated: token={}... previous_pool={} current_pool={}",
record.token[:10],
record.pool,
inferred,
diff --git a/app/control/proxy/__init__.py b/app/control/proxy/__init__.py
index 9d7581d60..1af2dcd55 100644
--- a/app/control/proxy/__init__.py
+++ b/app/control/proxy/__init__.py
@@ -228,12 +228,8 @@ async def _pick_proxy_url(self, resource: bool = False) -> str | None:
if self._egress_mode == EgressMode.DIRECT:
return None
async with self._lock:
- # Prefer resource-specific nodes when available; fall back to base nodes.
- nodes = (
- self._resource_nodes
- if resource and self._resource_nodes
- else self._nodes
- )
+ # Resource downloads are direct unless a resource proxy is explicit.
+ nodes = self._resource_nodes if resource else self._nodes
if not nodes:
return None
if self._egress_mode == EgressMode.SINGLE_PROXY:
diff --git a/app/dataplane/reverse/protocol/xai_subscription.py b/app/dataplane/reverse/protocol/xai_subscription.py
new file mode 100644
index 000000000..e3753c641
--- /dev/null
+++ b/app/dataplane/reverse/protocol/xai_subscription.py
@@ -0,0 +1,259 @@
+"""XAI subscription / account-type protocol — fetch subscription tier and account ID.
+
+Provides the official account type query by calling ``GET /rest/subscriptions``.
+Returns the user's subscription tier (basic / super / heavy) and unique account
+ID (``xaiUserId``) for deduplication purposes.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+from app.platform.errors import UpstreamError
+from app.platform.logging.logger import logger
+
+
+# ---------------------------------------------------------------------------
+# Tier mapping — Grok subscription tiers → Grok2API pool names
+# ---------------------------------------------------------------------------
+
+# Known subscription tier strings returned by grok.com/rest/subscriptions.
+_TIER_TO_POOL: dict[str, str] = {
+ "SUBSCRIPTION_TIER_UNKNOWN": "basic",
+ "SUBSCRIPTION_TIER_FREE": "basic",
+ "SUBSCRIPTION_TIER_GROK_PRO": "super",
+ "SUBSCRIPTION_TIER_SUPER_GROK": "super",
+ "SUBSCRIPTION_TIER_SUPER_GROK_PRO": "heavy",
+ "SUBSCRIPTION_TIER_GROK_PRO_HEAVY": "heavy",
+ "SUBSCRIPTION_TIER_SUPER_GROK_LITE": "super",
+ "SUBSCRIPTION_TIER_GROK_TEAMS": "super",
+}
+
+# Active subscription statuses.
+_ACTIVE_STATUSES: frozenset[str] = frozenset({
+ "SUBSCRIPTION_STATUS_ACTIVE",
+ "SUBSCRIPTION_STATUS_TRIAL",
+ "SUBSCRIPTION_STATUS_TRIALING",
+})
+
+
+# ---------------------------------------------------------------------------
+# Result type
+# ---------------------------------------------------------------------------
+
+
+@dataclass(slots=True)
+class SubscriptionInfo:
+ """Parsed account-type information from the subscription API.
+
+ ``xai_user_id`` — unique account identifier (UUID); ``None`` if unavailable.
+ ``pool`` — Grok2API pool name (``"basic"`` / ``"super"`` / ``"heavy"``).
+ ``tier`` — raw upstream tier string (e.g. ``"SUBSCRIPTION_TIER_GROK_PRO"``).
+ ``is_active`` — whether at least one subscription is currently active.
+ ``raw`` — the full parsed JSON dict for debugging / extension use.
+ """
+
+ xai_user_id: str | None
+ pool: str
+ tier: str
+ is_active: bool
+ raw: dict | None
+
+
+def tier_to_pool(tier: str) -> str:
+ """Map an upstream subscription tier string to a Grok2API pool name."""
+ return _TIER_TO_POOL.get(tier, "basic")
+
+
+def is_active_status(status: str) -> bool:
+ """Return True if *status* represents an active subscription."""
+ return status in _ACTIVE_STATUSES
+
+
+# ---------------------------------------------------------------------------
+# Response parser
+# ---------------------------------------------------------------------------
+
+
+def parse_subscription(body: dict) -> SubscriptionInfo:
+ """Parse the ``/rest/subscriptions`` response body.
+
+ Expected format::
+
+ {
+ "subscriptions": [
+ {
+ "xaiUserId": "uuid-here",
+ "tier": "SUBSCRIPTION_TIER_GROK_PRO",
+ "status": "SUBSCRIPTION_STATUS_ACTIVE",
+ ...
+ },
+ ...
+ ]
+ }
+
+ Returns a ``SubscriptionInfo`` populated from the most recent (last in
+ list) subscription entry. If the subscriptions array is empty, returns
+ a ``"basic"`` pool with no account ID.
+ """
+ subs: list[dict] = body.get("subscriptions", []) if body else []
+
+ if not subs:
+ return SubscriptionInfo(
+ xai_user_id=None,
+ pool="basic",
+ tier="",
+ is_active=False,
+ raw=body,
+ )
+
+ # Use the *last* subscription (typically most recent).
+ last = subs[-1]
+
+ xai_user_id = last.get("xaiUserId") or None
+ tier = last.get("tier", "")
+ status = last.get("status", "")
+
+ pool = tier_to_pool(tier)
+
+ # If the most recent subscription is not active but a later entry is, use
+ # the best active tier we can find. Also collect account ID from any entry.
+ if not xai_user_id:
+ for sub in subs:
+ xai_user_id = sub.get("xaiUserId") or xai_user_id
+ if xai_user_id:
+ break
+
+ is_active = False
+ best_active_pool = pool
+ for sub in subs:
+ if is_active_status(sub.get("status", "")):
+ is_active = True
+ active_tier = sub.get("tier", "")
+ active_pool = tier_to_pool(active_tier)
+ if _pool_rank(active_pool) > _pool_rank(best_active_pool):
+ best_active_pool = active_pool
+
+ if is_active:
+ pool = best_active_pool
+
+ # If all subscriptions are inactive (expired/cancelled), the account
+ # reverts to basic tier on the Grok side, but we keep the historical tier
+ # for information and let the rate-limits API refine it.
+ if not is_active and pool != "basic":
+ logger.debug(
+ "subscription inactive — may have reverted to basic: tier={} status={}",
+ tier, status,
+ )
+
+ return SubscriptionInfo(
+ xai_user_id=xai_user_id,
+ pool=pool,
+ tier=tier,
+ is_active=is_active,
+ raw=body,
+ )
+
+
+def _pool_rank(pool: str) -> int:
+ """Return an ordinal rank for a pool name (higher = better)."""
+ return {"basic": 0, "super": 1, "heavy": 2}.get(pool, 0)
+
+
+# ---------------------------------------------------------------------------
+# HTTP fetch
+# ---------------------------------------------------------------------------
+
+
+async def _do_fetch(token: str) -> dict:
+ """GET the subscriptions endpoint and return parsed JSON body."""
+ from app.dataplane.reverse.transport.http import get_json
+ from app.dataplane.proxy import get_proxy_runtime
+ from app.control.proxy.models import ProxyFeedback, ProxyFeedbackKind
+ from app.dataplane.reverse.runtime.endpoint_table import SUBSCRIPTIONS
+
+ proxy = await get_proxy_runtime()
+ lease = await proxy.acquire()
+ try:
+ body = await get_json(
+ SUBSCRIPTIONS,
+ token,
+ lease=lease,
+ timeout_s=20.0,
+ )
+ await proxy.feedback(
+ lease, ProxyFeedback(kind=ProxyFeedbackKind.SUCCESS, status_code=200)
+ )
+ return body
+ except Exception as exc:
+ status = getattr(exc, "status", None) or getattr(exc, "status_code", None)
+ from app.dataplane.reverse.protocol.xai_usage import _proxy_feedback_kind_for_error
+ kind = _proxy_feedback_kind_for_error(exc, status=status)
+ await proxy.feedback(lease, ProxyFeedback(kind=kind, status_code=status))
+ raise
+
+
+async def fetch_subscription(token: str) -> SubscriptionInfo | None:
+ """Fetch account-type / subscription information for *token*.
+
+ Returns a ``SubscriptionInfo``, or ``None`` if the endpoint is unreachable
+ (network timeout, 5xx, etc.).
+ """
+ import asyncio
+
+ try:
+ body = await asyncio.wait_for(_do_fetch(token), timeout=25.0)
+ except asyncio.TimeoutError:
+ logger.debug(
+ "subscription fetch timed out: token={}...", token[:10]
+ )
+ return None
+ except UpstreamError as exc:
+ if getattr(exc, "status", None) in (401, 403):
+ # Invalid token — let caller handle.
+ raise
+ logger.debug(
+ "subscription fetch failed: token={}... status={}",
+ token[:10], exc.status if hasattr(exc, "status") else "?",
+ )
+ return None
+ except Exception as exc:
+ logger.debug(
+ "subscription fetch error: token={}... error={}",
+ token[:10], exc,
+ )
+ return None
+
+ return parse_subscription(body)
+
+
+async def fetch_subscription_for_import(token: str) -> SubscriptionInfo:
+ """Fetch subscription info during account import.
+
+ Always returns a ``SubscriptionInfo`` — falls back to ``"basic"`` with
+ no account ID when the API is unreachable.
+
+ Raises ``UpstreamError`` for 401/403 (invalid credentials).
+ """
+ result = await fetch_subscription(token)
+ if result is not None:
+ return result
+
+ # API unreachable — return a safe default.
+ return SubscriptionInfo(
+ xai_user_id=None,
+ pool="basic",
+ tier="",
+ is_active=False,
+ raw=None,
+ )
+
+
+__all__ = [
+ "SubscriptionInfo",
+ "parse_subscription",
+ "fetch_subscription",
+ "fetch_subscription_for_import",
+ "tier_to_pool",
+ "is_active_status",
+]
diff --git a/app/dataplane/reverse/runtime/endpoint_table.py b/app/dataplane/reverse/runtime/endpoint_table.py
index bb390e503..13da5fc41 100644
--- a/app/dataplane/reverse/runtime/endpoint_table.py
+++ b/app/dataplane/reverse/runtime/endpoint_table.py
@@ -23,6 +23,11 @@
# ── Rate limits (usage / quota sync) ─────────────────────────────────────
RATE_LIMITS = f"{BASE}/rest/rate-limits" # POST
+# ── Subscription / account type ─────────────────────────────────────────
+SUBSCRIPTIONS = f"{BASE}/rest/subscriptions" # GET
+PRODUCTS = f"{BASE}/rest/products" # GET ?provider=SUBSCRIPTION_PROVIDER_STRIPE
+MODES = f"{BASE}/rest/modes" # POST
+
# ── gRPC-Web endpoints ──────────────────────────────────────────────────
ACCEPT_TOS = "https://accounts.x.ai/auth_mgmt.AuthManagement/SetTosAcceptedVersion"
NSFW_MGMT = f"{BASE}/auth_mgmt.AuthManagement/UpdateUserFeatureControls"
@@ -47,7 +52,7 @@
"BASE", "ASSETS_CDN",
"CHAT",
"ASSETS_UPLOAD", "ASSETS_LIST", "ASSETS_DELETE", "ASSETS_DOWNLOAD",
- "RATE_LIMITS",
+ "RATE_LIMITS", "SUBSCRIPTIONS", "PRODUCTS", "MODES",
"ACCEPT_TOS", "NSFW_MGMT", "SET_BIRTH",
"MEDIA_POST", "MEDIA_POST_LINK", "VIDEO_UPSCALE",
"WS_IMAGINE", "WS_LIVEKIT", "LIVEKIT_TOKENS",
diff --git a/app/dataplane/reverse/transport/asset_upload.py b/app/dataplane/reverse/transport/asset_upload.py
index 891594424..dc3ce32ab 100644
--- a/app/dataplane/reverse/transport/asset_upload.py
+++ b/app/dataplane/reverse/transport/asset_upload.py
@@ -184,7 +184,7 @@ async def upload_from_input(token: str, file_input: str) -> tuple[str, str]:
if _is_url(file_input):
# Fetch the remote URL and re-upload as base64.
proxy = await get_proxy_runtime()
- lease = await proxy.acquire()
+ lease = await proxy.acquire(resource=True)
try:
headers = build_http_headers(token, lease=lease)
kwargs = build_session_kwargs(lease=lease)
diff --git a/app/dataplane/reverse/transport/assets.py b/app/dataplane/reverse/transport/assets.py
index b786447ff..faa7135a0 100644
--- a/app/dataplane/reverse/transport/assets.py
+++ b/app/dataplane/reverse/transport/assets.py
@@ -194,7 +194,11 @@ async def download_asset(
}
proxy = await get_proxy_runtime()
- lease = await proxy.acquire(scope=ProxyScope.ASSET, kind=RequestKind.HTTP)
+ lease = await proxy.acquire(
+ scope=ProxyScope.ASSET,
+ kind=RequestKind.HTTP,
+ resource=True,
+ )
try:
stream = await get_bytes_stream(
diff --git a/app/platform/config/backends/factory.py b/app/platform/config/backends/factory.py
index 258b385eb..eda5e53b5 100644
--- a/app/platform/config/backends/factory.py
+++ b/app/platform/config/backends/factory.py
@@ -1,4 +1,4 @@
-"""Config backend factory — follows ACCOUNT_STORAGE automatically."""
+"""Config backend factory."""
import os
from pathlib import Path
@@ -8,19 +8,24 @@
def get_config_backend_name() -> str:
- """Return the active config backend name (mirrors ACCOUNT_STORAGE)."""
- return os.getenv("ACCOUNT_STORAGE", "local").strip().lower()
+ """Return the active config backend name.
+
+ Runtime configuration is local by default even when account data is stored
+ in Redis or SQL. Set CONFIG_STORAGE explicitly to opt into a remote config
+ backend.
+ """
+ return os.getenv("CONFIG_STORAGE", "local").strip().lower()
def create_config_backend() -> ConfigBackend:
- """Instantiate the config backend that matches the account storage backend.
+ """Instantiate the configured runtime config backend.
- ``ACCOUNT_STORAGE=local`` → TOML file (``${DATA_DIR}/config.toml``)
- ``ACCOUNT_STORAGE=redis`` → Redis (ACCOUNT_REDIS_URL)
- ``ACCOUNT_STORAGE=mysql`` → MySQL (ACCOUNT_MYSQL_URL)
- ``ACCOUNT_STORAGE=postgresql`` → PostgreSQL (ACCOUNT_POSTGRESQL_URL)
+ ``CONFIG_STORAGE=local`` → TOML file (``${DATA_DIR}/config.toml``)
+ ``CONFIG_STORAGE=redis`` → Redis (ACCOUNT_REDIS_URL)
+ ``CONFIG_STORAGE=mysql`` → MySQL (ACCOUNT_MYSQL_URL)
+ ``CONFIG_STORAGE=postgresql`` → PostgreSQL (ACCOUNT_POSTGRESQL_URL)
- No extra env vars needed — reuses the same connection settings as accounts.
+ Remote config backends reuse the same connection settings as accounts.
"""
backend = get_config_backend_name()
@@ -31,7 +36,7 @@ def create_config_backend() -> ConfigBackend:
if backend in ("mysql", "postgresql"):
return _make_sql(backend)
- raise ValueError(f"Unknown account storage backend: {backend!r}")
+ raise ValueError(f"Unknown config storage backend: {backend!r}")
def _make_toml() -> ConfigBackend:
diff --git a/app/platform/config/snapshot.py b/app/platform/config/snapshot.py
index f489f01ad..8b14b4575 100644
--- a/app/platform/config/snapshot.py
+++ b/app/platform/config/snapshot.py
@@ -22,6 +22,25 @@ def _mtime(path: Path) -> float:
return 0.0
+def _apply_legacy_cache_overrides(user_overrides: dict[str, Any]) -> dict[str, Any]:
+ """Map legacy [storage] cache limits into [cache.local] at load time."""
+ storage = user_overrides.get("storage")
+ if not isinstance(storage, dict):
+ return user_overrides
+
+ cache = user_overrides.setdefault("cache", {})
+ if not isinstance(cache, dict):
+ return user_overrides
+ local = cache.setdefault("local", {})
+ if not isinstance(local, dict):
+ return user_overrides
+
+ for key in ("image_max_mb", "video_max_mb"):
+ if key not in local and key in storage:
+ local[key] = storage[key]
+ return user_overrides
+
+
class ConfigSnapshot:
"""Immutable view over the loaded configuration dict.
@@ -74,6 +93,7 @@ async def load(self, defaults_path: Path | None = None) -> None:
defaults = await asyncio.to_thread(load_toml, dp)
user_overrides = await backend.load()
+ user_overrides = _apply_legacy_cache_overrides(user_overrides)
self._data = _deep_merge(defaults, user_overrides)
self._data = _apply_env(self._data)
diff --git a/app/platform/startup/migration.py b/app/platform/startup/migration.py
index e8cb7fc79..063b2be31 100644
--- a/app/platform/startup/migration.py
+++ b/app/platform/startup/migration.py
@@ -2,9 +2,10 @@
Config migration
----------------
-local : seeds ``${DATA_DIR}/config.toml`` from ``config.defaults.toml`` if
+local : default; seeds ``${DATA_DIR}/config.toml`` from ``config.defaults.toml`` if
the file does not exist yet — gives users an editable copy on first run.
-redis / sql : if the backend is empty (version == 0) AND
+redis / sql : only when CONFIG_STORAGE is explicitly set to a remote backend.
+ If the backend is empty (version == 0) AND
``${DATA_DIR}/config.toml`` exists, migrates the user overrides into
the DB backend. If it does not exist either, nothing is written
(defaults are always loaded from ``config.defaults.toml`` at runtime).
diff --git a/app/platform/storage/__init__.py b/app/platform/storage/__init__.py
index 895f9a58e..c57f78cae 100644
--- a/app/platform/storage/__init__.py
+++ b/app/platform/storage/__init__.py
@@ -3,6 +3,8 @@
from .media_cache import (
clear_local_media_files,
delete_local_media_file,
+ list_local_media_files,
+ local_media_stats,
reconcile_local_media_cache_async,
save_local_image,
save_local_video,
@@ -13,6 +15,8 @@
"clear_local_media_files",
"delete_local_media_file",
"image_files_dir",
+ "list_local_media_files",
+ "local_media_stats",
"reconcile_local_media_cache_async",
"save_local_image",
"save_local_video",
diff --git a/app/platform/storage/media_cache.py b/app/platform/storage/media_cache.py
index f91b06d3f..171cf2466 100644
--- a/app/platform/storage/media_cache.py
+++ b/app/platform/storage/media_cache.py
@@ -88,6 +88,57 @@ def reconcile(self, media_type: MediaType) -> None:
self._enforce_limit_locked(conn, media_type)
conn.commit()
+ def limit_mb(self, media_type: MediaType) -> int:
+ """Return the configured per-type cache limit in MB."""
+ cfg = self._config_provider()
+ limit_mb = max(0, int(cfg.get_int(f"cache.local.{media_type}_max_mb", 0)))
+ return limit_mb
+
+ def stats(self, media_type: MediaType) -> dict[str, Any]:
+ """Return count, size, and configured limit for one media type."""
+ files = list(self._iter_files(media_type))
+ total_size = sum(path.stat().st_size for path in files)
+ limit_mb = self.limit_mb(media_type)
+ limit_bytes = limit_mb * 1024 * 1024
+ usage_ratio = (total_size / limit_bytes) if limit_bytes > 0 else None
+ usage_percent = round(usage_ratio * 100, 1) if usage_ratio is not None else None
+ return {
+ "count": len(files),
+ "size_mb": round(total_size / 1024 / 1024, 2),
+ "size_bytes": total_size,
+ "limit_mb": limit_mb,
+ "limit_bytes": limit_bytes,
+ "limited": limit_bytes > 0,
+ "usage_ratio": round(usage_ratio, 4) if usage_ratio is not None else None,
+ "usage_percent": usage_percent,
+ }
+
+ def list_files(
+ self,
+ media_type: MediaType,
+ *,
+ page: int = 1,
+ page_size: int = 1000,
+ ) -> dict[str, Any]:
+ """Return paginated local media files newest first."""
+ files = sorted(
+ self._iter_files(media_type),
+ key=lambda path: path.stat().st_mtime,
+ reverse=True,
+ )
+ total = len(files)
+ start = max(0, page - 1) * page_size
+ chunk = files[start : start + page_size]
+ items = []
+ for path in chunk:
+ stat = path.stat()
+ items.append({
+ "name": path.name,
+ "size_bytes": stat.st_size,
+ "modified_at": stat.st_mtime,
+ })
+ return {"total": total, "page": page, "page_size": page_size, "items": items}
+
def delete(self, media_type: MediaType, name: str) -> bool:
"""Delete a single local media file and keep the index consistent."""
safe_name = self._validate_name(media_type, name)
@@ -138,9 +189,7 @@ def _save(
return path
def _limit_bytes(self, media_type: MediaType) -> int:
- cfg = self._config_provider()
- limit_mb = max(0, int(cfg.get_int(f"cache.local.{media_type}_max_mb", 0)))
- return limit_mb * 1024 * 1024
+ return self.limit_mb(media_type) * 1024 * 1024
def _target_bytes(self, max_bytes: int) -> int:
return max(0, int(max_bytes * _LOW_WATERMARK_RATIO))
@@ -479,6 +528,21 @@ def delete_local_media_file(media_type: MediaType, name: str) -> bool:
return local_media_cache.delete(media_type, name)
+def local_media_stats(media_type: MediaType) -> dict[str, Any]:
+ """Return count, size, and configured limit for local media files."""
+ return local_media_cache.stats(media_type)
+
+
+def list_local_media_files(
+ media_type: MediaType,
+ *,
+ page: int = 1,
+ page_size: int = 1000,
+) -> dict[str, Any]:
+ """Return paginated local media files newest first."""
+ return local_media_cache.list_files(media_type, page=page, page_size=page_size)
+
+
async def reconcile_local_media_cache_async(
media_type: MediaType | None = None,
) -> None:
@@ -494,7 +558,9 @@ async def reconcile_local_media_cache_async(
"LocalMediaCacheStore",
"clear_local_media_files",
"delete_local_media_file",
+ "list_local_media_files",
"local_media_cache",
+ "local_media_stats",
"reconcile_local_media_cache_async",
"save_local_image",
"save_local_video",
diff --git a/app/products/anthropic/messages.py b/app/products/anthropic/messages.py
index 42ba98a17..be01b35e1 100644
--- a/app/products/anthropic/messages.py
+++ b/app/products/anthropic/messages.py
@@ -31,7 +31,8 @@
from app.products.openai.chat import (
_stream_chat, _extract_message, _resolve_image,
_quota_sync, _fail_sync, _parse_retry_codes, _feedback_kind, _log_task_exception,
- _configured_retry_codes, _should_retry_upstream,
+ _adapter_has_visible_output, _configured_retry_codes,
+ _empty_upstream_response_error, _should_retry_upstream, _StreamStartGate,
)
from app.products._account_selection import reserve_account, selection_max_retries
from app.products.openai._tool_sieve import ToolSieve
@@ -345,11 +346,12 @@ async def _run_stream() -> AsyncGenerator[str, None]:
tool_output_tokens = 0
block_index = 0 # tracks next content_block index
collected_annotations: list[dict] = []
+ gate = _StreamStartGate()
try:
try:
# message_start
- yield _sse("message_start", {
+ for out in gate.emit(_sse("message_start", {
"type": "message_start",
"message": {
"id": msg_id,
@@ -360,8 +362,10 @@ async def _run_stream() -> AsyncGenerator[str, None]:
"stop_reason": None,
"usage": {"input_tokens": estimate_prompt_tokens(internal_message), "output_tokens": 0},
},
- })
- yield _sse("ping", {"type": "ping"})
+ })):
+ yield out
+ for out in gate.emit(_sse("ping", {"type": "ping"})):
+ yield out
ended = False
async for line in _stream_chat(
@@ -385,35 +389,38 @@ async def _run_stream() -> AsyncGenerator[str, None]:
if ev.kind == "thinking" and emit_think and not think_closed:
if not think_started:
think_started = True
- yield _sse("content_block_start", {
+ for out in gate.emit(_sse("content_block_start", {
"type": "content_block_start",
"index": block_index,
"content_block": {"type": "thinking", "thinking": ""},
- })
+ })):
+ yield out
think_buf.append(ev.content)
- yield _sse("content_block_delta", {
+ for out in gate.emit(_sse("content_block_delta", {
"type": "content_block_delta",
"index": block_index,
"delta": {"type": "thinking_delta", "thinking": ev.content},
- })
+ })):
+ yield out
elif ev.kind == "text":
# Close thinking block if open
if think_started and not think_closed:
think_closed = True
- yield _sse("content_block_stop", {
+ for out in gate.emit(_sse("content_block_stop", {
"type": "content_block_stop",
"index": block_index,
- })
+ })):
+ yield out
block_index += 1
# Feed through ToolSieve if tools active
if sieve is not None:
safe_text, calls = sieve.feed(ev.content)
- if calls is not None:
+ if calls:
# Emit tool_use blocks
for call in calls:
- yield _sse("content_block_start", {
+ for out in gate.emit(_sse("content_block_start", {
"type": "content_block_start",
"index": block_index,
"content_block": {
@@ -422,19 +429,22 @@ async def _run_stream() -> AsyncGenerator[str, None]:
"name": call.name,
"input": {},
},
- })
- yield _sse("content_block_delta", {
+ }), visible=True):
+ yield out
+ for out in gate.emit(_sse("content_block_delta", {
"type": "content_block_delta",
"index": block_index,
"delta": {
"type": "input_json_delta",
"partial_json": call.arguments,
},
- })
- yield _sse("content_block_stop", {
+ })):
+ yield out
+ for out in gate.emit(_sse("content_block_stop", {
"type": "content_block_stop",
"index": block_index,
- })
+ })):
+ yield out
block_index += 1
tool_output_tokens = estimate_tool_call_tokens(calls)
tool_calls_emitted = True
@@ -447,17 +457,19 @@ async def _run_stream() -> AsyncGenerator[str, None]:
if text_chunk:
if not text_started:
text_started = True
- yield _sse("content_block_start", {
+ for out in gate.emit(_sse("content_block_start", {
"type": "content_block_start",
"index": block_index,
"content_block": {"type": "text", "text": ""},
- })
+ })):
+ yield out
text_buf.append(text_chunk)
- yield _sse("content_block_delta", {
+ for out in gate.emit(_sse("content_block_delta", {
"type": "content_block_delta",
"index": block_index,
"delta": {"type": "text_delta", "text": text_chunk},
- })
+ }), visible=bool(text_chunk.strip())):
+ yield out
elif ev.kind == "annotation" and ev.annotation_data:
collected_annotations.append(ev.annotation_data)
@@ -475,14 +487,15 @@ async def _run_stream() -> AsyncGenerator[str, None]:
if calls:
# Close text block if open
if text_started:
- yield _sse("content_block_stop", {
+ for out in gate.emit(_sse("content_block_stop", {
"type": "content_block_stop",
"index": block_index,
- })
+ })):
+ yield out
block_index += 1
text_started = False
for call in calls:
- yield _sse("content_block_start", {
+ for out in gate.emit(_sse("content_block_start", {
"type": "content_block_start",
"index": block_index,
"content_block": {
@@ -491,19 +504,22 @@ async def _run_stream() -> AsyncGenerator[str, None]:
"name": call.name,
"input": {},
},
- })
- yield _sse("content_block_delta", {
+ }), visible=True):
+ yield out
+ for out in gate.emit(_sse("content_block_delta", {
"type": "content_block_delta",
"index": block_index,
"delta": {
"type": "input_json_delta",
"partial_json": call.arguments,
},
- })
- yield _sse("content_block_stop", {
+ })):
+ yield out
+ for out in gate.emit(_sse("content_block_stop", {
"type": "content_block_stop",
"index": block_index,
- })
+ })):
+ yield out
block_index += 1
tool_output_tokens = estimate_tool_call_tokens(calls)
tool_calls_emitted = True
@@ -514,13 +530,16 @@ async def _run_stream() -> AsyncGenerator[str, None]:
sources = adapter.search_sources_list()
if sources:
tool_delta["search_sources"] = sources
- yield _sse("message_delta", {
+ for out in gate.emit(_sse("message_delta", {
"type": "message_delta",
"delta": tool_delta,
"usage": {"output_tokens": tool_output_tokens},
- })
- yield _sse("message_stop", {"type": "message_stop"})
- yield "data: [DONE]\n\n"
+ }), visible=True):
+ yield out
+ for out in gate.emit(_sse("message_stop", {"type": "message_stop"})):
+ yield out
+ for out in gate.emit("data: [DONE]\n\n"):
+ yield out
success = True
logger.info("messages stream tool_calls: attempt={}/{} model={}",
attempt + 1, max_retries + 1, model)
@@ -532,37 +551,43 @@ async def _run_stream() -> AsyncGenerator[str, None]:
chunk = img_text + "\n"
text_buf.append(chunk)
if text_started:
- yield _sse("content_block_delta", {
+ for out in gate.emit(_sse("content_block_delta", {
"type": "content_block_delta",
"index": block_index,
"delta": {"type": "text_delta", "text": chunk},
- })
+ }), visible=bool(img_text.strip())):
+ yield out
references = adapter.references_suffix()
if references:
text_buf.append(references)
if text_started:
- yield _sse("content_block_delta", {
+ for out in gate.emit(_sse("content_block_delta", {
"type": "content_block_delta",
"index": block_index,
"delta": {"type": "text_delta", "text": references},
- })
+ }), visible=True):
+ yield out
# Close open blocks
if think_started and not think_closed:
- yield _sse("content_block_stop", {
+ for out in gate.emit(_sse("content_block_stop", {
"type": "content_block_stop",
"index": block_index,
- })
+ })):
+ yield out
block_index += 1
if text_started:
- yield _sse("content_block_stop", {
+ for out in gate.emit(_sse("content_block_stop", {
"type": "content_block_stop",
"index": block_index,
- })
+ })):
+ yield out
full_text = "".join(text_buf)
+ if not _adapter_has_visible_output(adapter, extra_text=full_text):
+ raise _empty_upstream_response_error()
full_think = "".join(think_buf)
out_tokens = estimate_tokens(full_text)
if full_think:
@@ -575,13 +600,16 @@ async def _run_stream() -> AsyncGenerator[str, None]:
msg_delta["search_sources"] = sources
if collected_annotations:
msg_delta["annotations"] = collected_annotations
- yield _sse("message_delta", {
+ for out in gate.emit(_sse("message_delta", {
"type": "message_delta",
"delta": msg_delta,
"usage": {"output_tokens": out_tokens},
- })
- yield _sse("message_stop", {"type": "message_stop"})
- yield "data: [DONE]\n\n"
+ }), visible=True):
+ yield out
+ for out in gate.emit(_sse("message_stop", {"type": "message_stop"})):
+ yield out
+ for out in gate.emit("data: [DONE]\n\n"):
+ yield out
success = True
logger.info(
"messages stream completed: attempt={}/{} model={} text_len={} think_len={} images={}",
@@ -664,6 +692,8 @@ async def _run_stream() -> AsyncGenerator[str, None]:
break
if ended:
break
+ if not _adapter_has_visible_output(adapter):
+ raise _empty_upstream_response_error()
success = True
except UpstreamError as exc:
diff --git a/app/products/anthropic/router.py b/app/products/anthropic/router.py
index 34b1f2126..cacfc3575 100644
--- a/app/products/anthropic/router.py
+++ b/app/products/anthropic/router.py
@@ -1,6 +1,6 @@
"""Anthropic Messages API router (/v1/messages)."""
-from typing import Any
+from typing import Any, AsyncGenerator, AsyncIterable
import orjson
from fastapi import APIRouter, Depends, Request
@@ -8,7 +8,7 @@
from pydantic import BaseModel, Field
from app.platform.auth.middleware import verify_api_key
-from app.platform.errors import AppError, ValidationError
+from app.platform.errors import AppError, UpstreamError, ValidationError
from app.platform.logging.logger import logger
from app.control.model import registry as model_registry
@@ -72,6 +72,26 @@ async def _safe_sse_anthropic(stream):
yield "data: [DONE]\n\n"
+async def _prime_sse(stream: AsyncIterable[str]) -> AsyncGenerator[str, None]:
+ """Read the first chunk before StreamingResponse sends HTTP 200."""
+ iterator = stream.__aiter__()
+ try:
+ first = await anext(iterator)
+ except StopAsyncIteration as exc:
+ raise UpstreamError(
+ "Upstream returned empty response",
+ status=429,
+ body="empty_response",
+ ) from exc
+
+ async def _primed() -> AsyncGenerator[str, None]:
+ yield first
+ async for chunk in iterator:
+ yield chunk
+
+ return _primed()
+
+
# ---------------------------------------------------------------------------
# /v1/messages
# ---------------------------------------------------------------------------
@@ -118,6 +138,7 @@ async def messages_endpoint(req: MessagesRequest):
if isinstance(result, dict):
return JSONResponse(result)
+ result = await _prime_sse(result)
return StreamingResponse(
_safe_sse_anthropic(result),
media_type = "text/event-stream",
diff --git a/app/products/openai/__init__.py b/app/products/openai/__init__.py
index 5bc0c2e6d..c3904ab36 100644
--- a/app/products/openai/__init__.py
+++ b/app/products/openai/__init__.py
@@ -1,3 +1,10 @@
-from .router import router
+"""OpenAI product package exports."""
__all__ = ["router"]
+
+
+def __getattr__(name: str):
+ if name == "router":
+ from .router import router
+ return router
+ raise AttributeError(name)
diff --git a/app/products/openai/chat.py b/app/products/openai/chat.py
index 7a551b13e..b87f35390 100644
--- a/app/products/openai/chat.py
+++ b/app/products/openai/chat.py
@@ -192,6 +192,64 @@ def _feedback_kind(exc: BaseException) -> "FeedbackKind":
return feedback_kind_for_error(exc)
+EMPTY_UPSTREAM_BODY = "empty_response"
+
+
+def _empty_upstream_response_error() -> UpstreamError:
+ return UpstreamError(
+ "Upstream returned empty response",
+ status=429,
+ body=EMPTY_UPSTREAM_BODY,
+ )
+
+
+def _adapter_has_visible_output(
+ adapter: StreamAdapter,
+ *,
+ extra_text: str = "",
+ has_tool_calls: bool = False,
+ include_thinking: bool = False,
+) -> bool:
+ """Return True when the adapter has user-visible response content.
+
+ Thinking/reasoning buffers intentionally do not count here.
+ """
+ if has_tool_calls:
+ return True
+ if include_thinking and "".join(adapter.thinking_buf).strip():
+ return True
+ if (extra_text or "".join(adapter.text_buf)).strip():
+ return True
+ if adapter.image_urls:
+ return True
+ if adapter.references_suffix().strip():
+ return True
+ if adapter.search_sources_list():
+ return True
+ return False
+
+
+class _StreamStartGate:
+ """Buffer stream events until a visible payload is available."""
+
+ __slots__ = ("_pending", "visible")
+
+ def __init__(self) -> None:
+ self._pending: list[str] = []
+ self.visible = False
+
+ def emit(self, chunk: str, *, visible: bool = False) -> list[str]:
+ if visible and not self.visible:
+ self.visible = True
+ out = self._pending + [chunk]
+ self._pending = []
+ return out
+ if self.visible:
+ return [chunk]
+ self._pending.append(chunk)
+ return []
+
+
async def _download_image_bytes(token: str, url: str) -> tuple[bytes, str]:
"""Download image bytes via the shared asset transport used by /v1/images."""
from app.dataplane.reverse.protocol.xai_assets import infer_content_type
@@ -527,6 +585,7 @@ async def _run_stream() -> AsyncGenerator[str, None]:
ended = False
sieve = ToolSieve(tool_names)
tool_calls_emitted = False
+ gate = _StreamStartGate()
async for line in _stream_chat(
token=token,
mode_id=ModeId(selected_mode_id),
@@ -552,8 +611,15 @@ async def _run_stream() -> AsyncGenerator[str, None]:
chunk = make_stream_chunk(
response_id, model, safe_text
)
- yield f"data: {orjson.dumps(chunk).decode()}\n\n"
- if parsed_calls is not None:
+ payload = f"data: {orjson.dumps(chunk).decode()}\n\n"
+ for out in gate.emit(
+ payload, visible=bool(safe_text.strip())
+ ):
+ yield out
+ if parsed_calls:
+ for out in gate.emit("", visible=True):
+ if out:
+ yield out
for i, tc in enumerate(parsed_calls):
chunk = make_tool_call_chunk(
response_id,
@@ -585,12 +651,18 @@ async def _run_stream() -> AsyncGenerator[str, None]:
chunk = make_stream_chunk(
response_id, model, ev.content
)
- yield f"data: {orjson.dumps(chunk).decode()}\n\n"
+ payload = f"data: {orjson.dumps(chunk).decode()}\n\n"
+ for out in gate.emit(
+ payload, visible=bool(ev.content.strip())
+ ):
+ yield out
elif ev.kind == "thinking" and emit_think:
chunk = make_thinking_chunk(
response_id, model, ev.content
)
- yield f"data: {orjson.dumps(chunk).decode()}\n\n"
+ payload = f"data: {orjson.dumps(chunk).decode()}\n\n"
+ for out in gate.emit(payload, visible=True):
+ yield out
elif ev.kind == "annotation" and ev.annotation_data:
collected_annotations.append(ev.annotation_data)
elif ev.kind == "soft_stop":
@@ -603,6 +675,9 @@ async def _run_stream() -> AsyncGenerator[str, None]:
# Stream ended — flush sieve for any buffered XML
flushed_calls = sieve.flush()
if flushed_calls:
+ for out in gate.emit("", visible=True):
+ if out:
+ yield out
for i, tc in enumerate(flushed_calls):
chunk = make_tool_call_chunk(
response_id,
@@ -637,14 +712,27 @@ async def _run_stream() -> AsyncGenerator[str, None]:
chunk = make_stream_chunk(
response_id, model, img_text + "\n"
)
- yield f"data: {orjson.dumps(chunk).decode()}\n\n"
+ payload = f"data: {orjson.dumps(chunk).decode()}\n\n"
+ for out in gate.emit(
+ payload, visible=bool(img_text.strip())
+ ):
+ yield out
references = adapter.references_suffix()
if references:
chunk = make_stream_chunk(
response_id, model, references
)
- yield f"data: {orjson.dumps(chunk).decode()}\n\n"
+ payload = f"data: {orjson.dumps(chunk).decode()}\n\n"
+ for out in gate.emit(payload, visible=True):
+ yield out
+
+ if not _adapter_has_visible_output(
+ adapter,
+ has_tool_calls=tool_calls_emitted,
+ include_thinking=emit_think,
+ ):
+ raise _empty_upstream_response_error()
chat_anns = _to_chat_annotations(collected_annotations)
final = make_stream_chunk(
@@ -658,8 +746,11 @@ async def _run_stream() -> AsyncGenerator[str, None]:
sources = adapter.search_sources_list()
if sources:
final["search_sources"] = sources
- yield f"data: {orjson.dumps(final).decode()}\n\n"
- yield "data: [DONE]\n\n"
+ final_payload = f"data: {orjson.dumps(final).decode()}\n\n"
+ for out in gate.emit(final_payload, visible=True):
+ yield out
+ for out in gate.emit("data: [DONE]\n\n"):
+ yield out
success = True
logger.info(
"chat stream completed: attempt={}/{} model={} image_count={}",
@@ -764,6 +855,8 @@ async def _run_stream() -> AsyncGenerator[str, None]:
break
if ended:
break
+ if not _adapter_has_visible_output(adapter):
+ raise _empty_upstream_response_error()
success = True
except UpstreamError as exc:
@@ -884,6 +977,10 @@ async def _run_stream() -> AsyncGenerator[str, None]:
__all__ = [
"completions",
+ "EMPTY_UPSTREAM_BODY",
+ "_adapter_has_visible_output",
"_configured_retry_codes",
+ "_empty_upstream_response_error",
"_should_retry_upstream",
+ "_StreamStartGate",
]
diff --git a/app/products/openai/responses.py b/app/products/openai/responses.py
index d816c7a92..9b94b881e 100644
--- a/app/products/openai/responses.py
+++ b/app/products/openai/responses.py
@@ -7,8 +7,6 @@
import asyncio
from typing import Any, AsyncGenerator
-import orjson
-
from app.platform.logging.logger import logger
from app.platform.config.snapshot import get_config
from app.platform.errors import RateLimitError, UpstreamError
@@ -20,13 +18,14 @@
from app.dataplane.reverse.protocol.xai_chat import classify_line, StreamAdapter
from app.products._account_selection import reserve_account, selection_max_retries
-from .chat import _stream_chat, _extract_message, _resolve_image, _quota_sync, _fail_sync, _parse_retry_codes, _feedback_kind, _log_task_exception, _upstream_body_excerpt
+from .chat import _stream_chat, _extract_message, _resolve_image, _quota_sync, _fail_sync, _feedback_kind, _log_task_exception, _upstream_body_excerpt
+from .chat import _adapter_has_visible_output, _empty_upstream_response_error, _StreamStartGate
from .chat import _configured_retry_codes, _should_retry_upstream
from ._format import (
make_resp_id, build_resp_usage, make_resp_object, format_sse,
)
from app.dataplane.reverse.protocol.tool_prompt import (
- build_tool_system_prompt, extract_tool_names, inject_into_message, tool_calls_to_xml,
+ build_tool_system_prompt, extract_tool_names, inject_into_message,
)
from app.dataplane.reverse.protocol.tool_parser import parse_tool_calls
from ._tool_sieve import ToolSieve
@@ -221,7 +220,6 @@ async def create(
cfg = get_config()
spec = resolve_model(model)
- mode_id = int(spec.mode_id) # cast once, reuse everywhere
messages: list[dict] = []
if instructions:
@@ -283,13 +281,15 @@ async def _run_stream() -> AsyncGenerator[str, None]:
tool_calls_emitted = False
detected_fc_items: list[dict] = []
collected_annotations: list[dict] = []
+ gate = _StreamStartGate()
try:
try:
- yield format_sse("response.created", {
+ for out in gate.emit(format_sse("response.created", {
"type": "response.created",
"response": make_resp_object(response_id, model, "in_progress", []),
- })
+ })):
+ yield out
ended = False
async for line in _stream_chat(
@@ -313,49 +313,54 @@ async def _run_stream() -> AsyncGenerator[str, None]:
if ev.kind == "thinking" and emit_think and not reasoning_closed:
if not reasoning_started:
reasoning_started = True
- yield format_sse("response.output_item.added", {
+ for out in gate.emit(format_sse("response.output_item.added", {
"type": "response.output_item.added",
"output_index": 0,
"item": {
"id": reasoning_id, "type": "reasoning",
"summary": [], "status": "in_progress",
},
- })
- yield format_sse("response.reasoning_summary_part.added", {
+ }), visible=True):
+ yield out
+ for out in gate.emit(format_sse("response.reasoning_summary_part.added", {
"type": "response.reasoning_summary_part.added",
"item_id": reasoning_id,
"output_index": 0,
"summary_index": 0,
"part": {"type": "summary_text", "text": ""},
- })
+ }), visible=True):
+ yield out
think_buf.append(ev.content)
- yield format_sse("response.reasoning_summary_text.delta", {
+ for out in gate.emit(format_sse("response.reasoning_summary_text.delta", {
"type": "response.reasoning_summary_text.delta",
"item_id": reasoning_id,
"output_index": 0,
"summary_index": 0,
"delta": ev.content,
- })
+ }), visible=True):
+ yield out
elif ev.kind == "text":
if reasoning_started and not reasoning_closed:
reasoning_closed = True
full_think = "".join(think_buf)
- yield format_sse("response.reasoning_summary_text.done", {
+ for out in gate.emit(format_sse("response.reasoning_summary_text.done", {
"type": "response.reasoning_summary_text.done",
"item_id": reasoning_id,
"output_index": 0,
"summary_index": 0,
"text": full_think,
- })
- yield format_sse("response.reasoning_summary_part.done", {
+ })):
+ yield out
+ for out in gate.emit(format_sse("response.reasoning_summary_part.done", {
"type": "response.reasoning_summary_part.done",
"item_id": reasoning_id,
"output_index": 0,
"summary_index": 0,
"part": {"type": "summary_text", "text": full_think},
- })
- yield format_sse("response.output_item.done", {
+ })):
+ yield out
+ for out in gate.emit(format_sse("response.output_item.done", {
"type": "response.output_item.done",
"output_index": 0,
"item": {
@@ -364,17 +369,19 @@ async def _run_stream() -> AsyncGenerator[str, None]:
"summary": [{"type": "summary_text", "text": full_think}],
"status": "completed",
},
- })
+ })):
+ yield out
# Feed through ToolSieve if tools are active
if sieve is not None:
safe_text, calls = sieve.feed(ev.content)
- if calls is not None:
+ if calls:
fc_items = _build_fc_items(calls)
detected_fc_items = fc_items
base_idx = 1 if reasoning_started else 0
async for evt in _emit_fc_events(fc_items, base_idx):
- yield evt
+ for out in gate.emit(evt, visible=True):
+ yield out
tool_calls_emitted = True
ended = True
break
@@ -386,43 +393,47 @@ async def _run_stream() -> AsyncGenerator[str, None]:
msg_idx = 1 if reasoning_started else 0
if not message_started:
message_started = True
- yield format_sse("response.output_item.added", {
+ for out in gate.emit(format_sse("response.output_item.added", {
"type": "response.output_item.added",
"output_index": msg_idx,
"item": {
"id": message_id, "type": "message",
"role": "assistant", "content": [], "status": "in_progress",
},
- })
- yield format_sse("response.content_part.added", {
+ })):
+ yield out
+ for out in gate.emit(format_sse("response.content_part.added", {
"type": "response.content_part.added",
"item_id": message_id,
"output_index": msg_idx,
"content_index": 0,
"part": {"type": "output_text", "text": "", "annotations": []},
- })
+ })):
+ yield out
text_buf.append(text_chunk)
- yield format_sse("response.output_text.delta", {
+ for out in gate.emit(format_sse("response.output_text.delta", {
"type": "response.output_text.delta",
"item_id": message_id,
"output_index": msg_idx,
"content_index": 0,
"delta": text_chunk,
- })
+ }), visible=bool(text_chunk.strip())):
+ yield out
elif ev.kind == "annotation" and ev.annotation_data:
if message_started:
collected_annotations.append(ev.annotation_data)
msg_idx = 1 if reasoning_started else 0
- yield format_sse("response.output_text.annotation.added", {
+ for out in gate.emit(format_sse("response.output_text.annotation.added", {
"type": "response.output_text.annotation.added",
"item_id": message_id,
"output_index": msg_idx,
"content_index": 0,
"annotation_index": len(collected_annotations) - 1,
"annotation": ev.annotation_data,
- })
+ })):
+ yield out
elif ev.kind == "soft_stop":
ended = True
@@ -439,7 +450,8 @@ async def _run_stream() -> AsyncGenerator[str, None]:
detected_fc_items = fc_items
base_idx = 1 if reasoning_started else 0
async for evt in _emit_fc_events(fc_items, base_idx):
- yield evt
+ for out in gate.emit(evt, visible=True):
+ yield out
tool_calls_emitted = True
if tool_calls_emitted:
@@ -457,14 +469,16 @@ async def _run_stream() -> AsyncGenerator[str, None]:
pt = estimate_prompt_tokens(message)
ct = estimate_tool_call_tokens(detected_fc_items)
rt = estimate_tokens(full_think) if full_think else 0
- yield format_sse("response.completed", {
+ for out in gate.emit(format_sse("response.completed", {
"type": "response.completed",
"response": make_resp_object(
response_id, model, "completed", output,
build_resp_usage(pt, ct + rt, rt),
),
- })
- yield "data: [DONE]\n\n"
+ }), visible=True):
+ yield out
+ for out in gate.emit("data: [DONE]\n\n"):
+ yield out
success = True
logger.info("responses stream tool_calls: attempt={}/{} model={}",
attempt + 1, max_retries + 1, model)
@@ -476,42 +490,52 @@ async def _run_stream() -> AsyncGenerator[str, None]:
img_md = img_text + "\n"
text_buf.append(img_md)
if message_started:
- yield format_sse("response.output_text.delta", {
+ for out in gate.emit(format_sse("response.output_text.delta", {
"type": "response.output_text.delta",
"item_id": message_id,
"output_index": msg_idx,
"content_index": 0,
"delta": img_md,
- })
+ }), visible=bool(img_text.strip())):
+ yield out
references = adapter.references_suffix()
if references:
text_buf.append(references)
if message_started:
- yield format_sse("response.output_text.delta", {
+ for out in gate.emit(format_sse("response.output_text.delta", {
"type": "response.output_text.delta",
"item_id": message_id,
"output_index": msg_idx,
"content_index": 0,
"delta": references,
- })
+ }), visible=True):
+ yield out
full_text = "".join(text_buf)
+ if not _adapter_has_visible_output(
+ adapter,
+ extra_text=full_text,
+ include_thinking=emit_think,
+ ):
+ raise _empty_upstream_response_error()
if message_started:
- yield format_sse("response.output_text.done", {
+ for out in gate.emit(format_sse("response.output_text.done", {
"type": "response.output_text.done",
"item_id": message_id,
"output_index": msg_idx,
"content_index": 0,
"text": full_text,
- })
- yield format_sse("response.content_part.done", {
+ })):
+ yield out
+ for out in gate.emit(format_sse("response.content_part.done", {
"type": "response.content_part.done",
"item_id": message_id,
"output_index": msg_idx,
"content_index": 0,
"part": {"type": "output_text", "text": full_text, "annotations": collected_annotations},
- })
+ })):
+ yield out
# 构建 message item(流式 output_item.done + response.completed 共用)
sources = adapter.search_sources_list()
msg_item: dict = {
@@ -523,11 +547,12 @@ async def _run_stream() -> AsyncGenerator[str, None]:
}
if sources:
msg_item["search_sources"] = sources
- yield format_sse("response.output_item.done", {
+ for out in gate.emit(format_sse("response.output_item.done", {
"type": "response.output_item.done",
"output_index": msg_idx,
"item": msg_item,
- })
+ })):
+ yield out
full_think = "".join(think_buf)
output = []
@@ -550,19 +575,22 @@ async def _run_stream() -> AsyncGenerator[str, None]:
sources = adapter.search_sources_list()
if sources:
msg_item["search_sources"] = sources
- output.append(msg_item)
+ if message_started or full_text:
+ output.append(msg_item)
pt = estimate_prompt_tokens(message)
ct = estimate_tokens(full_text)
rt = estimate_tokens(full_think) if full_think else 0
- yield format_sse("response.completed", {
+ for out in gate.emit(format_sse("response.completed", {
"type": "response.completed",
"response": make_resp_object(
response_id, model, "completed", output,
build_resp_usage(pt, ct + rt, rt),
),
- })
- yield "data: [DONE]\n\n"
+ }), visible=True):
+ yield out
+ for out in gate.emit("data: [DONE]\n\n"):
+ yield out
success = True
logger.info("responses stream completed: attempt={}/{} model={} text_len={} reasoning_len={} image_count={}",
attempt + 1, max_retries + 1, model,
@@ -644,6 +672,8 @@ async def _run_stream() -> AsyncGenerator[str, None]:
break
if ended:
break
+ if not _adapter_has_visible_output(adapter):
+ raise _empty_upstream_response_error()
success = True
except UpstreamError as exc:
diff --git a/app/products/openai/router.py b/app/products/openai/router.py
index 01a27504a..731654dff 100644
--- a/app/products/openai/router.py
+++ b/app/products/openai/router.py
@@ -3,15 +3,16 @@
import base64
import binascii
import mimetypes
-from typing import Annotated, AsyncGenerator, AsyncIterable, Literal
+from typing import Annotated, AsyncGenerator, AsyncIterable, Literal, TypeVar
import orjson
from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
+from pydantic import ValidationError as PydanticValidationError
from app.control.account.state_machine import is_manageable
from app.platform.auth.middleware import verify_api_key
-from app.platform.errors import AppError, ValidationError
+from app.platform.errors import AppError, UpstreamError, ValidationError
from app.platform.logging.logger import logger
from app.platform.storage import image_files_dir, video_files_dir
from app.control.model import registry as model_registry
@@ -131,6 +132,26 @@ async def _safe_sse(stream: AsyncIterable[str]) -> AsyncGenerator[str, None]:
yield "data: [DONE]\n\n"
+async def _prime_sse(stream: AsyncIterable[str]) -> AsyncGenerator[str, None]:
+ """Read the first chunk before StreamingResponse sends HTTP 200."""
+ iterator = stream.__aiter__()
+ try:
+ first = await anext(iterator)
+ except StopAsyncIteration as exc:
+ raise UpstreamError(
+ "Upstream returned empty response",
+ status=429,
+ body="empty_response",
+ ) from exc
+
+ async def _primed() -> AsyncGenerator[str, None]:
+ yield first
+ async for chunk in iterator:
+ yield chunk
+
+ return _primed()
+
+
_SSE_HEADERS = {"Cache-Control": "no-cache", "Connection": "keep-alive"}
@@ -143,6 +164,48 @@ async def _safe_sse(stream: AsyncIterable[str]) -> AsyncGenerator[str, None]:
_ALLOWED_SIZES = {"1280x720", "720x1280", "1792x1024", "1024x1792", "1024x1024"}
_EFFORT_VALUES = {"none", "minimal", "low", "medium", "high", "xhigh"}
_LITE_IMAGE_MODELS = {"grok-imagine-image-lite"}
+_ConfigT = TypeVar("_ConfigT", ImageConfig, VideoConfig)
+
+
+def _metadata_config(
+ req: ChatCompletionRequest,
+ key: str,
+ model_cls: type[_ConfigT],
+) -> _ConfigT | None:
+ metadata = req.metadata
+ if not isinstance(metadata, dict) or key not in metadata:
+ return None
+ value = metadata.get(key)
+ if value is None:
+ return None
+ if not isinstance(value, dict):
+ raise ValidationError(
+ f"metadata.{key} must be an object",
+ param=f"metadata.{key}",
+ )
+ try:
+ return model_cls.model_validate(value)
+ except PydanticValidationError as exc:
+ raise ValidationError(
+ f"metadata.{key} is invalid: {exc.errors()[0].get('msg', 'invalid value')}",
+ param=f"metadata.{key}",
+ ) from exc
+
+
+def _resolve_image_config(req: ChatCompletionRequest) -> ImageConfig:
+ return (
+ req.image_config
+ or _metadata_config(req, "image_config", ImageConfig)
+ or ImageConfig()
+ )
+
+
+def _resolve_video_config(req: ChatCompletionRequest) -> VideoConfig:
+ return (
+ req.video_config
+ or _metadata_config(req, "video_config", VideoConfig)
+ or VideoConfig()
+ )
def _validate_chat(req: ChatCompletionRequest) -> None:
@@ -236,7 +299,7 @@ async def chat_completions_endpoint(req: ChatCompletionRequest):
if spec.is_image_edit():
from .images import edit as img_edit
- cfg = req.image_config or ImageConfig()
+ cfg = _resolve_image_config(req)
_validate_image_edit_n(cfg.n or 1, param="image_config.n")
result = await img_edit(
model=req.model,
@@ -251,7 +314,7 @@ async def chat_completions_endpoint(req: ChatCompletionRequest):
elif spec.is_image():
from .images import generate as img_gen
- cfg = req.image_config or ImageConfig()
+ cfg = _resolve_image_config(req)
size = cfg.size or "1024x1024"
fmt = cfg.response_format or "url"
n = cfg.n or 1
@@ -280,7 +343,7 @@ async def chat_completions_endpoint(req: ChatCompletionRequest):
elif spec.is_video():
from .video import completions as vid_comp
- vcfg = req.video_config or VideoConfig()
+ vcfg = _resolve_video_config(req)
from .video import validate_video_length as _validate_video_length
_validate_video_length(vcfg.seconds or 6)
@@ -339,6 +402,8 @@ async def _err_stream():
if isinstance(result, dict):
return JSONResponse(result)
+ if not (spec.is_image_edit() or spec.is_image()):
+ result = await _prime_sse(result)
return StreamingResponse(
_safe_sse(result), media_type="text/event-stream", headers=_SSE_HEADERS
)
@@ -418,6 +483,7 @@ async def responses_endpoint(req: ResponsesCreateRequest):
if isinstance(result, dict):
return JSONResponse(result)
+ result = await _prime_sse(result)
return StreamingResponse(
_safe_sse_responses(result),
media_type = "text/event-stream",
@@ -480,9 +546,14 @@ async def videos_create(
references_payload = None
if input_reference:
+ if len(input_reference) > 7:
+ raise ValidationError(
+ "Video generation supports at most 7 reference images",
+ param="input_reference",
+ )
references_payload = [
{"image_url": await _upload_to_data_uri(f, param="input_reference")}
- for f in input_reference[:7]
+ for f in input_reference
]
result = await create_video(
diff --git a/app/products/openai/schemas.py b/app/products/openai/schemas.py
index f4caa0fee..a42ff1c2e 100644
--- a/app/products/openai/schemas.py
+++ b/app/products/openai/schemas.py
@@ -39,6 +39,7 @@ class ChatCompletionRequest(BaseModel):
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = True
max_tokens: int | None = None
+ metadata: dict[str, Any] | None = None
class ImageGenerationRequest(BaseModel):
diff --git a/app/products/openai/video.py b/app/products/openai/video.py
index f4e75baca..1240d0d7b 100644
--- a/app/products/openai/video.py
+++ b/app/products/openai/video.py
@@ -29,8 +29,10 @@
from app.platform.runtime.clock import now_s
from app.platform.storage import save_local_video
from app.control.account.enums import FeedbackKind
+from app.control.account.runtime import get_refresh_service
from app.control.model import registry as model_registry
from app.control.model.registry import resolve as resolve_model
+from app.dataplane.account.selector import current_strategy
from app.dataplane.proxy import get_proxy_runtime
from app.dataplane.proxy.adapters.headers import build_http_headers
from app.dataplane.proxy.adapters.session import ResettableSession, build_session_kwargs
@@ -52,7 +54,16 @@
make_stream_chunk,
make_thinking_chunk,
)
-from .chat import _fail_sync, _quota_sync, _feedback_kind
+from .chat import (
+ EMPTY_UPSTREAM_BODY,
+ _configured_retry_codes,
+ _fail_sync,
+ _feedback_kind,
+ _log_task_exception,
+ _quota_sync,
+ _should_retry_upstream,
+)
+from app.products._account_selection import selection_max_retries
_IMAGE_MEDIA_TYPE = "MEDIA_POST_TYPE_IMAGE"
_VIDEO_MEDIA_TYPE = "MEDIA_POST_TYPE_VIDEO"
@@ -61,7 +72,23 @@
_VIDEO_OBJECT = "video"
_VIDEO_JOB_TTL_S = 3600
_VIDEO_EXTENSION_REF_TYPE = "ORIGINAL_REF_TYPE_VIDEO_EXTENSION"
-_SUPPORTED_VIDEO_LENGTHS = frozenset({6, 10, 12, 16, 20})
+_VIDEO_MAX_REFERENCES = 7
+_VIDEO_ALLOWED_POOL_IDS = frozenset((1, 2)) # super / heavy only
+_SUPPORTED_VIDEO_LENGTHS = frozenset({6, 10, 12, 16, 20, 22, 26, 30, 32, 36, 40})
+_ASYNC_VIDEO_LENGTHS = frozenset({6, 10})
+_VIDEO_SEGMENT_LENGTHS: dict[int, list[int]] = {
+ 6: [6],
+ 10: [10],
+ 12: [6, 6],
+ 16: [10, 6],
+ 20: [10, 10],
+ 22: [10, 6, 6],
+ 26: [10, 10, 6],
+ 30: [10, 10, 10],
+ 32: [10, 10, 6, 6],
+ 36: [10, 10, 10, 6],
+ 40: [10, 10, 10, 10],
+}
_VIDEO_SIZE_MAP: dict[str, tuple[str, str]] = {
"720x1280": ("9:16", "720p"),
"1280x720": ("16:9", "720p"),
@@ -108,6 +135,7 @@ class _VideoJob:
remixed_from_video_id: str | None = None
video_url: str = ""
content_path: str = ""
+ content_file_id: str = ""
def to_dict(self) -> dict[str, Any]:
payload: dict[str, Any] = {
@@ -128,6 +156,11 @@ def to_dict(self) -> dict[str, Any]:
payload["error"] = self.error
if self.remixed_from_video_id:
payload["remixed_from_video_id"] = self.remixed_from_video_id
+ if self.status == "completed":
+ if self.content_file_id:
+ payload["metadata"] = {"url": _local_video_url(self.content_file_id)}
+ elif self.content_path:
+ payload["metadata"] = {"url": _video_content_url(self.id)}
return payload
@@ -139,6 +172,12 @@ def _build_message(prompt: str, preset: str) -> str:
return f"{prompt} {_PRESET_FLAGS.get(preset, '--mode=custom')}".strip()
+def _video_content_url(video_id: str) -> str:
+ app_url = get_config().get_str("app.app_url", "").rstrip("/")
+ path = f"/v1/videos/{video_id}/content"
+ return f"{app_url}{path}" if app_url else path
+
+
def _progress_reason(progress: int) -> str:
return f"视频正在生成 {max(0, min(100, int(progress)))}%"
@@ -169,6 +208,16 @@ def validate_video_length(seconds: int) -> None:
raise ValidationError(f"seconds must be one of [{allowed}]", param="seconds")
+def validate_async_video_length(seconds: int) -> None:
+ if seconds not in _ASYNC_VIDEO_LENGTHS:
+ allowed = ", ".join(str(item) for item in sorted(_ASYNC_VIDEO_LENGTHS))
+ raise ValidationError(
+ f"seconds must be one of [{allowed}] for /v1/videos; "
+ "use /v1/chat/completions for longer videos",
+ param="seconds",
+ )
+
+
def _resolve_video_size(size: str) -> tuple[str, str]:
normalized = (size or "720x1280").strip()
config = _VIDEO_SIZE_MAP.get(normalized)
@@ -196,20 +245,62 @@ def _resolve_video_preset(value: str | None, *, default: str = "custom") -> str:
def _build_segment_lengths(seconds: int) -> list[int]:
- if seconds == 6:
- return [6]
- if seconds == 10:
- return [10]
- if seconds == 12:
- return [6, 6]
- if seconds == 16:
- return [10, 6]
- if seconds == 20:
- return [10, 10]
+ segments = _VIDEO_SEGMENT_LENGTHS.get(seconds)
+ if segments is not None:
+ return list(segments)
validate_video_length(seconds)
raise AssertionError("unreachable")
+def _normalize_segment_prompts(
+ prompt: str,
+ segment_lengths: list[int],
+ segment_prompts: list[str] | None = None,
+ *,
+ strict: bool = False,
+) -> list[str]:
+ prompts = [p.strip() for p in (segment_prompts or [prompt]) if p and p.strip()]
+ if not prompts:
+ raise ValidationError("Video prompt cannot be empty", param="messages")
+ if strict and len(prompts) != len(segment_lengths):
+ raise ValidationError(
+ "Video generation requires exactly "
+ f"{len(segment_lengths)} prompt segment(s) for this duration",
+ param="messages",
+ )
+ if len(prompts) > len(segment_lengths):
+ raise ValidationError(
+ f"Video generation uses {len(segment_lengths)} prompt segment(s) for this duration",
+ param="messages",
+ )
+ while len(prompts) < len(segment_lengths):
+ prompts.append(prompts[-1])
+ return prompts
+
+
+def _video_pool_candidates(spec) -> tuple[int, ...]:
+ return tuple(
+ pool_id
+ for pool_id in spec.pool_candidates()
+ if pool_id in _VIDEO_ALLOWED_POOL_IDS
+ )
+
+
+def _empty_video_response_error() -> UpstreamError:
+ return UpstreamError(
+ "Video upstream returned empty response",
+ status=429,
+ body=EMPTY_UPSTREAM_BODY,
+ )
+
+
+def _transport_upstream_error(exc: BaseException, *, context: str) -> UpstreamError:
+ if isinstance(exc, UpstreamError):
+ return exc
+ body = str(exc).replace("\n", "\\n")[:400]
+ return UpstreamError(f"{context}: {exc}", status=502, body=body)
+
+
def _video_create_payload(
*,
prompt: str,
@@ -337,13 +428,18 @@ async def _stream_video_request(
kwargs = build_session_kwargs(lease=lease)
async with ResettableSession(**kwargs) as session:
- response = await session.post(
- CHAT,
- headers=headers,
- data=orjson.dumps(payload),
- timeout=timeout_s,
- stream=True,
- )
+ try:
+ response = await session.post(
+ CHAT,
+ headers=headers,
+ data=orjson.dumps(payload),
+ timeout=timeout_s,
+ stream=True,
+ )
+ except Exception as exc:
+ raise _transport_upstream_error(
+ exc, context="Video transport failed"
+ ) from exc
if response.status_code != 200:
body = response.content.decode("utf-8", "replace")[:300]
raise UpstreamError(
@@ -351,8 +447,13 @@ async def _stream_video_request(
status=response.status_code,
body=body,
)
- async for line in response.aiter_lines():
- yield line
+ try:
+ async for line in response.aiter_lines():
+ yield line
+ except Exception as exc:
+ raise _transport_upstream_error(
+ exc, context="Video stream read failed"
+ ) from exc
def _absolutize_video_url(url: str) -> str:
@@ -436,6 +537,12 @@ async def _prepare_video_references(
input_references: list[dict[str, Any]],
) -> list[_VideoReference]:
"""Upload multiple video references concurrently and preserve order."""
+ if len(input_references) > _VIDEO_MAX_REFERENCES:
+ raise ValidationError(
+ f"Video generation supports at most {_VIDEO_MAX_REFERENCES} reference images",
+ param="input_reference",
+ )
+
tasks = [
_prepare_video_reference(token, ref)
for ref in input_references
@@ -490,6 +597,7 @@ async def _collect_video_segment(
final_thumbnail = ""
video_post_id = ""
stream_data_items: list[str] = []
+ saw_video_signal = False
async for line in _stream_video_request(
token,
@@ -511,6 +619,7 @@ async def _collect_video_segment(
stream = _extract_streaming_video_response(obj)
if stream:
+ saw_video_signal = True
try:
progress = int(stream.get("progress") or 0)
except (TypeError, ValueError):
@@ -537,6 +646,8 @@ async def _collect_video_segment(
final_thumbnail = _absolutize_video_url(thumbnail)
attachments = _extract_model_response_file_attachments(obj)
+ if attachments:
+ saw_video_signal = True
if attachments and not final_asset_id:
final_asset_id = attachments[0]
@@ -549,6 +660,8 @@ async def _collect_video_segment(
body="\n".join(stream_data_items),
)
if not final_url:
+ if not stream_data_items or not saw_video_signal:
+ raise _empty_video_response_error()
raise UpstreamError(
"Video generation returned no final video URL",
body="\n".join(stream_data_items),
@@ -638,8 +751,11 @@ async def _generate_video_with_token(
preset: str,
timeout_s: float,
input_references: list[dict[str, Any]] | None = None,
+ segment_prompts: list[str] | None = None,
progress_cb: Callable[[int], Awaitable[None]] | None = None,
) -> _VideoArtifact:
+ segments = _build_segment_lengths(seconds)
+ prompts = _normalize_segment_prompts(prompt, segments, segment_prompts)
references: list[_VideoReference] = []
if input_references:
references = await _prepare_video_references(token, input_references)
@@ -648,7 +764,7 @@ async def _generate_video_with_token(
post = await create_media_post(
token,
media_type=_VIDEO_MEDIA_TYPE,
- prompt=prompt,
+ prompt=prompts[0],
referer="https://grok.com/imagine",
)
post_data = post.get("post")
@@ -658,16 +774,16 @@ async def _generate_video_with_token(
if not parent_post_id:
raise UpstreamError("Video create-post returned no post id")
- segments = _build_segment_lengths(seconds)
total_segments = len(segments)
artifact: _VideoArtifact | None = None
extend_post_id = parent_post_id
elapsed_seconds = 0
for index, segment_length in enumerate(segments):
+ segment_prompt = prompts[index]
if index == 0:
payload = _video_create_payload(
- prompt=prompt,
+ prompt=segment_prompt,
parent_post_id=parent_post_id,
aspect_ratio=aspect_ratio,
resolution_name=resolution_name,
@@ -680,7 +796,7 @@ async def _generate_video_with_token(
referer = "https://grok.com/imagine"
else:
payload = _video_extend_payload(
- prompt=prompt,
+ prompt=segment_prompt,
parent_post_id=parent_post_id,
extend_post_id=extend_post_id,
aspect_ratio=aspect_ratio,
@@ -725,6 +841,7 @@ async def _run_video_generation(
seconds: int,
preset: str = "custom",
input_references: list[dict[str, Any]] | None = None,
+ segment_prompts: list[str] | None = None,
progress_cb: Callable[[int], Awaitable[None]] | None = None,
) -> _VideoArtifact:
async def _runner(token: str, timeout_s: float) -> _VideoArtifact:
@@ -737,6 +854,7 @@ async def _runner(token: str, timeout_s: float) -> _VideoArtifact:
preset=preset,
timeout_s=timeout_s,
input_references=input_references,
+ segment_prompts=segment_prompts,
progress_cb=progress_cb,
)
@@ -753,44 +871,94 @@ async def _run_video_with_account(
spec = resolve_model(model)
if not spec.is_video():
raise ValidationError(f"Model {model!r} is not a video model", param="model")
+ pool_candidates = _video_pool_candidates(spec)
+ if not pool_candidates:
+ raise RateLimitError("No super/heavy accounts available for video generation")
from app.dataplane.account import _directory as _acct_dir
if _acct_dir is None:
raise RateLimitError("Account directory not initialised")
- acct = await _acct_dir.reserve(
- pool_candidates=spec.pool_candidates(),
- mode_id=int(spec.mode_id),
- now_s_override=now_s(),
- )
- if acct is None:
- raise RateLimitError("No available accounts for video generation")
+ max_retries = selection_max_retries()
+ retry_codes = _configured_retry_codes(cfg)
+ excluded: list[str] = []
+ mode_id = int(spec.mode_id)
- token = acct.token
- success = False
- fail_exc: BaseException | None = None
- try:
- artifact = await runner(token, timeout_s)
- success = True
- return artifact
- except BaseException as exc:
- fail_exc = exc
- raise
- finally:
- await _acct_dir.release(acct)
- kind = (
- FeedbackKind.SUCCESS
- if success
- else _feedback_kind(fail_exc)
- if fail_exc
- else FeedbackKind.SERVER_ERROR
+ async def _reserve():
+ acct = await _acct_dir.reserve(
+ pool_candidates=pool_candidates,
+ mode_id=mode_id,
+ now_s_override=now_s(),
+ exclude_tokens=excluded or None,
)
- await _acct_dir.feedback(token, kind, int(spec.mode_id))
- if success:
- asyncio.create_task(_quota_sync(token, int(spec.mode_id)))
- else:
- asyncio.create_task(_fail_sync(token, int(spec.mode_id), fail_exc))
+ if acct is not None or current_strategy() == "random":
+ return acct
+ refresh_svc = get_refresh_service()
+ if refresh_svc is not None:
+ await refresh_svc.refresh_on_demand()
+ acct = await _acct_dir.reserve(
+ pool_candidates=pool_candidates,
+ mode_id=mode_id,
+ now_s_override=now_s(),
+ exclude_tokens=excluded or None,
+ )
+ return acct
+
+ for attempt in range(max_retries + 1):
+ acct = await _reserve()
+ if acct is None:
+ raise RateLimitError("No available super/heavy accounts for video generation")
+
+ token = acct.token
+ success = False
+ retry = False
+ fail_exc: BaseException | None = None
+ try:
+ artifact = await runner(token, timeout_s)
+ success = True
+ return artifact
+ except UpstreamError as exc:
+ fail_exc = exc
+ if _should_retry_upstream(exc, retry_codes) and attempt < max_retries:
+ retry = True
+ logger.warning(
+ "video retry scheduled: attempt={}/{} status={} token={}...",
+ attempt + 1,
+ max_retries,
+ exc.status,
+ token[:8],
+ )
+ else:
+ raise
+ except BaseException as exc:
+ fail_exc = exc
+ raise
+ finally:
+ await _acct_dir.release(acct)
+ kind = (
+ FeedbackKind.SUCCESS
+ if success
+ else _feedback_kind(fail_exc)
+ if fail_exc
+ else FeedbackKind.SERVER_ERROR
+ )
+ await _acct_dir.feedback(token, kind, mode_id)
+ if success:
+ asyncio.create_task(_quota_sync(token, mode_id)).add_done_callback(
+ _log_task_exception
+ )
+ else:
+ asyncio.create_task(_fail_sync(token, mode_id, fail_exc)).add_done_callback(
+ _log_task_exception
+ )
+
+ if retry:
+ excluded.append(token)
+ continue
+ break
+
+ raise RateLimitError("No available super/heavy accounts for video generation")
async def _put_video_job(job: _VideoJob) -> None:
@@ -840,33 +1008,11 @@ async def _run_video_job(
default=default_resolution_name,
)
resolved_preset = _resolve_video_preset(preset)
- spec = resolve_model(job.model)
-
- from app.dataplane.account import _directory as _acct_dir
-
- if _acct_dir is None:
- raise RateLimitError("Account directory not initialised")
-
- acct = await _acct_dir.reserve(
- pool_candidates=spec.pool_candidates(),
- mode_id=int(spec.mode_id),
- now_s_override=now_s(),
- )
- if acct is None:
- raise RateLimitError("No available accounts for video generation")
- token = acct.token
- success = False
- fail_exc: BaseException | None = None
- try:
- cfg = get_config()
- timeout_s = cfg.get_float("video.timeout", 180.0)
-
- async def _progress(progress: int) -> None:
- await _set_job_status(
- job, status="in_progress", progress=max(1, progress)
- )
+ async def _progress(progress: int) -> None:
+ await _set_job_status(job, status="in_progress", progress=max(1, progress))
+ async def _runner(token: str, timeout_s: float) -> tuple[_VideoArtifact, bytes]:
artifact = await _generate_video_with_token(
token=token,
prompt=prompt,
@@ -879,32 +1025,19 @@ async def _progress(progress: int) -> None:
progress_cb=_progress,
)
raw, _mime = await _download_video_bytes(token, artifact.video_url)
- success = True
- except BaseException as exc:
- fail_exc = exc
- raise
- finally:
- await _acct_dir.release(acct)
- kind = (
- FeedbackKind.SUCCESS
- if success
- else _feedback_kind(fail_exc)
- if fail_exc
- else FeedbackKind.SERVER_ERROR
- )
- await _acct_dir.feedback(token, kind, int(spec.mode_id))
- if success:
- asyncio.create_task(_quota_sync(token, int(spec.mode_id)))
- else:
- asyncio.create_task(_fail_sync(token, int(spec.mode_id), fail_exc))
+ return artifact, raw
- path = _save_video_bytes(raw, job.id)
+ artifact, raw = await _run_video_with_account(model=job.model, runner=_runner)
+
+ file_id = hashlib.sha1(artifact.video_url.encode("utf-8")).hexdigest()[:32]
+ path = _save_video_bytes(raw, file_id)
async with _VIDEO_JOBS_LOCK:
job.status = "completed"
job.progress = 100
job.completed_at = int(time.time())
job.video_url = artifact.video_url
job.content_path = str(path)
+ job.content_file_id = file_id
job.remixed_from_video_id = artifact.remixed_from_video_id
except Exception as exc:
logger.exception("video job failed: job_id={} error={}", job.id, exc)
@@ -932,7 +1065,7 @@ async def create_video(
raise ValidationError("prompt cannot be empty", param="prompt")
normalized_seconds = _coerce_seconds(seconds)
- validate_video_length(normalized_seconds)
+ validate_async_video_length(normalized_seconds)
normalized_size = (size or "720x1280").strip()
_aspect_ratio, default_resolution_name = _resolve_video_size(normalized_size)
_resolve_video_resolution_name(resolution_name, default=default_resolution_name)
@@ -992,52 +1125,68 @@ async def content_path(video_id: str) -> Path:
def _extract_video_prompt_and_reference(
messages: list[dict],
) -> tuple[str, list[dict[str, Any]] | None]:
- prompt = ""
- reference_urls: list[str] = []
+ prompts, input_references = _extract_video_segment_prompts_and_references(messages)
+ return prompts[-1], input_references
- for msg in reversed(messages):
- content = msg.get("content", "")
- if isinstance(content, str) and content.strip():
- prompt = content.strip()
- if prompt:
- break
- continue
- if not isinstance(content, list):
+
+def _text_and_references_from_content(
+ content: str | list[dict[str, Any]] | None,
+) -> tuple[str, list[str]]:
+ if isinstance(content, str):
+ return content.strip(), []
+ if not isinstance(content, list):
+ return "", []
+
+ text_parts: list[str] = []
+ reference_urls: list[str] = []
+ for item in content:
+ if not isinstance(item, dict):
continue
+ item_type = item.get("type")
+ if item_type == "text":
+ text = str(item.get("text") or "").strip()
+ if text:
+ text_parts.append(text)
+ elif item_type == "image_url":
+ image_url = item.get("image_url")
+ if isinstance(image_url, dict):
+ url = str(image_url.get("url") or "").strip()
+ if url:
+ reference_urls.append(url)
+ elif isinstance(image_url, str) and image_url.strip():
+ reference_urls.append(image_url.strip())
+ return " ".join(text_parts).strip(), reference_urls
+
+
+def _extract_video_segment_prompts_and_references(
+ messages: list[dict],
+) -> tuple[list[str], list[dict[str, Any]] | None]:
+ prompts: list[str] = []
+ reference_urls: list[str] = []
- text_parts: list[str] = []
- block_references: list[str] = []
- for item in content:
- if not isinstance(item, dict):
- continue
- item_type = item.get("type")
- if item_type == "text":
- text = str(item.get("text") or "").strip()
- if text:
- text_parts.append(text)
- elif item_type == "image_url":
- image_url = item.get("image_url")
- if isinstance(image_url, dict):
- url = str(image_url.get("url") or "").strip()
- if url:
- block_references.append(url)
- elif isinstance(image_url, str) and image_url.strip():
- block_references.append(image_url.strip())
-
- if text_parts:
- prompt = " ".join(text_parts)
+ for msg in messages:
+ if msg.get("role", "user") != "user":
+ continue
+ prompt, block_references = _text_and_references_from_content(
+ msg.get("content", "")
+ )
+ if prompt:
+ prompts.append(prompt)
if block_references and not reference_urls:
reference_urls = block_references
- if prompt:
- break
- if not prompt:
+ if not prompts:
raise ValidationError("Video prompt cannot be empty", param="messages")
input_references: list[dict[str, Any]] | None = None
if reference_urls:
- input_references = [{"image_url": url} for url in reference_urls[:7]]
- return prompt, input_references
+ if len(reference_urls) > _VIDEO_MAX_REFERENCES:
+ raise ValidationError(
+ f"Video generation supports at most {_VIDEO_MAX_REFERENCES} reference images",
+ param="messages",
+ )
+ input_references = [{"image_url": url} for url in reference_urls]
+ return prompts, input_references
async def completions(
@@ -1058,7 +1207,17 @@ async def completions(
default=default_resolution_name,
)
resolved_preset = _resolve_video_preset(preset)
- prompt, input_references = _extract_video_prompt_and_reference(messages)
+ segments = _build_segment_lengths(seconds)
+ segment_prompts, input_references = _extract_video_segment_prompts_and_references(
+ messages
+ )
+ normalized_segment_prompts = _normalize_segment_prompts(
+ segment_prompts[0],
+ segments,
+ segment_prompts,
+ strict=True,
+ )
+ prompt = normalized_segment_prompts[0]
cfg = get_config()
is_stream = stream if stream is not None else cfg.get_bool("features.stream", False)
@@ -1075,6 +1234,7 @@ async def _runner(token: str, timeout_s: float) -> str:
preset=resolved_preset,
timeout_s=timeout_s,
input_references=input_references,
+ segment_prompts=normalized_segment_prompts,
progress_cb=progress_cb,
)
file_id = hashlib.sha1(artifact.video_url.encode("utf-8")).hexdigest()[:32]
@@ -1140,7 +1300,12 @@ async def _progress(progress: int) -> None:
"retrieve",
"content_path",
"validate_video_length",
+ "validate_async_video_length",
"completions",
"_build_segment_lengths",
+ "_empty_video_response_error",
+ "_extract_video_segment_prompts_and_references",
+ "_normalize_segment_prompts",
"_resolve_video_size",
+ "_video_pool_candidates",
]
diff --git a/app/products/web/admin/__init__.py b/app/products/web/admin/__init__.py
index 88093af78..313789485 100644
--- a/app/products/web/admin/__init__.py
+++ b/app/products/web/admin/__init__.py
@@ -196,7 +196,10 @@ async def update_config(req: ConfigPatchRequest):
patch = _sanitize_proxy_config(req.root)
_ensure_runtime_patch_allowed(patch)
- cache_local_changed = _patch_touches_prefix(patch, "cache.local")
+ cache_local_changed = (
+ _patch_touches_prefix(patch, "cache.local")
+ or _patch_touches_prefix(patch, "storage")
+ )
await config.update(patch)
# config.update() only writes to the backend and invalidates the in-memory
# snapshot (_version = None); it does not refresh the data. load() is
diff --git a/app/products/web/admin/cache.py b/app/products/web/admin/cache.py
index fc49ab597..c6abb504a 100644
--- a/app/products/web/admin/cache.py
+++ b/app/products/web/admin/cache.py
@@ -1,29 +1,21 @@
"""Local media cache management — stats, list, clear, delete."""
import asyncio
-from pathlib import Path
-from typing import Any, Literal
+from typing import Literal
from fastapi import APIRouter, Query
from pydantic import BaseModel
-from app.platform.config.snapshot import get_config
from app.platform.errors import AppError, ErrorKind
from app.platform.storage import (
clear_local_media_files,
delete_local_media_file,
- image_files_dir,
- video_files_dir,
+ list_local_media_files,
+ local_media_stats,
)
router = APIRouter(prefix="/cache", tags=["Admin - Cache"])
-# ---------------------------------------------------------------------------
-# Lightweight local media cache service.
-# ---------------------------------------------------------------------------
-_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
-_VIDEO_EXTS = {".mp4", ".mov", ".m4v", ".webm", ".avi", ".mkv"}
-
class ClearCacheRequest(BaseModel):
type: Literal["image", "video"] = "image"
@@ -39,67 +31,6 @@ class DeleteCacheItemsRequest(BaseModel):
names: list[str]
-def _dir(media_type: str) -> Path:
- return image_files_dir() if media_type == "image" else video_files_dir()
-
-
-def _exts(media_type: str):
- return _IMAGE_EXTS if media_type == "image" else _VIDEO_EXTS
-
-
-def _limit_mb(media_type: str) -> int:
- cfg = get_config()
- return max(0, int(cfg.get_int(f"cache.local.{media_type}_max_mb", 0)))
-
-
-def _stats(media_type: str) -> dict[str, Any]:
- d = _dir(media_type)
- files = []
- if d.exists():
- allowed = _exts(media_type)
- files = [f for f in d.glob("*") if f.is_file() and f.suffix.lower() in allowed]
-
- total_size = sum(f.stat().st_size for f in files)
- limit_mb = _limit_mb(media_type)
- limit_bytes = limit_mb * 1024 * 1024
- usage_ratio = (total_size / limit_bytes) if limit_bytes > 0 else None
- usage_percent = round(usage_ratio * 100, 1) if usage_ratio is not None else None
- return {
- "count": len(files),
- "size_mb": round(total_size / 1024 / 1024, 2),
- "size_bytes": total_size,
- "limit_mb": limit_mb,
- "limit_bytes": limit_bytes,
- "limited": limit_bytes > 0,
- "usage_ratio": round(usage_ratio, 4) if usage_ratio is not None else None,
- "usage_percent": usage_percent,
- }
-
-
-def _list_files(media_type: str, page: int, page_size: int) -> dict[str, Any]:
- d = _dir(media_type)
- if not d.exists():
- return {"total": 0, "page": page, "page_size": page_size, "items": []}
- allowed = _exts(media_type)
- files = sorted(
- (f for f in d.glob("*") if f.is_file() and f.suffix.lower() in allowed),
- key=lambda f: f.stat().st_mtime,
- reverse=True,
- )
- total = len(files)
- start = (page - 1) * page_size
- chunk = files[start : start + page_size]
- items = []
- for f in chunk:
- st = f.stat()
- items.append({
- "name": f.name,
- "size_bytes": st.st_size,
- "modified_at": st.st_mtime,
- })
- return {"total": total, "page": page, "page_size": page_size, "items": items}
-
-
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@@ -107,8 +38,8 @@ def _list_files(media_type: str, page: int, page_size: int) -> dict[str, Any]:
@router.get("")
async def cache_stats():
return {
- "local_image": _stats("image"),
- "local_video": _stats("video"),
+ "local_image": local_media_stats("image"),
+ "local_video": local_media_stats("video"),
}
@@ -120,7 +51,10 @@ async def list_local(
page_size: int = 1000,
):
media_type = type_ or cache_type
- return {"status": "success", **_list_files(media_type, page, page_size)}
+ return {
+ "status": "success",
+ **list_local_media_files(media_type, page=page, page_size=page_size),
+ }
@router.post("/clear")
diff --git a/app/products/web/admin/tokens.py b/app/products/web/admin/tokens.py
index bf5bf10f6..70e535b5e 100644
--- a/app/products/web/admin/tokens.py
+++ b/app/products/web/admin/tokens.py
@@ -121,6 +121,7 @@ def _quota_brief(q: dict) -> dict:
def _serialize_record(r) -> dict:
return {
"token": r.token,
+ "account_id": r.account_id,
"pool": r.pool or "basic",
"status": r.status,
"quota": _quota_brief(r.quota) if isinstance(r.quota, dict) else {},
diff --git a/config.defaults.toml b/config.defaults.toml
index f7a9b9437..8c132b639 100644
--- a/config.defaults.toml
+++ b/config.defaults.toml
@@ -78,9 +78,9 @@ mode = "direct"
proxy_url = ""
# 基础代理池(API 流量,proxy_pool 模式必填)
proxy_pool = []
-# 资源代理 URL(图片/视频下载,未配置则回落到 proxy_url)
+# 资源代理 URL(图片/视频下载,未配置则直连)
resource_proxy_url = ""
-# 资源代理池(图片/视频下载,未配置则回落到 proxy_pool)
+# 资源代理池(图片/视频下载,未配置则直连)
resource_proxy_pool = []
# 跳过代理 SSL 证书验证(代理使用自签名证书时启用)
skip_ssl_verify = false
diff --git a/docs/README.en.md b/docs/README.en.md
index 1cf8a7a6d..2968baf61 100644
--- a/docs/README.en.md
+++ b/docs/README.en.md
@@ -111,7 +111,7 @@ docker compose up -d
### Vercel
-[](https://vercel.com/new/clone?repository-url=https://github.com/chenyme/grok2api&env=LOG_LEVEL,LOG_FILE_ENABLED,DATA_DIR,LOG_DIR,ACCOUNT_STORAGE,ACCOUNT_REDIS_URL,ACCOUNT_MYSQL_URL,ACCOUNT_POSTGRESQL_URL)
+[](https://vercel.com/new/clone?repository-url=https://github.com/chenyme/grok2api&env=LOG_LEVEL,LOG_FILE_ENABLED,DATA_DIR,LOG_DIR,ACCOUNT_STORAGE,ACCOUNT_REDIS_URL,ACCOUNT_MYSQL_URL,ACCOUNT_POSTGRESQL_URL,CONFIG_STORAGE,CONFIG_LOCAL_PATH)
### Render
@@ -186,7 +186,8 @@ docker compose up -d
| `ACCOUNT_SQL_MAX_OVERFLOW` | Maximum overflow connections above pool size | `10` |
| `ACCOUNT_SQL_POOL_TIMEOUT` | Seconds to wait for a free connection from the pool | `30` |
| `ACCOUNT_SQL_POOL_RECYCLE` | Max connection lifetime in seconds before reconnect | `1800` |
-| `CONFIG_LOCAL_PATH` | Runtime config file path for `local` config storage | `${DATA_DIR}/config.toml` |
+| `CONFIG_STORAGE` | Runtime config storage backend; defaults to local TOML and does not follow `ACCOUNT_STORAGE` | `local` |
+| `CONFIG_LOCAL_PATH` | Local runtime config file path | `${DATA_DIR}/config.toml` |
Runtime config can also be overridden with `GROK_`-prefixed environment variables. For example, `GROK_APP_API_KEY` overrides `app.api_key`, and `GROK_FEATURES_STREAM` overrides `features.stream`.
@@ -361,17 +362,33 @@ curl http://localhost:8000/v1/chat/completions \
"model": "grok-imagine-video",
"stream": true,
"messages": [
- {"role":"user","content":"A neon rainy street at night, cinematic slow tracking shot"}
+ {"role":"user","content":"A neon rainy street at night, cinematic slow tracking shot"},
+ {"role":"user","content":"The camera passes under glowing signs as the character turns toward distant headlights"}
],
- "video_config": {
- "seconds": 10,
- "size": "1792x1024",
- "resolution_name": "720p",
- "preset": "normal"
+ "metadata": {
+ "video_config": {
+ "seconds": 16,
+ "size": "1792x1024",
+ "resolution_name": "720p",
+ "preset": "normal"
+ }
}
}'
```
+`image_config` / `video_config` may also be sent at the request top level; top-level fields take precedence over matching `metadata` fields.
+
+```json
+{
+ "video_config": {
+ "seconds": 10,
+ "size": "1792x1024",
+ "resolution_name": "720p",
+ "preset": "normal"
+ }
+}
+```
+
Field Notes
@@ -389,10 +406,11 @@ curl http://localhost:8000/v1/chat/completions \
| \|_ `size` | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` |
| \|_ `response_format` | `url`, `b64_json` |
| `video_config` | Video model parameters |
-| \|_ `seconds` | `6`, `10`, `12`, `16`, `20` |
+| \|_ `seconds` | `6`, `10`, `12`, `16`, `20`, `22`, `26`, `30`, `32`, `36`, `40`; longer videos require one user prompt per segment |
| \|_ `size` | `720x1280`, `1280x720`, `1024x1024`, `1024x1792`, `1792x1024` |
| \|_ `resolution_name` | `480p`, `720p` |
| \|_ `preset` | `fun`, `normal`, `spicy`, `custom` |
+| `metadata.image_config` / `metadata.video_config` | Compatibility location for `image_config` / `video_config`; top-level config wins |
@@ -588,7 +606,7 @@ curl -L http://localhost:8000/v1/videos//content \
| :-- | :-- |
| `model` | Video model, currently `grok-imagine-video` |
| `prompt` | Video generation prompt |
-| `seconds` | Video length: `6`, `10`, `12`, `16`, `20` |
+| `seconds` | Video length: `6`, `10`; use `/v1/chat/completions` for longer videos |
| `size` | Supports `720x1280`, `1280x720`, `1024x1024`, `1024x1792`, `1792x1024` |
| `resolution_name` | `480p` or `720p` |
| `preset` | `fun`, `normal`, `spicy`, `custom` |
diff --git a/scripts/inject_cookie.py b/scripts/inject_cookie.py
new file mode 100644
index 000000000..a78776559
--- /dev/null
+++ b/scripts/inject_cookie.py
@@ -0,0 +1,198 @@
+#!/usr/bin/env python3
+"""Grok cookie injector — set the SSO cookie in a browser and open Grok.
+
+Usage:
+ python inject_cookie.py
+ python inject_cookie.py --url https://grok.com
+ python inject_cookie.py --browser firefox
+
+The cookie value can be:
+ - A raw JWT: eyJ0eXAiOiJKV1Qi...
+ - With sso= prefix: sso=eyJ0eXAiOiJKV1Qi...
+ - A full cookie header: sso=eyJ...; Domain=.grok.com; Path=/
+
+Requirements:
+ pip install browser-cookie3 (optional, for reading existing cookies)
+
+The script uses webbrowser to open the URL and prints instructions for
+manual cookie injection into Chrome DevTools.
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import re
+import sys
+import webbrowser
+
+# ── JWT decode (no signature verification) ──────────────────────────────
+
+
+def decode_jwt_payload(raw: str) -> dict:
+ import base64
+
+ parts = raw.strip().split(".")
+ if len(parts) != 3:
+ raise ValueError("Not a valid JWT (expected 3 segments)")
+
+ payload_b64 = parts[1]
+ payload_b64 += "=" * (4 - len(payload_b64) % 4)
+ try:
+ payload = json.loads(base64.urlsafe_b64decode(payload_b64))
+ except Exception as exc:
+ raise ValueError(f"Cannot decode JWT payload: {exc}") from exc
+ return payload
+
+
+def parse_cookie(raw: str) -> tuple[str, str, dict]:
+ """Parse a Grok SSO cookie value.
+
+ Returns (cookie_name, cookie_value, jwt_payload).
+ """
+ text = raw.strip()
+
+ # Strip browser export format: "Cookie: sso=..."
+ if text.lower().startswith("cookie:"):
+ text = text[len("cookie:"):].strip()
+
+ # Find the JWT by looking for sso= prefix
+ cookie_value = text
+ if text.startswith("sso="):
+ cookie_value = text[4:]
+ if ";" in cookie_value:
+ cookie_value = cookie_value.split(";")[0].strip()
+
+ # If no sso= prefix, assume raw JWT
+ cookie_value = cookie_value.strip()
+ payload = decode_jwt_payload(cookie_value)
+ return ("sso", cookie_value, payload)
+
+
+# ── Browser injection helpers ───────────────────────────────────────────
+
+
+def inject_via_playwright(cookie_value: str, url: str) -> int:
+ """Use Playwright to inject the cookie and open the browser."""
+ try:
+ from playwright.sync_api import sync_playwright
+ except ImportError:
+ print("[!] playwright not installed. Install with: pip install playwright && playwright install chromium")
+ return 1
+
+ with sync_playwright() as p:
+ browser = p.chromium.launch(headless=False)
+ context = browser.new_context()
+ context.add_cookies([
+ {
+ "name": "sso",
+ "value": cookie_value,
+ "domain": ".grok.com",
+ "path": "/",
+ "httpOnly": False,
+ "secure": True,
+ "sameSite": "Lax",
+ }
+ ])
+ page = context.new_page()
+ page.goto(url)
+ print(f"[+] Cookie injected — browser opened at {url}")
+ print("[+] Press Ctrl+C to close the browser...")
+ try:
+ page.wait_for_timeout(60_000) # 60s before auto-close
+ except KeyboardInterrupt:
+ pass
+ browser.close()
+ return 0
+
+
+def print_manual_instructions(cookie_value: str, url: str) -> None:
+ """Print manual DevTools injection instructions."""
+ js = (
+ "document.cookie = 'sso=" + cookie_value + "; "
+ "domain=.grok.com; path=/; SameSite=Lax; Secure';"
+ )
+
+ print(f"""
+╔══════════════════════════════════════════════════════════════╗
+║ Grok Cookie Injection — Manual Instructions ║
+╠══════════════════════════════════════════════════════════════╣
+║ ║
+║ 1. Open {url} in your browser ║
+║ 2. Press F12 to open DevTools ║
+║ 3. Go to the Console tab ║
+║ 4. Paste this command: ║
+║ ║
+║ {js}
+║ ║
+║ 5. Press Enter (the page will reload) ║
+║ 6. You should be logged in ║
+║ ║
+╠══════════════════════════════════════════════════════════════╣
+║ Alternative — Application tab: ║
+║ 1. F12 → Application → Cookies → grok.com ║
+║ 2. Add cookie: name=sso, value=, path=/ ║
+║ 3. Refresh the page ║
+║ ║
+╠══════════════════════════════════════════════════════════════╣
+║ For grok2api import: ║
+║ Use the token value directly: ║
+║ curl -X POST .../admin/api/tokens/add ║
+║ -d '{{"tokens": ["{cookie_value}"], "pool": "auto"}}' ║
+╚══════════════════════════════════════════════════════════════╝
+""")
+
+
+# ── Main ─────────────────────────────────────────────────────────────────
+
+
+def main() -> int:
+ parser = argparse.ArgumentParser(
+ description="Inject a Grok SSO cookie into a browser and open Grok",
+ )
+ parser.add_argument(
+ "cookie",
+ help="Grok SSO cookie value (JWT, optionally with 'sso=' prefix)",
+ )
+ parser.add_argument(
+ "--url", default="https://grok.com",
+ help="Target URL (default: https://grok.com)",
+ )
+ parser.add_argument(
+ "--manual", action="store_true",
+ help="Print manual DevTools injection instructions (no automation)",
+ )
+ parser.add_argument(
+ "--playwright", action="store_true",
+ help="Use Playwright to automate cookie injection",
+ )
+ args = parser.parse_args()
+
+ try:
+ name, value, payload = parse_cookie(args.cookie)
+ except ValueError as exc:
+ print(f"[!] Cookie parse error: {exc}", file=sys.stderr)
+ return 1
+
+ session_id = payload.get("session_id", "unknown")
+ print(f"[+] Parsed cookie: name={name}, session_id={session_id}")
+ print(f"[+] Full token: {value[:40]}...{value[-20:]}")
+
+ if args.playwright:
+ return inject_via_playwright(value, args.url)
+
+ # Default: print instructions and offer to open browser.
+ print_manual_instructions(value, args.url)
+
+ try:
+ choice = input("\n[?] Open browser now? (y/N): ").strip().lower()
+ if choice in ("y", "yes"):
+ webbrowser.open(args.url)
+ except (KeyboardInterrupt, EOFError):
+ pass
+
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/tests/test_chat_metadata_config.py b/tests/test_chat_metadata_config.py
new file mode 100644
index 000000000..9ec196727
--- /dev/null
+++ b/tests/test_chat_metadata_config.py
@@ -0,0 +1,84 @@
+import unittest
+
+from app.platform.errors import ValidationError
+from app.products.openai import router
+from app.products.openai.schemas import ChatCompletionRequest, ImageConfig, VideoConfig
+
+
+def _request(**kwargs) -> ChatCompletionRequest:
+ payload = {
+ "model": "grok-imagine-video",
+ "messages": [{"role": "user", "content": "hello"}],
+ }
+ payload.update(kwargs)
+ return ChatCompletionRequest.model_validate(payload)
+
+
+class ChatMetadataConfigTests(unittest.TestCase):
+ def test_metadata_video_config_is_used_as_fallback(self) -> None:
+ req = _request(
+ metadata={
+ "video_config": {
+ "seconds": 16,
+ "size": "1792x1024",
+ "resolution_name": "720p",
+ "preset": "normal",
+ }
+ }
+ )
+
+ cfg = router._resolve_video_config(req)
+
+ self.assertEqual(cfg.seconds, 16)
+ self.assertEqual(cfg.size, "1792x1024")
+ self.assertEqual(cfg.resolution_name, "720p")
+ self.assertEqual(cfg.preset, "normal")
+
+ def test_top_level_video_config_wins_over_metadata(self) -> None:
+ req = _request(
+ video_config=VideoConfig(seconds=6, size="720x1280"),
+ metadata={"video_config": {"seconds": 16, "size": "1792x1024"}},
+ )
+
+ cfg = router._resolve_video_config(req)
+
+ self.assertEqual(cfg.seconds, 6)
+ self.assertEqual(cfg.size, "720x1280")
+
+ def test_metadata_image_config_is_used_as_fallback(self) -> None:
+ req = _request(
+ metadata={
+ "image_config": {
+ "n": 3,
+ "size": "1792x1024",
+ "response_format": "url",
+ }
+ }
+ )
+
+ cfg = router._resolve_image_config(req)
+
+ self.assertEqual(cfg.n, 3)
+ self.assertEqual(cfg.size, "1792x1024")
+ self.assertEqual(cfg.response_format, "url")
+
+ def test_top_level_image_config_wins_over_metadata(self) -> None:
+ req = _request(
+ image_config=ImageConfig(n=1, size="1024x1024"),
+ metadata={"image_config": {"n": 3, "size": "1792x1024"}},
+ )
+
+ cfg = router._resolve_image_config(req)
+
+ self.assertEqual(cfg.n, 1)
+ self.assertEqual(cfg.size, "1024x1024")
+
+ def test_metadata_config_must_be_object(self) -> None:
+ req = _request(metadata={"video_config": "bad"})
+
+ with self.assertRaises(ValidationError):
+ router._resolve_video_config(req)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_config_backend_factory.py b/tests/test_config_backend_factory.py
new file mode 100644
index 000000000..35ef254cf
--- /dev/null
+++ b/tests/test_config_backend_factory.py
@@ -0,0 +1,63 @@
+import os
+import tempfile
+import unittest
+from pathlib import Path
+from unittest.mock import patch
+
+from app.platform.config.backends.factory import (
+ create_config_backend,
+ get_config_backend_name,
+)
+from app.platform.config.backends.sql import SqlConfigBackend
+from app.platform.config.backends.toml import TomlConfigBackend
+
+
+class ConfigBackendFactoryTests(unittest.TestCase):
+ def test_config_backend_defaults_to_local_when_account_storage_is_mysql(self) -> None:
+ with patch.dict(os.environ, {"ACCOUNT_STORAGE": "mysql"}, clear=True):
+ self.assertEqual(get_config_backend_name(), "local")
+
+ def test_config_backend_defaults_to_local_when_account_storage_is_redis(self) -> None:
+ with patch.dict(os.environ, {"ACCOUNT_STORAGE": "redis"}, clear=True):
+ self.assertEqual(get_config_backend_name(), "local")
+
+ def test_create_config_backend_uses_toml_for_explicit_local(self) -> None:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ config_path = Path(tmpdir) / "config.toml"
+ with patch.dict(
+ os.environ,
+ {
+ "ACCOUNT_STORAGE": "mysql",
+ "CONFIG_STORAGE": "local",
+ "CONFIG_LOCAL_PATH": str(config_path),
+ },
+ clear=True,
+ ):
+ backend = create_config_backend()
+
+ self.assertIsInstance(backend, TomlConfigBackend)
+
+ def test_create_config_backend_uses_sql_only_when_config_storage_is_mysql(self) -> None:
+ sentinel_engine = object()
+ with patch.dict(
+ os.environ,
+ {
+ "ACCOUNT_STORAGE": "local",
+ "CONFIG_STORAGE": "mysql",
+ "ACCOUNT_MYSQL_URL": "mysql://user:pass@example.com/db",
+ },
+ clear=True,
+ ):
+ with patch(
+ "app.control.account.backends.sql.create_mysql_engine",
+ return_value=sentinel_engine,
+ ):
+ backend = create_config_backend()
+
+ self.assertIsInstance(backend, SqlConfigBackend)
+ self.assertIs(backend._engine, sentinel_engine)
+ self.assertEqual(backend._dialect, "mysql")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_empty_response_handling.py b/tests/test_empty_response_handling.py
new file mode 100644
index 000000000..769e8673a
--- /dev/null
+++ b/tests/test_empty_response_handling.py
@@ -0,0 +1,430 @@
+import asyncio
+import json
+import unittest
+from types import SimpleNamespace
+from unittest.mock import patch
+
+from app.control.account.enums import FeedbackKind
+from app.platform.errors import UpstreamError
+from app.products.anthropic import messages as anthropic_messages
+from app.products.anthropic import router as anthropic_router
+from app.products.openai import chat, responses, video
+from app.products.openai import router as openai_router
+
+
+class _FakeConfig:
+ def get(self, key: str, default=None):
+ return default
+
+ def get_bool(self, key: str, default: bool = False) -> bool:
+ return default
+
+ def get_float(self, key: str, default: float = 0.0) -> float:
+ return default
+
+ def get_str(self, key: str, default: str = "") -> str:
+ return default
+
+
+class _FakeDirectory:
+ def __init__(self) -> None:
+ self.feedbacks: list[tuple[str, FeedbackKind, int]] = []
+
+ async def release(self, acct) -> None:
+ return None
+
+ async def feedback(
+ self,
+ token: str,
+ kind: FeedbackKind,
+ selected_mode_id: int,
+ now_s_val=None,
+ ) -> None:
+ self.feedbacks.append((token, kind, selected_mode_id))
+
+
+class _FakeAdapter:
+ def __init__(
+ self,
+ *,
+ text: str = "",
+ images: list[tuple[str, str]] | None = None,
+ references: str = "",
+ sources: list[dict] | None = None,
+ thinking: str = "",
+ ) -> None:
+ self.text_buf = [text] if text else []
+ self.thinking_buf = [thinking] if thinking else []
+ self.image_urls = images or []
+ self._references = references
+ self._sources = sources
+
+ def references_suffix(self) -> str:
+ return self._references
+
+ def search_sources_list(self):
+ return self._sources
+
+
+async def _empty_stream(**kwargs):
+ yield "data: [DONE]"
+
+
+def _chat_frame(response: dict) -> str:
+ return "data: " + json.dumps({"result": {"response": response}}) + "\n\n"
+
+
+async def _thinking_stream(**kwargs):
+ yield _chat_frame({"token": "thinking one", "isThinking": True})
+ yield "data: [DONE]"
+
+
+async def _thinking_then_text_stream(**kwargs):
+ yield _chat_frame({"token": "thinking one", "isThinking": True})
+ yield _chat_frame({"token": "hello", "messageTag": "final"})
+ yield "data: [DONE]"
+
+
+async def _empty_video_stream(*args, **kwargs):
+ yield "data: [DONE]"
+
+
+async def _noop_async(*args, **kwargs):
+ return None
+
+
+async def _empty_generator():
+ if False:
+ yield ""
+
+
+class EmptyResponseHandlingTests(unittest.TestCase):
+ def test_visible_output_helper_ignores_thinking_only(self) -> None:
+ self.assertFalse(
+ chat._adapter_has_visible_output(_FakeAdapter(thinking="reasoning only"))
+ )
+
+ def test_visible_output_helper_accepts_text_tool_images_and_sources(self) -> None:
+ self.assertTrue(chat._adapter_has_visible_output(_FakeAdapter(text="hello")))
+ self.assertTrue(
+ chat._adapter_has_visible_output(
+ _FakeAdapter(images=[("https://example.com/a.png", "img-1")])
+ )
+ )
+ self.assertTrue(
+ chat._adapter_has_visible_output(
+ _FakeAdapter(sources=[{"url": "https://example.com", "title": "x"}])
+ )
+ )
+ self.assertTrue(
+ chat._adapter_has_visible_output(_FakeAdapter(), has_tool_calls=True)
+ )
+
+ def test_empty_response_error_is_retryable_429(self) -> None:
+ exc = chat._empty_upstream_response_error()
+
+ self.assertEqual(exc.status, 429)
+ self.assertEqual(exc.details["body"], chat.EMPTY_UPSTREAM_BODY)
+
+ def test_video_empty_response_error_is_retryable_429(self) -> None:
+ exc = video._empty_video_response_error()
+
+ self.assertEqual(exc.status, 429)
+ self.assertEqual(exc.details["body"], chat.EMPTY_UPSTREAM_BODY)
+
+ def test_video_segment_empty_stream_raises_429(self) -> None:
+ async def _run() -> None:
+ await video._collect_video_segment(
+ token="tok-test",
+ payload={},
+ referer="https://grok.com/imagine",
+ timeout_s=1,
+ )
+
+ with self.assertRaises(UpstreamError) as ctx:
+ with patch.object(
+ video, "_stream_video_request", side_effect=_empty_video_stream
+ ):
+ asyncio.run(_run())
+
+ self.assertEqual(ctx.exception.status, 429)
+ self.assertEqual(ctx.exception.details["body"], chat.EMPTY_UPSTREAM_BODY)
+
+ def test_stream_prime_turns_no_first_chunk_into_429(self) -> None:
+ async def _run() -> None:
+ await openai_router._prime_sse(_empty_generator())
+
+ with self.assertRaises(UpstreamError) as ctx:
+ asyncio.run(_run())
+
+ self.assertEqual(ctx.exception.status, 429)
+
+ def test_anthropic_stream_prime_turns_no_first_chunk_into_429(self) -> None:
+ async def _run() -> None:
+ await anthropic_router._prime_sse(_empty_generator())
+
+ with self.assertRaises(UpstreamError) as ctx:
+ asyncio.run(_run())
+
+ self.assertEqual(ctx.exception.status, 429)
+
+ def test_chat_non_stream_empty_response_raises_429_feedback(self) -> None:
+ directory = _FakeDirectory()
+ exc = self._run_chat_non_stream_empty(directory)
+
+ self.assertEqual(exc.status, 429)
+ self.assertEqual(exc.details["body"], chat.EMPTY_UPSTREAM_BODY)
+ self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED)
+
+ def test_chat_stream_empty_response_raises_429_before_first_chunk(self) -> None:
+ directory = _FakeDirectory()
+
+ async def _run() -> None:
+ stream = await chat.completions(
+ model="grok-test",
+ messages=[{"role": "user", "content": "hello"}],
+ stream=True,
+ )
+ await openai_router._prime_sse(stream)
+
+ with self.assertRaises(UpstreamError) as ctx:
+ with self._patch_common(chat, directory):
+ asyncio.run(_run())
+
+ self.assertEqual(ctx.exception.status, 429)
+ self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED)
+
+ def test_chat_stream_thinking_flushes_before_text(self) -> None:
+ directory = _FakeDirectory()
+
+ async def _run() -> str:
+ stream = await chat.completions(
+ model="grok-test",
+ messages=[{"role": "user", "content": "hello"}],
+ stream=True,
+ emit_think=True,
+ )
+ return await anext(stream)
+
+ with self._patch_common(chat, directory, stream_func=_thinking_then_text_stream):
+ first = asyncio.run(_run())
+
+ self.assertIn("reasoning_content", first)
+ self.assertIn("thinking one", first)
+
+ def test_chat_stream_thinking_only_is_not_empty_when_enabled(self) -> None:
+ directory = _FakeDirectory()
+
+ async def _run() -> list[str]:
+ stream = await chat.completions(
+ model="grok-test",
+ messages=[{"role": "user", "content": "hello"}],
+ stream=True,
+ emit_think=True,
+ )
+ chunks = []
+ async for chunk in stream:
+ chunks.append(chunk)
+ return chunks
+
+ with self._patch_common(chat, directory, stream_func=_thinking_stream):
+ chunks = asyncio.run(_run())
+
+ joined = "".join(chunks)
+ self.assertIn("reasoning_content", joined)
+ self.assertIn("thinking one", joined)
+ self.assertIn("data: [DONE]", joined)
+ self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.SUCCESS)
+
+ def test_chat_stream_thinking_only_raises_429_when_disabled(self) -> None:
+ directory = _FakeDirectory()
+
+ async def _run() -> None:
+ stream = await chat.completions(
+ model="grok-test",
+ messages=[{"role": "user", "content": "hello"}],
+ stream=True,
+ emit_think=False,
+ )
+ async for _chunk in stream:
+ pass
+
+ with self.assertRaises(UpstreamError) as ctx:
+ with self._patch_common(chat, directory, stream_func=_thinking_stream):
+ asyncio.run(_run())
+
+ self.assertEqual(ctx.exception.status, 429)
+ self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED)
+
+ def test_responses_stream_reasoning_flushes_before_text(self) -> None:
+ directory = _FakeDirectory()
+
+ async def _run() -> list[str]:
+ stream = await responses.create(
+ model="grok-test",
+ input_val="hello",
+ instructions=None,
+ stream=True,
+ emit_think=True,
+ temperature=0.8,
+ top_p=0.95,
+ )
+ return [await anext(stream), await anext(stream), await anext(stream)]
+
+ with self._patch_common(responses, directory, stream_func=_thinking_then_text_stream):
+ chunks = asyncio.run(_run())
+
+ self.assertIn("response.created", chunks[0])
+ self.assertIn("response.output_item.added", chunks[1])
+ self.assertIn("response.reasoning_summary_part.added", chunks[2])
+
+ def test_responses_stream_reasoning_only_is_not_empty_when_enabled(self) -> None:
+ directory = _FakeDirectory()
+
+ async def _run() -> list[str]:
+ stream = await responses.create(
+ model="grok-test",
+ input_val="hello",
+ instructions=None,
+ stream=True,
+ emit_think=True,
+ temperature=0.8,
+ top_p=0.95,
+ )
+ chunks = []
+ async for chunk in stream:
+ chunks.append(chunk)
+ return chunks
+
+ with self._patch_common(responses, directory, stream_func=_thinking_stream):
+ chunks = asyncio.run(_run())
+
+ joined = "".join(chunks)
+ self.assertIn("response.reasoning_summary_text.delta", joined)
+ self.assertIn("response.completed", joined)
+ self.assertIn("data: [DONE]", joined)
+ self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.SUCCESS)
+
+ def test_responses_stream_reasoning_only_raises_429_when_disabled(self) -> None:
+ directory = _FakeDirectory()
+
+ async def _run() -> None:
+ stream = await responses.create(
+ model="grok-test",
+ input_val="hello",
+ instructions=None,
+ stream=True,
+ emit_think=False,
+ temperature=0.8,
+ top_p=0.95,
+ )
+ async for _chunk in stream:
+ pass
+
+ with self.assertRaises(UpstreamError) as ctx:
+ with self._patch_common(responses, directory, stream_func=_thinking_stream):
+ asyncio.run(_run())
+
+ self.assertEqual(ctx.exception.status, 429)
+ self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED)
+
+ def test_responses_non_stream_empty_response_raises_429_feedback(self) -> None:
+ directory = _FakeDirectory()
+ exc = self._run_responses_non_stream_empty(directory)
+
+ self.assertEqual(exc.status, 429)
+ self.assertEqual(exc.details["body"], chat.EMPTY_UPSTREAM_BODY)
+ self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED)
+
+ def test_anthropic_non_stream_empty_response_raises_429_feedback(self) -> None:
+ directory = _FakeDirectory()
+ exc = self._run_anthropic_non_stream_empty(directory)
+
+ self.assertEqual(exc.status, 429)
+ self.assertEqual(exc.details["body"], chat.EMPTY_UPSTREAM_BODY)
+ self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED)
+
+ def _run_chat_non_stream_empty(self, directory: _FakeDirectory) -> UpstreamError:
+ async def _run() -> None:
+ await chat.completions(
+ model="grok-test",
+ messages=[{"role": "user", "content": "hello"}],
+ stream=False,
+ )
+
+ with self.assertRaises(UpstreamError) as ctx:
+ with self._patch_common(chat, directory):
+ asyncio.run(_run())
+ return ctx.exception
+
+ def _run_responses_non_stream_empty(self, directory: _FakeDirectory) -> UpstreamError:
+ async def _run() -> None:
+ await responses.create(
+ model="grok-test",
+ input_val="hello",
+ instructions=None,
+ stream=False,
+ emit_think=True,
+ temperature=0.8,
+ top_p=0.95,
+ )
+
+ with self.assertRaises(UpstreamError) as ctx:
+ with self._patch_common(responses, directory):
+ asyncio.run(_run())
+ return ctx.exception
+
+ def _run_anthropic_non_stream_empty(
+ self, directory: _FakeDirectory
+ ) -> UpstreamError:
+ async def _run() -> None:
+ await anthropic_messages.create(
+ model="grok-test",
+ messages=[{"role": "user", "content": "hello"}],
+ stream=False,
+ emit_think=True,
+ temperature=0.8,
+ top_p=0.95,
+ )
+
+ with self.assertRaises(UpstreamError) as ctx:
+ with self._patch_common(anthropic_messages, directory):
+ asyncio.run(_run())
+ return ctx.exception
+
+ def _patch_common(
+ self,
+ module,
+ directory: _FakeDirectory,
+ *,
+ stream_func=_empty_stream,
+ ):
+ import contextlib
+
+ @contextlib.contextmanager
+ def _ctx():
+ with patch("app.dataplane.account._directory", directory):
+ with patch.object(module, "get_config", return_value=_FakeConfig()):
+ with patch.object(
+ module,
+ "resolve_model",
+ return_value=SimpleNamespace(mode_id=0),
+ ):
+ with patch.object(module, "selection_max_retries", return_value=0):
+ with patch.object(
+ module,
+ "reserve_account",
+ return_value=(SimpleNamespace(token="tok-test"), 0),
+ ):
+ with patch.object(
+ module, "_stream_chat", side_effect=stream_func
+ ):
+ with patch.object(module, "_fail_sync", _noop_async):
+ with patch.object(module, "_quota_sync", _noop_async):
+ yield
+
+ return _ctx()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_image_output_format.py b/tests/test_image_output_format.py
new file mode 100644
index 000000000..dffc082df
--- /dev/null
+++ b/tests/test_image_output_format.py
@@ -0,0 +1,87 @@
+import asyncio
+import base64
+import tempfile
+import unittest
+from pathlib import Path
+from urllib.parse import parse_qs, urlparse
+from unittest.mock import patch
+
+from app.products.openai import images
+
+
+class _StubConfig:
+ def __init__(self, *, image_format: str, app_url: str) -> None:
+ self.image_format = image_format
+ self.app_url = app_url
+
+ def get_str(self, key: str, default: str = "") -> str:
+ if key == "features.image_format":
+ return self.image_format
+ if key == "app.app_url":
+ return self.app_url
+ return default
+
+
+class ImageOutputFormatTests(unittest.TestCase):
+ def test_resolve_image_output_uses_local_proxy_when_configured(self) -> None:
+ blob_b64 = base64.b64encode(b"fake-image-bytes").decode("ascii")
+ asset_id = "12345678-1234-1234-1234-123456789abc"
+ with tempfile.TemporaryDirectory() as tmpdir:
+ config = _StubConfig(
+ image_format="local_url",
+ app_url="https://app.example.com",
+ )
+
+ def _save_local_image(raw: bytes, mime: str, file_id: str) -> str:
+ suffix = ".png" if "png" in mime else ".jpg"
+ (Path(tmpdir) / f"{file_id}{suffix}").write_bytes(raw)
+ return file_id
+
+ with patch.object(images, "get_config", return_value=config):
+ with patch.object(images, "save_local_image", side_effect=_save_local_image):
+ result = asyncio.run(
+ images._resolve_image_output(
+ token="unused",
+ url=f"https://assets.grok.com/users/user-1/{asset_id}.png",
+ response_format="url",
+ blob_b64=blob_b64,
+ )
+ )
+ file_id = parse_qs(urlparse(result.api_value).query)["id"][0]
+ self.assertTrue((Path(tmpdir) / f"{file_id}.png").exists())
+
+ self.assertEqual(
+ result.api_value,
+ f"https://app.example.com/v1/files/image?id={asset_id}",
+ )
+ self.assertEqual(
+ result.markdown_value,
+ f"",
+ )
+
+ def test_resolve_image_output_keeps_upstream_url_when_configured(self) -> None:
+ config = _StubConfig(
+ image_format="grok_url",
+ app_url="",
+ )
+ with patch.object(images, "get_config", return_value=config):
+ result = asyncio.run(
+ images._resolve_image_output(
+ token="unused",
+ url="https://assets.grok.com/users/user-1/file-abc123/content.png",
+ response_format="url",
+ )
+ )
+
+ self.assertEqual(
+ result.api_value,
+ "https://assets.grok.com/users/user-1/file-abc123/content.png",
+ )
+ self.assertEqual(
+ result.markdown_value,
+ "",
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_legacy_cache_config.py b/tests/test_legacy_cache_config.py
new file mode 100644
index 000000000..8b3bf186c
--- /dev/null
+++ b/tests/test_legacy_cache_config.py
@@ -0,0 +1,66 @@
+import asyncio
+import tempfile
+import unittest
+from pathlib import Path
+from typing import Any
+
+from app.platform.config.snapshot import ConfigSnapshot
+
+
+class _Backend:
+ def __init__(self, data: dict[str, Any]) -> None:
+ self.data = data
+
+ async def load(self) -> dict[str, Any]:
+ return self.data
+
+ async def apply_patch(self, patch: dict[str, Any]) -> None:
+ self.data.update(patch)
+
+ async def version(self) -> object:
+ return 1
+
+
+class LegacyCacheConfigTests(unittest.TestCase):
+ def test_legacy_storage_cache_limits_map_to_cache_local(self) -> None:
+ cfg = asyncio.run(self._load({
+ "storage": {
+ "image_max_mb": 12,
+ "video_max_mb": 34,
+ },
+ }))
+
+ self.assertEqual(cfg.get_int("cache.local.image_max_mb"), 12)
+ self.assertEqual(cfg.get_int("cache.local.video_max_mb"), 34)
+
+ def test_cache_local_limits_win_over_legacy_storage(self) -> None:
+ cfg = asyncio.run(self._load({
+ "storage": {
+ "image_max_mb": 12,
+ "video_max_mb": 34,
+ },
+ "cache": {
+ "local": {
+ "image_max_mb": 56,
+ "video_max_mb": 78,
+ },
+ },
+ }))
+
+ self.assertEqual(cfg.get_int("cache.local.image_max_mb"), 56)
+ self.assertEqual(cfg.get_int("cache.local.video_max_mb"), 78)
+
+ async def _load(self, overrides: dict[str, Any]) -> ConfigSnapshot:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ defaults = Path(tmpdir) / "config.defaults.toml"
+ defaults.write_text(
+ "[cache.local]\nimage_max_mb = 0\nvideo_max_mb = 0\n",
+ encoding="utf-8",
+ )
+ cfg = ConfigSnapshot(backend=_Backend(overrides))
+ await cfg.load(defaults)
+ return cfg
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_media_cache_limits.py b/tests/test_media_cache_limits.py
new file mode 100644
index 000000000..a8477517d
--- /dev/null
+++ b/tests/test_media_cache_limits.py
@@ -0,0 +1,113 @@
+import tempfile
+import unittest
+import os
+from pathlib import Path
+from unittest.mock import patch
+
+from app.platform.storage.media_cache import LocalMediaCacheStore
+
+
+class _StubConfig:
+ def __init__(self, *, image_max_mb: int = 0, video_max_mb: int = 0) -> None:
+ self._ints = {
+ "cache.local.image_max_mb": image_max_mb,
+ "cache.local.video_max_mb": video_max_mb,
+ }
+
+ def get_int(self, key: str, default: int = 0) -> int:
+ return self._ints.get(key, default)
+
+
+class MediaCacheLimitTests(unittest.TestCase):
+ def test_save_image_prunes_oldest_file_when_type_limit_exceeded(self) -> None:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ root = Path(tmpdir)
+ image_dir = root / "images"
+ video_dir = root / "videos"
+ image_dir.mkdir()
+ video_dir.mkdir()
+
+ store = LocalMediaCacheStore(
+ config_provider=lambda: _StubConfig(image_max_mb=1)
+ )
+ with patch("app.platform.storage.media_cache.image_files_dir", return_value=image_dir):
+ with patch("app.platform.storage.media_cache.video_files_dir", return_value=video_dir):
+ with patch(
+ "app.platform.storage.media_cache.local_media_cache_db_path",
+ return_value=root / "cache.db",
+ ):
+ with patch(
+ "app.platform.storage.media_cache.local_media_lock_path",
+ return_value=root / "cache.lock",
+ ):
+ store.save_image(b"a" * 800_000, "image/png", "old")
+ file_id = store.save_image(b"b" * 400_000, "image/png", "new")
+
+ self.assertEqual(file_id, "new")
+ self.assertTrue((image_dir / "new.png").exists())
+ self.assertFalse((image_dir / "old.png").exists())
+
+ def test_save_video_prunes_oldest_video_only(self) -> None:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ root = Path(tmpdir)
+ image_dir = root / "images"
+ video_dir = root / "videos"
+ image_dir.mkdir()
+ video_dir.mkdir()
+
+ image_path = image_dir / "keep.png"
+ image_path.write_bytes(b"i" * 800_000)
+
+ store = LocalMediaCacheStore(
+ config_provider=lambda: _StubConfig(video_max_mb=1)
+ )
+ with patch("app.platform.storage.media_cache.image_files_dir", return_value=image_dir):
+ with patch("app.platform.storage.media_cache.video_files_dir", return_value=video_dir):
+ with patch(
+ "app.platform.storage.media_cache.local_media_cache_db_path",
+ return_value=root / "cache.db",
+ ):
+ with patch(
+ "app.platform.storage.media_cache.local_media_lock_path",
+ return_value=root / "cache.lock",
+ ):
+ store.save_video(b"a" * 800_000, "old")
+ new_video = store.save_video(b"b" * 400_000, "new")
+
+ self.assertTrue(new_video.exists())
+ self.assertFalse((video_dir / "old.mp4").exists())
+ self.assertTrue(image_path.exists())
+
+ def test_stats_and_list_files_use_shared_cache_rules(self) -> None:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ root = Path(tmpdir)
+ image_dir = root / "images"
+ video_dir = root / "videos"
+ image_dir.mkdir()
+ video_dir.mkdir()
+
+ older = image_dir / "older.png"
+ newer = image_dir / "newer.jpg"
+ ignored = image_dir / "ignored.txt"
+ older.write_bytes(b"a" * 128)
+ newer.write_bytes(b"b" * 256)
+ ignored.write_text("skip")
+ os.utime(older, (100, 100))
+ os.utime(newer, (200, 200))
+
+ store = LocalMediaCacheStore(
+ config_provider=lambda: _StubConfig(image_max_mb=1)
+ )
+ with patch("app.platform.storage.media_cache.image_files_dir", return_value=image_dir):
+ with patch("app.platform.storage.media_cache.video_files_dir", return_value=video_dir):
+ stats = store.stats("image")
+ listing = store.list_files("image", page=1, page_size=10)
+
+ self.assertEqual(stats["count"], 2)
+ self.assertEqual(stats["size_bytes"], 384)
+ self.assertEqual(stats["limit_mb"], 1)
+ self.assertEqual([item["name"] for item in listing["items"]], ["newer.jpg", "older.png"])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_startup_config_migration.py b/tests/test_startup_config_migration.py
new file mode 100644
index 000000000..39415ec55
--- /dev/null
+++ b/tests/test_startup_config_migration.py
@@ -0,0 +1,91 @@
+import asyncio
+import os
+import tempfile
+import unittest
+from pathlib import Path
+from unittest.mock import patch
+
+from app.platform.startup import migration
+
+
+class _Backend:
+ def __init__(self, version: int = 0) -> None:
+ self.version_calls = 0
+ self.applied: list[dict] = []
+ self._version = version
+
+ async def version(self) -> int:
+ self.version_calls += 1
+ return self._version
+
+ async def apply_patch(self, patch: dict) -> None:
+ self.applied.append(patch)
+
+
+class StartupConfigMigrationTests(unittest.TestCase):
+ def test_account_storage_mysql_does_not_migrate_config_when_config_storage_unset(self) -> None:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ root = Path(tmpdir)
+ defaults = root / "config.defaults.toml"
+ user_config = root / "config.toml"
+ defaults.write_text("[app]\napi_key = \"default\"\n", encoding="utf-8")
+ user_config.write_text("[app]\napi_key = \"local\"\n", encoding="utf-8")
+ backend = _Backend()
+
+ with patch.dict(os.environ, {"ACCOUNT_STORAGE": "mysql"}, clear=True):
+ with patch.object(migration, "_DEFAULTS_PATH", defaults):
+ with patch.object(migration, "_USER_CFG_PATH", user_config):
+ asyncio.run(migration._migrate_config(backend))
+
+ self.assertEqual(backend.version_calls, 0)
+ self.assertEqual(backend.applied, [])
+ self.assertEqual(
+ user_config.read_text(encoding="utf-8"),
+ "[app]\napi_key = \"local\"\n",
+ )
+
+ def test_local_config_is_seeded_from_defaults_when_missing(self) -> None:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ root = Path(tmpdir)
+ defaults = root / "config.defaults.toml"
+ user_config = root / "config.toml"
+ defaults.write_text("[app]\napi_key = \"default\"\n", encoding="utf-8")
+ backend = _Backend()
+
+ with patch.dict(os.environ, {"ACCOUNT_STORAGE": "postgresql"}, clear=True):
+ with patch.object(migration, "_DEFAULTS_PATH", defaults):
+ with patch.object(migration, "_USER_CFG_PATH", user_config):
+ asyncio.run(migration._migrate_config(backend))
+
+ self.assertTrue(user_config.exists())
+ self.assertEqual(
+ user_config.read_text(encoding="utf-8"),
+ defaults.read_text(encoding="utf-8"),
+ )
+ self.assertEqual(backend.version_calls, 0)
+ self.assertEqual(backend.applied, [])
+
+ def test_explicit_remote_config_storage_still_migrates_local_config(self) -> None:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ root = Path(tmpdir)
+ defaults = root / "config.defaults.toml"
+ user_config = root / "config.toml"
+ defaults.write_text("[app]\napi_key = \"default\"\n", encoding="utf-8")
+ user_config.write_text("[app]\napi_key = \"local\"\n", encoding="utf-8")
+ backend = _Backend()
+
+ with patch.dict(
+ os.environ,
+ {"ACCOUNT_STORAGE": "local", "CONFIG_STORAGE": "mysql"},
+ clear=True,
+ ):
+ with patch.object(migration, "_DEFAULTS_PATH", defaults):
+ with patch.object(migration, "_USER_CFG_PATH", user_config):
+ asyncio.run(migration._migrate_config(backend))
+
+ self.assertEqual(backend.version_calls, 1)
+ self.assertEqual(backend.applied, [{"app": {"api_key": "local"}}])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_video_proxy_and_newapi.py b/tests/test_video_proxy_and_newapi.py
new file mode 100644
index 000000000..fffd1d201
--- /dev/null
+++ b/tests/test_video_proxy_and_newapi.py
@@ -0,0 +1,264 @@
+import asyncio
+import hashlib
+import tempfile
+import time
+import unittest
+from pathlib import Path
+from types import SimpleNamespace
+from unittest.mock import patch
+
+from app.control.proxy import ProxyDirectory
+from app.dataplane.reverse.transport import assets
+from app.products.openai import video
+
+
+class _ProxyConfig:
+ def __init__(self, *, resource_proxy_url: str = "") -> None:
+ self.resource_proxy_url = resource_proxy_url
+
+ def get_str(self, key: str, default: str = "") -> str:
+ values = {
+ "proxy.egress.mode": "single_proxy",
+ "proxy.clearance.mode": "none",
+ "proxy.egress.proxy_url": "http://proxy.example:8080",
+ "proxy.egress.resource_proxy_url": self.resource_proxy_url,
+ }
+ return values.get(key, default)
+
+ def get_list(self, key: str, default=None):
+ return default or []
+
+ def get_int(self, key: str, default: int = 0) -> int:
+ return default
+
+
+class _AppConfig:
+ def __init__(self, *, app_url: str = "") -> None:
+ self.app_url = app_url
+
+ def get_str(self, key: str, default: str = "") -> str:
+ if key == "app.app_url":
+ return self.app_url
+ return default
+
+
+class _AssetConfig:
+ def get_float(self, key: str, default: float = 0.0) -> float:
+ return default
+
+
+class _VideoConfig:
+ def get_float(self, key: str, default: float = 0.0) -> float:
+ return default
+
+ def get(self, key: str, default=None):
+ return default
+
+
+class _FakeProxyRuntime:
+ def __init__(self) -> None:
+ self.acquire_kwargs = []
+
+ async def acquire(self, **kwargs):
+ self.acquire_kwargs.append(kwargs)
+ return SimpleNamespace(proxy_url=None)
+
+ async def feedback(self, lease, result) -> None:
+ return None
+
+
+class _FakeAccountDirectory:
+ def __init__(self) -> None:
+ self.reserve_calls = []
+ self.feedbacks = []
+
+ async def reserve(self, **kwargs):
+ self.reserve_calls.append(kwargs)
+ return SimpleNamespace(token="tok-video")
+
+ async def release(self, acct) -> None:
+ return None
+
+ async def feedback(self, token, kind, mode_id) -> None:
+ self.feedbacks.append((token, kind, mode_id))
+
+
+async def _empty_bytes_stream(*args, **kwargs):
+ if False:
+ yield b""
+
+
+async def _noop_async(*args, **kwargs):
+ return None
+
+
+class VideoProxyAndNewApiTests(unittest.TestCase):
+ def test_resource_download_is_direct_without_resource_proxy(self) -> None:
+ async def _run():
+ directory = ProxyDirectory()
+ with patch("app.control.proxy.get_config", return_value=_ProxyConfig()):
+ await directory.load()
+ return await directory.acquire(resource=True)
+
+ lease = asyncio.run(_run())
+
+ self.assertIsNone(lease.proxy_url)
+
+ def test_resource_download_uses_explicit_resource_proxy(self) -> None:
+ async def _run():
+ directory = ProxyDirectory()
+ cfg = _ProxyConfig(resource_proxy_url="http://res-proxy.example:8080")
+ with patch("app.control.proxy.get_config", return_value=cfg):
+ await directory.load()
+ return await directory.acquire(resource=True)
+
+ lease = asyncio.run(_run())
+
+ self.assertEqual(lease.proxy_url, "http://res-proxy.example:8080")
+
+ def test_download_asset_acquires_resource_lease(self) -> None:
+ async def _run():
+ proxy = _FakeProxyRuntime()
+ with patch.object(assets, "get_proxy_runtime", return_value=proxy):
+ with patch.object(assets, "get_config", return_value=_AssetConfig()):
+ with patch.object(
+ assets, "get_bytes_stream", return_value=_empty_bytes_stream()
+ ):
+ await assets.download_asset(
+ "tok-test",
+ "https://assets.grok.com/users/u/file.mp4",
+ )
+ return proxy.acquire_kwargs
+
+ calls = asyncio.run(_run())
+
+ self.assertEqual(calls[-1]["resource"], True)
+
+ def test_completed_video_job_returns_newapi_metadata_url(self) -> None:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ path = Path(tmpdir) / "video.mp4"
+ path.write_bytes(b"fake-video")
+ job = video._VideoJob(
+ id="video_test",
+ model="grok-imagine-video",
+ prompt="hello",
+ seconds="6",
+ size="720x1280",
+ quality="standard",
+ created_at=int(time.time()),
+ status="completed",
+ progress=100,
+ completed_at=int(time.time()),
+ content_path=str(path),
+ content_file_id="08b8238c88e9bbe8423c692cbd04ec52",
+ )
+
+ with patch.object(video, "get_config", return_value=_AppConfig(app_url="https://api.example.com")):
+ payload = job.to_dict()
+
+ self.assertEqual(
+ payload["metadata"]["url"],
+ "https://api.example.com/v1/files/video?id=08b8238c88e9bbe8423c692cbd04ec52",
+ )
+
+ def test_video_job_saves_with_sync_file_id_and_metadata_url(self) -> None:
+ async def _run(tmpdir: str):
+ artifact = video._VideoArtifact(
+ video_url="https://assets.grok.com/users/u/generated.mp4",
+ video_post_id="post_1",
+ asset_id="asset_1",
+ thumbnail_url="https://assets.grok.com/thumb.jpg",
+ )
+ raw = b"fake-video"
+ saved_ids: list[str] = []
+ job = video._VideoJob(
+ id="video_async",
+ model="grok-imagine-video",
+ prompt="hello",
+ seconds="6",
+ size="720x1280",
+ quality="standard",
+ created_at=int(time.time()),
+ )
+
+ async def _fake_run_video_with_account(*, model, runner):
+ return artifact, raw
+
+ def _fake_save_video_bytes(raw_bytes: bytes, file_id: str) -> Path:
+ saved_ids.append(file_id)
+ path = Path(tmpdir) / f"{file_id}.mp4"
+ path.write_bytes(raw_bytes)
+ return path
+
+ with patch.object(
+ video, "_run_video_with_account", _fake_run_video_with_account
+ ):
+ with patch.object(video, "_save_video_bytes", _fake_save_video_bytes):
+ await video._run_video_job(
+ job,
+ size="720x1280",
+ resolution_name=None,
+ prompt="hello",
+ seconds=6,
+ preset=None,
+ )
+
+ with patch.object(
+ video,
+ "get_config",
+ return_value=_AppConfig(app_url="https://api.example.com"),
+ ):
+ payload = job.to_dict()
+
+ return job, payload, saved_ids
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ job, payload, saved_ids = asyncio.run(_run(tmpdir))
+
+ expected_file_id = hashlib.sha1(
+ "https://assets.grok.com/users/u/generated.mp4".encode("utf-8")
+ ).hexdigest()[:32]
+ self.assertEqual(saved_ids, [expected_file_id])
+ self.assertEqual(job.content_file_id, expected_file_id)
+ self.assertTrue(job.content_path.endswith(f"{expected_file_id}.mp4"))
+ self.assertEqual(
+ payload["metadata"]["url"],
+ f"https://api.example.com/v1/files/video?id={expected_file_id}",
+ )
+
+ def test_video_pool_candidates_exclude_basic(self) -> None:
+ spec = SimpleNamespace(pool_candidates=lambda: (0, 1, 2))
+
+ self.assertEqual(video._video_pool_candidates(spec), (1, 2))
+
+ def test_video_account_runner_reserves_only_super_or_heavy(self) -> None:
+ async def _run():
+ directory = _FakeAccountDirectory()
+ spec = SimpleNamespace(
+ is_video=lambda: True,
+ mode_id=0,
+ pool_candidates=lambda: (0, 1, 2),
+ )
+
+ async def _runner(token: str, timeout_s: float) -> str:
+ return token
+
+ with patch("app.dataplane.account._directory", directory):
+ with patch.object(video, "get_config", return_value=_VideoConfig()):
+ with patch.object(video, "resolve_model", return_value=spec):
+ with patch.object(video, "selection_max_retries", return_value=0):
+ with patch.object(video, "_quota_sync", _noop_async):
+ result = await video._run_video_with_account(
+ model="grok-imagine-video",
+ runner=_runner,
+ )
+ return result, directory.reserve_calls
+
+ result, calls = asyncio.run(_run())
+
+ self.assertEqual(result, "tok-video")
+ self.assertEqual(calls[0]["pool_candidates"], (1, 2))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_video_reference_helpers.py b/tests/test_video_reference_helpers.py
new file mode 100644
index 000000000..e47c1ad33
--- /dev/null
+++ b/tests/test_video_reference_helpers.py
@@ -0,0 +1,151 @@
+import asyncio
+import unittest
+
+from app.platform.errors import ValidationError
+from app.products.openai import router, video
+
+
+def _message_with_references(count: int) -> list[dict]:
+ return [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "生成一个参考图视频"},
+ *[
+ {
+ "type": "image_url",
+ "image_url": {"url": f"https://example.com/ref-{idx}.png"},
+ }
+ for idx in range(count)
+ ],
+ ],
+ }
+ ]
+
+
+class VideoReferenceHelperTests(unittest.TestCase):
+ def test_chat_video_prompt_allows_seven_references(self) -> None:
+ prompt, refs = video._extract_video_prompt_and_reference(
+ _message_with_references(7)
+ )
+
+ self.assertEqual(prompt, "生成一个参考图视频")
+ self.assertEqual(
+ refs,
+ [
+ {"image_url": f"https://example.com/ref-{idx}.png"}
+ for idx in range(7)
+ ],
+ )
+
+ def test_chat_video_prompt_rejects_more_than_seven_references(self) -> None:
+ with self.assertRaises(ValidationError):
+ video._extract_video_prompt_and_reference(_message_with_references(8))
+
+ def test_prepare_video_references_rejects_more_than_seven_references(self) -> None:
+ refs = [{"image_url": f"https://example.com/ref-{idx}.png"} for idx in range(8)]
+
+ async def _run() -> None:
+ await video._prepare_video_references("token", refs)
+
+ with self.assertRaises(ValidationError):
+ asyncio.run(_run())
+
+ def test_videos_create_rejects_more_than_seven_multipart_references(self) -> None:
+ async def _run() -> None:
+ await router.videos_create(
+ model="grok-imagine-video",
+ prompt="生成一个参考图视频",
+ input_reference=[object() for _ in range(8)], # type: ignore[list-item]
+ )
+
+ with self.assertRaises(ValidationError):
+ asyncio.run(_run())
+
+ def test_chat_video_segment_prompts_follow_user_order(self) -> None:
+ prompts, refs = video._extract_video_segment_prompts_and_references(
+ [
+ {"role": "system", "content": "ignore"},
+ {"role": "user", "content": "第一段"},
+ {"role": "assistant", "content": "ignore"},
+ {"role": "user", "content": "第二段"},
+ ]
+ )
+
+ self.assertEqual(prompts, ["第一段", "第二段"])
+ self.assertIsNone(refs)
+
+ def test_segment_prompts_reuse_last_prompt_when_short(self) -> None:
+ self.assertEqual(
+ video._normalize_segment_prompts("第一段", [10, 6], ["第一段"]),
+ ["第一段", "第一段"],
+ )
+
+ def test_segment_prompts_strict_rejects_missing_prompts(self) -> None:
+ with self.assertRaises(ValidationError):
+ video._normalize_segment_prompts(
+ "第一段",
+ [10, 6],
+ ["第一段"],
+ strict=True,
+ )
+
+ def test_segment_prompts_reject_extra_prompts(self) -> None:
+ with self.assertRaises(ValidationError):
+ video._normalize_segment_prompts("一", [6], ["一", "二"])
+
+ def test_video_segment_lengths_prefer_ten_second_segments(self) -> None:
+ cases = {
+ 6: [6],
+ 10: [10],
+ 12: [6, 6],
+ 16: [10, 6],
+ 20: [10, 10],
+ 22: [10, 6, 6],
+ 26: [10, 10, 6],
+ 30: [10, 10, 10],
+ 32: [10, 10, 6, 6],
+ 36: [10, 10, 10, 6],
+ 40: [10, 10, 10, 10],
+ }
+
+ for seconds, expected in cases.items():
+ with self.subTest(seconds=seconds):
+ self.assertEqual(video._build_segment_lengths(seconds), expected)
+
+ def test_async_video_length_rejects_long_videos(self) -> None:
+ video.validate_async_video_length(6)
+ video.validate_async_video_length(10)
+ with self.assertRaises(ValidationError):
+ video.validate_async_video_length(12)
+
+ def test_async_video_create_rejects_long_videos(self) -> None:
+ async def _run() -> None:
+ await video.create_video(
+ model="grok-imagine-video",
+ prompt="长视频",
+ seconds=12,
+ )
+
+ with self.assertRaises(ValidationError):
+ asyncio.run(_run())
+
+ def test_chat_video_length_rejects_non_exact_duration(self) -> None:
+ with self.assertRaises(ValidationError):
+ video.validate_video_length(28)
+
+ def test_chat_video_completions_requires_prompt_for_each_segment(self) -> None:
+ async def _run() -> None:
+ await video.completions(
+ model="grok-imagine-video",
+ messages=[{"role": "user", "content": "第一段"}],
+ stream=False,
+ seconds=16,
+ )
+
+ with self.assertRaises(ValidationError):
+ asyncio.run(_run())
+
+
+if __name__ == "__main__":
+ unittest.main()