From cfe44d59fce7e0121bd4e3109fcdc119518efa0c Mon Sep 17 00:00:00 2001 From: yangyang Date: Tue, 14 Apr 2026 18:01:45 +0800 Subject: [PATCH 01/14] fix: restore multi-reference video inputs and local image proxy urls --- app/products/openai/__init__.py | 9 +- app/products/openai/images.py | 45 +++++- app/products/openai/router.py | 20 ++- app/products/openai/video.py | 220 +++++++++++++++++--------- tests/test_image_output_format.py | 81 ++++++++++ tests/test_video_reference_helpers.py | 79 +++++++++ 6 files changed, 363 insertions(+), 91 deletions(-) create mode 100644 tests/test_image_output_format.py create mode 100644 tests/test_video_reference_helpers.py diff --git a/app/products/openai/__init__.py b/app/products/openai/__init__.py index 5bc0c2e6d..c3904ab36 100644 --- a/app/products/openai/__init__.py +++ b/app/products/openai/__init__.py @@ -1,3 +1,10 @@ -from .router import router +"""OpenAI product package exports.""" __all__ = ["router"] + + +def __getattr__(name: str): + if name == "router": + from .router import router + return router + raise AttributeError(name) diff --git a/app/products/openai/images.py b/app/products/openai/images.py index e48de961a..876a15cb6 100644 --- a/app/products/openai/images.py +++ b/app/products/openai/images.py @@ -8,6 +8,7 @@ import time from dataclasses import dataclass from typing import Any, AsyncGenerator, Awaitable, Callable +from urllib.parse import urlparse import orjson @@ -146,6 +147,16 @@ def _normalize_response_format(response_format: str) -> str: return fmt +def _normalize_configured_image_format(value: str | None) -> str: + fmt = (value or "grok_url").strip().lower() + if fmt not in {"grok_url", "local_url", "grok_md", "local_md", "base64"}: + raise ValidationError( + "image_format must be one of [grok_url, local_url, grok_md, local_md, base64]", + param="features.image_format", + ) + return fmt + + def _app_url() -> str: return get_config().get_str("app.app_url", "").rstrip("/") @@ -156,10 +167,17 @@ def _local_image_url(file_id: str) -> str: def _extract_image_file_id(url: str) -> str: - parts = [part for part in url.split("/") if part] + parsed = urlparse(url) + parts = [part for part in parsed.path.split("/") if part] + if parsed.scheme == "https" and parsed.netloc == "assets.grok.com" and len(parts) >= 2: + last_stem = parts[-1].split(".", 1)[0].lower() + if last_stem == "content": + asset_id = parts[-2].strip() + if asset_id: + return asset_id for part in reversed(parts): stem = part.split(".", 1)[0] - if stem and stem not in {"image", "original", "thumbnail"}: + if stem and stem.lower() not in {"image", "original", "thumbnail", "content"}: return stem return hashlib.sha1(url.encode("utf-8")).hexdigest()[:32] @@ -194,7 +212,24 @@ async def _resolve_image_output( blob_b64: str | None = None, ) -> _ImageOutput: fmt = _normalize_response_format(response_format) - if fmt == "url" and not _app_url(): + if fmt == "b64_json": + mime = infer_content_type(url) or "image/jpeg" + if blob_b64 is not None: + try: + raw = base64.b64decode(blob_b64) + except (ValueError, TypeError, binascii.Error) as exc: + raise UpstreamError(f"Invalid upstream image blob: {exc}") from exc + else: + raw, mime = await _download_image_bytes(token, url) + + b64 = blob_b64 or base64.b64encode(raw).decode() + data_uri = f"data:{mime};base64,{b64}" + return _ImageOutput(api_value=b64, markdown_value=f"![image]({data_uri})") + + image_format = _normalize_configured_image_format( + get_config().get_str("features.image_format", "grok_url") + ) + if image_format in {"grok_url", "grok_md"}: return _ImageOutput(api_value=url, markdown_value=f"![image]({url})") mime = infer_content_type(url) or "image/jpeg" @@ -206,10 +241,10 @@ async def _resolve_image_output( else: raw, mime = await _download_image_bytes(token, url) - if fmt == "b64_json": + if image_format == "base64": b64 = blob_b64 or base64.b64encode(raw).decode() data_uri = f"data:{mime};base64,{b64}" - return _ImageOutput(api_value=b64, markdown_value=f"![image]({data_uri})") + return _ImageOutput(api_value=data_uri, markdown_value=f"![image]({data_uri})") file_id = _save_image(raw, mime, _extract_image_file_id(url)) local_url = _local_image_url(file_id) diff --git a/app/products/openai/router.py b/app/products/openai/router.py index 9020b2f4f..084685d6b 100644 --- a/app/products/openai/router.py +++ b/app/products/openai/router.py @@ -399,15 +399,21 @@ async def videos_create( size: Annotated[Literal["720x1280", "1280x720", "1024x1024", "1024x1792", "1792x1024"], Form()] = "720x1280", resolution_name: Annotated[Literal["480p", "720p"] | None, Form()] = None, preset: Annotated[Literal["fun", "normal", "spicy", "custom"] | None, Form()] = None, - input_reference: Annotated[UploadFile | None, File()] = None, + input_reference: Annotated[list[UploadFile] | None, File()] = None, + input_reference_array: Annotated[list[UploadFile] | None, File(alias="input_reference[]")] = None, ): from .video import create_video - reference_payload = None - if input_reference is not None: - reference_payload = { - "image_url": await _upload_to_data_uri(input_reference, param="input_reference"), - } + reference_payloads: list[dict[str, str]] = [] + for index, upload in enumerate(input_reference or []): + reference_payloads.append({ + "image_url": await _upload_to_data_uri(upload, param=f"input_reference.{index}"), + }) + offset = len(reference_payloads) + for index, upload in enumerate(input_reference_array or []): + reference_payloads.append({ + "image_url": await _upload_to_data_uri(upload, param=f"input_reference.{offset + index}"), + }) result = await create_video( model=model or "grok-video", @@ -416,7 +422,7 @@ async def videos_create( size=size or "720x1280", resolution_name=resolution_name, preset=preset, - input_reference=reference_payload, + input_references=reference_payloads or None, ) return JSONResponse(result) diff --git a/app/products/openai/video.py b/app/products/openai/video.py index 46ea6f789..ba1e41d47 100644 --- a/app/products/openai/video.py +++ b/app/products/openai/video.py @@ -8,6 +8,7 @@ import asyncio import hashlib import html +import re import time import uuid from dataclasses import dataclass @@ -52,6 +53,7 @@ _VIDEO_OBJECT = "video" _VIDEO_JOB_TTL_S = 3600 _VIDEO_EXTENSION_REF_TYPE = "ORIGINAL_REF_TYPE_VIDEO_EXTENSION" +_VIDEO_MAX_REFERENCES = 7 _SUPPORTED_VIDEO_LENGTHS = frozenset({6, 10, 12, 16, 20}) _VIDEO_SIZE_MAP: dict[str, tuple[str, str]] = { "720x1280": ("9:16", "720p"), @@ -66,6 +68,7 @@ "spicy": "--mode=extremely-spicy-or-crazy", "custom": "--mode=custom", } +_REFERENCE_PLACEHOLDER_RE = re.compile(r"@(?:(?:图|image|img)\s*(\d+))", re.IGNORECASE) @dataclass(slots=True) @@ -79,8 +82,8 @@ class _VideoArtifact: @dataclass(slots=True) class _VideoReference: + asset_id: str content_url: str - post_id: str @dataclass(slots=True) @@ -126,10 +129,8 @@ def to_dict(self) -> dict[str, Any]: _VIDEO_JOBS_LOCK = asyncio.Lock() -def _build_message(prompt: str, preset: str, *, reference_content_url: str | None = None) -> str: +def _build_message(prompt: str, preset: str) -> str: message = f"{prompt} {_PRESET_FLAGS.get(preset, '--mode=custom')}".strip() - if reference_content_url: - return f"{reference_content_url} {message}" return message @@ -204,25 +205,30 @@ def _video_create_payload( resolution_name: str, video_length: int, preset: str, - reference_content_url: str | None = None, + image_references: list[str] | None = None, file_attachments: list[str] | None = None, ) -> dict[str, Any]: + video_config: dict[str, Any] = { + "parentPostId": parent_post_id, + "aspectRatio": aspect_ratio, + "videoLength": video_length, + "resolutionName": resolution_name, + } + if image_references: + video_config["imageReferences"] = list(image_references) + video_config["isReferenceToVideo"] = True + payload = { "temporary": True, "modelName": _VIDEO_MODEL_NAME, - "message": _build_message(prompt, preset, reference_content_url=reference_content_url), + "message": _build_message(prompt, preset), "toolOverrides": {"videoGen": True}, "enableSideBySide": True, "responseMetadata": { "experiments": [], "modelConfigOverride": { "modelMap": { - "videoGenModelConfig": { - "parentPostId": parent_post_id, - "aspectRatio": aspect_ratio, - "videoLength": video_length, - "resolutionName": resolution_name, - } + "videoGenModelConfig": video_config } }, }, @@ -356,48 +362,104 @@ def _is_upstream_asset_content_url(value: str) -> bool: ) -async def _prepare_video_reference(token: str, input_reference: dict[str, Any]) -> _VideoReference: +def _extract_asset_id_from_content_url(value: str) -> str: + parsed = urlparse(value) + if parsed.scheme != "https" or parsed.netloc != "assets.grok.com": + return "" + parts = [part for part in parsed.path.split("/") if part] + if len(parts) >= 2 and parts[-1] == "content": + return parts[-2] + return "" + + +def _replace_reference_placeholders(prompt: str, asset_ids: list[str]) -> str: + """Replace @图N / @imageN placeholders with uploaded asset IDs.""" + + def _replace(match: re.Match[str]) -> str: + index = int(match.group(1)) - 1 + if index < 0 or index >= len(asset_ids): + raise ValidationError( + f"Reference placeholder {match.group(0)} has no matching uploaded image", + param="prompt", + ) + return f"@{asset_ids[index]}" + + return _REFERENCE_PLACEHOLDER_RE.sub(_replace, prompt) + + +async def _prepare_video_reference( + token: str, + input_reference: dict[str, Any], + *, + index: int, +) -> _VideoReference: file_id = str(input_reference.get("file_id") or "").strip() image_input = str(input_reference.get("image_url") or "").strip() + param_base = f"input_reference.{index}" if index >= 0 else "input_reference" if file_id and image_input: - raise ValidationError("input_reference accepts only one of file_id or image_url", param="input_reference") + raise ValidationError( + "input_reference accepts only one of file_id or image_url", + param=param_base, + ) if file_id: - raise ValidationError("input_reference.file_id is not supported yet", param="input_reference.file_id") + raise ValidationError( + "input_reference.file_id is not supported yet", + param=f"{param_base}.file_id", + ) if not image_input: - raise ValidationError("input_reference.image_url is required", param="input_reference.image_url") + raise ValidationError( + "input_reference.image_url is required", + param=f"{param_base}.image_url", + ) if _is_upstream_asset_content_url(image_input): content_url = image_input + asset_id = _extract_asset_id_from_content_url(image_input) + if not asset_id: + raise ValidationError( + "input_reference.image_url must include a valid upstream asset ID", + param=f"{param_base}.image_url", + ) else: try: uploaded_file_id, uploaded_file_uri = await upload_from_input(token, image_input) content_url = resolve_uploaded_asset_reference(token, uploaded_file_id, uploaded_file_uri) + asset_id = uploaded_file_id except ValidationError as exc: - raise ValidationError(exc.message, param="input_reference.image_url") from exc + raise ValidationError(exc.message, param=f"{param_base}.image_url") from exc except UpstreamError as exc: raise UpstreamError( - f"Video input reference upload failed: {exc.message}", + f"Video input reference {index + 1} upload failed: {exc.message}", status=exc.status, body=exc.details.get("body", ""), ) from exc except Exception as exc: - raise UpstreamError(f"Video input reference upload failed: {exc}") from exc + raise UpstreamError(f"Video input reference {index + 1} upload failed: {exc}") from exc - post = await create_media_post( - token, - media_type=_IMAGE_MEDIA_TYPE, - media_url=content_url, - prompt="", - referer="https://grok.com/imagine", - ) - post_data = post.get("post") - if not isinstance(post_data, dict): - raise UpstreamError("Video image reference create-post returned no post payload") - post_id = str(post_data.get("id") or "").strip() - if not post_id: - raise UpstreamError("Video image reference create-post returned no post id") - return _VideoReference(content_url=content_url, post_id=post_id) + return _VideoReference(asset_id=asset_id, content_url=content_url) + + +async def _prepare_video_references( + token: str, + input_references: list[dict[str, Any]], +) -> list[_VideoReference]: + if len(input_references) > _VIDEO_MAX_REFERENCES: + raise ValidationError( + f"Video generation supports at most {_VIDEO_MAX_REFERENCES} reference images", + param="input_reference", + ) + + results: list[_VideoReference | None] = [None] * len(input_references) + + async def _runner(index: int, input_reference: dict[str, Any]) -> None: + results[index] = await _prepare_video_reference(token, input_reference, index=index) + + async with asyncio.TaskGroup() as tg: + for index, input_reference in enumerate(input_references): + tg.create_task(_runner(index, input_reference), name=f"video-reference-{index}") + + return [result for result in results if result is not None] async def _collect_video_segment( @@ -544,26 +606,30 @@ async def _generate_video_with_token( seconds: int, preset: str, timeout_s: float, - input_reference: dict[str, Any] | None = None, + input_references: list[dict[str, Any]] | None = None, progress_cb: Callable[[int], Awaitable[None]] | None = None, ) -> _VideoArtifact: - reference: _VideoReference | None = None - if input_reference: - reference = await _prepare_video_reference(token, input_reference) - parent_post_id = reference.post_id - else: - post = await create_media_post( - token, - media_type=_VIDEO_MEDIA_TYPE, - prompt=prompt, - referer="https://grok.com/imagine", + references: list[_VideoReference] = [] + prompt_text = prompt + if input_references: + references = await _prepare_video_references(token, input_references) + prompt_text = _replace_reference_placeholders( + prompt_text, + [reference.asset_id for reference in references], ) - post_data = post.get("post") - if not isinstance(post_data, dict): - raise UpstreamError("Video create-post returned no post payload") - parent_post_id = str(post_data.get("id") or "").strip() - if not parent_post_id: - raise UpstreamError("Video create-post returned no post id") + + post = await create_media_post( + token, + media_type=_VIDEO_MEDIA_TYPE, + prompt=prompt_text, + referer="https://grok.com/imagine", + ) + post_data = post.get("post") + if not isinstance(post_data, dict): + raise UpstreamError("Video create-post returned no post payload") + parent_post_id = str(post_data.get("id") or "").strip() + if not parent_post_id: + raise UpstreamError("Video create-post returned no post id") segments = _build_segment_lengths(seconds) total_segments = len(segments) @@ -574,19 +640,19 @@ async def _generate_video_with_token( for index, segment_length in enumerate(segments): if index == 0: payload = _video_create_payload( - prompt=prompt, + prompt=prompt_text, parent_post_id=parent_post_id, aspect_ratio=aspect_ratio, resolution_name=resolution_name, video_length=segment_length, preset=preset, - reference_content_url=reference.content_url if reference is not None else None, - file_attachments=[reference.post_id] if reference is not None else None, + image_references=[reference.content_url for reference in references] if references else None, + file_attachments=[reference.asset_id for reference in references] if references else None, ) referer = "https://grok.com/imagine" else: payload = _video_extend_payload( - prompt=prompt, + prompt=prompt_text, parent_post_id=parent_post_id, extend_post_id=extend_post_id, aspect_ratio=aspect_ratio, @@ -628,7 +694,7 @@ async def _run_video_generation( resolution_name: str, seconds: int, preset: str = "custom", - input_reference: dict[str, Any] | None = None, + input_references: list[dict[str, Any]] | None = None, progress_cb: Callable[[int], Awaitable[None]] | None = None, ) -> _VideoArtifact: async def _runner(token: str, timeout_s: float) -> _VideoArtifact: @@ -640,7 +706,7 @@ async def _runner(token: str, timeout_s: float) -> _VideoArtifact: seconds=seconds, preset=preset, timeout_s=timeout_s, - input_reference=input_reference, + input_references=input_references, progress_cb=progress_cb, ) @@ -725,7 +791,7 @@ async def _run_video_job( prompt: str, seconds: int, preset: str | None, - input_reference: dict[str, Any] | None = None, + input_references: list[dict[str, Any]] | None = None, ) -> None: try: await _set_job_status(job, status="in_progress", progress=1) @@ -767,7 +833,7 @@ async def _progress(progress: int) -> None: seconds=seconds, preset=resolved_preset, timeout_s=timeout_s, - input_reference=input_reference, + input_references=input_references, progress_cb=_progress, ) raw, _mime = await _download_video_bytes(token, artifact.video_url) @@ -807,7 +873,7 @@ async def create_video( size: str | None = None, resolution_name: str | None = None, preset: str | None = None, - input_reference: dict[str, Any] | None = None, + input_references: list[dict[str, Any]] | None = None, ) -> dict[str, Any]: spec = model_registry.get(model) if spec is None or not spec.enabled or not spec.is_video(): @@ -842,7 +908,7 @@ async def create_video( prompt=cleaned_prompt, seconds=normalized_seconds, preset=preset, - input_reference=input_reference, + input_references=input_references, ) ) asyncio.create_task(_expire_video_job(job.id)) @@ -873,22 +939,20 @@ async def content_path(video_id: str) -> Path: return path -def _extract_video_prompt_and_reference(messages: list[dict]) -> tuple[str, dict[str, Any] | None]: +def _extract_video_prompt_and_references(messages: list[dict]) -> tuple[str, list[dict[str, Any]] | None]: prompt = "" - reference_url = "" + reference_urls: list[str] = [] - for msg in reversed(messages): + for msg in messages: content = msg.get("content", "") if isinstance(content, str) and content.strip(): prompt = content.strip() - if prompt: - break continue if not isinstance(content, list): continue text_parts: list[str] = [] - block_reference = "" + block_references: list[str] = [] for item in content: if not isinstance(item, dict): continue @@ -900,24 +964,24 @@ def _extract_video_prompt_and_reference(messages: list[dict]) -> tuple[str, dict elif item_type == "image_url": image_url = item.get("image_url") if isinstance(image_url, dict): - block_reference = str(image_url.get("url") or "").strip() or block_reference + url = str(image_url.get("url") or "").strip() elif isinstance(image_url, str): - block_reference = image_url.strip() or block_reference + url = image_url.strip() + else: + url = "" + if url: + block_references.append(url) if text_parts: prompt = " ".join(text_parts) - if block_reference and not reference_url: - reference_url = block_reference - if prompt: - break + if block_references: + reference_urls.extend(block_references) if not prompt: raise ValidationError("Video prompt cannot be empty", param="messages") - input_reference: dict[str, Any] | None = None - if reference_url: - input_reference = {"image_url": reference_url} - return prompt, input_reference + input_references = [{"image_url": reference_url} for reference_url in reference_urls] or None + return prompt, input_references async def completions( @@ -938,7 +1002,7 @@ async def completions( default=default_resolution_name, ) resolved_preset = _resolve_video_preset(preset) - prompt, input_reference = _extract_video_prompt_and_reference(messages) + prompt, input_references = _extract_video_prompt_and_references(messages) cfg = get_config() is_stream = stream if stream is not None else cfg.get_bool("features.stream", False) @@ -954,7 +1018,7 @@ async def _runner(token: str, timeout_s: float) -> str: seconds=seconds, preset=resolved_preset, timeout_s=timeout_s, - input_reference=input_reference, + input_references=input_references, progress_cb=progress_cb, ) file_id = hashlib.sha1(artifact.video_url.encode("utf-8")).hexdigest()[:32] diff --git a/tests/test_image_output_format.py b/tests/test_image_output_format.py new file mode 100644 index 000000000..eaf8cc379 --- /dev/null +++ b/tests/test_image_output_format.py @@ -0,0 +1,81 @@ +import asyncio +import base64 +import tempfile +import unittest +from pathlib import Path +from urllib.parse import parse_qs, urlparse +from unittest.mock import patch + +from app.products.openai import images + + +class _StubConfig: + def __init__(self, *, image_format: str, app_url: str) -> None: + self.image_format = image_format + self.app_url = app_url + + def get_str(self, key: str, default: str = "") -> str: + if key == "features.image_format": + return self.image_format + if key == "app.app_url": + return self.app_url + return default + + +class ImageOutputFormatTests(unittest.TestCase): + def test_resolve_image_output_uses_local_proxy_when_configured(self) -> None: + blob_b64 = base64.b64encode(b"fake-image-bytes").decode("ascii") + asset_id = "12345678-1234-1234-1234-123456789abc" + with tempfile.TemporaryDirectory() as tmpdir: + config = _StubConfig( + image_format="local_url", + app_url="https://app.example.com", + ) + with patch.object(images, "get_config", return_value=config): + with patch.object(images, "image_files_dir", return_value=Path(tmpdir)): + result = asyncio.run( + images._resolve_image_output( + token="unused", + url=f"https://assets.grok.com/users/user-1/{asset_id}/content.png", + response_format="url", + blob_b64=blob_b64, + ) + ) + file_id = parse_qs(urlparse(result.api_value).query)["id"][0] + self.assertTrue((Path(tmpdir) / f"{file_id}.png").exists()) + + self.assertEqual( + result.api_value, + f"https://app.example.com/v1/files/image?id={asset_id}", + ) + self.assertEqual( + result.markdown_value, + f"![image](https://app.example.com/v1/files/image?id={asset_id})", + ) + + def test_resolve_image_output_keeps_upstream_url_when_configured(self) -> None: + config = _StubConfig( + image_format="grok_url", + app_url="https://app.example.com", + ) + with patch.object(images, "get_config", return_value=config): + result = asyncio.run( + images._resolve_image_output( + token="unused", + url="https://assets.grok.com/users/user-1/file-abc123/content.png", + response_format="url", + ) + ) + + self.assertEqual( + result.api_value, + "https://assets.grok.com/users/user-1/file-abc123/content.png", + ) + self.assertEqual( + result.markdown_value, + "![image](https://assets.grok.com/users/user-1/file-abc123/content.png)", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_video_reference_helpers.py b/tests/test_video_reference_helpers.py new file mode 100644 index 000000000..99514b833 --- /dev/null +++ b/tests/test_video_reference_helpers.py @@ -0,0 +1,79 @@ +import unittest + +from app.platform.errors import ValidationError +from app.products.openai import video + + +class VideoReferenceHelperTests(unittest.TestCase): + def test_replace_reference_placeholders_supports_cn_and_en_aliases(self) -> None: + prompt = "先展示@图1里的角色,再切到@image2的场景,最后回到@img1" + replaced = video._replace_reference_placeholders( + prompt, + ["asset_one", "asset_two"], + ) + self.assertEqual( + replaced, + "先展示@asset_one里的角色,再切到@asset_two的场景,最后回到@asset_one", + ) + + def test_replace_reference_placeholders_rejects_missing_index(self) -> None: + with self.assertRaises(ValidationError): + video._replace_reference_placeholders("参考@图2", ["asset_one"]) + + def test_extract_video_prompt_and_references_collects_all_images(self) -> None: + prompt, refs = video._extract_video_prompt_and_references( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "第一段提示"}, + {"type": "image_url", "image_url": {"url": "https://example.com/ref-1.png"}}, + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "最终提示词"}, + {"type": "image_url", "image_url": {"url": "https://example.com/ref-2.png"}}, + {"type": "image_url", "image_url": {"url": "https://example.com/ref-3.png"}}, + ], + }, + ] + ) + + self.assertEqual(prompt, "最终提示词") + self.assertEqual( + refs, + [ + {"image_url": "https://example.com/ref-1.png"}, + {"image_url": "https://example.com/ref-2.png"}, + {"image_url": "https://example.com/ref-3.png"}, + ], + ) + + def test_video_create_payload_includes_multi_reference_fields(self) -> None: + payload = video._video_create_payload( + prompt="test prompt", + parent_post_id="post_123", + aspect_ratio="16:9", + resolution_name="720p", + video_length=10, + preset="normal", + image_references=["https://assets.grok.com/users/u/ref-1/content"], + file_attachments=["asset_ref_1"], + ) + + config = payload["responseMetadata"]["modelConfigOverride"]["modelMap"]["videoGenModelConfig"] + self.assertEqual(config["imageReferences"], ["https://assets.grok.com/users/u/ref-1/content"]) + self.assertTrue(config["isReferenceToVideo"]) + self.assertEqual(payload["fileAttachments"], ["asset_ref_1"]) + + def test_extract_asset_id_from_content_url(self) -> None: + asset_id = video._extract_asset_id_from_content_url( + "https://assets.grok.com/users/user-123/asset-456/content" + ) + self.assertEqual(asset_id, "asset-456") + + +if __name__ == "__main__": + unittest.main() From 54d78779604ce3665add22c91a8f062f20ba4195 Mon Sep 17 00:00:00 2001 From: yangyang Date: Wed, 15 Apr 2026 23:37:43 +0800 Subject: [PATCH 02/14] feat: add local media cache eviction --- .github/workflows/docker.yml | 9 ++- .gitignore | 1 + app/platform/storage/__init__.py | 3 +- app/platform/storage/media_cache.py | 103 ++++++++++++++++++++++++++++ app/products/openai/chat.py | 5 +- app/products/openai/images.py | 5 +- app/products/openai/video.py | 6 +- config.defaults.toml | 10 +++ tests/test_media_cache_limits.py | 84 +++++++++++++++++++++++ 9 files changed, 212 insertions(+), 14 deletions(-) create mode 100644 app/platform/storage/media_cache.py create mode 100644 tests/test_media_cache_limits.py diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index c049c5179..b3fe502f3 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -12,7 +12,6 @@ on: env: REGISTRY: ghcr.io - IMAGE_NAME: ${{ github.repository }} jobs: build-and-push: @@ -40,11 +39,15 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} + - name: Normalize image name + id: image + run: echo "name=${GITHUB_REPOSITORY,,}" >> "$GITHUB_OUTPUT" + - name: Extract metadata id: meta uses: docker/metadata-action@v5 with: - images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + images: ${{ env.REGISTRY }}/${{ steps.image.outputs.name }} tags: | type=raw,value=latest,enable=${{ github.ref_type == 'branch' }} type=ref,event=tag @@ -60,4 +63,4 @@ jobs: labels: ${{ steps.meta.outputs.labels }} cache-from: type=gha cache-to: type=gha,mode=max - pull: true \ No newline at end of file + pull: true diff --git a/.gitignore b/.gitignore index e94a37e41..652d6743b 100644 --- a/.gitignore +++ b/.gitignore @@ -53,3 +53,4 @@ htmlcov/ # Project specific .ace-tool/ +.tmp_live/ diff --git a/app/platform/storage/__init__.py b/app/platform/storage/__init__.py index 65fd31df3..545753f19 100644 --- a/app/platform/storage/__init__.py +++ b/app/platform/storage/__init__.py @@ -1,5 +1,6 @@ """Platform storage helpers.""" from .media_paths import image_files_dir, video_files_dir +from .media_cache import save_media_bytes -__all__ = ["image_files_dir", "video_files_dir"] +__all__ = ["image_files_dir", "video_files_dir", "save_media_bytes"] diff --git a/app/platform/storage/media_cache.py b/app/platform/storage/media_cache.py new file mode 100644 index 000000000..2204ff305 --- /dev/null +++ b/app/platform/storage/media_cache.py @@ -0,0 +1,103 @@ +"""Helpers for writing local media cache files with size-based eviction.""" + +from pathlib import Path +from threading import Lock +from typing import Literal + +from app.platform.config.snapshot import get_config +from app.platform.logging.logger import logger + +from .media_paths import image_files_dir, video_files_dir + +MediaType = Literal["image", "video"] + +_CACHE_LOCK = Lock() +_MB = 1024 * 1024 + + +def _media_dirs() -> dict[MediaType, Path]: + return { + "image": image_files_dir(), + "video": video_files_dir(), + } + + +def _read_limit_mb(key: str) -> float: + value = get_config().get_float(key, 0.0) + return value if value > 0 else 0.0 + + +def _specific_limit_bytes(media_type: MediaType) -> int: + limit_mb = _read_limit_mb(f"storage.{media_type}_max_mb") + return int(limit_mb * _MB) if limit_mb > 0 else 0 + + +def _total_limit_bytes() -> int: + limit_mb = _read_limit_mb("storage.media_max_mb") + return int(limit_mb * _MB) if limit_mb > 0 else 0 + + +def _list_files(media_type: MediaType | None = None) -> list[Path]: + dirs = _media_dirs() + selected = [dirs[media_type]] if media_type else list(dirs.values()) + files: list[Path] = [] + for directory in selected: + files.extend(path for path in directory.glob("*") if path.is_file()) + return files + + +def _prune_paths(paths: list[Path], limit_bytes: int) -> list[Path]: + if limit_bytes <= 0: + return [] + + file_rows: list[tuple[float, str, int, Path]] = [] + total_size = 0 + for path in paths: + try: + stat = path.stat() + except FileNotFoundError: + continue + file_rows.append((stat.st_mtime, path.name, stat.st_size, path)) + total_size += stat.st_size + + if total_size <= limit_bytes: + return [] + + removed: list[Path] = [] + for _mtime, _name, size, path in sorted(file_rows): + try: + path.unlink(missing_ok=True) + except OSError as exc: + logger.warning("media cache eviction failed: path={} error={}", path, exc) + continue + removed.append(path) + total_size -= size + if total_size <= limit_bytes: + break + return removed + + +def _enforce_limits_locked(media_type: MediaType) -> None: + removed = _prune_paths(_list_files(media_type), _specific_limit_bytes(media_type)) + if removed: + logger.info( + "media cache pruned: scope={} removed_count={}", + media_type, + len(removed), + ) + + removed = _prune_paths(_list_files(), _total_limit_bytes()) + if removed: + logger.info("media cache pruned: scope=all removed_count={}", len(removed)) + + +def save_media_bytes(raw: bytes, path: Path, *, media_type: MediaType) -> Path: + path.parent.mkdir(parents=True, exist_ok=True) + with _CACHE_LOCK: + if not path.exists(): + path.write_bytes(raw) + _enforce_limits_locked(media_type) + return path + + +__all__ = ["MediaType", "save_media_bytes"] diff --git a/app/products/openai/chat.py b/app/products/openai/chat.py index f704040ae..5f9a5b55b 100644 --- a/app/products/openai/chat.py +++ b/app/products/openai/chat.py @@ -15,7 +15,7 @@ estimate_tokens, estimate_tool_call_tokens, ) -from app.platform.storage import image_files_dir +from app.platform.storage import image_files_dir, save_media_bytes from app.control.account.runtime import get_refresh_service from app.control.account.invalid_credentials import feedback_kind_for_error from app.control.model.registry import resolve as resolve_model @@ -148,8 +148,7 @@ def _save_image(raw: bytes, mime: str, image_id: str) -> str: img_dir = image_files_dir() ext = ".png" if "png" in mime else ".jpg" path = img_dir / f"{image_id}{ext}" - if not path.exists(): - path.write_bytes(raw) + save_media_bytes(raw, path, media_type="image") return image_id diff --git a/app/products/openai/images.py b/app/products/openai/images.py index 876a15cb6..c7ad8f9ea 100644 --- a/app/products/openai/images.py +++ b/app/products/openai/images.py @@ -16,7 +16,7 @@ from app.platform.config.snapshot import get_config from app.platform.errors import RateLimitError, UpstreamError, ValidationError from app.platform.runtime.clock import now_s -from app.platform.storage import image_files_dir +from app.platform.storage import image_files_dir, save_media_bytes from app.control.model.registry import resolve as resolve_model from app.control.model.enums import ModeId from app.control.model.spec import ModelSpec @@ -186,8 +186,7 @@ def _save_image(raw: bytes, mime: str, file_id: str) -> str: img_dir = image_files_dir() ext = ".png" if "png" in mime else ".jpg" path = img_dir / f"{file_id}{ext}" - if not path.exists(): - path.write_bytes(raw) + save_media_bytes(raw, path, media_type="image") return file_id diff --git a/app/products/openai/video.py b/app/products/openai/video.py index ba1e41d47..83f11f903 100644 --- a/app/products/openai/video.py +++ b/app/products/openai/video.py @@ -22,7 +22,7 @@ from app.platform.errors import AppError, ErrorKind, RateLimitError, UpstreamError, ValidationError from app.platform.logging.logger import logger from app.platform.runtime.clock import now_s -from app.platform.storage import video_files_dir +from app.platform.storage import video_files_dir, save_media_bytes from app.control.account.enums import FeedbackKind from app.control.model import registry as model_registry from app.control.model.registry import resolve as resolve_model @@ -554,9 +554,7 @@ async def _download_video_bytes(token: str, url: str) -> tuple[bytes, str]: def _save_video_bytes(raw: bytes, file_id: str) -> Path: out_dir = video_files_dir() path = out_dir / f"{file_id}.mp4" - if not path.exists(): - path.write_bytes(raw) - return path + return save_media_bytes(raw, path, media_type="video") def _local_video_url(file_id: str) -> str: diff --git a/config.defaults.toml b/config.defaults.toml index 639cf26d2..4c380dcc5 100644 --- a/config.defaults.toml +++ b/config.defaults.toml @@ -20,6 +20,16 @@ file_level = "INFO" max_files = 7 +# ==================== 本地存储配置 ==================== +[storage] +# 本地图片+视频缓存总上限(MB),<= 0 表示不限制 +media_max_mb = 0 +# 本地图片缓存上限(MB),<= 0 表示不限制 +image_max_mb = 0 +# 本地视频缓存上限(MB),<= 0 表示不限制 +video_max_mb = 0 + + # ==================== 应用功能 ==================== [features] # 是否启用临时对话 diff --git a/tests/test_media_cache_limits.py b/tests/test_media_cache_limits.py new file mode 100644 index 000000000..ee2f2c733 --- /dev/null +++ b/tests/test_media_cache_limits.py @@ -0,0 +1,84 @@ +import os +import tempfile +import time +import unittest +from pathlib import Path +from unittest.mock import patch + +from app.platform.storage.media_cache import save_media_bytes + + +class _StubConfig: + def __init__( + self, + *, + media_max_mb: float = 0.0, + image_max_mb: float = 0.0, + video_max_mb: float = 0.0, + ) -> None: + self._floats = { + "storage.media_max_mb": media_max_mb, + "storage.image_max_mb": image_max_mb, + "storage.video_max_mb": video_max_mb, + } + + def get_float(self, key: str, default: float = 0.0) -> float: + return self._floats.get(key, default) + + +class MediaCacheLimitTests(unittest.TestCase): + def test_save_media_bytes_prunes_oldest_file_when_type_limit_exceeded(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + image_dir = Path(tmpdir) / "images" + video_dir = Path(tmpdir) / "videos" + image_dir.mkdir() + video_dir.mkdir() + + old_path = image_dir / "old.png" + old_path.write_bytes(b"a" * 80) + old_mtime = time.time() - 60 + os.utime(old_path, (old_mtime, old_mtime)) + + config = _StubConfig(image_max_mb=100 / (1024 * 1024)) + with patch("app.platform.storage.media_cache.get_config", return_value=config): + with patch("app.platform.storage.media_cache.image_files_dir", return_value=image_dir): + with patch("app.platform.storage.media_cache.video_files_dir", return_value=video_dir): + new_path = save_media_bytes( + b"b" * 40, + image_dir / "new.png", + media_type="image", + ) + + self.assertTrue(new_path.exists()) + self.assertFalse(old_path.exists()) + + self.assertEqual(new_path.name, "new.png") + + def test_save_media_bytes_prunes_across_media_types_when_total_limit_exceeded(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + image_dir = Path(tmpdir) / "images" + video_dir = Path(tmpdir) / "videos" + image_dir.mkdir() + video_dir.mkdir() + + old_image = image_dir / "old.png" + old_image.write_bytes(b"a" * 80) + old_mtime = time.time() - 60 + os.utime(old_image, (old_mtime, old_mtime)) + + config = _StubConfig(media_max_mb=100 / (1024 * 1024)) + with patch("app.platform.storage.media_cache.get_config", return_value=config): + with patch("app.platform.storage.media_cache.image_files_dir", return_value=image_dir): + with patch("app.platform.storage.media_cache.video_files_dir", return_value=video_dir): + new_video = save_media_bytes( + b"b" * 40, + video_dir / "new.mp4", + media_type="video", + ) + + self.assertTrue(new_video.exists()) + self.assertFalse(old_image.exists()) + + +if __name__ == "__main__": + unittest.main() From b56436647c6eed9b9a0eac637a3c9e20efe805e0 Mon Sep 17 00:00:00 2001 From: yangyang Date: Wed, 15 Apr 2026 23:47:07 +0800 Subject: [PATCH 03/14] chore: trigger docker workflow From 9eae1f2c067e23787d0fadc72b237dbc4dc1a1d2 Mon Sep 17 00:00:00 2001 From: yangyang Date: Thu, 16 Apr 2026 00:17:32 +0800 Subject: [PATCH 04/14] feat: expose storage cache limits in admin config --- app/statics/admin/config.html | 30 ++++++++++++++++++++++++++++++ app/statics/i18n/en.json | 15 ++++++++++++++- app/statics/i18n/zh.json | 15 ++++++++++++++- 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/app/statics/admin/config.html b/app/statics/admin/config.html index bcf2df6aa..5955a2c73 100644 --- a/app/statics/admin/config.html +++ b/app/statics/admin/config.html @@ -369,6 +369,36 @@ }, ] }, + { + title: '本地存储', titleKey: 'config.schema.groups.storage', + section: 'storage', + fields: [ + { + key: 'media_max_mb', + label: '媒体缓存总上限(MB)', + labelKey: 'config.schema.fields.mediaCacheMaxMb.label', + type: 'number', + desc: '本地图片与视频缓存共享的总容量上限,单位 MB。设置为 0 或负数表示不限制。', + descKey: 'config.schema.fields.mediaCacheMaxMb.desc', + }, + { + key: 'image_max_mb', + label: '图片缓存上限(MB)', + labelKey: 'config.schema.fields.imageCacheMaxMb.label', + type: 'number', + desc: '本地图片缓存的独立容量上限,单位 MB。设置为 0 或负数表示不限制。', + descKey: 'config.schema.fields.imageCacheMaxMb.desc', + }, + { + key: 'video_max_mb', + label: '视频缓存上限(MB)', + labelKey: 'config.schema.fields.videoCacheMaxMb.label', + type: 'number', + desc: '本地视频缓存的独立容量上限,单位 MB。设置为 0 或负数表示不限制。', + descKey: 'config.schema.fields.videoCacheMaxMb.desc', + }, + ] + }, ] }, { diff --git a/app/statics/i18n/en.json b/app/statics/i18n/en.json index fd54dfd12..f649d1ae9 100644 --- a/app/statics/i18n/en.json +++ b/app/statics/i18n/en.json @@ -335,7 +335,8 @@ "retry": "Retry Policy", "timeouts": "Request Timeouts", "imageProtocol": "NSFW and Image Streaming", - "logging": "Logging" + "logging": "Logging", + "storage": "Local Storage" }, "options": { "imageFormat": { @@ -578,6 +579,18 @@ "logMaxFiles": { "label": "Log Retention (days)", "desc": "Maximum number of daily log files to keep during rotation. Default: 7." + }, + "mediaCacheMaxMb": { + "label": "Total Media Cache Limit (MB)", + "desc": "Shared size limit for local image and video cache, in MB. Set to 0 or a negative value to disable the limit." + }, + "imageCacheMaxMb": { + "label": "Image Cache Limit (MB)", + "desc": "Dedicated size limit for local image cache, in MB. Set to 0 or a negative value to disable the limit." + }, + "videoCacheMaxMb": { + "label": "Video Cache Limit (MB)", + "desc": "Dedicated size limit for local video cache, in MB. Set to 0 or a negative value to disable the limit." } } } diff --git a/app/statics/i18n/zh.json b/app/statics/i18n/zh.json index 052dac571..c6309c5d8 100644 --- a/app/statics/i18n/zh.json +++ b/app/statics/i18n/zh.json @@ -335,7 +335,8 @@ "retry": "重试策略", "timeouts": "请求超时设置", "imageProtocol": "NSFW 与图像协议", - "logging": "日志配置" + "logging": "日志配置", + "storage": "本地存储" }, "options": { "imageFormat": { @@ -578,6 +579,18 @@ "logMaxFiles": { "label": "日志保留天数", "desc": "按天轮转时最多保留的日志文件数量。默认 7。" + }, + "mediaCacheMaxMb": { + "label": "媒体缓存总上限(MB)", + "desc": "本地图片与视频缓存共享的总容量上限,单位 MB。设置为 0 或负数表示不限制。" + }, + "imageCacheMaxMb": { + "label": "图片缓存上限(MB)", + "desc": "本地图片缓存的独立容量上限,单位 MB。设置为 0 或负数表示不限制。" + }, + "videoCacheMaxMb": { + "label": "视频缓存上限(MB)", + "desc": "本地视频缓存的独立容量上限,单位 MB。设置为 0 或负数表示不限制。" } } } From 32ce23ea55f3db50027bb8a714faaec2723d9989 Mon Sep 17 00:00:00 2001 From: yangyang Date: Sat, 25 Apr 2026 20:54:10 +0800 Subject: [PATCH 05/14] Default runtime config to local TOML --- README.md | 5 +- app/platform/config/backends/factory.py | 25 ++++--- app/platform/startup/migration.py | 5 +- docs/README.en.md | 5 +- tests/test_config_backend_factory.py | 63 +++++++++++++++++ tests/test_startup_config_migration.py | 91 +++++++++++++++++++++++++ 6 files changed, 178 insertions(+), 16 deletions(-) create mode 100644 tests/test_config_backend_factory.py create mode 100644 tests/test_startup_config_migration.py diff --git a/README.md b/README.md index 35d49b133..b5c663dae 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,7 @@ docker compose up -d ### Vercel -[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https://github.com/chenyme/grok2api&env=LOG_LEVEL,LOG_FILE_ENABLED,DATA_DIR,LOG_DIR,ACCOUNT_STORAGE,ACCOUNT_REDIS_URL,ACCOUNT_MYSQL_URL,ACCOUNT_POSTGRESQL_URL) +[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https://github.com/chenyme/grok2api&env=LOG_LEVEL,LOG_FILE_ENABLED,DATA_DIR,LOG_DIR,ACCOUNT_STORAGE,ACCOUNT_REDIS_URL,ACCOUNT_MYSQL_URL,ACCOUNT_POSTGRESQL_URL,CONFIG_STORAGE,CONFIG_LOCAL_PATH) ### Render @@ -187,7 +187,8 @@ docker compose up -d | `ACCOUNT_SQL_MAX_OVERFLOW` | SQL 连接池最大溢出连接数 | `10` | | `ACCOUNT_SQL_POOL_TIMEOUT` | 等待连接池空闲连接的超时时间(秒) | `30` | | `ACCOUNT_SQL_POOL_RECYCLE` | 连接最大复用时间(秒),超时后自动重连 | `1800` | -| `CONFIG_LOCAL_PATH` | `local` 模式运行时配置文件路径 | `${DATA_DIR}/config.toml` | +| `CONFIG_STORAGE` | 运行时配置存储后端;默认本地 TOML,不跟随 `ACCOUNT_STORAGE` | `local` | +| `CONFIG_LOCAL_PATH` | 本地运行时配置文件路径 | `${DATA_DIR}/config.toml` | 运行时配置也支持 `GROK_` 前缀环境变量覆盖,例如 `GROK_APP_API_KEY` 会覆盖 `app.api_key`,`GROK_FEATURES_STREAM` 会覆盖 `features.stream`。 diff --git a/app/platform/config/backends/factory.py b/app/platform/config/backends/factory.py index 258b385eb..eda5e53b5 100644 --- a/app/platform/config/backends/factory.py +++ b/app/platform/config/backends/factory.py @@ -1,4 +1,4 @@ -"""Config backend factory — follows ACCOUNT_STORAGE automatically.""" +"""Config backend factory.""" import os from pathlib import Path @@ -8,19 +8,24 @@ def get_config_backend_name() -> str: - """Return the active config backend name (mirrors ACCOUNT_STORAGE).""" - return os.getenv("ACCOUNT_STORAGE", "local").strip().lower() + """Return the active config backend name. + + Runtime configuration is local by default even when account data is stored + in Redis or SQL. Set CONFIG_STORAGE explicitly to opt into a remote config + backend. + """ + return os.getenv("CONFIG_STORAGE", "local").strip().lower() def create_config_backend() -> ConfigBackend: - """Instantiate the config backend that matches the account storage backend. + """Instantiate the configured runtime config backend. - ``ACCOUNT_STORAGE=local`` → TOML file (``${DATA_DIR}/config.toml``) - ``ACCOUNT_STORAGE=redis`` → Redis (ACCOUNT_REDIS_URL) - ``ACCOUNT_STORAGE=mysql`` → MySQL (ACCOUNT_MYSQL_URL) - ``ACCOUNT_STORAGE=postgresql`` → PostgreSQL (ACCOUNT_POSTGRESQL_URL) + ``CONFIG_STORAGE=local`` → TOML file (``${DATA_DIR}/config.toml``) + ``CONFIG_STORAGE=redis`` → Redis (ACCOUNT_REDIS_URL) + ``CONFIG_STORAGE=mysql`` → MySQL (ACCOUNT_MYSQL_URL) + ``CONFIG_STORAGE=postgresql`` → PostgreSQL (ACCOUNT_POSTGRESQL_URL) - No extra env vars needed — reuses the same connection settings as accounts. + Remote config backends reuse the same connection settings as accounts. """ backend = get_config_backend_name() @@ -31,7 +36,7 @@ def create_config_backend() -> ConfigBackend: if backend in ("mysql", "postgresql"): return _make_sql(backend) - raise ValueError(f"Unknown account storage backend: {backend!r}") + raise ValueError(f"Unknown config storage backend: {backend!r}") def _make_toml() -> ConfigBackend: diff --git a/app/platform/startup/migration.py b/app/platform/startup/migration.py index 741f64253..35f5db4a3 100644 --- a/app/platform/startup/migration.py +++ b/app/platform/startup/migration.py @@ -2,9 +2,10 @@ Config migration ---------------- -local : seeds ``${DATA_DIR}/config.toml`` from ``config.defaults.toml`` if +local : default; seeds ``${DATA_DIR}/config.toml`` from ``config.defaults.toml`` if the file does not exist yet — gives users an editable copy on first run. -redis / sql : if the backend is empty (version == 0) AND +redis / sql : only when CONFIG_STORAGE is explicitly set to a remote backend. + If the backend is empty (version == 0) AND ``${DATA_DIR}/config.toml`` exists, migrates the user overrides into the DB backend. If it does not exist either, nothing is written (defaults are always loaded from ``config.defaults.toml`` at runtime). diff --git a/docs/README.en.md b/docs/README.en.md index 03930a027..f37581570 100644 --- a/docs/README.en.md +++ b/docs/README.en.md @@ -111,7 +111,7 @@ docker compose up -d ### Vercel -[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https://github.com/chenyme/grok2api&env=LOG_LEVEL,LOG_FILE_ENABLED,DATA_DIR,LOG_DIR,ACCOUNT_STORAGE,ACCOUNT_REDIS_URL,ACCOUNT_MYSQL_URL,ACCOUNT_POSTGRESQL_URL) +[![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https://github.com/chenyme/grok2api&env=LOG_LEVEL,LOG_FILE_ENABLED,DATA_DIR,LOG_DIR,ACCOUNT_STORAGE,ACCOUNT_REDIS_URL,ACCOUNT_MYSQL_URL,ACCOUNT_POSTGRESQL_URL,CONFIG_STORAGE,CONFIG_LOCAL_PATH) ### Render @@ -186,7 +186,8 @@ docker compose up -d | `ACCOUNT_SQL_MAX_OVERFLOW` | Maximum overflow connections above pool size | `10` | | `ACCOUNT_SQL_POOL_TIMEOUT` | Seconds to wait for a free connection from the pool | `30` | | `ACCOUNT_SQL_POOL_RECYCLE` | Max connection lifetime in seconds before reconnect | `1800` | -| `CONFIG_LOCAL_PATH` | Runtime config file path for `local` config storage | `${DATA_DIR}/config.toml` | +| `CONFIG_STORAGE` | Runtime config storage backend; defaults to local TOML and does not follow `ACCOUNT_STORAGE` | `local` | +| `CONFIG_LOCAL_PATH` | Local runtime config file path | `${DATA_DIR}/config.toml` | Runtime config can also be overridden with `GROK_`-prefixed environment variables. For example, `GROK_APP_API_KEY` overrides `app.api_key`, and `GROK_FEATURES_STREAM` overrides `features.stream`. diff --git a/tests/test_config_backend_factory.py b/tests/test_config_backend_factory.py new file mode 100644 index 000000000..35ef254cf --- /dev/null +++ b/tests/test_config_backend_factory.py @@ -0,0 +1,63 @@ +import os +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +from app.platform.config.backends.factory import ( + create_config_backend, + get_config_backend_name, +) +from app.platform.config.backends.sql import SqlConfigBackend +from app.platform.config.backends.toml import TomlConfigBackend + + +class ConfigBackendFactoryTests(unittest.TestCase): + def test_config_backend_defaults_to_local_when_account_storage_is_mysql(self) -> None: + with patch.dict(os.environ, {"ACCOUNT_STORAGE": "mysql"}, clear=True): + self.assertEqual(get_config_backend_name(), "local") + + def test_config_backend_defaults_to_local_when_account_storage_is_redis(self) -> None: + with patch.dict(os.environ, {"ACCOUNT_STORAGE": "redis"}, clear=True): + self.assertEqual(get_config_backend_name(), "local") + + def test_create_config_backend_uses_toml_for_explicit_local(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + config_path = Path(tmpdir) / "config.toml" + with patch.dict( + os.environ, + { + "ACCOUNT_STORAGE": "mysql", + "CONFIG_STORAGE": "local", + "CONFIG_LOCAL_PATH": str(config_path), + }, + clear=True, + ): + backend = create_config_backend() + + self.assertIsInstance(backend, TomlConfigBackend) + + def test_create_config_backend_uses_sql_only_when_config_storage_is_mysql(self) -> None: + sentinel_engine = object() + with patch.dict( + os.environ, + { + "ACCOUNT_STORAGE": "local", + "CONFIG_STORAGE": "mysql", + "ACCOUNT_MYSQL_URL": "mysql://user:pass@example.com/db", + }, + clear=True, + ): + with patch( + "app.control.account.backends.sql.create_mysql_engine", + return_value=sentinel_engine, + ): + backend = create_config_backend() + + self.assertIsInstance(backend, SqlConfigBackend) + self.assertIs(backend._engine, sentinel_engine) + self.assertEqual(backend._dialect, "mysql") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_startup_config_migration.py b/tests/test_startup_config_migration.py new file mode 100644 index 000000000..39415ec55 --- /dev/null +++ b/tests/test_startup_config_migration.py @@ -0,0 +1,91 @@ +import asyncio +import os +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +from app.platform.startup import migration + + +class _Backend: + def __init__(self, version: int = 0) -> None: + self.version_calls = 0 + self.applied: list[dict] = [] + self._version = version + + async def version(self) -> int: + self.version_calls += 1 + return self._version + + async def apply_patch(self, patch: dict) -> None: + self.applied.append(patch) + + +class StartupConfigMigrationTests(unittest.TestCase): + def test_account_storage_mysql_does_not_migrate_config_when_config_storage_unset(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + defaults = root / "config.defaults.toml" + user_config = root / "config.toml" + defaults.write_text("[app]\napi_key = \"default\"\n", encoding="utf-8") + user_config.write_text("[app]\napi_key = \"local\"\n", encoding="utf-8") + backend = _Backend() + + with patch.dict(os.environ, {"ACCOUNT_STORAGE": "mysql"}, clear=True): + with patch.object(migration, "_DEFAULTS_PATH", defaults): + with patch.object(migration, "_USER_CFG_PATH", user_config): + asyncio.run(migration._migrate_config(backend)) + + self.assertEqual(backend.version_calls, 0) + self.assertEqual(backend.applied, []) + self.assertEqual( + user_config.read_text(encoding="utf-8"), + "[app]\napi_key = \"local\"\n", + ) + + def test_local_config_is_seeded_from_defaults_when_missing(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + defaults = root / "config.defaults.toml" + user_config = root / "config.toml" + defaults.write_text("[app]\napi_key = \"default\"\n", encoding="utf-8") + backend = _Backend() + + with patch.dict(os.environ, {"ACCOUNT_STORAGE": "postgresql"}, clear=True): + with patch.object(migration, "_DEFAULTS_PATH", defaults): + with patch.object(migration, "_USER_CFG_PATH", user_config): + asyncio.run(migration._migrate_config(backend)) + + self.assertTrue(user_config.exists()) + self.assertEqual( + user_config.read_text(encoding="utf-8"), + defaults.read_text(encoding="utf-8"), + ) + self.assertEqual(backend.version_calls, 0) + self.assertEqual(backend.applied, []) + + def test_explicit_remote_config_storage_still_migrates_local_config(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + defaults = root / "config.defaults.toml" + user_config = root / "config.toml" + defaults.write_text("[app]\napi_key = \"default\"\n", encoding="utf-8") + user_config.write_text("[app]\napi_key = \"local\"\n", encoding="utf-8") + backend = _Backend() + + with patch.dict( + os.environ, + {"ACCOUNT_STORAGE": "local", "CONFIG_STORAGE": "mysql"}, + clear=True, + ): + with patch.object(migration, "_DEFAULTS_PATH", defaults): + with patch.object(migration, "_USER_CFG_PATH", user_config): + asyncio.run(migration._migrate_config(backend)) + + self.assertEqual(backend.version_calls, 1) + self.assertEqual(backend.applied, [{"app": {"api_key": "local"}}]) + + +if __name__ == "__main__": + unittest.main() From 074ce5ec6edea788a477bedfee4a56dcdf84b312 Mon Sep 17 00:00:00 2001 From: yangyang Date: Sun, 26 Apr 2026 12:19:06 +0800 Subject: [PATCH 06/14] Handle empty upstream responses and unify cache config --- app/platform/config/snapshot.py | 20 +++ app/platform/storage/__init__.py | 4 + app/platform/storage/media_cache.py | 72 +++++++- app/products/anthropic/messages.py | 120 ++++++++----- app/products/anthropic/router.py | 25 ++- app/products/openai/chat.py | 108 ++++++++++- app/products/openai/responses.py | 110 +++++++----- app/products/openai/router.py | 25 ++- app/products/web/admin/__init__.py | 5 +- app/products/web/admin/cache.py | 84 +-------- app/statics/admin/config.html | 30 ---- tests/test_empty_response_handling.py | 246 ++++++++++++++++++++++++++ tests/test_legacy_cache_config.py | 66 +++++++ tests/test_media_cache_limits.py | 31 ++++ 14 files changed, 740 insertions(+), 206 deletions(-) create mode 100644 tests/test_empty_response_handling.py create mode 100644 tests/test_legacy_cache_config.py diff --git a/app/platform/config/snapshot.py b/app/platform/config/snapshot.py index f489f01ad..8b14b4575 100644 --- a/app/platform/config/snapshot.py +++ b/app/platform/config/snapshot.py @@ -22,6 +22,25 @@ def _mtime(path: Path) -> float: return 0.0 +def _apply_legacy_cache_overrides(user_overrides: dict[str, Any]) -> dict[str, Any]: + """Map legacy [storage] cache limits into [cache.local] at load time.""" + storage = user_overrides.get("storage") + if not isinstance(storage, dict): + return user_overrides + + cache = user_overrides.setdefault("cache", {}) + if not isinstance(cache, dict): + return user_overrides + local = cache.setdefault("local", {}) + if not isinstance(local, dict): + return user_overrides + + for key in ("image_max_mb", "video_max_mb"): + if key not in local and key in storage: + local[key] = storage[key] + return user_overrides + + class ConfigSnapshot: """Immutable view over the loaded configuration dict. @@ -74,6 +93,7 @@ async def load(self, defaults_path: Path | None = None) -> None: defaults = await asyncio.to_thread(load_toml, dp) user_overrides = await backend.load() + user_overrides = _apply_legacy_cache_overrides(user_overrides) self._data = _deep_merge(defaults, user_overrides) self._data = _apply_env(self._data) diff --git a/app/platform/storage/__init__.py b/app/platform/storage/__init__.py index 895f9a58e..c57f78cae 100644 --- a/app/platform/storage/__init__.py +++ b/app/platform/storage/__init__.py @@ -3,6 +3,8 @@ from .media_cache import ( clear_local_media_files, delete_local_media_file, + list_local_media_files, + local_media_stats, reconcile_local_media_cache_async, save_local_image, save_local_video, @@ -13,6 +15,8 @@ "clear_local_media_files", "delete_local_media_file", "image_files_dir", + "list_local_media_files", + "local_media_stats", "reconcile_local_media_cache_async", "save_local_image", "save_local_video", diff --git a/app/platform/storage/media_cache.py b/app/platform/storage/media_cache.py index f91b06d3f..171cf2466 100644 --- a/app/platform/storage/media_cache.py +++ b/app/platform/storage/media_cache.py @@ -88,6 +88,57 @@ def reconcile(self, media_type: MediaType) -> None: self._enforce_limit_locked(conn, media_type) conn.commit() + def limit_mb(self, media_type: MediaType) -> int: + """Return the configured per-type cache limit in MB.""" + cfg = self._config_provider() + limit_mb = max(0, int(cfg.get_int(f"cache.local.{media_type}_max_mb", 0))) + return limit_mb + + def stats(self, media_type: MediaType) -> dict[str, Any]: + """Return count, size, and configured limit for one media type.""" + files = list(self._iter_files(media_type)) + total_size = sum(path.stat().st_size for path in files) + limit_mb = self.limit_mb(media_type) + limit_bytes = limit_mb * 1024 * 1024 + usage_ratio = (total_size / limit_bytes) if limit_bytes > 0 else None + usage_percent = round(usage_ratio * 100, 1) if usage_ratio is not None else None + return { + "count": len(files), + "size_mb": round(total_size / 1024 / 1024, 2), + "size_bytes": total_size, + "limit_mb": limit_mb, + "limit_bytes": limit_bytes, + "limited": limit_bytes > 0, + "usage_ratio": round(usage_ratio, 4) if usage_ratio is not None else None, + "usage_percent": usage_percent, + } + + def list_files( + self, + media_type: MediaType, + *, + page: int = 1, + page_size: int = 1000, + ) -> dict[str, Any]: + """Return paginated local media files newest first.""" + files = sorted( + self._iter_files(media_type), + key=lambda path: path.stat().st_mtime, + reverse=True, + ) + total = len(files) + start = max(0, page - 1) * page_size + chunk = files[start : start + page_size] + items = [] + for path in chunk: + stat = path.stat() + items.append({ + "name": path.name, + "size_bytes": stat.st_size, + "modified_at": stat.st_mtime, + }) + return {"total": total, "page": page, "page_size": page_size, "items": items} + def delete(self, media_type: MediaType, name: str) -> bool: """Delete a single local media file and keep the index consistent.""" safe_name = self._validate_name(media_type, name) @@ -138,9 +189,7 @@ def _save( return path def _limit_bytes(self, media_type: MediaType) -> int: - cfg = self._config_provider() - limit_mb = max(0, int(cfg.get_int(f"cache.local.{media_type}_max_mb", 0))) - return limit_mb * 1024 * 1024 + return self.limit_mb(media_type) * 1024 * 1024 def _target_bytes(self, max_bytes: int) -> int: return max(0, int(max_bytes * _LOW_WATERMARK_RATIO)) @@ -479,6 +528,21 @@ def delete_local_media_file(media_type: MediaType, name: str) -> bool: return local_media_cache.delete(media_type, name) +def local_media_stats(media_type: MediaType) -> dict[str, Any]: + """Return count, size, and configured limit for local media files.""" + return local_media_cache.stats(media_type) + + +def list_local_media_files( + media_type: MediaType, + *, + page: int = 1, + page_size: int = 1000, +) -> dict[str, Any]: + """Return paginated local media files newest first.""" + return local_media_cache.list_files(media_type, page=page, page_size=page_size) + + async def reconcile_local_media_cache_async( media_type: MediaType | None = None, ) -> None: @@ -494,7 +558,9 @@ async def reconcile_local_media_cache_async( "LocalMediaCacheStore", "clear_local_media_files", "delete_local_media_file", + "list_local_media_files", "local_media_cache", + "local_media_stats", "reconcile_local_media_cache_async", "save_local_image", "save_local_video", diff --git a/app/products/anthropic/messages.py b/app/products/anthropic/messages.py index 42ba98a17..be01b35e1 100644 --- a/app/products/anthropic/messages.py +++ b/app/products/anthropic/messages.py @@ -31,7 +31,8 @@ from app.products.openai.chat import ( _stream_chat, _extract_message, _resolve_image, _quota_sync, _fail_sync, _parse_retry_codes, _feedback_kind, _log_task_exception, - _configured_retry_codes, _should_retry_upstream, + _adapter_has_visible_output, _configured_retry_codes, + _empty_upstream_response_error, _should_retry_upstream, _StreamStartGate, ) from app.products._account_selection import reserve_account, selection_max_retries from app.products.openai._tool_sieve import ToolSieve @@ -345,11 +346,12 @@ async def _run_stream() -> AsyncGenerator[str, None]: tool_output_tokens = 0 block_index = 0 # tracks next content_block index collected_annotations: list[dict] = [] + gate = _StreamStartGate() try: try: # message_start - yield _sse("message_start", { + for out in gate.emit(_sse("message_start", { "type": "message_start", "message": { "id": msg_id, @@ -360,8 +362,10 @@ async def _run_stream() -> AsyncGenerator[str, None]: "stop_reason": None, "usage": {"input_tokens": estimate_prompt_tokens(internal_message), "output_tokens": 0}, }, - }) - yield _sse("ping", {"type": "ping"}) + })): + yield out + for out in gate.emit(_sse("ping", {"type": "ping"})): + yield out ended = False async for line in _stream_chat( @@ -385,35 +389,38 @@ async def _run_stream() -> AsyncGenerator[str, None]: if ev.kind == "thinking" and emit_think and not think_closed: if not think_started: think_started = True - yield _sse("content_block_start", { + for out in gate.emit(_sse("content_block_start", { "type": "content_block_start", "index": block_index, "content_block": {"type": "thinking", "thinking": ""}, - }) + })): + yield out think_buf.append(ev.content) - yield _sse("content_block_delta", { + for out in gate.emit(_sse("content_block_delta", { "type": "content_block_delta", "index": block_index, "delta": {"type": "thinking_delta", "thinking": ev.content}, - }) + })): + yield out elif ev.kind == "text": # Close thinking block if open if think_started and not think_closed: think_closed = True - yield _sse("content_block_stop", { + for out in gate.emit(_sse("content_block_stop", { "type": "content_block_stop", "index": block_index, - }) + })): + yield out block_index += 1 # Feed through ToolSieve if tools active if sieve is not None: safe_text, calls = sieve.feed(ev.content) - if calls is not None: + if calls: # Emit tool_use blocks for call in calls: - yield _sse("content_block_start", { + for out in gate.emit(_sse("content_block_start", { "type": "content_block_start", "index": block_index, "content_block": { @@ -422,19 +429,22 @@ async def _run_stream() -> AsyncGenerator[str, None]: "name": call.name, "input": {}, }, - }) - yield _sse("content_block_delta", { + }), visible=True): + yield out + for out in gate.emit(_sse("content_block_delta", { "type": "content_block_delta", "index": block_index, "delta": { "type": "input_json_delta", "partial_json": call.arguments, }, - }) - yield _sse("content_block_stop", { + })): + yield out + for out in gate.emit(_sse("content_block_stop", { "type": "content_block_stop", "index": block_index, - }) + })): + yield out block_index += 1 tool_output_tokens = estimate_tool_call_tokens(calls) tool_calls_emitted = True @@ -447,17 +457,19 @@ async def _run_stream() -> AsyncGenerator[str, None]: if text_chunk: if not text_started: text_started = True - yield _sse("content_block_start", { + for out in gate.emit(_sse("content_block_start", { "type": "content_block_start", "index": block_index, "content_block": {"type": "text", "text": ""}, - }) + })): + yield out text_buf.append(text_chunk) - yield _sse("content_block_delta", { + for out in gate.emit(_sse("content_block_delta", { "type": "content_block_delta", "index": block_index, "delta": {"type": "text_delta", "text": text_chunk}, - }) + }), visible=bool(text_chunk.strip())): + yield out elif ev.kind == "annotation" and ev.annotation_data: collected_annotations.append(ev.annotation_data) @@ -475,14 +487,15 @@ async def _run_stream() -> AsyncGenerator[str, None]: if calls: # Close text block if open if text_started: - yield _sse("content_block_stop", { + for out in gate.emit(_sse("content_block_stop", { "type": "content_block_stop", "index": block_index, - }) + })): + yield out block_index += 1 text_started = False for call in calls: - yield _sse("content_block_start", { + for out in gate.emit(_sse("content_block_start", { "type": "content_block_start", "index": block_index, "content_block": { @@ -491,19 +504,22 @@ async def _run_stream() -> AsyncGenerator[str, None]: "name": call.name, "input": {}, }, - }) - yield _sse("content_block_delta", { + }), visible=True): + yield out + for out in gate.emit(_sse("content_block_delta", { "type": "content_block_delta", "index": block_index, "delta": { "type": "input_json_delta", "partial_json": call.arguments, }, - }) - yield _sse("content_block_stop", { + })): + yield out + for out in gate.emit(_sse("content_block_stop", { "type": "content_block_stop", "index": block_index, - }) + })): + yield out block_index += 1 tool_output_tokens = estimate_tool_call_tokens(calls) tool_calls_emitted = True @@ -514,13 +530,16 @@ async def _run_stream() -> AsyncGenerator[str, None]: sources = adapter.search_sources_list() if sources: tool_delta["search_sources"] = sources - yield _sse("message_delta", { + for out in gate.emit(_sse("message_delta", { "type": "message_delta", "delta": tool_delta, "usage": {"output_tokens": tool_output_tokens}, - }) - yield _sse("message_stop", {"type": "message_stop"}) - yield "data: [DONE]\n\n" + }), visible=True): + yield out + for out in gate.emit(_sse("message_stop", {"type": "message_stop"})): + yield out + for out in gate.emit("data: [DONE]\n\n"): + yield out success = True logger.info("messages stream tool_calls: attempt={}/{} model={}", attempt + 1, max_retries + 1, model) @@ -532,37 +551,43 @@ async def _run_stream() -> AsyncGenerator[str, None]: chunk = img_text + "\n" text_buf.append(chunk) if text_started: - yield _sse("content_block_delta", { + for out in gate.emit(_sse("content_block_delta", { "type": "content_block_delta", "index": block_index, "delta": {"type": "text_delta", "text": chunk}, - }) + }), visible=bool(img_text.strip())): + yield out references = adapter.references_suffix() if references: text_buf.append(references) if text_started: - yield _sse("content_block_delta", { + for out in gate.emit(_sse("content_block_delta", { "type": "content_block_delta", "index": block_index, "delta": {"type": "text_delta", "text": references}, - }) + }), visible=True): + yield out # Close open blocks if think_started and not think_closed: - yield _sse("content_block_stop", { + for out in gate.emit(_sse("content_block_stop", { "type": "content_block_stop", "index": block_index, - }) + })): + yield out block_index += 1 if text_started: - yield _sse("content_block_stop", { + for out in gate.emit(_sse("content_block_stop", { "type": "content_block_stop", "index": block_index, - }) + })): + yield out full_text = "".join(text_buf) + if not _adapter_has_visible_output(adapter, extra_text=full_text): + raise _empty_upstream_response_error() full_think = "".join(think_buf) out_tokens = estimate_tokens(full_text) if full_think: @@ -575,13 +600,16 @@ async def _run_stream() -> AsyncGenerator[str, None]: msg_delta["search_sources"] = sources if collected_annotations: msg_delta["annotations"] = collected_annotations - yield _sse("message_delta", { + for out in gate.emit(_sse("message_delta", { "type": "message_delta", "delta": msg_delta, "usage": {"output_tokens": out_tokens}, - }) - yield _sse("message_stop", {"type": "message_stop"}) - yield "data: [DONE]\n\n" + }), visible=True): + yield out + for out in gate.emit(_sse("message_stop", {"type": "message_stop"})): + yield out + for out in gate.emit("data: [DONE]\n\n"): + yield out success = True logger.info( "messages stream completed: attempt={}/{} model={} text_len={} think_len={} images={}", @@ -664,6 +692,8 @@ async def _run_stream() -> AsyncGenerator[str, None]: break if ended: break + if not _adapter_has_visible_output(adapter): + raise _empty_upstream_response_error() success = True except UpstreamError as exc: diff --git a/app/products/anthropic/router.py b/app/products/anthropic/router.py index 34b1f2126..cacfc3575 100644 --- a/app/products/anthropic/router.py +++ b/app/products/anthropic/router.py @@ -1,6 +1,6 @@ """Anthropic Messages API router (/v1/messages).""" -from typing import Any +from typing import Any, AsyncGenerator, AsyncIterable import orjson from fastapi import APIRouter, Depends, Request @@ -8,7 +8,7 @@ from pydantic import BaseModel, Field from app.platform.auth.middleware import verify_api_key -from app.platform.errors import AppError, ValidationError +from app.platform.errors import AppError, UpstreamError, ValidationError from app.platform.logging.logger import logger from app.control.model import registry as model_registry @@ -72,6 +72,26 @@ async def _safe_sse_anthropic(stream): yield "data: [DONE]\n\n" +async def _prime_sse(stream: AsyncIterable[str]) -> AsyncGenerator[str, None]: + """Read the first chunk before StreamingResponse sends HTTP 200.""" + iterator = stream.__aiter__() + try: + first = await anext(iterator) + except StopAsyncIteration as exc: + raise UpstreamError( + "Upstream returned empty response", + status=429, + body="empty_response", + ) from exc + + async def _primed() -> AsyncGenerator[str, None]: + yield first + async for chunk in iterator: + yield chunk + + return _primed() + + # --------------------------------------------------------------------------- # /v1/messages # --------------------------------------------------------------------------- @@ -118,6 +138,7 @@ async def messages_endpoint(req: MessagesRequest): if isinstance(result, dict): return JSONResponse(result) + result = await _prime_sse(result) return StreamingResponse( _safe_sse_anthropic(result), media_type = "text/event-stream", diff --git a/app/products/openai/chat.py b/app/products/openai/chat.py index 457bac655..cca70f6d4 100644 --- a/app/products/openai/chat.py +++ b/app/products/openai/chat.py @@ -191,6 +191,61 @@ def _feedback_kind(exc: BaseException) -> "FeedbackKind": return feedback_kind_for_error(exc) +EMPTY_UPSTREAM_BODY = "empty_response" + + +def _empty_upstream_response_error() -> UpstreamError: + return UpstreamError( + "Upstream returned empty response", + status=429, + body=EMPTY_UPSTREAM_BODY, + ) + + +def _adapter_has_visible_output( + adapter: StreamAdapter, + *, + extra_text: str = "", + has_tool_calls: bool = False, +) -> bool: + """Return True when the adapter has user-visible response content. + + Thinking/reasoning buffers intentionally do not count here. + """ + if has_tool_calls: + return True + if (extra_text or "".join(adapter.text_buf)).strip(): + return True + if adapter.image_urls: + return True + if adapter.references_suffix().strip(): + return True + if adapter.search_sources_list(): + return True + return False + + +class _StreamStartGate: + """Buffer stream events until a visible payload is available.""" + + __slots__ = ("_pending", "visible") + + def __init__(self) -> None: + self._pending: list[str] = [] + self.visible = False + + def emit(self, chunk: str, *, visible: bool = False) -> list[str]: + if visible and not self.visible: + self.visible = True + out = self._pending + [chunk] + self._pending = [] + return out + if self.visible: + return [chunk] + self._pending.append(chunk) + return [] + + async def _download_image_bytes(token: str, url: str) -> tuple[bytes, str]: """Download image bytes via the shared asset transport used by /v1/images.""" from app.dataplane.reverse.protocol.xai_assets import infer_content_type @@ -513,6 +568,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: ended = False sieve = ToolSieve(tool_names) tool_calls_emitted = False + gate = _StreamStartGate() async for line in _stream_chat( token=token, mode_id=ModeId(selected_mode_id), @@ -538,8 +594,15 @@ async def _run_stream() -> AsyncGenerator[str, None]: chunk = make_stream_chunk( response_id, model, safe_text ) - yield f"data: {orjson.dumps(chunk).decode()}\n\n" - if parsed_calls is not None: + payload = f"data: {orjson.dumps(chunk).decode()}\n\n" + for out in gate.emit( + payload, visible=bool(safe_text.strip()) + ): + yield out + if parsed_calls: + for out in gate.emit("", visible=True): + if out: + yield out for i, tc in enumerate(parsed_calls): chunk = make_tool_call_chunk( response_id, @@ -571,12 +634,18 @@ async def _run_stream() -> AsyncGenerator[str, None]: chunk = make_stream_chunk( response_id, model, ev.content ) - yield f"data: {orjson.dumps(chunk).decode()}\n\n" + payload = f"data: {orjson.dumps(chunk).decode()}\n\n" + for out in gate.emit( + payload, visible=bool(ev.content.strip()) + ): + yield out elif ev.kind == "thinking" and emit_think: chunk = make_thinking_chunk( response_id, model, ev.content ) - yield f"data: {orjson.dumps(chunk).decode()}\n\n" + payload = f"data: {orjson.dumps(chunk).decode()}\n\n" + for out in gate.emit(payload): + yield out elif ev.kind == "annotation" and ev.annotation_data: collected_annotations.append(ev.annotation_data) elif ev.kind == "soft_stop": @@ -589,6 +658,9 @@ async def _run_stream() -> AsyncGenerator[str, None]: # Stream ended — flush sieve for any buffered XML flushed_calls = sieve.flush() if flushed_calls: + for out in gate.emit("", visible=True): + if out: + yield out for i, tc in enumerate(flushed_calls): chunk = make_tool_call_chunk( response_id, @@ -623,14 +695,25 @@ async def _run_stream() -> AsyncGenerator[str, None]: chunk = make_stream_chunk( response_id, model, img_text + "\n" ) - yield f"data: {orjson.dumps(chunk).decode()}\n\n" + payload = f"data: {orjson.dumps(chunk).decode()}\n\n" + for out in gate.emit( + payload, visible=bool(img_text.strip()) + ): + yield out references = adapter.references_suffix() if references: chunk = make_stream_chunk( response_id, model, references ) - yield f"data: {orjson.dumps(chunk).decode()}\n\n" + payload = f"data: {orjson.dumps(chunk).decode()}\n\n" + for out in gate.emit(payload, visible=True): + yield out + + if not _adapter_has_visible_output( + adapter, has_tool_calls=tool_calls_emitted + ): + raise _empty_upstream_response_error() chat_anns = _to_chat_annotations(collected_annotations) final = make_stream_chunk( @@ -644,8 +727,11 @@ async def _run_stream() -> AsyncGenerator[str, None]: sources = adapter.search_sources_list() if sources: final["search_sources"] = sources - yield f"data: {orjson.dumps(final).decode()}\n\n" - yield "data: [DONE]\n\n" + final_payload = f"data: {orjson.dumps(final).decode()}\n\n" + for out in gate.emit(final_payload, visible=True): + yield out + for out in gate.emit("data: [DONE]\n\n"): + yield out success = True logger.info( "chat stream completed: attempt={}/{} model={} image_count={}", @@ -750,6 +836,8 @@ async def _run_stream() -> AsyncGenerator[str, None]: break if ended: break + if not _adapter_has_visible_output(adapter): + raise _empty_upstream_response_error() success = True except UpstreamError as exc: @@ -870,6 +958,10 @@ async def _run_stream() -> AsyncGenerator[str, None]: __all__ = [ "completions", + "EMPTY_UPSTREAM_BODY", + "_adapter_has_visible_output", "_configured_retry_codes", + "_empty_upstream_response_error", "_should_retry_upstream", + "_StreamStartGate", ] diff --git a/app/products/openai/responses.py b/app/products/openai/responses.py index d816c7a92..1abe6a89d 100644 --- a/app/products/openai/responses.py +++ b/app/products/openai/responses.py @@ -21,6 +21,7 @@ from app.products._account_selection import reserve_account, selection_max_retries from .chat import _stream_chat, _extract_message, _resolve_image, _quota_sync, _fail_sync, _parse_retry_codes, _feedback_kind, _log_task_exception, _upstream_body_excerpt +from .chat import _adapter_has_visible_output, _empty_upstream_response_error, _StreamStartGate from .chat import _configured_retry_codes, _should_retry_upstream from ._format import ( make_resp_id, build_resp_usage, make_resp_object, format_sse, @@ -283,13 +284,15 @@ async def _run_stream() -> AsyncGenerator[str, None]: tool_calls_emitted = False detected_fc_items: list[dict] = [] collected_annotations: list[dict] = [] + gate = _StreamStartGate() try: try: - yield format_sse("response.created", { + for out in gate.emit(format_sse("response.created", { "type": "response.created", "response": make_resp_object(response_id, model, "in_progress", []), - }) + })): + yield out ended = False async for line in _stream_chat( @@ -313,49 +316,54 @@ async def _run_stream() -> AsyncGenerator[str, None]: if ev.kind == "thinking" and emit_think and not reasoning_closed: if not reasoning_started: reasoning_started = True - yield format_sse("response.output_item.added", { + for out in gate.emit(format_sse("response.output_item.added", { "type": "response.output_item.added", "output_index": 0, "item": { "id": reasoning_id, "type": "reasoning", "summary": [], "status": "in_progress", }, - }) - yield format_sse("response.reasoning_summary_part.added", { + })): + yield out + for out in gate.emit(format_sse("response.reasoning_summary_part.added", { "type": "response.reasoning_summary_part.added", "item_id": reasoning_id, "output_index": 0, "summary_index": 0, "part": {"type": "summary_text", "text": ""}, - }) + })): + yield out think_buf.append(ev.content) - yield format_sse("response.reasoning_summary_text.delta", { + for out in gate.emit(format_sse("response.reasoning_summary_text.delta", { "type": "response.reasoning_summary_text.delta", "item_id": reasoning_id, "output_index": 0, "summary_index": 0, "delta": ev.content, - }) + })): + yield out elif ev.kind == "text": if reasoning_started and not reasoning_closed: reasoning_closed = True full_think = "".join(think_buf) - yield format_sse("response.reasoning_summary_text.done", { + for out in gate.emit(format_sse("response.reasoning_summary_text.done", { "type": "response.reasoning_summary_text.done", "item_id": reasoning_id, "output_index": 0, "summary_index": 0, "text": full_think, - }) - yield format_sse("response.reasoning_summary_part.done", { + })): + yield out + for out in gate.emit(format_sse("response.reasoning_summary_part.done", { "type": "response.reasoning_summary_part.done", "item_id": reasoning_id, "output_index": 0, "summary_index": 0, "part": {"type": "summary_text", "text": full_think}, - }) - yield format_sse("response.output_item.done", { + })): + yield out + for out in gate.emit(format_sse("response.output_item.done", { "type": "response.output_item.done", "output_index": 0, "item": { @@ -364,17 +372,19 @@ async def _run_stream() -> AsyncGenerator[str, None]: "summary": [{"type": "summary_text", "text": full_think}], "status": "completed", }, - }) + })): + yield out # Feed through ToolSieve if tools are active if sieve is not None: safe_text, calls = sieve.feed(ev.content) - if calls is not None: + if calls: fc_items = _build_fc_items(calls) detected_fc_items = fc_items base_idx = 1 if reasoning_started else 0 async for evt in _emit_fc_events(fc_items, base_idx): - yield evt + for out in gate.emit(evt, visible=True): + yield out tool_calls_emitted = True ended = True break @@ -386,43 +396,47 @@ async def _run_stream() -> AsyncGenerator[str, None]: msg_idx = 1 if reasoning_started else 0 if not message_started: message_started = True - yield format_sse("response.output_item.added", { + for out in gate.emit(format_sse("response.output_item.added", { "type": "response.output_item.added", "output_index": msg_idx, "item": { "id": message_id, "type": "message", "role": "assistant", "content": [], "status": "in_progress", }, - }) - yield format_sse("response.content_part.added", { + })): + yield out + for out in gate.emit(format_sse("response.content_part.added", { "type": "response.content_part.added", "item_id": message_id, "output_index": msg_idx, "content_index": 0, "part": {"type": "output_text", "text": "", "annotations": []}, - }) + })): + yield out text_buf.append(text_chunk) - yield format_sse("response.output_text.delta", { + for out in gate.emit(format_sse("response.output_text.delta", { "type": "response.output_text.delta", "item_id": message_id, "output_index": msg_idx, "content_index": 0, "delta": text_chunk, - }) + }), visible=bool(text_chunk.strip())): + yield out elif ev.kind == "annotation" and ev.annotation_data: if message_started: collected_annotations.append(ev.annotation_data) msg_idx = 1 if reasoning_started else 0 - yield format_sse("response.output_text.annotation.added", { + for out in gate.emit(format_sse("response.output_text.annotation.added", { "type": "response.output_text.annotation.added", "item_id": message_id, "output_index": msg_idx, "content_index": 0, "annotation_index": len(collected_annotations) - 1, "annotation": ev.annotation_data, - }) + })): + yield out elif ev.kind == "soft_stop": ended = True @@ -439,7 +453,8 @@ async def _run_stream() -> AsyncGenerator[str, None]: detected_fc_items = fc_items base_idx = 1 if reasoning_started else 0 async for evt in _emit_fc_events(fc_items, base_idx): - yield evt + for out in gate.emit(evt, visible=True): + yield out tool_calls_emitted = True if tool_calls_emitted: @@ -457,14 +472,16 @@ async def _run_stream() -> AsyncGenerator[str, None]: pt = estimate_prompt_tokens(message) ct = estimate_tool_call_tokens(detected_fc_items) rt = estimate_tokens(full_think) if full_think else 0 - yield format_sse("response.completed", { + for out in gate.emit(format_sse("response.completed", { "type": "response.completed", "response": make_resp_object( response_id, model, "completed", output, build_resp_usage(pt, ct + rt, rt), ), - }) - yield "data: [DONE]\n\n" + }), visible=True): + yield out + for out in gate.emit("data: [DONE]\n\n"): + yield out success = True logger.info("responses stream tool_calls: attempt={}/{} model={}", attempt + 1, max_retries + 1, model) @@ -476,42 +493,48 @@ async def _run_stream() -> AsyncGenerator[str, None]: img_md = img_text + "\n" text_buf.append(img_md) if message_started: - yield format_sse("response.output_text.delta", { + for out in gate.emit(format_sse("response.output_text.delta", { "type": "response.output_text.delta", "item_id": message_id, "output_index": msg_idx, "content_index": 0, "delta": img_md, - }) + }), visible=bool(img_text.strip())): + yield out references = adapter.references_suffix() if references: text_buf.append(references) if message_started: - yield format_sse("response.output_text.delta", { + for out in gate.emit(format_sse("response.output_text.delta", { "type": "response.output_text.delta", "item_id": message_id, "output_index": msg_idx, "content_index": 0, "delta": references, - }) + }), visible=True): + yield out full_text = "".join(text_buf) + if not _adapter_has_visible_output(adapter, extra_text=full_text): + raise _empty_upstream_response_error() if message_started: - yield format_sse("response.output_text.done", { + for out in gate.emit(format_sse("response.output_text.done", { "type": "response.output_text.done", "item_id": message_id, "output_index": msg_idx, "content_index": 0, "text": full_text, - }) - yield format_sse("response.content_part.done", { + })): + yield out + for out in gate.emit(format_sse("response.content_part.done", { "type": "response.content_part.done", "item_id": message_id, "output_index": msg_idx, "content_index": 0, "part": {"type": "output_text", "text": full_text, "annotations": collected_annotations}, - }) + })): + yield out # 构建 message item(流式 output_item.done + response.completed 共用) sources = adapter.search_sources_list() msg_item: dict = { @@ -523,11 +546,12 @@ async def _run_stream() -> AsyncGenerator[str, None]: } if sources: msg_item["search_sources"] = sources - yield format_sse("response.output_item.done", { + for out in gate.emit(format_sse("response.output_item.done", { "type": "response.output_item.done", "output_index": msg_idx, "item": msg_item, - }) + })): + yield out full_think = "".join(think_buf) output = [] @@ -555,14 +579,16 @@ async def _run_stream() -> AsyncGenerator[str, None]: pt = estimate_prompt_tokens(message) ct = estimate_tokens(full_text) rt = estimate_tokens(full_think) if full_think else 0 - yield format_sse("response.completed", { + for out in gate.emit(format_sse("response.completed", { "type": "response.completed", "response": make_resp_object( response_id, model, "completed", output, build_resp_usage(pt, ct + rt, rt), ), - }) - yield "data: [DONE]\n\n" + }), visible=True): + yield out + for out in gate.emit("data: [DONE]\n\n"): + yield out success = True logger.info("responses stream completed: attempt={}/{} model={} text_len={} reasoning_len={} image_count={}", attempt + 1, max_retries + 1, model, @@ -644,6 +670,8 @@ async def _run_stream() -> AsyncGenerator[str, None]: break if ended: break + if not _adapter_has_visible_output(adapter): + raise _empty_upstream_response_error() success = True except UpstreamError as exc: diff --git a/app/products/openai/router.py b/app/products/openai/router.py index 11b82019f..737ad10de 100644 --- a/app/products/openai/router.py +++ b/app/products/openai/router.py @@ -11,7 +11,7 @@ from app.control.account.state_machine import is_manageable from app.platform.auth.middleware import verify_api_key -from app.platform.errors import AppError, ValidationError +from app.platform.errors import AppError, UpstreamError, ValidationError from app.platform.logging.logger import logger from app.platform.storage import image_files_dir, video_files_dir from app.control.model import registry as model_registry @@ -127,6 +127,26 @@ async def _safe_sse(stream: AsyncIterable[str]) -> AsyncGenerator[str, None]: yield "data: [DONE]\n\n" +async def _prime_sse(stream: AsyncIterable[str]) -> AsyncGenerator[str, None]: + """Read the first chunk before StreamingResponse sends HTTP 200.""" + iterator = stream.__aiter__() + try: + first = await anext(iterator) + except StopAsyncIteration as exc: + raise UpstreamError( + "Upstream returned empty response", + status=429, + body="empty_response", + ) from exc + + async def _primed() -> AsyncGenerator[str, None]: + yield first + async for chunk in iterator: + yield chunk + + return _primed() + + _SSE_HEADERS = {"Cache-Control": "no-cache", "Connection": "keep-alive"} @@ -335,6 +355,8 @@ async def _err_stream(): if isinstance(result, dict): return JSONResponse(result) + if not (spec.is_image_edit() or spec.is_image() or spec.is_video()): + result = await _prime_sse(result) return StreamingResponse( _safe_sse(result), media_type="text/event-stream", headers=_SSE_HEADERS ) @@ -414,6 +436,7 @@ async def responses_endpoint(req: ResponsesCreateRequest): if isinstance(result, dict): return JSONResponse(result) + result = await _prime_sse(result) return StreamingResponse( _safe_sse_responses(result), media_type = "text/event-stream", diff --git a/app/products/web/admin/__init__.py b/app/products/web/admin/__init__.py index 88093af78..313789485 100644 --- a/app/products/web/admin/__init__.py +++ b/app/products/web/admin/__init__.py @@ -196,7 +196,10 @@ async def update_config(req: ConfigPatchRequest): patch = _sanitize_proxy_config(req.root) _ensure_runtime_patch_allowed(patch) - cache_local_changed = _patch_touches_prefix(patch, "cache.local") + cache_local_changed = ( + _patch_touches_prefix(patch, "cache.local") + or _patch_touches_prefix(patch, "storage") + ) await config.update(patch) # config.update() only writes to the backend and invalidates the in-memory # snapshot (_version = None); it does not refresh the data. load() is diff --git a/app/products/web/admin/cache.py b/app/products/web/admin/cache.py index fc49ab597..c6abb504a 100644 --- a/app/products/web/admin/cache.py +++ b/app/products/web/admin/cache.py @@ -1,29 +1,21 @@ """Local media cache management — stats, list, clear, delete.""" import asyncio -from pathlib import Path -from typing import Any, Literal +from typing import Literal from fastapi import APIRouter, Query from pydantic import BaseModel -from app.platform.config.snapshot import get_config from app.platform.errors import AppError, ErrorKind from app.platform.storage import ( clear_local_media_files, delete_local_media_file, - image_files_dir, - video_files_dir, + list_local_media_files, + local_media_stats, ) router = APIRouter(prefix="/cache", tags=["Admin - Cache"]) -# --------------------------------------------------------------------------- -# Lightweight local media cache service. -# --------------------------------------------------------------------------- -_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"} -_VIDEO_EXTS = {".mp4", ".mov", ".m4v", ".webm", ".avi", ".mkv"} - class ClearCacheRequest(BaseModel): type: Literal["image", "video"] = "image" @@ -39,67 +31,6 @@ class DeleteCacheItemsRequest(BaseModel): names: list[str] -def _dir(media_type: str) -> Path: - return image_files_dir() if media_type == "image" else video_files_dir() - - -def _exts(media_type: str): - return _IMAGE_EXTS if media_type == "image" else _VIDEO_EXTS - - -def _limit_mb(media_type: str) -> int: - cfg = get_config() - return max(0, int(cfg.get_int(f"cache.local.{media_type}_max_mb", 0))) - - -def _stats(media_type: str) -> dict[str, Any]: - d = _dir(media_type) - files = [] - if d.exists(): - allowed = _exts(media_type) - files = [f for f in d.glob("*") if f.is_file() and f.suffix.lower() in allowed] - - total_size = sum(f.stat().st_size for f in files) - limit_mb = _limit_mb(media_type) - limit_bytes = limit_mb * 1024 * 1024 - usage_ratio = (total_size / limit_bytes) if limit_bytes > 0 else None - usage_percent = round(usage_ratio * 100, 1) if usage_ratio is not None else None - return { - "count": len(files), - "size_mb": round(total_size / 1024 / 1024, 2), - "size_bytes": total_size, - "limit_mb": limit_mb, - "limit_bytes": limit_bytes, - "limited": limit_bytes > 0, - "usage_ratio": round(usage_ratio, 4) if usage_ratio is not None else None, - "usage_percent": usage_percent, - } - - -def _list_files(media_type: str, page: int, page_size: int) -> dict[str, Any]: - d = _dir(media_type) - if not d.exists(): - return {"total": 0, "page": page, "page_size": page_size, "items": []} - allowed = _exts(media_type) - files = sorted( - (f for f in d.glob("*") if f.is_file() and f.suffix.lower() in allowed), - key=lambda f: f.stat().st_mtime, - reverse=True, - ) - total = len(files) - start = (page - 1) * page_size - chunk = files[start : start + page_size] - items = [] - for f in chunk: - st = f.stat() - items.append({ - "name": f.name, - "size_bytes": st.st_size, - "modified_at": st.st_mtime, - }) - return {"total": total, "page": page, "page_size": page_size, "items": items} - - # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @@ -107,8 +38,8 @@ def _list_files(media_type: str, page: int, page_size: int) -> dict[str, Any]: @router.get("") async def cache_stats(): return { - "local_image": _stats("image"), - "local_video": _stats("video"), + "local_image": local_media_stats("image"), + "local_video": local_media_stats("video"), } @@ -120,7 +51,10 @@ async def list_local( page_size: int = 1000, ): media_type = type_ or cache_type - return {"status": "success", **_list_files(media_type, page, page_size)} + return { + "status": "success", + **list_local_media_files(media_type, page=page, page_size=page_size), + } @router.post("/clear") diff --git a/app/statics/admin/config.html b/app/statics/admin/config.html index fc6a16ed0..0f033ca5c 100644 --- a/app/statics/admin/config.html +++ b/app/statics/admin/config.html @@ -430,36 +430,6 @@ }, ] }, - { - title: '本地存储', titleKey: 'config.schema.groups.storage', - section: 'storage', - fields: [ - { - key: 'media_max_mb', - label: '媒体缓存总上限(MB)', - labelKey: 'config.schema.fields.mediaCacheMaxMb.label', - type: 'number', - desc: '本地图片与视频缓存共享的总容量上限,单位 MB。设置为 0 或负数表示不限制。', - descKey: 'config.schema.fields.mediaCacheMaxMb.desc', - }, - { - key: 'image_max_mb', - label: '图片缓存上限(MB)', - labelKey: 'config.schema.fields.imageCacheMaxMb.label', - type: 'number', - desc: '本地图片缓存的独立容量上限,单位 MB。设置为 0 或负数表示不限制。', - descKey: 'config.schema.fields.imageCacheMaxMb.desc', - }, - { - key: 'video_max_mb', - label: '视频缓存上限(MB)', - labelKey: 'config.schema.fields.videoCacheMaxMb.label', - type: 'number', - desc: '本地视频缓存的独立容量上限,单位 MB。设置为 0 或负数表示不限制。', - descKey: 'config.schema.fields.videoCacheMaxMb.desc', - }, - ] - }, ] }, { diff --git a/tests/test_empty_response_handling.py b/tests/test_empty_response_handling.py new file mode 100644 index 000000000..fec3d7472 --- /dev/null +++ b/tests/test_empty_response_handling.py @@ -0,0 +1,246 @@ +import asyncio +import unittest +from types import SimpleNamespace +from unittest.mock import patch + +from app.control.account.enums import FeedbackKind +from app.platform.errors import UpstreamError +from app.products.anthropic import messages as anthropic_messages +from app.products.anthropic import router as anthropic_router +from app.products.openai import chat, responses +from app.products.openai import router as openai_router + + +class _FakeConfig: + def get(self, key: str, default=None): + return default + + def get_bool(self, key: str, default: bool = False) -> bool: + return default + + def get_float(self, key: str, default: float = 0.0) -> float: + return default + + def get_str(self, key: str, default: str = "") -> str: + return default + + +class _FakeDirectory: + def __init__(self) -> None: + self.feedbacks: list[tuple[str, FeedbackKind, int]] = [] + + async def release(self, acct) -> None: + return None + + async def feedback( + self, + token: str, + kind: FeedbackKind, + selected_mode_id: int, + now_s_val=None, + ) -> None: + self.feedbacks.append((token, kind, selected_mode_id)) + + +class _FakeAdapter: + def __init__( + self, + *, + text: str = "", + images: list[tuple[str, str]] | None = None, + references: str = "", + sources: list[dict] | None = None, + thinking: str = "", + ) -> None: + self.text_buf = [text] if text else [] + self.thinking_buf = [thinking] if thinking else [] + self.image_urls = images or [] + self._references = references + self._sources = sources + + def references_suffix(self) -> str: + return self._references + + def search_sources_list(self): + return self._sources + + +async def _empty_stream(**kwargs): + yield "data: [DONE]" + + +async def _noop_async(*args, **kwargs): + return None + + +async def _empty_generator(): + if False: + yield "" + + +class EmptyResponseHandlingTests(unittest.TestCase): + def test_visible_output_helper_ignores_thinking_only(self) -> None: + self.assertFalse( + chat._adapter_has_visible_output(_FakeAdapter(thinking="reasoning only")) + ) + + def test_visible_output_helper_accepts_text_tool_images_and_sources(self) -> None: + self.assertTrue(chat._adapter_has_visible_output(_FakeAdapter(text="hello"))) + self.assertTrue( + chat._adapter_has_visible_output( + _FakeAdapter(images=[("https://example.com/a.png", "img-1")]) + ) + ) + self.assertTrue( + chat._adapter_has_visible_output( + _FakeAdapter(sources=[{"url": "https://example.com", "title": "x"}]) + ) + ) + self.assertTrue( + chat._adapter_has_visible_output(_FakeAdapter(), has_tool_calls=True) + ) + + def test_empty_response_error_is_retryable_429(self) -> None: + exc = chat._empty_upstream_response_error() + + self.assertEqual(exc.status, 429) + self.assertEqual(exc.details["body"], chat.EMPTY_UPSTREAM_BODY) + + def test_stream_prime_turns_no_first_chunk_into_429(self) -> None: + async def _run() -> None: + await openai_router._prime_sse(_empty_generator()) + + with self.assertRaises(UpstreamError) as ctx: + asyncio.run(_run()) + + self.assertEqual(ctx.exception.status, 429) + + def test_anthropic_stream_prime_turns_no_first_chunk_into_429(self) -> None: + async def _run() -> None: + await anthropic_router._prime_sse(_empty_generator()) + + with self.assertRaises(UpstreamError) as ctx: + asyncio.run(_run()) + + self.assertEqual(ctx.exception.status, 429) + + def test_chat_non_stream_empty_response_raises_429_feedback(self) -> None: + directory = _FakeDirectory() + exc = self._run_chat_non_stream_empty(directory) + + self.assertEqual(exc.status, 429) + self.assertEqual(exc.details["body"], chat.EMPTY_UPSTREAM_BODY) + self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED) + + def test_chat_stream_empty_response_raises_429_before_first_chunk(self) -> None: + directory = _FakeDirectory() + + async def _run() -> None: + stream = await chat.completions( + model="grok-test", + messages=[{"role": "user", "content": "hello"}], + stream=True, + ) + await openai_router._prime_sse(stream) + + with self.assertRaises(UpstreamError) as ctx: + with self._patch_common(chat, directory): + asyncio.run(_run()) + + self.assertEqual(ctx.exception.status, 429) + self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED) + + def test_responses_non_stream_empty_response_raises_429_feedback(self) -> None: + directory = _FakeDirectory() + exc = self._run_responses_non_stream_empty(directory) + + self.assertEqual(exc.status, 429) + self.assertEqual(exc.details["body"], chat.EMPTY_UPSTREAM_BODY) + self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED) + + def test_anthropic_non_stream_empty_response_raises_429_feedback(self) -> None: + directory = _FakeDirectory() + exc = self._run_anthropic_non_stream_empty(directory) + + self.assertEqual(exc.status, 429) + self.assertEqual(exc.details["body"], chat.EMPTY_UPSTREAM_BODY) + self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED) + + def _run_chat_non_stream_empty(self, directory: _FakeDirectory) -> UpstreamError: + async def _run() -> None: + await chat.completions( + model="grok-test", + messages=[{"role": "user", "content": "hello"}], + stream=False, + ) + + with self.assertRaises(UpstreamError) as ctx: + with self._patch_common(chat, directory): + asyncio.run(_run()) + return ctx.exception + + def _run_responses_non_stream_empty(self, directory: _FakeDirectory) -> UpstreamError: + async def _run() -> None: + await responses.create( + model="grok-test", + input_val="hello", + instructions=None, + stream=False, + emit_think=True, + temperature=0.8, + top_p=0.95, + ) + + with self.assertRaises(UpstreamError) as ctx: + with self._patch_common(responses, directory): + asyncio.run(_run()) + return ctx.exception + + def _run_anthropic_non_stream_empty( + self, directory: _FakeDirectory + ) -> UpstreamError: + async def _run() -> None: + await anthropic_messages.create( + model="grok-test", + messages=[{"role": "user", "content": "hello"}], + stream=False, + emit_think=True, + temperature=0.8, + top_p=0.95, + ) + + with self.assertRaises(UpstreamError) as ctx: + with self._patch_common(anthropic_messages, directory): + asyncio.run(_run()) + return ctx.exception + + def _patch_common(self, module, directory: _FakeDirectory): + import contextlib + + @contextlib.contextmanager + def _ctx(): + with patch("app.dataplane.account._directory", directory): + with patch.object(module, "get_config", return_value=_FakeConfig()): + with patch.object( + module, + "resolve_model", + return_value=SimpleNamespace(mode_id=0), + ): + with patch.object(module, "selection_max_retries", return_value=0): + with patch.object( + module, + "reserve_account", + return_value=(SimpleNamespace(token="tok-test"), 0), + ): + with patch.object( + module, "_stream_chat", side_effect=_empty_stream + ): + with patch.object(module, "_fail_sync", _noop_async): + with patch.object(module, "_quota_sync", _noop_async): + yield + + return _ctx() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_legacy_cache_config.py b/tests/test_legacy_cache_config.py new file mode 100644 index 000000000..8b3bf186c --- /dev/null +++ b/tests/test_legacy_cache_config.py @@ -0,0 +1,66 @@ +import asyncio +import tempfile +import unittest +from pathlib import Path +from typing import Any + +from app.platform.config.snapshot import ConfigSnapshot + + +class _Backend: + def __init__(self, data: dict[str, Any]) -> None: + self.data = data + + async def load(self) -> dict[str, Any]: + return self.data + + async def apply_patch(self, patch: dict[str, Any]) -> None: + self.data.update(patch) + + async def version(self) -> object: + return 1 + + +class LegacyCacheConfigTests(unittest.TestCase): + def test_legacy_storage_cache_limits_map_to_cache_local(self) -> None: + cfg = asyncio.run(self._load({ + "storage": { + "image_max_mb": 12, + "video_max_mb": 34, + }, + })) + + self.assertEqual(cfg.get_int("cache.local.image_max_mb"), 12) + self.assertEqual(cfg.get_int("cache.local.video_max_mb"), 34) + + def test_cache_local_limits_win_over_legacy_storage(self) -> None: + cfg = asyncio.run(self._load({ + "storage": { + "image_max_mb": 12, + "video_max_mb": 34, + }, + "cache": { + "local": { + "image_max_mb": 56, + "video_max_mb": 78, + }, + }, + })) + + self.assertEqual(cfg.get_int("cache.local.image_max_mb"), 56) + self.assertEqual(cfg.get_int("cache.local.video_max_mb"), 78) + + async def _load(self, overrides: dict[str, Any]) -> ConfigSnapshot: + with tempfile.TemporaryDirectory() as tmpdir: + defaults = Path(tmpdir) / "config.defaults.toml" + defaults.write_text( + "[cache.local]\nimage_max_mb = 0\nvideo_max_mb = 0\n", + encoding="utf-8", + ) + cfg = ConfigSnapshot(backend=_Backend(overrides)) + await cfg.load(defaults) + return cfg + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_media_cache_limits.py b/tests/test_media_cache_limits.py index b1aa687dc..a8477517d 100644 --- a/tests/test_media_cache_limits.py +++ b/tests/test_media_cache_limits.py @@ -1,5 +1,6 @@ import tempfile import unittest +import os from pathlib import Path from unittest.mock import patch @@ -77,6 +78,36 @@ def test_save_video_prunes_oldest_video_only(self) -> None: self.assertFalse((video_dir / "old.mp4").exists()) self.assertTrue(image_path.exists()) + def test_stats_and_list_files_use_shared_cache_rules(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + image_dir = root / "images" + video_dir = root / "videos" + image_dir.mkdir() + video_dir.mkdir() + + older = image_dir / "older.png" + newer = image_dir / "newer.jpg" + ignored = image_dir / "ignored.txt" + older.write_bytes(b"a" * 128) + newer.write_bytes(b"b" * 256) + ignored.write_text("skip") + os.utime(older, (100, 100)) + os.utime(newer, (200, 200)) + + store = LocalMediaCacheStore( + config_provider=lambda: _StubConfig(image_max_mb=1) + ) + with patch("app.platform.storage.media_cache.image_files_dir", return_value=image_dir): + with patch("app.platform.storage.media_cache.video_files_dir", return_value=video_dir): + stats = store.stats("image") + listing = store.list_files("image", page=1, page_size=10) + + self.assertEqual(stats["count"], 2) + self.assertEqual(stats["size_bytes"], 384) + self.assertEqual(stats["limit_mb"], 1) + self.assertEqual([item["name"] for item in listing["items"]], ["newer.jpg", "older.png"]) + if __name__ == "__main__": unittest.main() From b5aa8f333db11ba6215708875adc9d3e555ba889 Mon Sep 17 00:00:00 2001 From: yangyang Date: Fri, 1 May 2026 00:24:48 +0800 Subject: [PATCH 07/14] Fix video retries and resource proxy handling --- app/control/proxy/__init__.py | 8 +- .../reverse/transport/asset_upload.py | 2 +- app/dataplane/reverse/transport/assets.py | 6 +- app/products/openai/video.py | 366 ++++++++++++------ config.defaults.toml | 4 +- tests/test_empty_response_handling.py | 30 +- tests/test_video_proxy_and_newapi.py | 197 ++++++++++ tests/test_video_reference_helpers.py | 41 +- 8 files changed, 511 insertions(+), 143 deletions(-) create mode 100644 tests/test_video_proxy_and_newapi.py diff --git a/app/control/proxy/__init__.py b/app/control/proxy/__init__.py index 9d7581d60..1af2dcd55 100644 --- a/app/control/proxy/__init__.py +++ b/app/control/proxy/__init__.py @@ -228,12 +228,8 @@ async def _pick_proxy_url(self, resource: bool = False) -> str | None: if self._egress_mode == EgressMode.DIRECT: return None async with self._lock: - # Prefer resource-specific nodes when available; fall back to base nodes. - nodes = ( - self._resource_nodes - if resource and self._resource_nodes - else self._nodes - ) + # Resource downloads are direct unless a resource proxy is explicit. + nodes = self._resource_nodes if resource else self._nodes if not nodes: return None if self._egress_mode == EgressMode.SINGLE_PROXY: diff --git a/app/dataplane/reverse/transport/asset_upload.py b/app/dataplane/reverse/transport/asset_upload.py index 891594424..dc3ce32ab 100644 --- a/app/dataplane/reverse/transport/asset_upload.py +++ b/app/dataplane/reverse/transport/asset_upload.py @@ -184,7 +184,7 @@ async def upload_from_input(token: str, file_input: str) -> tuple[str, str]: if _is_url(file_input): # Fetch the remote URL and re-upload as base64. proxy = await get_proxy_runtime() - lease = await proxy.acquire() + lease = await proxy.acquire(resource=True) try: headers = build_http_headers(token, lease=lease) kwargs = build_session_kwargs(lease=lease) diff --git a/app/dataplane/reverse/transport/assets.py b/app/dataplane/reverse/transport/assets.py index b786447ff..faa7135a0 100644 --- a/app/dataplane/reverse/transport/assets.py +++ b/app/dataplane/reverse/transport/assets.py @@ -194,7 +194,11 @@ async def download_asset( } proxy = await get_proxy_runtime() - lease = await proxy.acquire(scope=ProxyScope.ASSET, kind=RequestKind.HTTP) + lease = await proxy.acquire( + scope=ProxyScope.ASSET, + kind=RequestKind.HTTP, + resource=True, + ) try: stream = await get_bytes_stream( diff --git a/app/products/openai/video.py b/app/products/openai/video.py index 13438e1fd..4e4789f8c 100644 --- a/app/products/openai/video.py +++ b/app/products/openai/video.py @@ -29,8 +29,10 @@ from app.platform.runtime.clock import now_s from app.platform.storage import save_local_video from app.control.account.enums import FeedbackKind +from app.control.account.runtime import get_refresh_service from app.control.model import registry as model_registry from app.control.model.registry import resolve as resolve_model +from app.dataplane.account.selector import current_strategy from app.dataplane.proxy import get_proxy_runtime from app.dataplane.proxy.adapters.headers import build_http_headers from app.dataplane.proxy.adapters.session import ResettableSession, build_session_kwargs @@ -52,7 +54,16 @@ make_stream_chunk, make_thinking_chunk, ) -from .chat import _fail_sync, _quota_sync, _feedback_kind +from .chat import ( + EMPTY_UPSTREAM_BODY, + _configured_retry_codes, + _fail_sync, + _feedback_kind, + _log_task_exception, + _quota_sync, + _should_retry_upstream, +) +from app.products._account_selection import selection_max_retries _IMAGE_MEDIA_TYPE = "MEDIA_POST_TYPE_IMAGE" _VIDEO_MEDIA_TYPE = "MEDIA_POST_TYPE_VIDEO" @@ -62,6 +73,7 @@ _VIDEO_JOB_TTL_S = 3600 _VIDEO_EXTENSION_REF_TYPE = "ORIGINAL_REF_TYPE_VIDEO_EXTENSION" _VIDEO_MAX_REFERENCES = 7 +_VIDEO_ALLOWED_POOL_IDS = frozenset((1, 2)) # super / heavy only _SUPPORTED_VIDEO_LENGTHS = frozenset({6, 10, 12, 16, 20}) _VIDEO_SIZE_MAP: dict[str, tuple[str, str]] = { "720x1280": ("9:16", "720p"), @@ -129,6 +141,8 @@ def to_dict(self) -> dict[str, Any]: payload["error"] = self.error if self.remixed_from_video_id: payload["remixed_from_video_id"] = self.remixed_from_video_id + if self.status == "completed" and self.content_path: + payload["metadata"] = {"url": _video_content_url(self.id)} return payload @@ -140,6 +154,12 @@ def _build_message(prompt: str, preset: str) -> str: return f"{prompt} {_PRESET_FLAGS.get(preset, '--mode=custom')}".strip() +def _video_content_url(video_id: str) -> str: + app_url = get_config().get_str("app.app_url", "").rstrip("/") + path = f"/v1/videos/{video_id}/content" + return f"{app_url}{path}" if app_url else path + + def _progress_reason(progress: int) -> str: return f"视频正在生成 {max(0, min(100, int(progress)))}%" @@ -211,6 +231,47 @@ def _build_segment_lengths(seconds: int) -> list[int]: raise AssertionError("unreachable") +def _normalize_segment_prompts( + prompt: str, + segment_lengths: list[int], + segment_prompts: list[str] | None = None, +) -> list[str]: + prompts = [p.strip() for p in (segment_prompts or [prompt]) if p and p.strip()] + if not prompts: + raise ValidationError("Video prompt cannot be empty", param="messages") + if len(prompts) > len(segment_lengths): + raise ValidationError( + f"Video generation uses {len(segment_lengths)} prompt segment(s) for this duration", + param="messages", + ) + while len(prompts) < len(segment_lengths): + prompts.append(prompts[-1]) + return prompts + + +def _video_pool_candidates(spec) -> tuple[int, ...]: + return tuple( + pool_id + for pool_id in spec.pool_candidates() + if pool_id in _VIDEO_ALLOWED_POOL_IDS + ) + + +def _empty_video_response_error() -> UpstreamError: + return UpstreamError( + "Video upstream returned empty response", + status=429, + body=EMPTY_UPSTREAM_BODY, + ) + + +def _transport_upstream_error(exc: BaseException, *, context: str) -> UpstreamError: + if isinstance(exc, UpstreamError): + return exc + body = str(exc).replace("\n", "\\n")[:400] + return UpstreamError(f"{context}: {exc}", status=502, body=body) + + def _video_create_payload( *, prompt: str, @@ -338,13 +399,18 @@ async def _stream_video_request( kwargs = build_session_kwargs(lease=lease) async with ResettableSession(**kwargs) as session: - response = await session.post( - CHAT, - headers=headers, - data=orjson.dumps(payload), - timeout=timeout_s, - stream=True, - ) + try: + response = await session.post( + CHAT, + headers=headers, + data=orjson.dumps(payload), + timeout=timeout_s, + stream=True, + ) + except Exception as exc: + raise _transport_upstream_error( + exc, context="Video transport failed" + ) from exc if response.status_code != 200: body = response.content.decode("utf-8", "replace")[:300] raise UpstreamError( @@ -352,8 +418,13 @@ async def _stream_video_request( status=response.status_code, body=body, ) - async for line in response.aiter_lines(): - yield line + try: + async for line in response.aiter_lines(): + yield line + except Exception as exc: + raise _transport_upstream_error( + exc, context="Video stream read failed" + ) from exc def _absolutize_video_url(url: str) -> str: @@ -497,6 +568,7 @@ async def _collect_video_segment( final_thumbnail = "" video_post_id = "" stream_data_items: list[str] = [] + saw_video_signal = False async for line in _stream_video_request( token, @@ -518,6 +590,7 @@ async def _collect_video_segment( stream = _extract_streaming_video_response(obj) if stream: + saw_video_signal = True try: progress = int(stream.get("progress") or 0) except (TypeError, ValueError): @@ -544,6 +617,8 @@ async def _collect_video_segment( final_thumbnail = _absolutize_video_url(thumbnail) attachments = _extract_model_response_file_attachments(obj) + if attachments: + saw_video_signal = True if attachments and not final_asset_id: final_asset_id = attachments[0] @@ -556,6 +631,8 @@ async def _collect_video_segment( body="\n".join(stream_data_items), ) if not final_url: + if not stream_data_items or not saw_video_signal: + raise _empty_video_response_error() raise UpstreamError( "Video generation returned no final video URL", body="\n".join(stream_data_items), @@ -645,8 +722,11 @@ async def _generate_video_with_token( preset: str, timeout_s: float, input_references: list[dict[str, Any]] | None = None, + segment_prompts: list[str] | None = None, progress_cb: Callable[[int], Awaitable[None]] | None = None, ) -> _VideoArtifact: + segments = _build_segment_lengths(seconds) + prompts = _normalize_segment_prompts(prompt, segments, segment_prompts) references: list[_VideoReference] = [] if input_references: references = await _prepare_video_references(token, input_references) @@ -655,7 +735,7 @@ async def _generate_video_with_token( post = await create_media_post( token, media_type=_VIDEO_MEDIA_TYPE, - prompt=prompt, + prompt=prompts[0], referer="https://grok.com/imagine", ) post_data = post.get("post") @@ -665,16 +745,16 @@ async def _generate_video_with_token( if not parent_post_id: raise UpstreamError("Video create-post returned no post id") - segments = _build_segment_lengths(seconds) total_segments = len(segments) artifact: _VideoArtifact | None = None extend_post_id = parent_post_id elapsed_seconds = 0 for index, segment_length in enumerate(segments): + segment_prompt = prompts[index] if index == 0: payload = _video_create_payload( - prompt=prompt, + prompt=segment_prompt, parent_post_id=parent_post_id, aspect_ratio=aspect_ratio, resolution_name=resolution_name, @@ -687,7 +767,7 @@ async def _generate_video_with_token( referer = "https://grok.com/imagine" else: payload = _video_extend_payload( - prompt=prompt, + prompt=segment_prompt, parent_post_id=parent_post_id, extend_post_id=extend_post_id, aspect_ratio=aspect_ratio, @@ -732,6 +812,7 @@ async def _run_video_generation( seconds: int, preset: str = "custom", input_references: list[dict[str, Any]] | None = None, + segment_prompts: list[str] | None = None, progress_cb: Callable[[int], Awaitable[None]] | None = None, ) -> _VideoArtifact: async def _runner(token: str, timeout_s: float) -> _VideoArtifact: @@ -744,6 +825,7 @@ async def _runner(token: str, timeout_s: float) -> _VideoArtifact: preset=preset, timeout_s=timeout_s, input_references=input_references, + segment_prompts=segment_prompts, progress_cb=progress_cb, ) @@ -760,44 +842,94 @@ async def _run_video_with_account( spec = resolve_model(model) if not spec.is_video(): raise ValidationError(f"Model {model!r} is not a video model", param="model") + pool_candidates = _video_pool_candidates(spec) + if not pool_candidates: + raise RateLimitError("No super/heavy accounts available for video generation") from app.dataplane.account import _directory as _acct_dir if _acct_dir is None: raise RateLimitError("Account directory not initialised") - acct = await _acct_dir.reserve( - pool_candidates=spec.pool_candidates(), - mode_id=int(spec.mode_id), - now_s_override=now_s(), - ) - if acct is None: - raise RateLimitError("No available accounts for video generation") + max_retries = selection_max_retries() + retry_codes = _configured_retry_codes(cfg) + excluded: list[str] = [] + mode_id = int(spec.mode_id) - token = acct.token - success = False - fail_exc: BaseException | None = None - try: - artifact = await runner(token, timeout_s) - success = True - return artifact - except BaseException as exc: - fail_exc = exc - raise - finally: - await _acct_dir.release(acct) - kind = ( - FeedbackKind.SUCCESS - if success - else _feedback_kind(fail_exc) - if fail_exc - else FeedbackKind.SERVER_ERROR + async def _reserve(): + acct = await _acct_dir.reserve( + pool_candidates=pool_candidates, + mode_id=mode_id, + now_s_override=now_s(), + exclude_tokens=excluded or None, ) - await _acct_dir.feedback(token, kind, int(spec.mode_id)) - if success: - asyncio.create_task(_quota_sync(token, int(spec.mode_id))) - else: - asyncio.create_task(_fail_sync(token, int(spec.mode_id), fail_exc)) + if acct is not None or current_strategy() == "random": + return acct + refresh_svc = get_refresh_service() + if refresh_svc is not None: + await refresh_svc.refresh_on_demand() + acct = await _acct_dir.reserve( + pool_candidates=pool_candidates, + mode_id=mode_id, + now_s_override=now_s(), + exclude_tokens=excluded or None, + ) + return acct + + for attempt in range(max_retries + 1): + acct = await _reserve() + if acct is None: + raise RateLimitError("No available super/heavy accounts for video generation") + + token = acct.token + success = False + retry = False + fail_exc: BaseException | None = None + try: + artifact = await runner(token, timeout_s) + success = True + return artifact + except UpstreamError as exc: + fail_exc = exc + if _should_retry_upstream(exc, retry_codes) and attempt < max_retries: + retry = True + logger.warning( + "video retry scheduled: attempt={}/{} status={} token={}...", + attempt + 1, + max_retries, + exc.status, + token[:8], + ) + else: + raise + except BaseException as exc: + fail_exc = exc + raise + finally: + await _acct_dir.release(acct) + kind = ( + FeedbackKind.SUCCESS + if success + else _feedback_kind(fail_exc) + if fail_exc + else FeedbackKind.SERVER_ERROR + ) + await _acct_dir.feedback(token, kind, mode_id) + if success: + asyncio.create_task(_quota_sync(token, mode_id)).add_done_callback( + _log_task_exception + ) + else: + asyncio.create_task(_fail_sync(token, mode_id, fail_exc)).add_done_callback( + _log_task_exception + ) + + if retry: + excluded.append(token) + continue + break + + raise RateLimitError("No available super/heavy accounts for video generation") async def _put_video_job(job: _VideoJob) -> None: @@ -847,33 +979,11 @@ async def _run_video_job( default=default_resolution_name, ) resolved_preset = _resolve_video_preset(preset) - spec = resolve_model(job.model) - from app.dataplane.account import _directory as _acct_dir - - if _acct_dir is None: - raise RateLimitError("Account directory not initialised") - - acct = await _acct_dir.reserve( - pool_candidates=spec.pool_candidates(), - mode_id=int(spec.mode_id), - now_s_override=now_s(), - ) - if acct is None: - raise RateLimitError("No available accounts for video generation") - - token = acct.token - success = False - fail_exc: BaseException | None = None - try: - cfg = get_config() - timeout_s = cfg.get_float("video.timeout", 180.0) - - async def _progress(progress: int) -> None: - await _set_job_status( - job, status="in_progress", progress=max(1, progress) - ) + async def _progress(progress: int) -> None: + await _set_job_status(job, status="in_progress", progress=max(1, progress)) + async def _runner(token: str, timeout_s: float) -> tuple[_VideoArtifact, bytes]: artifact = await _generate_video_with_token( token=token, prompt=prompt, @@ -886,24 +996,9 @@ async def _progress(progress: int) -> None: progress_cb=_progress, ) raw, _mime = await _download_video_bytes(token, artifact.video_url) - success = True - except BaseException as exc: - fail_exc = exc - raise - finally: - await _acct_dir.release(acct) - kind = ( - FeedbackKind.SUCCESS - if success - else _feedback_kind(fail_exc) - if fail_exc - else FeedbackKind.SERVER_ERROR - ) - await _acct_dir.feedback(token, kind, int(spec.mode_id)) - if success: - asyncio.create_task(_quota_sync(token, int(spec.mode_id))) - else: - asyncio.create_task(_fail_sync(token, int(spec.mode_id), fail_exc)) + return artifact, raw + + artifact, raw = await _run_video_with_account(model=job.model, runner=_runner) path = _save_video_bytes(raw, job.id) async with _VIDEO_JOBS_LOCK: @@ -999,46 +1094,57 @@ async def content_path(video_id: str) -> Path: def _extract_video_prompt_and_reference( messages: list[dict], ) -> tuple[str, list[dict[str, Any]] | None]: - prompt = "" - reference_urls: list[str] = [] + prompts, input_references = _extract_video_segment_prompts_and_references(messages) + return prompts[-1], input_references - for msg in reversed(messages): - content = msg.get("content", "") - if isinstance(content, str) and content.strip(): - prompt = content.strip() - if prompt: - break - continue - if not isinstance(content, list): + +def _text_and_references_from_content( + content: str | list[dict[str, Any]] | None, +) -> tuple[str, list[str]]: + if isinstance(content, str): + return content.strip(), [] + if not isinstance(content, list): + return "", [] + + text_parts: list[str] = [] + reference_urls: list[str] = [] + for item in content: + if not isinstance(item, dict): continue + item_type = item.get("type") + if item_type == "text": + text = str(item.get("text") or "").strip() + if text: + text_parts.append(text) + elif item_type == "image_url": + image_url = item.get("image_url") + if isinstance(image_url, dict): + url = str(image_url.get("url") or "").strip() + if url: + reference_urls.append(url) + elif isinstance(image_url, str) and image_url.strip(): + reference_urls.append(image_url.strip()) + return " ".join(text_parts).strip(), reference_urls + + +def _extract_video_segment_prompts_and_references( + messages: list[dict], +) -> tuple[list[str], list[dict[str, Any]] | None]: + prompts: list[str] = [] + reference_urls: list[str] = [] - text_parts: list[str] = [] - block_references: list[str] = [] - for item in content: - if not isinstance(item, dict): - continue - item_type = item.get("type") - if item_type == "text": - text = str(item.get("text") or "").strip() - if text: - text_parts.append(text) - elif item_type == "image_url": - image_url = item.get("image_url") - if isinstance(image_url, dict): - url = str(image_url.get("url") or "").strip() - if url: - block_references.append(url) - elif isinstance(image_url, str) and image_url.strip(): - block_references.append(image_url.strip()) - - if text_parts: - prompt = " ".join(text_parts) + for msg in messages: + if msg.get("role", "user") != "user": + continue + prompt, block_references = _text_and_references_from_content( + msg.get("content", "") + ) + if prompt: + prompts.append(prompt) if block_references and not reference_urls: reference_urls = block_references - if prompt: - break - if not prompt: + if not prompts: raise ValidationError("Video prompt cannot be empty", param="messages") input_references: list[dict[str, Any]] | None = None @@ -1049,7 +1155,7 @@ def _extract_video_prompt_and_reference( param="messages", ) input_references = [{"image_url": url} for url in reference_urls] - return prompt, input_references + return prompts, input_references async def completions( @@ -1070,7 +1176,16 @@ async def completions( default=default_resolution_name, ) resolved_preset = _resolve_video_preset(preset) - prompt, input_references = _extract_video_prompt_and_reference(messages) + segments = _build_segment_lengths(seconds) + segment_prompts, input_references = _extract_video_segment_prompts_and_references( + messages + ) + normalized_segment_prompts = _normalize_segment_prompts( + segment_prompts[0], + segments, + segment_prompts, + ) + prompt = normalized_segment_prompts[0] cfg = get_config() is_stream = stream if stream is not None else cfg.get_bool("features.stream", False) @@ -1087,6 +1202,7 @@ async def _runner(token: str, timeout_s: float) -> str: preset=resolved_preset, timeout_s=timeout_s, input_references=input_references, + segment_prompts=normalized_segment_prompts, progress_cb=progress_cb, ) file_id = hashlib.sha1(artifact.video_url.encode("utf-8")).hexdigest()[:32] @@ -1154,5 +1270,9 @@ async def _progress(progress: int) -> None: "validate_video_length", "completions", "_build_segment_lengths", + "_empty_video_response_error", + "_extract_video_segment_prompts_and_references", + "_normalize_segment_prompts", "_resolve_video_size", + "_video_pool_candidates", ] diff --git a/config.defaults.toml b/config.defaults.toml index f7a9b9437..8c132b639 100644 --- a/config.defaults.toml +++ b/config.defaults.toml @@ -78,9 +78,9 @@ mode = "direct" proxy_url = "" # 基础代理池(API 流量,proxy_pool 模式必填) proxy_pool = [] -# 资源代理 URL(图片/视频下载,未配置则回落到 proxy_url) +# 资源代理 URL(图片/视频下载,未配置则直连) resource_proxy_url = "" -# 资源代理池(图片/视频下载,未配置则回落到 proxy_pool) +# 资源代理池(图片/视频下载,未配置则直连) resource_proxy_pool = [] # 跳过代理 SSL 证书验证(代理使用自签名证书时启用) skip_ssl_verify = false diff --git a/tests/test_empty_response_handling.py b/tests/test_empty_response_handling.py index fec3d7472..22d9c1bc8 100644 --- a/tests/test_empty_response_handling.py +++ b/tests/test_empty_response_handling.py @@ -7,7 +7,7 @@ from app.platform.errors import UpstreamError from app.products.anthropic import messages as anthropic_messages from app.products.anthropic import router as anthropic_router -from app.products.openai import chat, responses +from app.products.openai import chat, responses, video from app.products.openai import router as openai_router @@ -69,6 +69,10 @@ async def _empty_stream(**kwargs): yield "data: [DONE]" +async def _empty_video_stream(*args, **kwargs): + yield "data: [DONE]" + + async def _noop_async(*args, **kwargs): return None @@ -106,6 +110,30 @@ def test_empty_response_error_is_retryable_429(self) -> None: self.assertEqual(exc.status, 429) self.assertEqual(exc.details["body"], chat.EMPTY_UPSTREAM_BODY) + def test_video_empty_response_error_is_retryable_429(self) -> None: + exc = video._empty_video_response_error() + + self.assertEqual(exc.status, 429) + self.assertEqual(exc.details["body"], chat.EMPTY_UPSTREAM_BODY) + + def test_video_segment_empty_stream_raises_429(self) -> None: + async def _run() -> None: + await video._collect_video_segment( + token="tok-test", + payload={}, + referer="https://grok.com/imagine", + timeout_s=1, + ) + + with self.assertRaises(UpstreamError) as ctx: + with patch.object( + video, "_stream_video_request", side_effect=_empty_video_stream + ): + asyncio.run(_run()) + + self.assertEqual(ctx.exception.status, 429) + self.assertEqual(ctx.exception.details["body"], chat.EMPTY_UPSTREAM_BODY) + def test_stream_prime_turns_no_first_chunk_into_429(self) -> None: async def _run() -> None: await openai_router._prime_sse(_empty_generator()) diff --git a/tests/test_video_proxy_and_newapi.py b/tests/test_video_proxy_and_newapi.py new file mode 100644 index 000000000..9afbafd0a --- /dev/null +++ b/tests/test_video_proxy_and_newapi.py @@ -0,0 +1,197 @@ +import asyncio +import tempfile +import time +import unittest +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import patch + +from app.control.proxy import ProxyDirectory +from app.dataplane.reverse.transport import assets +from app.products.openai import video + + +class _ProxyConfig: + def __init__(self, *, resource_proxy_url: str = "") -> None: + self.resource_proxy_url = resource_proxy_url + + def get_str(self, key: str, default: str = "") -> str: + values = { + "proxy.egress.mode": "single_proxy", + "proxy.clearance.mode": "none", + "proxy.egress.proxy_url": "http://proxy.example:8080", + "proxy.egress.resource_proxy_url": self.resource_proxy_url, + } + return values.get(key, default) + + def get_list(self, key: str, default=None): + return default or [] + + def get_int(self, key: str, default: int = 0) -> int: + return default + + +class _AppConfig: + def __init__(self, *, app_url: str = "") -> None: + self.app_url = app_url + + def get_str(self, key: str, default: str = "") -> str: + if key == "app.app_url": + return self.app_url + return default + + +class _AssetConfig: + def get_float(self, key: str, default: float = 0.0) -> float: + return default + + +class _VideoConfig: + def get_float(self, key: str, default: float = 0.0) -> float: + return default + + def get(self, key: str, default=None): + return default + + +class _FakeProxyRuntime: + def __init__(self) -> None: + self.acquire_kwargs = [] + + async def acquire(self, **kwargs): + self.acquire_kwargs.append(kwargs) + return SimpleNamespace(proxy_url=None) + + async def feedback(self, lease, result) -> None: + return None + + +class _FakeAccountDirectory: + def __init__(self) -> None: + self.reserve_calls = [] + self.feedbacks = [] + + async def reserve(self, **kwargs): + self.reserve_calls.append(kwargs) + return SimpleNamespace(token="tok-video") + + async def release(self, acct) -> None: + return None + + async def feedback(self, token, kind, mode_id) -> None: + self.feedbacks.append((token, kind, mode_id)) + + +async def _empty_bytes_stream(*args, **kwargs): + if False: + yield b"" + + +async def _noop_async(*args, **kwargs): + return None + + +class VideoProxyAndNewApiTests(unittest.TestCase): + def test_resource_download_is_direct_without_resource_proxy(self) -> None: + async def _run(): + directory = ProxyDirectory() + with patch("app.control.proxy.get_config", return_value=_ProxyConfig()): + await directory.load() + return await directory.acquire(resource=True) + + lease = asyncio.run(_run()) + + self.assertIsNone(lease.proxy_url) + + def test_resource_download_uses_explicit_resource_proxy(self) -> None: + async def _run(): + directory = ProxyDirectory() + cfg = _ProxyConfig(resource_proxy_url="http://res-proxy.example:8080") + with patch("app.control.proxy.get_config", return_value=cfg): + await directory.load() + return await directory.acquire(resource=True) + + lease = asyncio.run(_run()) + + self.assertEqual(lease.proxy_url, "http://res-proxy.example:8080") + + def test_download_asset_acquires_resource_lease(self) -> None: + async def _run(): + proxy = _FakeProxyRuntime() + with patch.object(assets, "get_proxy_runtime", return_value=proxy): + with patch.object(assets, "get_config", return_value=_AssetConfig()): + with patch.object( + assets, "get_bytes_stream", return_value=_empty_bytes_stream() + ): + await assets.download_asset( + "tok-test", + "https://assets.grok.com/users/u/file.mp4", + ) + return proxy.acquire_kwargs + + calls = asyncio.run(_run()) + + self.assertEqual(calls[-1]["resource"], True) + + def test_completed_video_job_returns_newapi_metadata_url(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "video.mp4" + path.write_bytes(b"fake-video") + job = video._VideoJob( + id="video_test", + model="grok-imagine-video", + prompt="hello", + seconds="6", + size="720x1280", + quality="standard", + created_at=int(time.time()), + status="completed", + progress=100, + completed_at=int(time.time()), + content_path=str(path), + ) + + with patch.object(video, "get_config", return_value=_AppConfig(app_url="https://api.example.com")): + payload = job.to_dict() + + self.assertEqual( + payload["metadata"]["url"], + "https://api.example.com/v1/videos/video_test/content", + ) + + def test_video_pool_candidates_exclude_basic(self) -> None: + spec = SimpleNamespace(pool_candidates=lambda: (0, 1, 2)) + + self.assertEqual(video._video_pool_candidates(spec), (1, 2)) + + def test_video_account_runner_reserves_only_super_or_heavy(self) -> None: + async def _run(): + directory = _FakeAccountDirectory() + spec = SimpleNamespace( + is_video=lambda: True, + mode_id=0, + pool_candidates=lambda: (0, 1, 2), + ) + + async def _runner(token: str, timeout_s: float) -> str: + return token + + with patch("app.dataplane.account._directory", directory): + with patch.object(video, "get_config", return_value=_VideoConfig()): + with patch.object(video, "resolve_model", return_value=spec): + with patch.object(video, "selection_max_retries", return_value=0): + with patch.object(video, "_quota_sync", _noop_async): + result = await video._run_video_with_account( + model="grok-imagine-video", + runner=_runner, + ) + return result, directory.reserve_calls + + result, calls = asyncio.run(_run()) + + self.assertEqual(result, "tok-video") + self.assertEqual(calls[0]["pool_candidates"], (1, 2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_video_reference_helpers.py b/tests/test_video_reference_helpers.py index a7d17141e..fcb4c6706 100644 --- a/tests/test_video_reference_helpers.py +++ b/tests/test_video_reference_helpers.py @@ -24,9 +24,9 @@ def _message_with_references(count: int) -> list[dict]: class VideoReferenceHelperTests(unittest.TestCase): - def test_chat_video_prompt_allows_five_references(self) -> None: + def test_chat_video_prompt_allows_seven_references(self) -> None: prompt, refs = video._extract_video_prompt_and_reference( - _message_with_references(5) + _message_with_references(7) ) self.assertEqual(prompt, "生成一个参考图视频") @@ -34,16 +34,16 @@ def test_chat_video_prompt_allows_five_references(self) -> None: refs, [ {"image_url": f"https://example.com/ref-{idx}.png"} - for idx in range(5) + for idx in range(7) ], ) - def test_chat_video_prompt_rejects_more_than_five_references(self) -> None: + def test_chat_video_prompt_rejects_more_than_seven_references(self) -> None: with self.assertRaises(ValidationError): - video._extract_video_prompt_and_reference(_message_with_references(6)) + video._extract_video_prompt_and_reference(_message_with_references(8)) - def test_prepare_video_references_rejects_more_than_five_references(self) -> None: - refs = [{"image_url": f"https://example.com/ref-{idx}.png"} for idx in range(6)] + def test_prepare_video_references_rejects_more_than_seven_references(self) -> None: + refs = [{"image_url": f"https://example.com/ref-{idx}.png"} for idx in range(8)] async def _run() -> None: await video._prepare_video_references("token", refs) @@ -51,17 +51,40 @@ async def _run() -> None: with self.assertRaises(ValidationError): asyncio.run(_run()) - def test_videos_create_rejects_more_than_five_multipart_references(self) -> None: + def test_videos_create_rejects_more_than_seven_multipart_references(self) -> None: async def _run() -> None: await router.videos_create( model="grok-imagine-video", prompt="生成一个参考图视频", - input_reference=[object() for _ in range(6)], # type: ignore[list-item] + input_reference=[object() for _ in range(8)], # type: ignore[list-item] ) with self.assertRaises(ValidationError): asyncio.run(_run()) + def test_chat_video_segment_prompts_follow_user_order(self) -> None: + prompts, refs = video._extract_video_segment_prompts_and_references( + [ + {"role": "system", "content": "ignore"}, + {"role": "user", "content": "第一段"}, + {"role": "assistant", "content": "ignore"}, + {"role": "user", "content": "第二段"}, + ] + ) + + self.assertEqual(prompts, ["第一段", "第二段"]) + self.assertIsNone(refs) + + def test_segment_prompts_reuse_last_prompt_when_short(self) -> None: + self.assertEqual( + video._normalize_segment_prompts("第一段", [10, 6], ["第一段"]), + ["第一段", "第一段"], + ) + + def test_segment_prompts_reject_extra_prompts(self) -> None: + with self.assertRaises(ValidationError): + video._normalize_segment_prompts("一", [6], ["一", "二"]) + if __name__ == "__main__": unittest.main() From 8c1f1f7c4f906689ea0f44c37cbcd643ea403c62 Mon Sep 17 00:00:00 2001 From: yangyang Date: Fri, 1 May 2026 12:51:03 +0800 Subject: [PATCH 08/14] Restore streaming reasoning output --- app/products/openai/chat.py | 9 +- app/products/openai/responses.py | 22 ++-- tests/test_empty_response_handling.py | 160 +++++++++++++++++++++++++- 3 files changed, 177 insertions(+), 14 deletions(-) diff --git a/app/products/openai/chat.py b/app/products/openai/chat.py index ef075e22c..b87f35390 100644 --- a/app/products/openai/chat.py +++ b/app/products/openai/chat.py @@ -208,6 +208,7 @@ def _adapter_has_visible_output( *, extra_text: str = "", has_tool_calls: bool = False, + include_thinking: bool = False, ) -> bool: """Return True when the adapter has user-visible response content. @@ -215,6 +216,8 @@ def _adapter_has_visible_output( """ if has_tool_calls: return True + if include_thinking and "".join(adapter.thinking_buf).strip(): + return True if (extra_text or "".join(adapter.text_buf)).strip(): return True if adapter.image_urls: @@ -658,7 +661,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: response_id, model, ev.content ) payload = f"data: {orjson.dumps(chunk).decode()}\n\n" - for out in gate.emit(payload): + for out in gate.emit(payload, visible=True): yield out elif ev.kind == "annotation" and ev.annotation_data: collected_annotations.append(ev.annotation_data) @@ -725,7 +728,9 @@ async def _run_stream() -> AsyncGenerator[str, None]: yield out if not _adapter_has_visible_output( - adapter, has_tool_calls=tool_calls_emitted + adapter, + has_tool_calls=tool_calls_emitted, + include_thinking=emit_think, ): raise _empty_upstream_response_error() diff --git a/app/products/openai/responses.py b/app/products/openai/responses.py index 1abe6a89d..9b94b881e 100644 --- a/app/products/openai/responses.py +++ b/app/products/openai/responses.py @@ -7,8 +7,6 @@ import asyncio from typing import Any, AsyncGenerator -import orjson - from app.platform.logging.logger import logger from app.platform.config.snapshot import get_config from app.platform.errors import RateLimitError, UpstreamError @@ -20,14 +18,14 @@ from app.dataplane.reverse.protocol.xai_chat import classify_line, StreamAdapter from app.products._account_selection import reserve_account, selection_max_retries -from .chat import _stream_chat, _extract_message, _resolve_image, _quota_sync, _fail_sync, _parse_retry_codes, _feedback_kind, _log_task_exception, _upstream_body_excerpt +from .chat import _stream_chat, _extract_message, _resolve_image, _quota_sync, _fail_sync, _feedback_kind, _log_task_exception, _upstream_body_excerpt from .chat import _adapter_has_visible_output, _empty_upstream_response_error, _StreamStartGate from .chat import _configured_retry_codes, _should_retry_upstream from ._format import ( make_resp_id, build_resp_usage, make_resp_object, format_sse, ) from app.dataplane.reverse.protocol.tool_prompt import ( - build_tool_system_prompt, extract_tool_names, inject_into_message, tool_calls_to_xml, + build_tool_system_prompt, extract_tool_names, inject_into_message, ) from app.dataplane.reverse.protocol.tool_parser import parse_tool_calls from ._tool_sieve import ToolSieve @@ -222,7 +220,6 @@ async def create( cfg = get_config() spec = resolve_model(model) - mode_id = int(spec.mode_id) # cast once, reuse everywhere messages: list[dict] = [] if instructions: @@ -323,7 +320,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: "id": reasoning_id, "type": "reasoning", "summary": [], "status": "in_progress", }, - })): + }), visible=True): yield out for out in gate.emit(format_sse("response.reasoning_summary_part.added", { "type": "response.reasoning_summary_part.added", @@ -331,7 +328,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: "output_index": 0, "summary_index": 0, "part": {"type": "summary_text", "text": ""}, - })): + }), visible=True): yield out think_buf.append(ev.content) for out in gate.emit(format_sse("response.reasoning_summary_text.delta", { @@ -340,7 +337,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: "output_index": 0, "summary_index": 0, "delta": ev.content, - })): + }), visible=True): yield out elif ev.kind == "text": @@ -516,7 +513,11 @@ async def _run_stream() -> AsyncGenerator[str, None]: yield out full_text = "".join(text_buf) - if not _adapter_has_visible_output(adapter, extra_text=full_text): + if not _adapter_has_visible_output( + adapter, + extra_text=full_text, + include_thinking=emit_think, + ): raise _empty_upstream_response_error() if message_started: for out in gate.emit(format_sse("response.output_text.done", { @@ -574,7 +575,8 @@ async def _run_stream() -> AsyncGenerator[str, None]: sources = adapter.search_sources_list() if sources: msg_item["search_sources"] = sources - output.append(msg_item) + if message_started or full_text: + output.append(msg_item) pt = estimate_prompt_tokens(message) ct = estimate_tokens(full_text) diff --git a/tests/test_empty_response_handling.py b/tests/test_empty_response_handling.py index 22d9c1bc8..769e8673a 100644 --- a/tests/test_empty_response_handling.py +++ b/tests/test_empty_response_handling.py @@ -1,4 +1,5 @@ import asyncio +import json import unittest from types import SimpleNamespace from unittest.mock import patch @@ -69,6 +70,21 @@ async def _empty_stream(**kwargs): yield "data: [DONE]" +def _chat_frame(response: dict) -> str: + return "data: " + json.dumps({"result": {"response": response}}) + "\n\n" + + +async def _thinking_stream(**kwargs): + yield _chat_frame({"token": "thinking one", "isThinking": True}) + yield "data: [DONE]" + + +async def _thinking_then_text_stream(**kwargs): + yield _chat_frame({"token": "thinking one", "isThinking": True}) + yield _chat_frame({"token": "hello", "messageTag": "final"}) + yield "data: [DONE]" + + async def _empty_video_stream(*args, **kwargs): yield "data: [DONE]" @@ -178,6 +194,140 @@ async def _run() -> None: self.assertEqual(ctx.exception.status, 429) self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED) + def test_chat_stream_thinking_flushes_before_text(self) -> None: + directory = _FakeDirectory() + + async def _run() -> str: + stream = await chat.completions( + model="grok-test", + messages=[{"role": "user", "content": "hello"}], + stream=True, + emit_think=True, + ) + return await anext(stream) + + with self._patch_common(chat, directory, stream_func=_thinking_then_text_stream): + first = asyncio.run(_run()) + + self.assertIn("reasoning_content", first) + self.assertIn("thinking one", first) + + def test_chat_stream_thinking_only_is_not_empty_when_enabled(self) -> None: + directory = _FakeDirectory() + + async def _run() -> list[str]: + stream = await chat.completions( + model="grok-test", + messages=[{"role": "user", "content": "hello"}], + stream=True, + emit_think=True, + ) + chunks = [] + async for chunk in stream: + chunks.append(chunk) + return chunks + + with self._patch_common(chat, directory, stream_func=_thinking_stream): + chunks = asyncio.run(_run()) + + joined = "".join(chunks) + self.assertIn("reasoning_content", joined) + self.assertIn("thinking one", joined) + self.assertIn("data: [DONE]", joined) + self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.SUCCESS) + + def test_chat_stream_thinking_only_raises_429_when_disabled(self) -> None: + directory = _FakeDirectory() + + async def _run() -> None: + stream = await chat.completions( + model="grok-test", + messages=[{"role": "user", "content": "hello"}], + stream=True, + emit_think=False, + ) + async for _chunk in stream: + pass + + with self.assertRaises(UpstreamError) as ctx: + with self._patch_common(chat, directory, stream_func=_thinking_stream): + asyncio.run(_run()) + + self.assertEqual(ctx.exception.status, 429) + self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED) + + def test_responses_stream_reasoning_flushes_before_text(self) -> None: + directory = _FakeDirectory() + + async def _run() -> list[str]: + stream = await responses.create( + model="grok-test", + input_val="hello", + instructions=None, + stream=True, + emit_think=True, + temperature=0.8, + top_p=0.95, + ) + return [await anext(stream), await anext(stream), await anext(stream)] + + with self._patch_common(responses, directory, stream_func=_thinking_then_text_stream): + chunks = asyncio.run(_run()) + + self.assertIn("response.created", chunks[0]) + self.assertIn("response.output_item.added", chunks[1]) + self.assertIn("response.reasoning_summary_part.added", chunks[2]) + + def test_responses_stream_reasoning_only_is_not_empty_when_enabled(self) -> None: + directory = _FakeDirectory() + + async def _run() -> list[str]: + stream = await responses.create( + model="grok-test", + input_val="hello", + instructions=None, + stream=True, + emit_think=True, + temperature=0.8, + top_p=0.95, + ) + chunks = [] + async for chunk in stream: + chunks.append(chunk) + return chunks + + with self._patch_common(responses, directory, stream_func=_thinking_stream): + chunks = asyncio.run(_run()) + + joined = "".join(chunks) + self.assertIn("response.reasoning_summary_text.delta", joined) + self.assertIn("response.completed", joined) + self.assertIn("data: [DONE]", joined) + self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.SUCCESS) + + def test_responses_stream_reasoning_only_raises_429_when_disabled(self) -> None: + directory = _FakeDirectory() + + async def _run() -> None: + stream = await responses.create( + model="grok-test", + input_val="hello", + instructions=None, + stream=True, + emit_think=False, + temperature=0.8, + top_p=0.95, + ) + async for _chunk in stream: + pass + + with self.assertRaises(UpstreamError) as ctx: + with self._patch_common(responses, directory, stream_func=_thinking_stream): + asyncio.run(_run()) + + self.assertEqual(ctx.exception.status, 429) + self.assertEqual(directory.feedbacks[-1][1], FeedbackKind.RATE_LIMITED) + def test_responses_non_stream_empty_response_raises_429_feedback(self) -> None: directory = _FakeDirectory() exc = self._run_responses_non_stream_empty(directory) @@ -242,7 +392,13 @@ async def _run() -> None: asyncio.run(_run()) return ctx.exception - def _patch_common(self, module, directory: _FakeDirectory): + def _patch_common( + self, + module, + directory: _FakeDirectory, + *, + stream_func=_empty_stream, + ): import contextlib @contextlib.contextmanager @@ -261,7 +417,7 @@ def _ctx(): return_value=(SimpleNamespace(token="tok-test"), 0), ): with patch.object( - module, "_stream_chat", side_effect=_empty_stream + module, "_stream_chat", side_effect=stream_func ): with patch.object(module, "_fail_sync", _noop_async): with patch.object(module, "_quota_sync", _noop_async): From bc1bb08118a6bdb51dc66c18dbeafc2bbd951b1a Mon Sep 17 00:00:00 2001 From: yangyang Date: Sat, 2 May 2026 14:36:39 +0800 Subject: [PATCH 09/14] Fix async video metadata file URL --- app/products/openai/video.py | 12 +++-- tests/test_video_proxy_and_newapi.py | 69 +++++++++++++++++++++++++++- 2 files changed, 77 insertions(+), 4 deletions(-) diff --git a/app/products/openai/video.py b/app/products/openai/video.py index 4e4789f8c..88c6fed1c 100644 --- a/app/products/openai/video.py +++ b/app/products/openai/video.py @@ -121,6 +121,7 @@ class _VideoJob: remixed_from_video_id: str | None = None video_url: str = "" content_path: str = "" + content_file_id: str = "" def to_dict(self) -> dict[str, Any]: payload: dict[str, Any] = { @@ -141,8 +142,11 @@ def to_dict(self) -> dict[str, Any]: payload["error"] = self.error if self.remixed_from_video_id: payload["remixed_from_video_id"] = self.remixed_from_video_id - if self.status == "completed" and self.content_path: - payload["metadata"] = {"url": _video_content_url(self.id)} + if self.status == "completed": + if self.content_file_id: + payload["metadata"] = {"url": _local_video_url(self.content_file_id)} + elif self.content_path: + payload["metadata"] = {"url": _video_content_url(self.id)} return payload @@ -1000,13 +1004,15 @@ async def _runner(token: str, timeout_s: float) -> tuple[_VideoArtifact, bytes]: artifact, raw = await _run_video_with_account(model=job.model, runner=_runner) - path = _save_video_bytes(raw, job.id) + file_id = hashlib.sha1(artifact.video_url.encode("utf-8")).hexdigest()[:32] + path = _save_video_bytes(raw, file_id) async with _VIDEO_JOBS_LOCK: job.status = "completed" job.progress = 100 job.completed_at = int(time.time()) job.video_url = artifact.video_url job.content_path = str(path) + job.content_file_id = file_id job.remixed_from_video_id = artifact.remixed_from_video_id except Exception as exc: logger.exception("video job failed: job_id={} error={}", job.id, exc) diff --git a/tests/test_video_proxy_and_newapi.py b/tests/test_video_proxy_and_newapi.py index 9afbafd0a..fffd1d201 100644 --- a/tests/test_video_proxy_and_newapi.py +++ b/tests/test_video_proxy_and_newapi.py @@ -1,4 +1,5 @@ import asyncio +import hashlib import tempfile import time import unittest @@ -149,6 +150,7 @@ def test_completed_video_job_returns_newapi_metadata_url(self) -> None: progress=100, completed_at=int(time.time()), content_path=str(path), + content_file_id="08b8238c88e9bbe8423c692cbd04ec52", ) with patch.object(video, "get_config", return_value=_AppConfig(app_url="https://api.example.com")): @@ -156,7 +158,72 @@ def test_completed_video_job_returns_newapi_metadata_url(self) -> None: self.assertEqual( payload["metadata"]["url"], - "https://api.example.com/v1/videos/video_test/content", + "https://api.example.com/v1/files/video?id=08b8238c88e9bbe8423c692cbd04ec52", + ) + + def test_video_job_saves_with_sync_file_id_and_metadata_url(self) -> None: + async def _run(tmpdir: str): + artifact = video._VideoArtifact( + video_url="https://assets.grok.com/users/u/generated.mp4", + video_post_id="post_1", + asset_id="asset_1", + thumbnail_url="https://assets.grok.com/thumb.jpg", + ) + raw = b"fake-video" + saved_ids: list[str] = [] + job = video._VideoJob( + id="video_async", + model="grok-imagine-video", + prompt="hello", + seconds="6", + size="720x1280", + quality="standard", + created_at=int(time.time()), + ) + + async def _fake_run_video_with_account(*, model, runner): + return artifact, raw + + def _fake_save_video_bytes(raw_bytes: bytes, file_id: str) -> Path: + saved_ids.append(file_id) + path = Path(tmpdir) / f"{file_id}.mp4" + path.write_bytes(raw_bytes) + return path + + with patch.object( + video, "_run_video_with_account", _fake_run_video_with_account + ): + with patch.object(video, "_save_video_bytes", _fake_save_video_bytes): + await video._run_video_job( + job, + size="720x1280", + resolution_name=None, + prompt="hello", + seconds=6, + preset=None, + ) + + with patch.object( + video, + "get_config", + return_value=_AppConfig(app_url="https://api.example.com"), + ): + payload = job.to_dict() + + return job, payload, saved_ids + + with tempfile.TemporaryDirectory() as tmpdir: + job, payload, saved_ids = asyncio.run(_run(tmpdir)) + + expected_file_id = hashlib.sha1( + "https://assets.grok.com/users/u/generated.mp4".encode("utf-8") + ).hexdigest()[:32] + self.assertEqual(saved_ids, [expected_file_id]) + self.assertEqual(job.content_file_id, expected_file_id) + self.assertTrue(job.content_path.endswith(f"{expected_file_id}.mp4")) + self.assertEqual( + payload["metadata"]["url"], + f"https://api.example.com/v1/files/video?id={expected_file_id}", ) def test_video_pool_candidates_exclude_basic(self) -> None: From 3e39e5b7d60b29e4c2630ffb4b0a6d0ad0e5b8aa Mon Sep 17 00:00:00 2001 From: yangyang Date: Sat, 9 May 2026 22:07:25 +0800 Subject: [PATCH 10/14] =?UTF-8?q?fix:=20auto=E5=AF=BC=E5=85=A5pool?= =?UTF-8?q?=E6=8E=A8=E6=96=AD=E9=94=99=E8=AF=AF=20&=20=E8=A7=86=E9=A2=91?= =?UTF-8?q?=E6=B5=81=E5=BC=8F=E7=A9=BA=E5=9B=9E=E4=B8=8D=E8=BF=94429?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. quota_defaults: _SUPPORTED_MODE_IDS_BY_POOL 新增 "auto" 键(全集0-4), 修复 pool="auto" 导入时只 fetch fast 模式导致 infer_pool 把 super/heavy 误判为 basic 的问题。 2. router: chat/completions 视频流式结果不再跳过 _prime_sse, 空生成器会正确抛出 429 而非静默返回 HTTP 200 空 SSE。 Co-Authored-By: Claude Opus 4.6 --- app/control/account/quota_defaults.py | 1 + app/products/openai/router.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/app/control/account/quota_defaults.py b/app/control/account/quota_defaults.py index eb3250c8d..a8e913316 100644 --- a/app/control/account/quota_defaults.py +++ b/app/control/account/quota_defaults.py @@ -75,6 +75,7 @@ def _w(remaining: int, total: int, window_seconds: int) -> QuotaWindow: "basic": frozenset((1,)), "super": frozenset((0, 1, 2, 4)), "heavy": frozenset((0, 1, 2, 3, 4)), + "auto": frozenset((0, 1, 2, 3, 4)), # 探测用全集,确保 infer_pool 能拿到 auto.total } # --------------------------------------------------------------------------- diff --git a/app/products/openai/router.py b/app/products/openai/router.py index a8c1102b4..4e90f65fb 100644 --- a/app/products/openai/router.py +++ b/app/products/openai/router.py @@ -359,7 +359,7 @@ async def _err_stream(): if isinstance(result, dict): return JSONResponse(result) - if not (spec.is_image_edit() or spec.is_image() or spec.is_video()): + if not (spec.is_image_edit() or spec.is_image()): result = await _prime_sse(result) return StreamingResponse( _safe_sse(result), media_type="text/event-stream", headers=_SSE_HEADERS From 52d27de3885b8a376e9ba92ac94df50496a9984d Mon Sep 17 00:00:00 2001 From: pigeonxian <470688162@qq.com> Date: Mon, 11 May 2026 15:26:24 +0000 Subject: [PATCH 11/14] feat: support metadata configs and video duration limits --- README.md | 29 +++++++-- app/products/openai/router.py | 51 ++++++++++++++-- app/products/openai/schemas.py | 1 + app/products/openai/video.py | 51 ++++++++++++---- docs/README.en.md | 33 ++++++++--- tests/test_chat_metadata_config.py | 84 +++++++++++++++++++++++++++ tests/test_video_reference_helpers.py | 61 +++++++++++++++++++ 7 files changed, 280 insertions(+), 30 deletions(-) create mode 100644 tests/test_chat_metadata_config.py diff --git a/README.md b/README.md index 64f1e5226..29a6131ed 100644 --- a/README.md +++ b/README.md @@ -363,15 +363,31 @@ curl http://localhost:8000/v1/chat/completions \ "model": "grok-imagine-video", "stream": true, "messages": [ - {"role":"user","content":"霓虹雨夜街头,电影感慢镜头追拍"} + {"role":"user","content":"霓虹雨夜街头,电影感慢镜头追拍"}, + {"role":"user","content":"镜头穿过霓虹招牌,人物回头看向远处车灯"} ], - "video_config": { + "metadata": { + "video_config": { + "seconds": 16, + "size": "1792x1024", + "resolution_name": "720p", + "preset": "normal" + } + } + }' +``` + +`image_config` / `video_config` 也可以放在请求顶层;顶层字段优先于 `metadata` 内的同名配置。 + +```json +{ + "video_config": { "seconds": 10, "size": "1792x1024", "resolution_name": "720p", "preset": "normal" - } - }' + } +} ```
@@ -391,10 +407,11 @@ curl http://localhost:8000/v1/chat/completions \ | \|_ `size` | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | | \|_ `response_format` | `url`, `b64_json` | | `video_config` | 视频模型参数 | -| \|_ `seconds` | `6`, `10`, `12`, `16`, `20` | +| \|_ `seconds` | `6`, `10`, `12`, `16`, `20`, `22`, `26`, `30`, `32`, `36`, `40`;长视频需要按分段数量提供同数量 user 提示词 | | \|_ `size` | `720x1280`, `1280x720`, `1024x1024`, `1024x1792`, `1792x1024` | | \|_ `resolution_name` | `480p`, `720p` | | \|_ `preset` | `fun`, `normal`, `spicy`, `custom` | +| `metadata.image_config` / `metadata.video_config` | `image_config` / `video_config` 的兼容位置;顶层配置优先 |
@@ -590,7 +607,7 @@ curl -L http://localhost:8000/v1/videos//content \ | :-- | :-- | | `model` | 视频模型,目前为 `grok-imagine-video` | | `prompt` | 视频生成提示词 | -| `seconds` | 视频长度:`6`, `10`, `12`, `16`, `20` | +| `seconds` | 视频长度:`6`, `10`;更长视频请使用 `/v1/chat/completions` | | `size` | 支持 `720x1280`, `1280x720`, `1024x1024`, `1024x1792`, `1792x1024` | | `resolution_name` | `480p` 或 `720p` | | `preset` | `fun`, `normal`, `spicy`, `custom` | diff --git a/app/products/openai/router.py b/app/products/openai/router.py index 4e90f65fb..731654dff 100644 --- a/app/products/openai/router.py +++ b/app/products/openai/router.py @@ -3,11 +3,12 @@ import base64 import binascii import mimetypes -from typing import Annotated, AsyncGenerator, AsyncIterable, Literal +from typing import Annotated, AsyncGenerator, AsyncIterable, Literal, TypeVar import orjson from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile from fastapi.responses import JSONResponse, StreamingResponse, FileResponse +from pydantic import ValidationError as PydanticValidationError from app.control.account.state_machine import is_manageable from app.platform.auth.middleware import verify_api_key @@ -163,6 +164,48 @@ async def _primed() -> AsyncGenerator[str, None]: _ALLOWED_SIZES = {"1280x720", "720x1280", "1792x1024", "1024x1792", "1024x1024"} _EFFORT_VALUES = {"none", "minimal", "low", "medium", "high", "xhigh"} _LITE_IMAGE_MODELS = {"grok-imagine-image-lite"} +_ConfigT = TypeVar("_ConfigT", ImageConfig, VideoConfig) + + +def _metadata_config( + req: ChatCompletionRequest, + key: str, + model_cls: type[_ConfigT], +) -> _ConfigT | None: + metadata = req.metadata + if not isinstance(metadata, dict) or key not in metadata: + return None + value = metadata.get(key) + if value is None: + return None + if not isinstance(value, dict): + raise ValidationError( + f"metadata.{key} must be an object", + param=f"metadata.{key}", + ) + try: + return model_cls.model_validate(value) + except PydanticValidationError as exc: + raise ValidationError( + f"metadata.{key} is invalid: {exc.errors()[0].get('msg', 'invalid value')}", + param=f"metadata.{key}", + ) from exc + + +def _resolve_image_config(req: ChatCompletionRequest) -> ImageConfig: + return ( + req.image_config + or _metadata_config(req, "image_config", ImageConfig) + or ImageConfig() + ) + + +def _resolve_video_config(req: ChatCompletionRequest) -> VideoConfig: + return ( + req.video_config + or _metadata_config(req, "video_config", VideoConfig) + or VideoConfig() + ) def _validate_chat(req: ChatCompletionRequest) -> None: @@ -256,7 +299,7 @@ async def chat_completions_endpoint(req: ChatCompletionRequest): if spec.is_image_edit(): from .images import edit as img_edit - cfg = req.image_config or ImageConfig() + cfg = _resolve_image_config(req) _validate_image_edit_n(cfg.n or 1, param="image_config.n") result = await img_edit( model=req.model, @@ -271,7 +314,7 @@ async def chat_completions_endpoint(req: ChatCompletionRequest): elif spec.is_image(): from .images import generate as img_gen - cfg = req.image_config or ImageConfig() + cfg = _resolve_image_config(req) size = cfg.size or "1024x1024" fmt = cfg.response_format or "url" n = cfg.n or 1 @@ -300,7 +343,7 @@ async def chat_completions_endpoint(req: ChatCompletionRequest): elif spec.is_video(): from .video import completions as vid_comp - vcfg = req.video_config or VideoConfig() + vcfg = _resolve_video_config(req) from .video import validate_video_length as _validate_video_length _validate_video_length(vcfg.seconds or 6) diff --git a/app/products/openai/schemas.py b/app/products/openai/schemas.py index f4caa0fee..a42ff1c2e 100644 --- a/app/products/openai/schemas.py +++ b/app/products/openai/schemas.py @@ -39,6 +39,7 @@ class ChatCompletionRequest(BaseModel): tool_choice: str | dict[str, Any] | None = None parallel_tool_calls: bool | None = True max_tokens: int | None = None + metadata: dict[str, Any] | None = None class ImageGenerationRequest(BaseModel): diff --git a/app/products/openai/video.py b/app/products/openai/video.py index 88c6fed1c..1240d0d7b 100644 --- a/app/products/openai/video.py +++ b/app/products/openai/video.py @@ -74,7 +74,21 @@ _VIDEO_EXTENSION_REF_TYPE = "ORIGINAL_REF_TYPE_VIDEO_EXTENSION" _VIDEO_MAX_REFERENCES = 7 _VIDEO_ALLOWED_POOL_IDS = frozenset((1, 2)) # super / heavy only -_SUPPORTED_VIDEO_LENGTHS = frozenset({6, 10, 12, 16, 20}) +_SUPPORTED_VIDEO_LENGTHS = frozenset({6, 10, 12, 16, 20, 22, 26, 30, 32, 36, 40}) +_ASYNC_VIDEO_LENGTHS = frozenset({6, 10}) +_VIDEO_SEGMENT_LENGTHS: dict[int, list[int]] = { + 6: [6], + 10: [10], + 12: [6, 6], + 16: [10, 6], + 20: [10, 10], + 22: [10, 6, 6], + 26: [10, 10, 6], + 30: [10, 10, 10], + 32: [10, 10, 6, 6], + 36: [10, 10, 10, 6], + 40: [10, 10, 10, 10], +} _VIDEO_SIZE_MAP: dict[str, tuple[str, str]] = { "720x1280": ("9:16", "720p"), "1280x720": ("16:9", "720p"), @@ -194,6 +208,16 @@ def validate_video_length(seconds: int) -> None: raise ValidationError(f"seconds must be one of [{allowed}]", param="seconds") +def validate_async_video_length(seconds: int) -> None: + if seconds not in _ASYNC_VIDEO_LENGTHS: + allowed = ", ".join(str(item) for item in sorted(_ASYNC_VIDEO_LENGTHS)) + raise ValidationError( + f"seconds must be one of [{allowed}] for /v1/videos; " + "use /v1/chat/completions for longer videos", + param="seconds", + ) + + def _resolve_video_size(size: str) -> tuple[str, str]: normalized = (size or "720x1280").strip() config = _VIDEO_SIZE_MAP.get(normalized) @@ -221,16 +245,9 @@ def _resolve_video_preset(value: str | None, *, default: str = "custom") -> str: def _build_segment_lengths(seconds: int) -> list[int]: - if seconds == 6: - return [6] - if seconds == 10: - return [10] - if seconds == 12: - return [6, 6] - if seconds == 16: - return [10, 6] - if seconds == 20: - return [10, 10] + segments = _VIDEO_SEGMENT_LENGTHS.get(seconds) + if segments is not None: + return list(segments) validate_video_length(seconds) raise AssertionError("unreachable") @@ -239,10 +256,18 @@ def _normalize_segment_prompts( prompt: str, segment_lengths: list[int], segment_prompts: list[str] | None = None, + *, + strict: bool = False, ) -> list[str]: prompts = [p.strip() for p in (segment_prompts or [prompt]) if p and p.strip()] if not prompts: raise ValidationError("Video prompt cannot be empty", param="messages") + if strict and len(prompts) != len(segment_lengths): + raise ValidationError( + "Video generation requires exactly " + f"{len(segment_lengths)} prompt segment(s) for this duration", + param="messages", + ) if len(prompts) > len(segment_lengths): raise ValidationError( f"Video generation uses {len(segment_lengths)} prompt segment(s) for this duration", @@ -1040,7 +1065,7 @@ async def create_video( raise ValidationError("prompt cannot be empty", param="prompt") normalized_seconds = _coerce_seconds(seconds) - validate_video_length(normalized_seconds) + validate_async_video_length(normalized_seconds) normalized_size = (size or "720x1280").strip() _aspect_ratio, default_resolution_name = _resolve_video_size(normalized_size) _resolve_video_resolution_name(resolution_name, default=default_resolution_name) @@ -1190,6 +1215,7 @@ async def completions( segment_prompts[0], segments, segment_prompts, + strict=True, ) prompt = normalized_segment_prompts[0] @@ -1274,6 +1300,7 @@ async def _progress(progress: int) -> None: "retrieve", "content_path", "validate_video_length", + "validate_async_video_length", "completions", "_build_segment_lengths", "_empty_video_response_error", diff --git a/docs/README.en.md b/docs/README.en.md index f10680286..2968baf61 100644 --- a/docs/README.en.md +++ b/docs/README.en.md @@ -362,17 +362,33 @@ curl http://localhost:8000/v1/chat/completions \ "model": "grok-imagine-video", "stream": true, "messages": [ - {"role":"user","content":"A neon rainy street at night, cinematic slow tracking shot"} + {"role":"user","content":"A neon rainy street at night, cinematic slow tracking shot"}, + {"role":"user","content":"The camera passes under glowing signs as the character turns toward distant headlights"} ], - "video_config": { - "seconds": 10, - "size": "1792x1024", - "resolution_name": "720p", - "preset": "normal" + "metadata": { + "video_config": { + "seconds": 16, + "size": "1792x1024", + "resolution_name": "720p", + "preset": "normal" + } } }' ``` +`image_config` / `video_config` may also be sent at the request top level; top-level fields take precedence over matching `metadata` fields. + +```json +{ + "video_config": { + "seconds": 10, + "size": "1792x1024", + "resolution_name": "720p", + "preset": "normal" + } +} +``` +
Field Notes
@@ -390,10 +406,11 @@ curl http://localhost:8000/v1/chat/completions \ | \|_ `size` | `1280x720`, `720x1280`, `1792x1024`, `1024x1792`, `1024x1024` | | \|_ `response_format` | `url`, `b64_json` | | `video_config` | Video model parameters | -| \|_ `seconds` | `6`, `10`, `12`, `16`, `20` | +| \|_ `seconds` | `6`, `10`, `12`, `16`, `20`, `22`, `26`, `30`, `32`, `36`, `40`; longer videos require one user prompt per segment | | \|_ `size` | `720x1280`, `1280x720`, `1024x1024`, `1024x1792`, `1792x1024` | | \|_ `resolution_name` | `480p`, `720p` | | \|_ `preset` | `fun`, `normal`, `spicy`, `custom` | +| `metadata.image_config` / `metadata.video_config` | Compatibility location for `image_config` / `video_config`; top-level config wins |
@@ -589,7 +606,7 @@ curl -L http://localhost:8000/v1/videos//content \ | :-- | :-- | | `model` | Video model, currently `grok-imagine-video` | | `prompt` | Video generation prompt | -| `seconds` | Video length: `6`, `10`, `12`, `16`, `20` | +| `seconds` | Video length: `6`, `10`; use `/v1/chat/completions` for longer videos | | `size` | Supports `720x1280`, `1280x720`, `1024x1024`, `1024x1792`, `1792x1024` | | `resolution_name` | `480p` or `720p` | | `preset` | `fun`, `normal`, `spicy`, `custom` | diff --git a/tests/test_chat_metadata_config.py b/tests/test_chat_metadata_config.py new file mode 100644 index 000000000..9ec196727 --- /dev/null +++ b/tests/test_chat_metadata_config.py @@ -0,0 +1,84 @@ +import unittest + +from app.platform.errors import ValidationError +from app.products.openai import router +from app.products.openai.schemas import ChatCompletionRequest, ImageConfig, VideoConfig + + +def _request(**kwargs) -> ChatCompletionRequest: + payload = { + "model": "grok-imagine-video", + "messages": [{"role": "user", "content": "hello"}], + } + payload.update(kwargs) + return ChatCompletionRequest.model_validate(payload) + + +class ChatMetadataConfigTests(unittest.TestCase): + def test_metadata_video_config_is_used_as_fallback(self) -> None: + req = _request( + metadata={ + "video_config": { + "seconds": 16, + "size": "1792x1024", + "resolution_name": "720p", + "preset": "normal", + } + } + ) + + cfg = router._resolve_video_config(req) + + self.assertEqual(cfg.seconds, 16) + self.assertEqual(cfg.size, "1792x1024") + self.assertEqual(cfg.resolution_name, "720p") + self.assertEqual(cfg.preset, "normal") + + def test_top_level_video_config_wins_over_metadata(self) -> None: + req = _request( + video_config=VideoConfig(seconds=6, size="720x1280"), + metadata={"video_config": {"seconds": 16, "size": "1792x1024"}}, + ) + + cfg = router._resolve_video_config(req) + + self.assertEqual(cfg.seconds, 6) + self.assertEqual(cfg.size, "720x1280") + + def test_metadata_image_config_is_used_as_fallback(self) -> None: + req = _request( + metadata={ + "image_config": { + "n": 3, + "size": "1792x1024", + "response_format": "url", + } + } + ) + + cfg = router._resolve_image_config(req) + + self.assertEqual(cfg.n, 3) + self.assertEqual(cfg.size, "1792x1024") + self.assertEqual(cfg.response_format, "url") + + def test_top_level_image_config_wins_over_metadata(self) -> None: + req = _request( + image_config=ImageConfig(n=1, size="1024x1024"), + metadata={"image_config": {"n": 3, "size": "1792x1024"}}, + ) + + cfg = router._resolve_image_config(req) + + self.assertEqual(cfg.n, 1) + self.assertEqual(cfg.size, "1024x1024") + + def test_metadata_config_must_be_object(self) -> None: + req = _request(metadata={"video_config": "bad"}) + + with self.assertRaises(ValidationError): + router._resolve_video_config(req) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_video_reference_helpers.py b/tests/test_video_reference_helpers.py index fcb4c6706..e47c1ad33 100644 --- a/tests/test_video_reference_helpers.py +++ b/tests/test_video_reference_helpers.py @@ -81,10 +81,71 @@ def test_segment_prompts_reuse_last_prompt_when_short(self) -> None: ["第一段", "第一段"], ) + def test_segment_prompts_strict_rejects_missing_prompts(self) -> None: + with self.assertRaises(ValidationError): + video._normalize_segment_prompts( + "第一段", + [10, 6], + ["第一段"], + strict=True, + ) + def test_segment_prompts_reject_extra_prompts(self) -> None: with self.assertRaises(ValidationError): video._normalize_segment_prompts("一", [6], ["一", "二"]) + def test_video_segment_lengths_prefer_ten_second_segments(self) -> None: + cases = { + 6: [6], + 10: [10], + 12: [6, 6], + 16: [10, 6], + 20: [10, 10], + 22: [10, 6, 6], + 26: [10, 10, 6], + 30: [10, 10, 10], + 32: [10, 10, 6, 6], + 36: [10, 10, 10, 6], + 40: [10, 10, 10, 10], + } + + for seconds, expected in cases.items(): + with self.subTest(seconds=seconds): + self.assertEqual(video._build_segment_lengths(seconds), expected) + + def test_async_video_length_rejects_long_videos(self) -> None: + video.validate_async_video_length(6) + video.validate_async_video_length(10) + with self.assertRaises(ValidationError): + video.validate_async_video_length(12) + + def test_async_video_create_rejects_long_videos(self) -> None: + async def _run() -> None: + await video.create_video( + model="grok-imagine-video", + prompt="长视频", + seconds=12, + ) + + with self.assertRaises(ValidationError): + asyncio.run(_run()) + + def test_chat_video_length_rejects_non_exact_duration(self) -> None: + with self.assertRaises(ValidationError): + video.validate_video_length(28) + + def test_chat_video_completions_requires_prompt_for_each_segment(self) -> None: + async def _run() -> None: + await video.completions( + model="grok-imagine-video", + messages=[{"role": "user", "content": "第一段"}], + stream=False, + seconds=16, + ) + + with self.assertRaises(ValidationError): + asyncio.run(_run()) + if __name__ == "__main__": unittest.main() From fefcfb1ec35cd8b1ed59226a9b3acbcff3c7c2ef Mon Sep 17 00:00:00 2001 From: dongxuelian1010 <234438803@qq.com> Date: Sat, 30 May 2026 23:52:27 +0800 Subject: [PATCH 12/14] =?UTF-8?q?@=20feat:=20=E6=8C=89=20account=5Fid=20?= =?UTF-8?q?=E5=8E=BB=E9=87=8D=20+=20Grok=20=E5=AE=98=E6=96=B9=E8=B4=A6?= =?UTF-8?q?=E5=8F=B7=E7=B1=BB=E5=9E=8B=E6=9F=A5=E8=AF=A2=20+=20Cookie=20?= =?UTF-8?q?=E6=B3=A8=E5=85=A5=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 核心变更: - 添加 account_id (xaiUserId) 字段到 AccountRecord, 实现基于账号唯一 ID 的去重, 相同 xaiUserId 的多个 SSO token 自动合并, 只保留最新一个 - 新增 xai_subscription.py 协议模块, 调用 /rest/subscriptions 获取官方账号类型 (SUBSCRIPTION_TIER_GROK_PRO → super, GROK_PRO_HEAVY → heavy) - 导入时自动探测订阅层级和 account_id, 配合 rate-limits API 交叉验证 - 三个存储后端 (SQL/Redis/SQLite) 全部支持 account_id 列和去重逻辑 - Redis 后端补全缺失的 quota_grok_4_3 字段 - 添加 /rest/subscriptions, /rest/products, /rest/modes 端点 - Admin tokens API 输出 account_id - 附带 inject_cookie.py 浏览器 Cookie 注入工具 Grok API 探索发现: - GET /rest/subscriptions → xaiUserId + tier + status - POST /rest/modes → 基于订阅的模式可用性 - GET /rest/products → 可用产品层级列表 Co-Authored-By: Claude Opus 4.8 @ --- app/control/account/backends/local.py | 55 +++- app/control/account/backends/redis.py | 51 +++- app/control/account/backends/sql.py | 47 ++++ app/control/account/commands.py | 2 + app/control/account/models.py | 1 + app/control/account/refresh.py | 77 +++++- .../reverse/protocol/xai_subscription.py | 259 ++++++++++++++++++ .../reverse/runtime/endpoint_table.py | 7 +- app/products/web/admin/tokens.py | 1 + scripts/inject_cookie.py | 198 +++++++++++++ 10 files changed, 674 insertions(+), 24 deletions(-) create mode 100644 app/dataplane/reverse/protocol/xai_subscription.py create mode 100644 scripts/inject_cookie.py diff --git a/app/control/account/backends/local.py b/app/control/account/backends/local.py index 85eb3971f..ffd7f6274 100644 --- a/app/control/account/backends/local.py +++ b/app/control/account/backends/local.py @@ -55,6 +55,7 @@ def _init_sync(self) -> None: CREATE TABLE IF NOT EXISTS {_TBL} ( token TEXT NOT NULL PRIMARY KEY, + account_id TEXT, pool TEXT NOT NULL DEFAULT 'basic', status TEXT NOT NULL DEFAULT 'active', created_at INTEGER NOT NULL, @@ -85,7 +86,13 @@ def _init_sync(self) -> None: CREATE INDEX IF NOT EXISTS idx_acc_deleted ON {_TBL} (deleted_at) WHERE deleted_at IS NOT NULL; """) + # Migration: add account_id column if missing. + self._ensure_column_sync(conn, "account_id", "TEXT") self._ensure_column_sync(conn, "quota_grok_4_3", "TEXT NOT NULL DEFAULT '{}'") + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_acc_account_id " + f"ON {_TBL} (account_id) WHERE account_id IS NOT NULL" + ) conn.commit() @staticmethod @@ -112,6 +119,7 @@ def _get_revision_sync(self, conn: sqlite3.Connection) -> int: @staticmethod def _row_to_record(row: sqlite3.Row) -> AccountRecord: d = dict(row) + d["account_id"] = d.get("account_id") or None d["tags"] = json.loads(d.get("tags") or "[]") heavy_raw = d.pop("quota_heavy", "{}") or "{}" grok_4_3_raw = d.pop("quota_grok_4_3", "{}") or "{}" @@ -132,6 +140,7 @@ def _record_to_row(record: AccountRecord, revision: int) -> dict[str, Any]: qs = record.quota_set() return { "token": record.token, + "account_id": record.account_id, "pool": record.pool, "status": record.status.value, "created_at": record.created_at, @@ -171,19 +180,36 @@ def _upsert_sync( continue pool = item.pool if item.pool in ("basic", "super", "heavy") else "basic" qs = default_quota_set(pool) + account_id = item.account_id or None + + # Dedup by account_id: soft-delete old token records with same account_id. + if account_id: + dups = conn.execute( + f"SELECT token FROM {_TBL} " + f"WHERE account_id = ? AND token != ? AND deleted_at IS NULL", + (account_id, token), + ).fetchall() + for dup in dups: + conn.execute( + f"UPDATE {_TBL} SET deleted_at = ?, updated_at = ?, revision = ? " + f"WHERE token = ?", + (ts, ts, revision, dup[0]), + ) + conn.execute( f""" INSERT INTO {_TBL} ( - token, pool, status, created_at, updated_at, + token, account_id, pool, status, created_at, updated_at, tags, quota_auto, quota_fast, quota_expert, quota_heavy, quota_grok_4_3, usage_use_count, usage_fail_count, usage_sync_count, ext, revision ) VALUES ( - :token, :pool, 'active', :ts, :ts, + :token, :account_id, :pool, 'active', :ts, :ts, :tags, :qa, :qf, :qe, :qh, :qg, 0, 0, 0, :ext, :rev ) ON CONFLICT(token) DO UPDATE SET + account_id = COALESCE(excluded.account_id, account_id), pool = excluded.pool, status = 'active', deleted_at = NULL, @@ -193,17 +219,18 @@ def _upsert_sync( revision = excluded.revision """, { - "token": token, - "pool": pool, - "ts": ts, - "tags": json.dumps(item.tags), - "qa": json.dumps(qs.auto.to_dict()), - "qf": json.dumps(qs.fast.to_dict()), - "qe": json.dumps(qs.expert.to_dict()), - "qh": json.dumps(qs.heavy.to_dict()) if qs.heavy else "{}", - "qg": json.dumps(qs.grok_4_3.to_dict()) if qs.grok_4_3 else "{}", - "ext": json.dumps(item.ext), - "rev": revision, + "token": token, + "account_id": account_id, + "pool": pool, + "ts": ts, + "tags": json.dumps(item.tags), + "qa": json.dumps(qs.auto.to_dict()), + "qf": json.dumps(qs.fast.to_dict()), + "qe": json.dumps(qs.expert.to_dict()), + "qh": json.dumps(qs.heavy.to_dict()) if qs.heavy else "{}", + "qg": json.dumps(qs.grok_4_3.to_dict()) if qs.grok_4_3 else "{}", + "ext": json.dumps(item.ext), + "rev": revision, }, ) count += conn.execute("SELECT changes()").fetchone()[0] @@ -229,6 +256,8 @@ def _patch_sync( sets: dict[str, Any] = {"updated_at": ts, "revision": revision} + if patch.account_id is not None: + sets["account_id"] = patch.account_id if patch.pool is not None: sets["pool"] = patch.pool if patch.status is not None: diff --git a/app/control/account/backends/redis.py b/app/control/account/backends/redis.py index 68c4d4b49..3c3aece09 100644 --- a/app/control/account/backends/redis.py +++ b/app/control/account/backends/redis.py @@ -55,6 +55,7 @@ def __init__(self, redis: "Redis") -> None: def _to_hash(record: AccountRecord, revision: int) -> dict[str, str]: qs = record.quota_set() return { + "account_id": record.account_id or "", "pool": record.pool, "status": record.status.value, "created_at": str(record.created_at), @@ -64,6 +65,7 @@ def _to_hash(record: AccountRecord, revision: int) -> dict[str, str]: "quota_fast": json.dumps(qs.fast.to_dict()), "quota_expert": json.dumps(qs.expert.to_dict()), "quota_heavy": json.dumps(qs.heavy.to_dict()) if qs.heavy else "{}", + "quota_grok_4_3": json.dumps(qs.grok_4_3.to_dict()) if qs.grok_4_3 else "{}", "usage_use_count": str(record.usage_use_count), "usage_fail_count": str(record.usage_fail_count), "usage_sync_count": str(record.usage_sync_count), @@ -88,8 +90,10 @@ def _i(k: str) -> int | None: v = _s(k) return int(v) if v else None + aid = _s("account_id") return AccountRecord.model_validate({ "token": token, + "account_id": aid if aid else None, "pool": _s("pool") or "basic", "status": _s("status") or "active", "created_at": _i("created_at") or now_ms(), @@ -102,6 +106,9 @@ def _i(k: str) -> int | None: **({ "heavy": json.loads(_s("quota_heavy")) } if _s("quota_heavy") and _s("quota_heavy") != "{}" else {}), + **({ + "grok_4_3": json.loads(_s("quota_grok_4_3")) + } if _s("quota_grok_4_3") and _s("quota_grok_4_3") != "{}" else {}), }, "usage_use_count": int(_s("usage_use_count") or 0), "usage_fail_count": int(_s("usage_fail_count") or 0), @@ -207,14 +214,40 @@ async def upsert_accounts( pool = item.pool if item.pool in ("basic", "super", "heavy") else "basic" qs = default_quota_set(pool) ts = now_ms() + account_id = item.account_id or None + + # Dedup by account_id: scan for existing records with same account_id. + if account_id: + async for key in self._r.scan_iter("accounts:record:*"): + h = await self._r.hgetall(key) + if not h: + continue + existing_aid = (h.get(b"account_id") or h.get("account_id") or b"") + if isinstance(existing_aid, bytes): + existing_aid = existing_aid.decode() + if existing_aid == account_id: + dup_token = (key.decode() if isinstance(key, bytes) else key).split(":", 2)[-1] + if dup_token != token: + await self._r.hset(key, mapping={ + "deleted_at": str(ts), + "updated_at": str(ts), + "revision": str(rev), + }) + # Remove from pool set and add to deleted pool tracking. + dup_pool = (h.get(b"pool") or h.get("pool") or b"basic") + if isinstance(dup_pool, bytes): + dup_pool = dup_pool.decode() + await self._r.srem(_pool_key(dup_pool), dup_token) + record = AccountRecord( - token = token, - pool = pool, - tags = item.tags, - ext = item.ext, - quota = qs.to_dict(), - created_at = ts, - updated_at = ts, + token = token, + account_id = account_id, + pool = pool, + tags = item.tags, + ext = item.ext, + quota = qs.to_dict(), + created_at = ts, + updated_at = ts, ) key = _record_key(token) await self._r.hset(key, mapping=self._to_hash(record, rev)) @@ -258,6 +291,8 @@ async def patch_accounts( updates["last_sync_at"] = str(patch.last_sync_at) if patch.last_clear_at is not None: updates["last_clear_at"] = str(patch.last_clear_at) + if patch.account_id is not None: + updates["account_id"] = patch.account_id if patch.pool is not None: updates["pool"] = patch.pool if patch.quota_auto is not None: @@ -268,6 +303,8 @@ async def patch_accounts( updates["quota_expert"] = json.dumps(patch.quota_expert) if patch.quota_heavy is not None: updates["quota_heavy"] = json.dumps(patch.quota_heavy) + if patch.quota_grok_4_3 is not None: + updates["quota_grok_4_3"] = json.dumps(patch.quota_grok_4_3) # Usage counters. if patch.usage_use_delta is not None: diff --git a/app/control/account/backends/sql.py b/app/control/account/backends/sql.py index 66b605bb6..aca46705a 100644 --- a/app/control/account/backends/sql.py +++ b/app/control/account/backends/sql.py @@ -37,6 +37,7 @@ _TBL_ACCOUNTS, metadata, sa.Column("token", sa.String(512), primary_key=True), + sa.Column("account_id", sa.String(64), nullable=True, index=True), # xaiUserId UUID for dedup sa.Column("pool", sa.Text, nullable=False, default="basic"), sa.Column("status", sa.Text, nullable=False, default="active"), sa.Column("created_at", sa.BigInteger, nullable=False), @@ -404,6 +405,8 @@ def _evict_cached_engine(engine: AsyncEngine) -> None: def _row_to_record(row: Any) -> AccountRecord: d = dict(row._mapping) + # Normalise account_id — stored as string, can be empty. + d["account_id"] = d.get("account_id") or None d["tags"] = json.loads(d.get("tags") or "[]") heavy_raw = d.pop("quota_heavy", "{}") or "{}" grok_4_3_raw = d.pop("quota_grok_4_3", "{}") or "{}" @@ -522,6 +525,23 @@ async def _do_initialize(self) -> None: async def _ensure_columns(self, conn: Any) -> None: """Idempotent ALTER TABLE migrations for columns added after the initial schema.""" existing = await self._table_columns(conn, _TBL_ACCOUNTS) + if "account_id" not in existing: + if self._dialect == "mysql": + await conn.exec_driver_sql( + f"ALTER TABLE {_TBL_ACCOUNTS} " + f"ADD COLUMN account_id VARCHAR(64) NULL" + ) + await conn.exec_driver_sql( + f"CREATE INDEX idx_accounts_account_id ON {_TBL_ACCOUNTS} (account_id)" + ) + else: + await conn.exec_driver_sql( + f"ALTER TABLE {_TBL_ACCOUNTS} " + f"ADD COLUMN account_id VARCHAR(64)" + ) + await conn.exec_driver_sql( + f"CREATE INDEX idx_accounts_account_id ON {_TBL_ACCOUNTS} (account_id)" + ) if "quota_grok_4_3" not in existing: if self._dialect == "mysql": # MySQL forbids DEFAULT values on TEXT/BLOB columns; @@ -629,8 +649,10 @@ async def upsert_accounts( continue pool = item.pool if item.pool in ("basic", "super", "heavy") else "basic" qs = default_quota_set(pool) + account_id = item.account_id or None row = { "token": token, + "account_id": account_id, "pool": pool, "status": "active", "created_at": ts, @@ -648,6 +670,29 @@ async def upsert_accounts( "ext": json.dumps(item.ext), "revision": rev, } + # Dedup by account_id: if this account_id already exists under a + # *different* token, soft-delete the old token record so we + # maintain one canonical token per xaiUserId. + if account_id: + existing = (await conn.execute( + sa.select(accounts_table.c.token).where( + accounts_table.c.account_id == account_id, + accounts_table.c.token != token, + accounts_table.c.deleted_at.is_(None), + ) + )).fetchall() + for dup in existing: + dup_token = dup[0] + await conn.execute( + accounts_table.update() + .where(accounts_table.c.token == dup_token) + .values(deleted_at=ts, updated_at=ts, revision=rev) + ) + logger.info( + "account deduplicated by account_id: old_token={}... new_token={}... account_id={}", + dup_token[:10], token[:10], account_id, + ) + await conn.execute(self._build_upsert(row)) count += 1 return AccountMutationResult(upserted=count, revision=rev) @@ -672,6 +717,8 @@ async def patch_accounts( record = _row_to_record(row) updates: dict[str, Any] = {"updated_at": ts, "revision": rev} + if patch.account_id is not None: + updates["account_id"] = patch.account_id if patch.pool is not None: updates["pool"] = patch.pool if patch.status is not None: diff --git a/app/control/account/commands.py b/app/control/account/commands.py index ba5011432..c62d103dc 100644 --- a/app/control/account/commands.py +++ b/app/control/account/commands.py @@ -14,6 +14,7 @@ class AccountUpsert(BaseModel): """ token: str + account_id: str | None = None # Grok xaiUserId for deduplication pool: str = "basic" tags: list[str] = Field(default_factory=list) ext: dict[str, Any] = Field(default_factory=dict) @@ -26,6 +27,7 @@ class AccountPatch(BaseModel): """ token: str + account_id: str | None = None # update account ID (from subscription API) pool: str | None = None # update pool type when inferred status: AccountStatus | None = None tags: list[str] | None = None diff --git a/app/control/account/models.py b/app/control/account/models.py index cb8e8e17f..dc6288d50 100644 --- a/app/control/account/models.py +++ b/app/control/account/models.py @@ -176,6 +176,7 @@ class AccountRecord(BaseModel): """ token: str + account_id: str | None = None # Grok's xaiUserId — UUID, used for deduplication pool: str = "basic" status: AccountStatus = AccountStatus.ACTIVE created_at: int = Field(default_factory=now_ms) diff --git a/app/control/account/refresh.py b/app/control/account/refresh.py index f95a1baa6..1d8939583 100644 --- a/app/control/account/refresh.py +++ b/app/control/account/refresh.py @@ -2,7 +2,7 @@ import asyncio from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Any, TYPE_CHECKING from app.platform.errors import UpstreamError from app.platform.config.snapshot import get_config @@ -126,17 +126,87 @@ async def _fetch_mode_quota( ) return None + # ------------------------------------------------------------------ + # Subscription API fetch (account type + account_id) + # ------------------------------------------------------------------ + + async def _fetch_subscription(self, token: str): + """Fetch subscription info (account type + xaiUserId) for *token*.""" + try: + from app.dataplane.reverse.protocol.xai_subscription import ( + fetch_subscription, + ) + return await fetch_subscription(token) + except UpstreamError: + raise + except Exception as exc: + logger.debug( + "subscription fetch failed: token={}... error={}", + token[:10], exc, + ) + return None + + async def _refresh_subscription(self, record: AccountRecord) -> None: + """Fetch subscription info for an account and persist account_id + pool. + + Only called for accounts that are missing ``account_id`` or whose pool + needs verification from the subscription API. + """ + if record.is_deleted(): + return + try: + info = await self._fetch_subscription(record.token) + except UpstreamError as exc: + if await self._expire_invalid_credentials(record, exc): + return + raise + except Exception: + return + + if info is None: + return + + from .commands import AccountPatch + + patch_fields: dict[str, Any] = {} + if info.xai_user_id and not record.account_id: + patch_fields["account_id"] = info.xai_user_id + if info.pool and info.pool != record.pool and info.is_active: + patch_fields["pool"] = info.pool + + if patch_fields: + await self._repo.patch_accounts([ + AccountPatch(token=record.token, **patch_fields) + ]) + logger.info( + "subscription info updated: token={}... account_id={} pool={}", + record.token[:10], + info.xai_user_id or record.account_id, + patch_fields.get("pool", record.pool), + ) + # ------------------------------------------------------------------ # Core refresh logic # ------------------------------------------------------------------ async def refresh_on_import(self, tokens: list[str]) -> RefreshResult: - """Called after bulk import — sync real quotas for all accounts.""" + """Called after bulk import — sync real quotas and subscription info for all accounts.""" records = await self._repo.get_accounts(tokens) active = [r for r in records if is_manageable(r)] if not active: return RefreshResult(checked=len(records)) + # Phase 1: Fetch subscription info (account_id + tier) for accounts missing it. + sub_results = await run_batch( + [r for r in active if not r.account_id], + lambda r: self._refresh_subscription(r), + concurrency=get_config("account.refresh.usage_concurrency", 50), + ) + # Refresh the records after subscription updates. + records = await self._repo.get_accounts(tokens) + active = [r for r in records if is_manageable(r)] + + # Phase 2: Fetch quota windows. concurrency = get_config("account.refresh.usage_concurrency", 50) results = await run_batch( active, @@ -145,7 +215,8 @@ async def refresh_on_import(self, tokens: list[str]) -> RefreshResult: ) agg = RefreshResult(checked=len(records)) for r in results: - agg.merge(r) + if r: + agg.merge(r) return agg async def refresh_call_async(self, token: str, mode_id: int) -> None: diff --git a/app/dataplane/reverse/protocol/xai_subscription.py b/app/dataplane/reverse/protocol/xai_subscription.py new file mode 100644 index 000000000..e3753c641 --- /dev/null +++ b/app/dataplane/reverse/protocol/xai_subscription.py @@ -0,0 +1,259 @@ +"""XAI subscription / account-type protocol — fetch subscription tier and account ID. + +Provides the official account type query by calling ``GET /rest/subscriptions``. +Returns the user's subscription tier (basic / super / heavy) and unique account +ID (``xaiUserId``) for deduplication purposes. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from app.platform.errors import UpstreamError +from app.platform.logging.logger import logger + + +# --------------------------------------------------------------------------- +# Tier mapping — Grok subscription tiers → Grok2API pool names +# --------------------------------------------------------------------------- + +# Known subscription tier strings returned by grok.com/rest/subscriptions. +_TIER_TO_POOL: dict[str, str] = { + "SUBSCRIPTION_TIER_UNKNOWN": "basic", + "SUBSCRIPTION_TIER_FREE": "basic", + "SUBSCRIPTION_TIER_GROK_PRO": "super", + "SUBSCRIPTION_TIER_SUPER_GROK": "super", + "SUBSCRIPTION_TIER_SUPER_GROK_PRO": "heavy", + "SUBSCRIPTION_TIER_GROK_PRO_HEAVY": "heavy", + "SUBSCRIPTION_TIER_SUPER_GROK_LITE": "super", + "SUBSCRIPTION_TIER_GROK_TEAMS": "super", +} + +# Active subscription statuses. +_ACTIVE_STATUSES: frozenset[str] = frozenset({ + "SUBSCRIPTION_STATUS_ACTIVE", + "SUBSCRIPTION_STATUS_TRIAL", + "SUBSCRIPTION_STATUS_TRIALING", +}) + + +# --------------------------------------------------------------------------- +# Result type +# --------------------------------------------------------------------------- + + +@dataclass(slots=True) +class SubscriptionInfo: + """Parsed account-type information from the subscription API. + + ``xai_user_id`` — unique account identifier (UUID); ``None`` if unavailable. + ``pool`` — Grok2API pool name (``"basic"`` / ``"super"`` / ``"heavy"``). + ``tier`` — raw upstream tier string (e.g. ``"SUBSCRIPTION_TIER_GROK_PRO"``). + ``is_active`` — whether at least one subscription is currently active. + ``raw`` — the full parsed JSON dict for debugging / extension use. + """ + + xai_user_id: str | None + pool: str + tier: str + is_active: bool + raw: dict | None + + +def tier_to_pool(tier: str) -> str: + """Map an upstream subscription tier string to a Grok2API pool name.""" + return _TIER_TO_POOL.get(tier, "basic") + + +def is_active_status(status: str) -> bool: + """Return True if *status* represents an active subscription.""" + return status in _ACTIVE_STATUSES + + +# --------------------------------------------------------------------------- +# Response parser +# --------------------------------------------------------------------------- + + +def parse_subscription(body: dict) -> SubscriptionInfo: + """Parse the ``/rest/subscriptions`` response body. + + Expected format:: + + { + "subscriptions": [ + { + "xaiUserId": "uuid-here", + "tier": "SUBSCRIPTION_TIER_GROK_PRO", + "status": "SUBSCRIPTION_STATUS_ACTIVE", + ... + }, + ... + ] + } + + Returns a ``SubscriptionInfo`` populated from the most recent (last in + list) subscription entry. If the subscriptions array is empty, returns + a ``"basic"`` pool with no account ID. + """ + subs: list[dict] = body.get("subscriptions", []) if body else [] + + if not subs: + return SubscriptionInfo( + xai_user_id=None, + pool="basic", + tier="", + is_active=False, + raw=body, + ) + + # Use the *last* subscription (typically most recent). + last = subs[-1] + + xai_user_id = last.get("xaiUserId") or None + tier = last.get("tier", "") + status = last.get("status", "") + + pool = tier_to_pool(tier) + + # If the most recent subscription is not active but a later entry is, use + # the best active tier we can find. Also collect account ID from any entry. + if not xai_user_id: + for sub in subs: + xai_user_id = sub.get("xaiUserId") or xai_user_id + if xai_user_id: + break + + is_active = False + best_active_pool = pool + for sub in subs: + if is_active_status(sub.get("status", "")): + is_active = True + active_tier = sub.get("tier", "") + active_pool = tier_to_pool(active_tier) + if _pool_rank(active_pool) > _pool_rank(best_active_pool): + best_active_pool = active_pool + + if is_active: + pool = best_active_pool + + # If all subscriptions are inactive (expired/cancelled), the account + # reverts to basic tier on the Grok side, but we keep the historical tier + # for information and let the rate-limits API refine it. + if not is_active and pool != "basic": + logger.debug( + "subscription inactive — may have reverted to basic: tier={} status={}", + tier, status, + ) + + return SubscriptionInfo( + xai_user_id=xai_user_id, + pool=pool, + tier=tier, + is_active=is_active, + raw=body, + ) + + +def _pool_rank(pool: str) -> int: + """Return an ordinal rank for a pool name (higher = better).""" + return {"basic": 0, "super": 1, "heavy": 2}.get(pool, 0) + + +# --------------------------------------------------------------------------- +# HTTP fetch +# --------------------------------------------------------------------------- + + +async def _do_fetch(token: str) -> dict: + """GET the subscriptions endpoint and return parsed JSON body.""" + from app.dataplane.reverse.transport.http import get_json + from app.dataplane.proxy import get_proxy_runtime + from app.control.proxy.models import ProxyFeedback, ProxyFeedbackKind + from app.dataplane.reverse.runtime.endpoint_table import SUBSCRIPTIONS + + proxy = await get_proxy_runtime() + lease = await proxy.acquire() + try: + body = await get_json( + SUBSCRIPTIONS, + token, + lease=lease, + timeout_s=20.0, + ) + await proxy.feedback( + lease, ProxyFeedback(kind=ProxyFeedbackKind.SUCCESS, status_code=200) + ) + return body + except Exception as exc: + status = getattr(exc, "status", None) or getattr(exc, "status_code", None) + from app.dataplane.reverse.protocol.xai_usage import _proxy_feedback_kind_for_error + kind = _proxy_feedback_kind_for_error(exc, status=status) + await proxy.feedback(lease, ProxyFeedback(kind=kind, status_code=status)) + raise + + +async def fetch_subscription(token: str) -> SubscriptionInfo | None: + """Fetch account-type / subscription information for *token*. + + Returns a ``SubscriptionInfo``, or ``None`` if the endpoint is unreachable + (network timeout, 5xx, etc.). + """ + import asyncio + + try: + body = await asyncio.wait_for(_do_fetch(token), timeout=25.0) + except asyncio.TimeoutError: + logger.debug( + "subscription fetch timed out: token={}...", token[:10] + ) + return None + except UpstreamError as exc: + if getattr(exc, "status", None) in (401, 403): + # Invalid token — let caller handle. + raise + logger.debug( + "subscription fetch failed: token={}... status={}", + token[:10], exc.status if hasattr(exc, "status") else "?", + ) + return None + except Exception as exc: + logger.debug( + "subscription fetch error: token={}... error={}", + token[:10], exc, + ) + return None + + return parse_subscription(body) + + +async def fetch_subscription_for_import(token: str) -> SubscriptionInfo: + """Fetch subscription info during account import. + + Always returns a ``SubscriptionInfo`` — falls back to ``"basic"`` with + no account ID when the API is unreachable. + + Raises ``UpstreamError`` for 401/403 (invalid credentials). + """ + result = await fetch_subscription(token) + if result is not None: + return result + + # API unreachable — return a safe default. + return SubscriptionInfo( + xai_user_id=None, + pool="basic", + tier="", + is_active=False, + raw=None, + ) + + +__all__ = [ + "SubscriptionInfo", + "parse_subscription", + "fetch_subscription", + "fetch_subscription_for_import", + "tier_to_pool", + "is_active_status", +] diff --git a/app/dataplane/reverse/runtime/endpoint_table.py b/app/dataplane/reverse/runtime/endpoint_table.py index bb390e503..13da5fc41 100644 --- a/app/dataplane/reverse/runtime/endpoint_table.py +++ b/app/dataplane/reverse/runtime/endpoint_table.py @@ -23,6 +23,11 @@ # ── Rate limits (usage / quota sync) ───────────────────────────────────── RATE_LIMITS = f"{BASE}/rest/rate-limits" # POST +# ── Subscription / account type ───────────────────────────────────────── +SUBSCRIPTIONS = f"{BASE}/rest/subscriptions" # GET +PRODUCTS = f"{BASE}/rest/products" # GET ?provider=SUBSCRIPTION_PROVIDER_STRIPE +MODES = f"{BASE}/rest/modes" # POST + # ── gRPC-Web endpoints ────────────────────────────────────────────────── ACCEPT_TOS = "https://accounts.x.ai/auth_mgmt.AuthManagement/SetTosAcceptedVersion" NSFW_MGMT = f"{BASE}/auth_mgmt.AuthManagement/UpdateUserFeatureControls" @@ -47,7 +52,7 @@ "BASE", "ASSETS_CDN", "CHAT", "ASSETS_UPLOAD", "ASSETS_LIST", "ASSETS_DELETE", "ASSETS_DOWNLOAD", - "RATE_LIMITS", + "RATE_LIMITS", "SUBSCRIPTIONS", "PRODUCTS", "MODES", "ACCEPT_TOS", "NSFW_MGMT", "SET_BIRTH", "MEDIA_POST", "MEDIA_POST_LINK", "VIDEO_UPSCALE", "WS_IMAGINE", "WS_LIVEKIT", "LIVEKIT_TOKENS", diff --git a/app/products/web/admin/tokens.py b/app/products/web/admin/tokens.py index bf5bf10f6..70e535b5e 100644 --- a/app/products/web/admin/tokens.py +++ b/app/products/web/admin/tokens.py @@ -121,6 +121,7 @@ def _quota_brief(q: dict) -> dict: def _serialize_record(r) -> dict: return { "token": r.token, + "account_id": r.account_id, "pool": r.pool or "basic", "status": r.status, "quota": _quota_brief(r.quota) if isinstance(r.quota, dict) else {}, diff --git a/scripts/inject_cookie.py b/scripts/inject_cookie.py new file mode 100644 index 000000000..a78776559 --- /dev/null +++ b/scripts/inject_cookie.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +"""Grok cookie injector — set the SSO cookie in a browser and open Grok. + +Usage: + python inject_cookie.py + python inject_cookie.py --url https://grok.com + python inject_cookie.py --browser firefox + +The cookie value can be: + - A raw JWT: eyJ0eXAiOiJKV1Qi... + - With sso= prefix: sso=eyJ0eXAiOiJKV1Qi... + - A full cookie header: sso=eyJ...; Domain=.grok.com; Path=/ + +Requirements: + pip install browser-cookie3 (optional, for reading existing cookies) + +The script uses webbrowser to open the URL and prints instructions for +manual cookie injection into Chrome DevTools. +""" + +from __future__ import annotations + +import argparse +import json +import re +import sys +import webbrowser + +# ── JWT decode (no signature verification) ────────────────────────────── + + +def decode_jwt_payload(raw: str) -> dict: + import base64 + + parts = raw.strip().split(".") + if len(parts) != 3: + raise ValueError("Not a valid JWT (expected 3 segments)") + + payload_b64 = parts[1] + payload_b64 += "=" * (4 - len(payload_b64) % 4) + try: + payload = json.loads(base64.urlsafe_b64decode(payload_b64)) + except Exception as exc: + raise ValueError(f"Cannot decode JWT payload: {exc}") from exc + return payload + + +def parse_cookie(raw: str) -> tuple[str, str, dict]: + """Parse a Grok SSO cookie value. + + Returns (cookie_name, cookie_value, jwt_payload). + """ + text = raw.strip() + + # Strip browser export format: "Cookie: sso=..." + if text.lower().startswith("cookie:"): + text = text[len("cookie:"):].strip() + + # Find the JWT by looking for sso= prefix + cookie_value = text + if text.startswith("sso="): + cookie_value = text[4:] + if ";" in cookie_value: + cookie_value = cookie_value.split(";")[0].strip() + + # If no sso= prefix, assume raw JWT + cookie_value = cookie_value.strip() + payload = decode_jwt_payload(cookie_value) + return ("sso", cookie_value, payload) + + +# ── Browser injection helpers ─────────────────────────────────────────── + + +def inject_via_playwright(cookie_value: str, url: str) -> int: + """Use Playwright to inject the cookie and open the browser.""" + try: + from playwright.sync_api import sync_playwright + except ImportError: + print("[!] playwright not installed. Install with: pip install playwright && playwright install chromium") + return 1 + + with sync_playwright() as p: + browser = p.chromium.launch(headless=False) + context = browser.new_context() + context.add_cookies([ + { + "name": "sso", + "value": cookie_value, + "domain": ".grok.com", + "path": "/", + "httpOnly": False, + "secure": True, + "sameSite": "Lax", + } + ]) + page = context.new_page() + page.goto(url) + print(f"[+] Cookie injected — browser opened at {url}") + print("[+] Press Ctrl+C to close the browser...") + try: + page.wait_for_timeout(60_000) # 60s before auto-close + except KeyboardInterrupt: + pass + browser.close() + return 0 + + +def print_manual_instructions(cookie_value: str, url: str) -> None: + """Print manual DevTools injection instructions.""" + js = ( + "document.cookie = 'sso=" + cookie_value + "; " + "domain=.grok.com; path=/; SameSite=Lax; Secure';" + ) + + print(f""" +╔══════════════════════════════════════════════════════════════╗ +║ Grok Cookie Injection — Manual Instructions ║ +╠══════════════════════════════════════════════════════════════╣ +║ ║ +║ 1. Open {url} in your browser ║ +║ 2. Press F12 to open DevTools ║ +║ 3. Go to the Console tab ║ +║ 4. Paste this command: ║ +║ ║ +║ {js} +║ ║ +║ 5. Press Enter (the page will reload) ║ +║ 6. You should be logged in ║ +║ ║ +╠══════════════════════════════════════════════════════════════╣ +║ Alternative — Application tab: ║ +║ 1. F12 → Application → Cookies → grok.com ║ +║ 2. Add cookie: name=sso, value=, path=/ ║ +║ 3. Refresh the page ║ +║ ║ +╠══════════════════════════════════════════════════════════════╣ +║ For grok2api import: ║ +║ Use the token value directly: ║ +║ curl -X POST .../admin/api/tokens/add ║ +║ -d '{{"tokens": ["{cookie_value}"], "pool": "auto"}}' ║ +╚══════════════════════════════════════════════════════════════╝ +""") + + +# ── Main ───────────────────────────────────────────────────────────────── + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Inject a Grok SSO cookie into a browser and open Grok", + ) + parser.add_argument( + "cookie", + help="Grok SSO cookie value (JWT, optionally with 'sso=' prefix)", + ) + parser.add_argument( + "--url", default="https://grok.com", + help="Target URL (default: https://grok.com)", + ) + parser.add_argument( + "--manual", action="store_true", + help="Print manual DevTools injection instructions (no automation)", + ) + parser.add_argument( + "--playwright", action="store_true", + help="Use Playwright to automate cookie injection", + ) + args = parser.parse_args() + + try: + name, value, payload = parse_cookie(args.cookie) + except ValueError as exc: + print(f"[!] Cookie parse error: {exc}", file=sys.stderr) + return 1 + + session_id = payload.get("session_id", "unknown") + print(f"[+] Parsed cookie: name={name}, session_id={session_id}") + print(f"[+] Full token: {value[:40]}...{value[-20:]}") + + if args.playwright: + return inject_via_playwright(value, args.url) + + # Default: print instructions and offer to open browser. + print_manual_instructions(value, args.url) + + try: + choice = input("\n[?] Open browser now? (y/N): ").strip().lower() + if choice in ("y", "yes"): + webbrowser.open(args.url) + except (KeyboardInterrupt, EOFError): + pass + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) From dc0c599258d5f00ec4132e8d7d75b47fc3468df5 Mon Sep 17 00:00:00 2001 From: dongxuelian1010 <234438803@qq.com> Date: Sun, 31 May 2026 00:38:34 +0800 Subject: [PATCH 13/14] feat: multi-factor pool classification (rate-limits + subscription) Phase 1 (_refresh_subscription): set account_id + store sub_tier/sub_active in ext Phase 2 (_refresh_one): multi-factor inference: 1. Primary: infer_pool from rate-limits auto.total (ground truth) 2. Fallback: if rate-limits says basic but active subscription says super/heavy, subscription wins (paying customer = real super account) 3. For basic accounts, always fetch ALL modes to see auto.total Scenarios: auto=50, any sub -> super (rate-limits) auto=7, INACTIVE -> basic (genuinely downgraded) auto=N/A, ACTIVE sub -> super (API fluke, subscription wins) auto=7, ACTIVE sub -> super (active sub > anomalous quota) Co-Authored-By: Claude Opus 4.8 --- app/control/account/refresh.py | 57 +++++++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 11 deletions(-) diff --git a/app/control/account/refresh.py b/app/control/account/refresh.py index 1d8939583..a606eee43 100644 --- a/app/control/account/refresh.py +++ b/app/control/account/refresh.py @@ -147,10 +147,10 @@ async def _fetch_subscription(self, token: str): return None async def _refresh_subscription(self, record: AccountRecord) -> None: - """Fetch subscription info for an account and persist account_id + pool. + """Fetch subscription info, persist account_id, and store tier hint for Phase 2. - Only called for accounts that are missing ``account_id`` or whose pool - needs verification from the subscription API. + Does NOT set pool — that's Phase 2's job (multi-factor: rate-limits + subscription). + Stores ``sub_tier`` and ``sub_active`` in ext so _refresh_one can cross-validate. """ if record.is_deleted(): return @@ -169,20 +169,24 @@ async def _refresh_subscription(self, record: AccountRecord) -> None: from .commands import AccountPatch patch_fields: dict[str, Any] = {} + ext_merge: dict[str, Any] = { + "sub_tier": info.tier, + "sub_active": info.is_active, + } + if info.xai_user_id and not record.account_id: patch_fields["account_id"] = info.xai_user_id - if info.pool and info.pool != record.pool and info.is_active: - patch_fields["pool"] = info.pool - if patch_fields: + if patch_fields or ext_merge: + patch_fields["ext_merge"] = ext_merge await self._repo.patch_accounts([ AccountPatch(token=record.token, **patch_fields) ]) logger.info( - "subscription info updated: token={}... account_id={} pool={}", + "subscription info updated: token={}... account_id={} tier={} active={}", record.token[:10], info.xai_user_id or record.account_id, - patch_fields.get("pool", record.pool), + info.tier, info.is_active, ) # ------------------------------------------------------------------ @@ -307,8 +311,11 @@ async def _refresh_one( if record.is_deleted(): return RefreshResult() + # For basic accounts, fetch all modes (like "auto" pool) so infer_pool + # can see auto.total and correctly upgrade to super/heavy when warranted. + fetch_pool = "auto" if record.pool == "basic" else record.pool try: - windows = await self._fetch_all_quotas(record.token, record.pool) + windows = await self._fetch_all_quotas(record.token, fetch_pool) except UpstreamError as exc: if await self._expire_invalid_credentials(record, exc): return RefreshResult(checked=1, expired=1, failed=0) @@ -364,12 +371,40 @@ async def _refresh_one( if not patches: return RefreshResult(checked=1, failed=0 if refreshed else 1) - # Infer pool type from live quota data and patch if it changed. + # ── Multi-factor pool inference ────────────────────────────── + # Primary: rate-limits auto.total (ground truth of current quota) + # Fallback: subscription API (active subscription = paying customer) + # + # Scenarios handled: + # auto=50, any sub status → super (rate-limits wins) + # auto=7, INACTIVE sub → basic (genuinely downgraded) + # auto=?, ACTIVE sub → super (subscription wins, API fluke) + # auto=7, ACTIVE sub (anomaly) → super (active sub > anomalous quota) inferred = infer_pool(windows) # type: ignore[arg-type] + + if inferred == "basic": + sub_tier = record.ext.get("sub_tier", "") + sub_active = record.ext.get("sub_active", False) + if sub_active and sub_tier in ( + "SUBSCRIPTION_TIER_GROK_PRO", + "SUBSCRIPTION_TIER_SUPER_GROK", + "SUBSCRIPTION_TIER_SUPER_GROK_PRO", + "SUBSCRIPTION_TIER_GROK_PRO_HEAVY", + "SUBSCRIPTION_TIER_SUPER_GROK_LITE", + "SUBSCRIPTION_TIER_GROK_TEAMS", + ): + from app.dataplane.reverse.protocol.xai_subscription import tier_to_pool as _t2p + inferred = _t2p(sub_tier) + logger.info( + "account pool upgraded by subscription: token={}... " + "rate_limits_pool=basic subscription_tier={} → pool={}", + record.token[:10], sub_tier, inferred, + ) + pool_patch = inferred if inferred != record.pool else None if pool_patch: logger.info( - "account pool updated from live quota: token={}... previous_pool={} current_pool={}", + "account pool updated: token={}... previous_pool={} current_pool={}", record.token[:10], record.pool, inferred, From c0e31ad83314590f15c9d4c44afe0cf603f92634 Mon Sep 17 00:00:00 2001 From: dongxuelian1010 <234438803@qq.com> Date: Sun, 31 May 2026 00:55:53 +0800 Subject: [PATCH 14/14] fix: preserve auto/expert quota when upgrading pool from basic to super Use inferred pool for quota normalization so auto/expert/grok_4_3 modes fetched for basic accounts aren't discarded by normalize_quota_window (basic pool doesn't support those modes). Before: super accounts had auto=0 expert=0 (data fetched but thrown away) After: super accounts have auto=19137 expert=15212 (correctly preserved) Co-Authored-By: Claude Opus 4.8 --- app/control/account/refresh.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/app/control/account/refresh.py b/app/control/account/refresh.py index a606eee43..71313984b 100644 --- a/app/control/account/refresh.py +++ b/app/control/account/refresh.py @@ -334,10 +334,17 @@ async def _refresh_one( patches: dict[str, dict] = {} refreshed = False + # Use the inferred pool for quota normalization when it represents + # an upgrade (e.g. basic→super). Without this, auto/expert/grok_4_3 + # quota data fetched for basic accounts gets discarded by + # normalize_quota_window (basic pool doesn't support those modes). + inferred = infer_pool(windows) # type: ignore[arg-type] + effective_pool = inferred if inferred != "basic" else record.pool + for mode in ALL_MODES_FULL: mode_id = int(mode) if mode_id in windows: - window = normalize_quota_window(record.pool, mode_id, windows[mode_id]) + window = normalize_quota_window(effective_pool, mode_id, windows[mode_id]) if window is None: continue patches[_MODE_KEYS[mode_id]] = window.to_dict() @@ -372,16 +379,14 @@ async def _refresh_one( return RefreshResult(checked=1, failed=0 if refreshed else 1) # ── Multi-factor pool inference ────────────────────────────── - # Primary: rate-limits auto.total (ground truth of current quota) - # Fallback: subscription API (active subscription = paying customer) + # 'inferred' was set above from rate-limits auto.total (primary). + # Now cross-validate with subscription data (fallback). # # Scenarios handled: # auto=50, any sub status → super (rate-limits wins) # auto=7, INACTIVE sub → basic (genuinely downgraded) # auto=?, ACTIVE sub → super (subscription wins, API fluke) # auto=7, ACTIVE sub (anomaly) → super (active sub > anomalous quota) - inferred = infer_pool(windows) # type: ignore[arg-type] - if inferred == "basic": sub_tier = record.ext.get("sub_tier", "") sub_active = record.ext.get("sub_active", False) @@ -397,7 +402,7 @@ async def _refresh_one( inferred = _t2p(sub_tier) logger.info( "account pool upgraded by subscription: token={}... " - "rate_limits_pool=basic subscription_tier={} → pool={}", + "rate_limits_pool=basic subscription_tier={} -> pool={}", record.token[:10], sub_tier, inferred, )