diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..5b2df7d --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,35 @@ +name: CI + +on: + pull_request: + branches: [main] + push: + branches: [main] + +permissions: + contents: read + +jobs: + go-quality: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + + - name: Check formatting + run: test -z "$(gofmt -l .)" + + - name: Vet + run: go vet ./... + + - name: Test + run: go test ./... + + - name: Build + run: go build ./... diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml new file mode 100644 index 0000000..caf8add --- /dev/null +++ b/.github/workflows/docker.yml @@ -0,0 +1,57 @@ +name: Docker + +on: + workflow_dispatch: + pull_request: + branches: + - main + push: + branches: + - main + tags: + - 'v*' + +permissions: + contents: read + packages: write + +jobs: + docker: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to GitHub Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract Docker metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ghcr.io/${{ github.repository }} + tags: | + type=raw,value=latest,enable={{is_default_branch}} + type=sha + type=ref,event=tag + + - name: Build and push Docker image + uses: docker/build-push-action@v6 + with: + context: . + file: ./Dockerfile + platforms: linux/amd64,linux/arm64 + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 9ca3815..932b669 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -21,7 +21,7 @@ jobs: - name: ⚙️ 设置 Go 环境 uses: actions/setup-go@v4 with: - go-version: '1.21' + go-version: '1.25.5' cache: true - name: 📦 下载依赖 diff --git a/.gitignore b/.gitignore index b793827..f6cafd7 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,14 @@ main.zip # Local config config.json +state.json # Temporary/test files test.txt -*.exe \ No newline at end of file +*.exe +*.exe~ +*.har +*.ndjson +__pycache__/ +browser-profile/ +capture_gemini_mitm.py diff --git a/Dockerfile b/Dockerfile index bf7fc5b..22d1685 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,12 +14,20 @@ FROM alpine:3.21 WORKDIR /app -RUN apk add --no-cache ca-certificates tzdata && update-ca-certificates +RUN apk add --no-cache ca-certificates tzdata wget su-exec && update-ca-certificates \ + && addgroup -S app && adduser -S -G app app \ + && mkdir -p /app && chown -R app:app /app COPY --from=builder /out/geminiweb2api /app/geminiweb2api +COPY docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh + +RUN chown app:app /app/geminiweb2api \ + && chmod +x /usr/local/bin/docker-entrypoint.sh EXPOSE 8080 VOLUME ["/app"] -CMD ["/app/geminiweb2api"] +HEALTHCHECK --interval=30s --timeout=10s --start-period=45s --retries=3 CMD sh -c 'wget -q -O /dev/null http://127.0.0.1:8080/api/telemetry || exit 1' + +ENTRYPOINT ["/usr/local/bin/docker-entrypoint.sh"] diff --git a/README.md b/README.md index 3807d42..42ffb0c 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,18 @@ docker run -d \ geminiweb2api:latest ``` +Health check inside the container: + +```bash +curl -s http://127.0.0.1:8080/api/telemetry +``` + +The image prefers running as a non-root user and includes a Docker `HEALTHCHECK` based on `/api/telemetry`. + +If the mounted `/app` volume is not writable by the non-root user in your Docker environment, the entrypoint automatically falls back to root so the container can still start instead of crashing on permissions. + +Use `GET /healthz` for upstream business health checks. It returns `503` when the service is up but no account is currently healthy. + Docker Compose: ```bash @@ -96,6 +108,11 @@ cp config.json.example config.json docker compose up -d --build ``` +Optional environment overrides in `docker compose`: + +- `GEMINIWEB2API_API_KEY` +- `GEMINIWEB2API_PUBLIC_ACCOUNT_STATUS` + ### Configuration Use `config.json` in the project root. You can start from `config.json.example`: @@ -128,8 +145,14 @@ Use `config.json` in the project root. You can start from `config.json.example`: Gemini web cookie string. Recommended when anonymous access is unstable or the environment requires sign-in state. - `tokens` Reserved field. Currently unused. +- `accounts` + Optional multi-account pool. When present, requests are assigned by session binding plus round-robin selection across healthy accounts. Each account supports `id`, `email`, `cookies`, `token`, `proxy`, `enabled`, and `weight`. - `proxy` Explicit proxy such as `http://127.0.0.1:7890`. The app also respects `HTTP_PROXY`, `HTTPS_PROXY`, and `ALL_PROXY`. +- `models` + Optional model ID list returned by `GET /v1/models`. If empty, the built-in default Gemini model list is used. +- `model_aliases` + Optional request model alias map. Example: map `gpt-4.1` to `gemini-3-pro` for upstream panels such as NewAPI. - `gemini_url` Override for the Gemini generation endpoint in reverse-proxy setups. - `gemini_home_url` @@ -140,9 +163,30 @@ Use `config.json` in the project root. You can start from `config.json.example`: Log output path. Empty means stdout. - `log_level` `debug`, `info`, `warn`, or `error`. +- `public_account_status` + Defaults to `false`. When `false`, `GET /api/accounts` and `GET /api/accounts/bindings` require `Authorization: Bearer `. Set to `true` only for trusted local deployments where unauthenticated read-only status is acceptable. - `note` Free-form note strings surfaced by `/api/telemetry` and the WebUI. +### Environment Variables + +Production deployments can override selected `config.json` values with environment variables: + +- `GEMINIWEB2API_API_KEY` +- `GEMINIWEB2API_PROXY` +- `GEMINIWEB2API_PORT` +- `GEMINIWEB2API_LOG_LEVEL` +- `GEMINIWEB2API_PUBLIC_ACCOUNT_STATUS` + +Environment values take precedence over `config.json` at load time. + +### Security Notes + +- Do not commit `config.json`; it can contain API keys, Google cookies, tokens, and proxies. +- Keep `public_account_status` disabled for public or production deployments. +- Management APIs that mutate accounts always require `Authorization: Bearer `. +- The authenticated account details endpoint can return full cookies and tokens; only expose the service behind trusted networks or authentication layers. + ### Hot Reload The process checks `config.json` every 5 seconds and reloads it automatically when the file changes. You do not need to restart the service after editing the config. @@ -192,6 +236,109 @@ SID=...; APISID=...; SAPISID=...; ... Do not commit real cookies to a public repository. +### Multi-Account Pool + +You can now run the proxy in multi-account mode by filling `accounts` in `config.json`. + +Example: + +```json +{ + "api_key": "your-api-key-here", + "accounts": [ + { + "id": "acc-1", + "email": "first@example.com", + "cookies": "SID=...; APISID=...", + "token": "", + "proxy": "", + "enabled": true, + "weight": 1 + }, + { + "id": "acc-2", + "email": "second@example.com", + "cookies": "SID=...; APISID=...", + "token": "", + "proxy": "http://user:pass@proxy-host:port", + "enabled": true, + "weight": 1 + } + ] +} +``` + +Behavior: + +- The same `X-Session-ID` stays bound to the same account while that account is healthy. +- New sessions are assigned by round-robin across healthy accounts. +- Failed accounts enter exponential backoff starting at 30 seconds, doubling up to 30 minutes. +- If an account has `proxy`, token refresh and Gemini requests for that account use that proxy. +- If account `proxy` is empty, the service falls back to the global `proxy` setting or the machine's proxy environment. +- If `accounts` is empty, the service falls back to the legacy single-account `cookies` and `token` fields. + +### Session Binding Persistence + +Session-to-account bindings are persisted in `state.json` beside `config.json`. + +- Persisted: session/account binding, bind time, last used time +- Not persisted: short-lived runtime page tokens like `SNlM0e`, `BL`, `f.sid` + +On restart, bindings are restored when the referenced account still exists. + +### Account Pool APIs + +- `GET /api/accounts` + Returns configured accounts and runtime state. +- `POST /api/accounts` + Creates or updates an account. +- `GET /api/accounts/bindings` + Returns current session-to-account bindings. +- `POST /api/accounts/{id}/enable` + Enables an account. +- `POST /api/accounts/{id}/disable` + Disables an account. +- `POST /api/accounts/{id}/refresh` + Refreshes token state for one account immediately. + +All account APIs require `Authorization: Bearer `. + +### Google Account Manager v1.8 Compatibility + +The legacy Google account manager can keep using its existing Gemini session callback: + +```http +POST /api/session/cookies +Authorization: Bearer +Content-Type: application/json +``` + +Body: + +```json +{ + "email": "account@gmail.com", + "cookies": "SID=...; __Secure-1PSID=...; ...", + "proxy": "http://user:pass@proxy-host:port", + "persist": true +} +``` + +When `email` is present, this endpoint now upserts the cookie into the multi-account pool instead of only updating the legacy single-account `cookies` field. The generated account ID uses the email directly, for example: + +```text +account@gmail.com +``` + +In the Google account manager settings, set: + +- `GEMINIWEB2API_URL` to this service, for example `http://127.0.0.1:8080` +- `GEMINIWEB2API_KEY` to this service's `api_key` +- `GEMINIWEB2API_PERSIST` to `true` if you want updates written to `config.json` +- Optional `GEMINIWEB2API_ACCOUNT_PROXY` if all callbacks from that manager should use the same outbound proxy in this service + +Then use its existing `抓 Session` / `批量抓 Session` action. Successful callbacks should show the imported account in this service's account pool. + ### Usage Examples #### Health check @@ -240,6 +387,51 @@ curl -N "http://127.0.0.1:8080/v1/chat/completions" \ }' ``` +### Use Behind NewAPI + +If you run a NewAPI panel or any OpenAI-compatible gateway, the recommended topology is: + +1. Google cookie -> `geminiweb2api` +2. NewAPI upstream -> `geminiweb2api` +3. End users -> NewAPI + +Recommended upstream settings in NewAPI: + +- Base URL: `http://your-geminiweb2api-host:8080/v1` +- API Key: the `api_key` from `config.json` +- Model discovery: `GET /v1/models` +- Chat endpoint: `POST /v1/chat/completions` +- Responses endpoint: `POST /v1/responses` +- Health check: `GET /healthz` + +Notes: + +- `GET /v1/models` also requires `Authorization: Bearer `. +- `POST /v1/responses` is supported as a minimal compatibility layer and is internally translated into `/v1/chat/completions` for text input. +- Streaming is supported with SSE and ends with `data: [DONE]`. The current implementation streams incremental chunks from the final Gemini content instead of a true token-by-token upstream stream. +- `stream_options.include_usage` is supported. +- `model_aliases` can be used to align NewAPI/OpenAI-style model names with Gemini model IDs. +- Common OpenAI/NewAPI fields such as `max_completion_tokens`, `top_p`, `presence_penalty`, `frequency_penalty`, `response_format`, and `user` are accepted for compatibility. Some are pass-through compatibility fields and may not materially change Gemini Web behavior. + +Recommended model names for upstream mapping: + +- `gemini-3-flash` +- `gemini-3` +- `gemini-3-pro` +- `gemini-2.5-flash` +- `gemini-2.5-pro` + +Suggested first-choice default: + +- `gemini-3-flash` + +Example NewAPI health probe: + +```bash +curl -s "http://127.0.0.1:8080/v1/models" \ + -H "Authorization: Bearer your-api-key-here" +``` + ### Session Continuity - Keep `X-Session-ID` stable for the same user or conversation. diff --git a/capture_gemini_mitm.py b/capture_gemini_mitm.py new file mode 100644 index 0000000..efce611 --- /dev/null +++ b/capture_gemini_mitm.py @@ -0,0 +1,122 @@ +from mitmproxy import http +from pathlib import Path +import json +import time + +OUT = Path("mitm-gemini-tools.ndjson") +SENSITIVE_HEADERS = {"cookie", "authorization", "x-client-data"} +BODY_LIMIT = 500 * 1024 +MULTIPART_PART_LIMIT = 16 * 1024 + + +def scrub_headers(headers): + result = {} + for key, value in headers.items(): + if key.lower() in SENSITIVE_HEADERS: + result[key] = f"" + else: + result[key] = value + return result + + +def body_text(content, content_type=""): + if not content: + return "" + if "multipart/form-data" in (content_type or "").lower(): + return multipart_body_text(content) + text = content.decode("utf-8", errors="replace") + if len(text) > BODY_LIMIT: + return text[:BODY_LIMIT] + "\n" + return text + + +def multipart_body_text(content): + text = content.decode("utf-8", errors="replace") + lines = text.splitlines(keepends=True) + result = [] + in_binary_part = False + binary_size = 0 + trimmed_binary = False + + for line in lines: + stripped = line.strip() + if stripped.startswith("--"): + if trimmed_binary: + result.append(f"\r\n") + result.append(line) + in_binary_part = False + binary_size = 0 + trimmed_binary = False + continue + if line.lower().startswith("content-type: image/") or line.lower().startswith("content-type: video/"): + in_binary_part = True + result.append(line) + continue + if in_binary_part and stripped and not line.lower().startswith("content-"): + binary_size += len(line) + if binary_size <= MULTIPART_PART_LIMIT: + result.append(line) + else: + trimmed_binary = True + continue + result.append(line) + + captured = "".join(result) + if len(captured) > BODY_LIMIT: + return captured[:BODY_LIMIT] + "\n" + return captured + + +def header_value(headers, name): + for key, value in headers.items(): + if key.lower() == name: + return value + return "" + + +def content_length(headers): + value = header_value(headers, "content-length") + if not value: + return 0 + try: + return int(value) + except ValueError: + return 0 + + +def append(event): + with OUT.open("a", encoding="utf-8") as f: + f.write(json.dumps(event, ensure_ascii=False) + "\n") + + +def request(flow: http.HTTPFlow): + if "gemini.google.com" not in (flow.request.pretty_host or ""): + return + flow.metadata["capture_gemini"] = True + content_type = header_value(flow.request.headers, "content-type") + append({ + "ts": time.time(), + "type": "request", + "id": flow.id, + "method": flow.request.method, + "url": flow.request.pretty_url, + "content_type": content_type, + "content_length": content_length(flow.request.headers), + "headers": scrub_headers(flow.request.headers), + "body": body_text(flow.request.raw_content, content_type), + }) + + +def response(flow: http.HTTPFlow): + if not flow.metadata.get("capture_gemini"): + return + content_type = header_value(flow.response.headers, "content-type") + append({ + "ts": time.time(), + "type": "response", + "id": flow.id, + "status_code": flow.response.status_code, + "content_type": content_type, + "headers": scrub_headers(flow.response.headers), + "body": body_text(flow.response.content, content_type), + }) diff --git a/config.json.example b/config.json.example index 4b4d5d2..0899729 100644 --- a/config.json.example +++ b/config.json.example @@ -3,12 +3,42 @@ "token": "", "cookies": "", "tokens": [], + "accounts": [ + { + "id": "acc-1", + "email": "first@example.com", + "cookies": "", + "token": "", + "proxy": "", + "enabled": true, + "weight": 1 + }, + { + "id": "acc-2", + "email": "second@example.com", + "cookies": "", + "token": "", + "proxy": "", + "enabled": true, + "weight": 1 + } + ], "proxy": "", + "models": [ + "gemini-3-flash", + "gemini-3-pro", + "gemini-2.5-flash" + ], + "model_aliases": { + "gpt-4.1": "gemini-3-pro", + "gpt-4o-mini": "gemini-3-flash" + }, "gemini_url": "", "gemini_home_url": "", "port": 8080, "log_file": "", "log_level": "info", + "public_account_status": false, "note": [ "Auto-generated config" ] diff --git a/docker-compose.yml b/docker-compose.yml index fa4f27d..8e940e4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,7 +3,16 @@ services: build: . container_name: geminiweb2api restart: unless-stopped + environment: + GEMINIWEB2API_API_KEY: ${GEMINIWEB2API_API_KEY:-your-api-key-here} + GEMINIWEB2API_PUBLIC_ACCOUNT_STATUS: ${GEMINIWEB2API_PUBLIC_ACCOUNT_STATUS:-false} ports: - "8080:8080" volumes: - - ./config.json:/app/config.json:ro + - ./config.json:/app/config.json + healthcheck: + test: ["CMD-SHELL", "wget -q -O /dev/null http://127.0.0.1:8080/api/telemetry || exit 1"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 45s diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh new file mode 100644 index 0000000..83238e1 --- /dev/null +++ b/docker-entrypoint.sh @@ -0,0 +1,26 @@ +#!/bin/sh +set -eu + +APP_BIN="/app/geminiweb2api" +APP_DIR="/app" + +can_write_app_dir() { + test_file="$APP_DIR/.perm-check" + if touch "$test_file" 2>/dev/null; then + rm -f "$test_file" 2>/dev/null || true + return 0 + fi + return 1 +} + +if [ "$(id -u)" = "0" ]; then + if can_write_app_dir; then + chown -R app:app "$APP_DIR" 2>/dev/null || true + exec su-exec app "$APP_BIN" + fi + + echo "[WARN] /app volume is not writable by non-root user; falling back to root runtime" + exec "$APP_BIN" +fi + +exec "$APP_BIN" diff --git a/internal/config/store.go b/internal/config/store.go index ec7633c..6e24576 100644 --- a/internal/config/store.go +++ b/internal/config/store.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "os" + "strconv" "sync" "time" ) @@ -14,17 +15,31 @@ const ( ) type Config struct { - APIKey string `json:"api_key"` - Token string `json:"token"` - Cookies string `json:"cookies"` - Tokens []string `json:"tokens"` - Proxy string `json:"proxy"` - GeminiURL string `json:"gemini_url"` - GeminiHomeURL string `json:"gemini_home_url"` - Port int `json:"port"` - LogFile string `json:"log_file"` - LogLevel string `json:"log_level"` - Note []string `json:"note"` + APIKey string `json:"api_key"` + Token string `json:"token"` + Cookies string `json:"cookies"` + Tokens []string `json:"tokens"` + Accounts []Account `json:"accounts"` + Proxy string `json:"proxy"` + Models []string `json:"models"` + ModelAliases map[string]string `json:"model_aliases"` + GeminiURL string `json:"gemini_url"` + GeminiHomeURL string `json:"gemini_home_url"` + Port int `json:"port"` + LogFile string `json:"log_file"` + LogLevel string `json:"log_level"` + PublicAccountStatus bool `json:"public_account_status"` + Note []string `json:"note"` +} + +type Account struct { + ID string `json:"id"` + Email string `json:"email"` + Cookies string `json:"cookies"` + Token string `json:"token"` + Proxy string `json:"proxy"` + Enabled bool `json:"enabled"` + Weight int `json:"weight"` } type Store struct { @@ -38,6 +53,10 @@ func NewStore(path string) *Store { return &Store{path: path} } +func (s *Store) Path() string { + return s.path +} + func (s *Store) Snapshot() Config { s.mu.RLock() defer s.mu.RUnlock() @@ -64,6 +83,7 @@ func (s *Store) Load() error { if err := os.WriteFile(s.path, data, 0644); err != nil { return fmt.Errorf("failed to write default config: %w", err) } + applyEnvOverrides(&defaultConfig) s.cfg = defaultConfig return nil } @@ -86,6 +106,7 @@ func (s *Store) Load() error { cfg.LogLevel = DefaultLogLevel } + applyEnvOverrides(&cfg) s.cfg = cfg return nil } @@ -94,6 +115,63 @@ func (s *Store) Reload() error { return s.Load() } +func (s *Store) Update(mutator func(*Config) error) error { + s.mu.Lock() + defer s.mu.Unlock() + + cfg := s.cfg + if err := mutator(&cfg); err != nil { + return err + } + + data, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal config: %w", err) + } + if err := os.WriteFile(s.path, data, 0644); err != nil { + return fmt.Errorf("failed to write config: %w", err) + } + + applyEnvOverrides(&cfg) + s.cfg = cfg + return nil +} + +func applyEnvOverrides(cfg *Config) { + if value := os.Getenv("GEMINIWEB2API_API_KEY"); value != "" { + cfg.APIKey = value + } + if value := os.Getenv("GEMINIWEB2API_PROXY"); value != "" { + cfg.Proxy = value + } + if value := os.Getenv("GEMINIWEB2API_PORT"); value != "" { + if port, err := strconv.Atoi(value); err == nil && port > 0 { + cfg.Port = port + } + } + if value := os.Getenv("GEMINIWEB2API_LOG_LEVEL"); value != "" { + cfg.LogLevel = value + } + if value := os.Getenv("GEMINIWEB2API_PUBLIC_ACCOUNT_STATUS"); value != "" { + if parsed, err := strconv.ParseBool(value); err == nil { + cfg.PublicAccountStatus = parsed + } + } +} + +func (s *Store) UpdateInMemory(mutator func(*Config) error) error { + s.mu.Lock() + defer s.mu.Unlock() + + cfg := s.cfg + if err := mutator(&cfg); err != nil { + return err + } + + s.cfg = cfg + return nil +} + func (s *Store) Watch(onReload func() error) { go func() { var lastModTime time.Time diff --git a/internal/config/store_test.go b/internal/config/store_test.go new file mode 100644 index 0000000..e3b8aef --- /dev/null +++ b/internal/config/store_test.go @@ -0,0 +1,30 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadAppliesEnvironmentOverrides(t *testing.T) { + t.Setenv("GEMINIWEB2API_API_KEY", "env-key") + t.Setenv("GEMINIWEB2API_PROXY", "http://127.0.0.1:7890") + t.Setenv("GEMINIWEB2API_PORT", "9090") + t.Setenv("GEMINIWEB2API_LOG_LEVEL", "debug") + t.Setenv("GEMINIWEB2API_PUBLIC_ACCOUNT_STATUS", "true") + + dir := t.TempDir() + path := filepath.Join(dir, "config.json") + if err := os.WriteFile(path, []byte(`{"api_key":"file-key","port":8080,"log_level":"info"}`), 0o600); err != nil { + t.Fatal(err) + } + + store := NewStore(path) + if err := store.Load(); err != nil { + t.Fatal(err) + } + cfg := store.Snapshot() + if cfg.APIKey != "env-key" || cfg.Proxy != "http://127.0.0.1:7890" || cfg.Port != 9090 || cfg.LogLevel != "debug" || !cfg.PublicAccountStatus { + t.Fatalf("environment overrides were not applied: %+v", cfg) + } +} diff --git a/internal/gemini/client.go b/internal/gemini/client.go index 912b7bc..43f408f 100644 --- a/internal/gemini/client.go +++ b/internal/gemini/client.go @@ -1,6 +1,8 @@ package gemini import ( + "bufio" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -85,14 +87,27 @@ func (s *GeminiSession) SetConversationID(conversationID string) { } type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Stream bool `json:"stream"` - Tools []Tool `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - ConversationID string `json:"conversation_id,omitempty"` + Model string `json:"model"` + Messages []Message `json:"messages"` + Stream bool `json:"stream"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + ConversationID string `json:"conversation_id,omitempty"` + N int `json:"n,omitempty"` + Stop interface{} `json:"stop,omitempty"` + TopP float64 `json:"top_p,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + ResponseFormat map[string]any `json:"response_format,omitempty"` + User string `json:"user,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` +} + +type StreamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` } type Tool struct { @@ -107,15 +122,33 @@ type Function struct { } type Message struct { - Role string `json:"role"` - Content interface{} `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content interface{} `json:"content"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` } type ContentPart struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + ImageURL *ImageURL `json:"image_url,omitempty"` +} + +type ImageURL struct { + URL string `json:"url"` + Detail string `json:"detail,omitempty"` +} + +type ParsedMessage struct { + Text string + Images []ImageData +} + +type ImageData struct { + MimeType string + Base64 string + URL string } type ToolCall struct { @@ -147,9 +180,10 @@ type Choice struct { } type Delta struct { - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` } type Usage struct { @@ -170,13 +204,68 @@ type Model struct { OwnedBy string `json:"owned_by"` } +type ResponsesRequest struct { + Model string `json:"model"` + Input interface{} `json:"input"` + Stream bool `json:"stream,omitempty"` +} + +type ResponsesResponse struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Model string `json:"model"` + Output []struct { + Type string `json:"type"` + Role string `json:"role"` + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + } `json:"output"` +} + type ErrorResponse struct { Error struct { Message string `json:"message"` Type string `json:"type"` + Code string `json:"code,omitempty"` } `json:"error"` } +type OpenAIError struct { + Status int + Type string + Code string + Message string +} + +type AccountContext struct { + ID string + Email string + Cookies string + Proxy string + Token string + BLToken string + FSID string + ReqID string +} + +var accountHTTPClients sync.Map + +func httpClientForAccount(accountCtx AccountContext) *http.Client { + proxyValue := strings.TrimSpace(accountCtx.Proxy) + if proxyValue == "" { + return depGetHTTPClient() + } + if client, ok := accountHTTPClients.Load(proxyValue); ok { + return client.(*http.Client) + } + client, _, _ := httpclient.NewWithProxy(depGetConfig(), proxyValue, depGetLogger()) + actual, _ := accountHTTPClients.LoadOrStore(proxyValue, client) + return actual.(*http.Client) +} + var errorCodeMap = map[int]string{ 0: "success", 1: "invalid_request", @@ -197,11 +286,23 @@ var errorCodeMap = map[int]string{ const ( geminiInnerReqLen = 69 + geminiInnerReqLenThinking = 80 geminiStreamingFlagIndex = 7 geminiDefaultMetadataSlots = 10 geminiWebLanguage = "zh-CN" headerModelJSPB = "x-goog-ext-525001261-jspb" headerRequestUUIDJSPB = "x-goog-ext-525005358-jspb" + + idxFeatureMode = 49 + idxThinkingLevel = 79 + + thinkingLevelStandard = 1 + thinkingLevelExtended = 2 + thinkingLevelDeepThink = 3 + + featureModeDeepThink = 20 + featureModeVideo = 11 + featureModeImage = 14 ) type webModelSpec struct { @@ -209,6 +310,14 @@ type webModelSpec struct { Capacity int } +type experimentalRequestConfig struct { + FeatureMode int + ThinkingLevel int + Ef int + Xpc string + Lo *bool +} + var modelSpecMap = map[string]webModelSpec{ "gemini-3-flash": {"fbb127bbb056c959", 1}, "gemini-3": {"fbb127bbb056c959", 1}, @@ -220,6 +329,9 @@ var modelSpecMap = map[string]webModelSpec{ "gemini-3-pro": {"9d8ca3786ebdfbea", 1}, "gemini-pro": {"9d8ca3786ebdfbea", 1}, "gemini-2.5-pro": {"9d8ca3786ebdfbea", 1}, + "gemini-3-pro-deep-think": {"e6fa609c3fa255c0", 4}, + "gemini-3-pro-image": {"e6fa609c3fa255c0", 4}, + "gemini-3-pro-video": {"e6fa609c3fa255c0", 4}, "gemini-3-pro-plus": {"e6fa609c3fa255c0", 4}, "gemini-3-pro-advanced": {"e6fa609c3fa255c0", 2}, "gemini-3.1": {"e6fa609c3fa255c0", 2}, @@ -253,7 +365,11 @@ func sessionToGeminiMetadata(snapshot GeminiSessionSnapshot) []interface{} { return m } -func buildModelHeaderJSPB(spec webModelSpec) string { +func buildModelHeaderJSPB(spec webModelSpec, thinkingLevel int, uuidVal string) string { + if thinkingLevel > 0 { + return fmt.Sprintf(`[1,null,null,null,"%s",null,null,0,[4],null,null,%d,null,null,%d,null,"%s"]`, + spec.HexID, thinkingLevel, thinkingLevel, uuidVal) + } return fmt.Sprintf(`[1,null,null,null,"%s",null,null,0,[4],null,null,%d]`, spec.HexID, spec.Capacity) } @@ -271,25 +387,103 @@ func noteGeminiResponseErrors(body string, sessionKey string, mode string) { } func extractMessageContent(msg Message) string { + return extractMultimodalContent(msg).Text +} + +func extractMultimodalContent(msg Message) ParsedMessage { switch v := msg.Content.(type) { case string: - return v + return ParsedMessage{Text: v} case []interface{}: - var parts []string + var parsed ParsedMessage + var textParts []string for _, part := range v { - if p, ok := part.(map[string]interface{}); ok { - if text, ok := p["text"].(string); ok { - parts = append(parts, text) - } + p, ok := part.(map[string]interface{}) + if !ok { + continue } + if text, ok := p["text"].(string); ok { + textParts = append(textParts, text) + } + imageURL, ok := extractImageURLPart(p) + if !ok || imageURL == "" { + continue + } + image := ImageData{URL: imageURL} + if mimeType, data, ok := parseDataURI(imageURL); ok { + image.MimeType = mimeType + image.Base64 = data + } + parsed.Images = append(parsed.Images, image) } - return strings.Join(parts, "\n") + parsed.Text = strings.Join(textParts, "\n") + return parsed default: if v != nil { - return fmt.Sprintf("%v", v) + return ParsedMessage{Text: fmt.Sprintf("%v", v)} } - return "" + return ParsedMessage{} + } +} + +func extractImageURLPart(part map[string]interface{}) (string, bool) { + imageURL, ok := part["image_url"] + if !ok { + return "", false + } + switch v := imageURL.(type) { + case string: + return v, true + case map[string]interface{}: + urlValue, ok := v["url"].(string) + return urlValue, ok + default: + return "", false + } +} + +func parseDataURI(uri string) (mimeType string, data string, ok bool) { + const marker = ";base64," + if !strings.HasPrefix(uri, "data:") { + return "", "", false + } + idx := strings.Index(uri, marker) + if idx < 0 { + return "", "", false + } + mimeType = uri[len("data:"):idx] + data = uri[idx+len(marker):] + return mimeType, data, mimeType != "" && data != "" +} + +func downloadImageAsBase64(imageURL string, httpClient *http.Client) (ImageData, error) { + if httpClient == nil { + httpClient = http.DefaultClient + } + resp, err := httpClient.Get(imageURL) + if err != nil { + return ImageData{}, err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return ImageData{}, fmt.Errorf("download image returned status %d", resp.StatusCode) } + body, err := io.ReadAll(resp.Body) + if err != nil { + return ImageData{}, err + } + mimeType := resp.Header.Get("Content-Type") + if idx := strings.Index(mimeType, ";"); idx >= 0 { + mimeType = strings.TrimSpace(mimeType[:idx]) + } + if mimeType == "" { + mimeType = http.DetectContentType(body) + } + return ImageData{ + MimeType: mimeType, + Base64: base64.StdEncoding.EncodeToString(body), + URL: imageURL, + }, nil } func buildToolsPrompt(tools []Tool) string { @@ -326,14 +520,24 @@ func buildToolsPrompt(tools []Tool) string { } func BuildPrompt(req ChatCompletionRequest) string { + prompt, _ := BuildPromptWithMedia(req) + return prompt +} + +func BuildPromptWithMedia(req ChatCompletionRequest) (string, []ImageData) { var prompt strings.Builder + var images []ImageData toolsPrompt := buildToolsPrompt(req.Tools) if toolsPrompt != "" { prompt.WriteString(toolsPrompt) prompt.WriteString("\n---\n\n") } for _, msg := range req.Messages { - content := extractMessageContent(msg) + parsed := extractMultimodalContent(msg) + if len(parsed.Images) > 0 { + images = append(images, parsed.Images...) + } + content := parsed.Text switch msg.Role { case "system": prompt.WriteString(fmt.Sprintf("[System Instruction]\n%s\n[/System Instruction]\n\n", content)) @@ -351,7 +555,76 @@ func BuildPrompt(req ChatCompletionRequest) string { prompt.WriteString(fmt.Sprintf("Tool Result [%s]: %s\n\n", msg.ToolCallID, content)) } } - return prompt.String() + return prompt.String(), images +} + +func isDeepThinkAlias(modelName string) bool { + n := strings.ToLower(strings.TrimSpace(modelName)) + return n == "gemini-3-pro-deep-think" +} + +// isThinkingModel 判断模型是否需要设置 thinking 协议字段 +func isThinkingModel(modelName string) bool { + n := strings.ToLower(strings.TrimSpace(modelName)) + switch n { + case "gemini-3-flash-thinking", + "gemini-3-flash-thinking-plus": + return true + default: + return false + } +} + +// getThinkingLevel 返回模型对应的 thinking level 和 feature mode +// 返回 (thinkingLevel, featureMode, needsThinkingFields) +func getThinkingLevel(modelName string) (int, int, bool) { + n := strings.ToLower(strings.TrimSpace(modelName)) + switch n { + case "gemini-3-pro-deep-think": + return thinkingLevelDeepThink, featureModeDeepThink, true + case "gemini-3-flash-thinking": + return thinkingLevelStandard, 0, true + case "gemini-3-flash-thinking-plus": + return thinkingLevelExtended, 0, true + default: + return 0, 0, false + } +} + +func getExperimentalFeatureMode(modelName string) (int, bool) { + switch strings.ToLower(strings.TrimSpace(modelName)) { + case "gemini-3-pro-image": + return featureModeImage, true + case "gemini-3-pro-video": + return featureModeVideo, true + default: + return 0, false + } +} + +func getExperimentalRequestConfig(modelName string) (experimentalRequestConfig, bool) { + switch strings.ToLower(strings.TrimSpace(modelName)) { + case "gemini-3-pro-image": + lo := false + return experimentalRequestConfig{ + FeatureMode: featureModeImage, + ThinkingLevel: 5, + Ef: featureModeImage, + Xpc: "MODE_CATEGORY_FAST", + Lo: &lo, + }, true + case "gemini-3-pro-video": + lo := false + return experimentalRequestConfig{ + FeatureMode: featureModeVideo, + ThinkingLevel: 5, + Ef: featureModeVideo, + Xpc: "MODE_CATEGORY_FAST", + Lo: &lo, + }, true + default: + return experimentalRequestConfig{}, false + } } func parseToolCalls(content string, tools []Tool) (string, []ToolCall) { @@ -359,74 +632,159 @@ func parseToolCalls(content string, tools []Tool) (string, []ToolCall) { return content, nil } - var toolCalls []ToolCall - cleanContent := content - re1 := regexp.MustCompile(`(?s)\{\s*"name"\s*:\s*"([^"]+)"\s*,\s*"arguments"\s*:\s*(\{[^}]*\})\s*\}`) - matches1 := re1.FindAllStringSubmatch(content, -1) - for i, match := range matches1 { - name := match[1] - args := match[2] - for _, t := range tools { - if t.Function.Name == name { - toolCalls = append(toolCalls, ToolCall{ - ID: fmt.Sprintf("call_%s_%d", support.GenerateRandomHex(8), i), - Type: "function", - Function: FunctionCall{ - Name: name, - Arguments: args, - }, - }) - cleanContent = strings.Replace(cleanContent, match[0], "", 1) - break - } - } + allowed := make(map[string]struct{}, len(tools)) + for _, t := range tools { + allowed[t.Function.Name] = struct{}{} } - if len(toolCalls) > 0 { - return strings.TrimSpace(cleanContent), toolCalls - } + clean := content + toolCalls := make([]ToolCall, 0) + seen := make(map[string]struct{}) - re2 := regexp.MustCompile("(?s)```tool_call\\s*\\n?(\\{.*?\\})\\s*```") - matches2 := re2.FindAllStringSubmatch(content, -1) - for i, match := range matches2 { - var tc struct { - Name string `json:"name"` - Arguments json.RawMessage `json:"arguments"` + addCall := func(name, args, rawBlock string) { + if _, ok := allowed[name]; !ok { + return } + key := name + "\n" + args + if _, ok := seen[key]; ok { + return + } + seen[key] = struct{}{} + toolCalls = append(toolCalls, ToolCall{ + ID: fmt.Sprintf("call_%s_%d", support.GenerateRandomHex(8), len(toolCalls)), + Type: "function", + Function: FunctionCall{ + Name: name, + Arguments: args, + }, + }) + if rawBlock != "" { + clean = strings.Replace(clean, rawBlock, "", 1) + } + } - jsonStr := match[1] - if err := json.Unmarshal([]byte(jsonStr), &tc); err != nil { - depGetLogger().Debug("解析工具调用失败: %v", err) + // 1) 优先解析 markdown fenced tool_call 块 + fenceRe := regexp.MustCompile("(?is)```tool_call\\s*(\\{[\\s\\S]*?\\})\\s*```") + for _, m := range fenceRe.FindAllStringSubmatch(content, -1) { + if len(m) < 2 { continue } - toolExists := false - for _, t := range tools { - if t.Function.Name == tc.Name { - toolExists = true - break - } + name, args, ok := parseOneToolCallJSON(strings.TrimSpace(m[1])) + if ok { + addCall(name, args, m[0]) } - if !toolExists { - continue + } + + // 2) 再扫描正文里可能出现的 JSON 对象 + for _, raw := range extractJSONObjectCandidates(content) { + name, args, ok := parseOneToolCallJSON(raw) + if ok { + addCall(name, args, raw) } + } - toolCall := ToolCall{ - ID: fmt.Sprintf("call_%s_%d", support.GenerateRandomHex(8), i), - Type: "function", - Function: FunctionCall{ - Name: tc.Name, - Arguments: string(tc.Arguments), - }, + return strings.TrimSpace(clean), toolCalls +} + +func parseOneToolCallJSON(raw string) (name string, args string, ok bool) { + var tc struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` + } + if err := json.Unmarshal([]byte(raw), &tc); err != nil { + return "", "", false + } + if strings.TrimSpace(tc.Name) == "" { + return "", "", false + } + argsNorm, ok := normalizeToolArguments(tc.Arguments) + if !ok { + return "", "", false + } + return tc.Name, argsNorm, true +} + +func normalizeToolArguments(raw json.RawMessage) (string, bool) { + trimmed := strings.TrimSpace(string(raw)) + if trimmed == "" || trimmed == "null" { + return "{}", true + } + + // arguments 可能被模型输出为 JSON 字符串,需要二次反序列化 + if strings.HasPrefix(trimmed, `"`) { + var inner string + if err := json.Unmarshal([]byte(trimmed), &inner); err != nil { + return "", false } - toolCalls = append(toolCalls, toolCall) - cleanContent = strings.Replace(cleanContent, match[0], "", 1) + trimmed = strings.TrimSpace(inner) + } + + if trimmed == "" { + return "{}", true + } + + var obj interface{} + if err := json.Unmarshal([]byte(trimmed), &obj); err != nil { + return "", false + } + canon, err := json.Marshal(obj) + if err != nil { + return "", false } + return string(canon), true +} + +func extractJSONObjectCandidates(s string) []string { + result := make([]string, 0) + b := []byte(s) + for i := 0; i < len(b); i++ { + if b[i] != '{' { + continue + } + depth := 0 + inString := false + escaped := false + for j := i; j < len(b); j++ { + ch := b[j] + if inString { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' { + inString = false + } + continue + } - return strings.TrimSpace(cleanContent), toolCalls + if ch == '"' { + inString = true + continue + } + if ch == '{' { + depth++ + } + if ch == '}' { + depth-- + if depth == 0 { + candidate := strings.TrimSpace(string(b[i : j+1])) + if strings.Contains(candidate, `"name"`) && strings.Contains(candidate, `"arguments"`) { + result = append(result, candidate) + } + i = j + break + } + } + } + } + return result } -func buildGeminiRequest(prompt string, session *GeminiSession, modelName string, snlm0eToken string) (*http.Request, error) { - depTokens.RefreshTokenIfNeeded() +func buildGeminiRequest(prompt string, images []ImageData, session *GeminiSession, modelName string, accountCtx AccountContext) (*http.Request, error) { uuidVal := strings.ToUpper(support.GenerateUUIDv4()) spec := modelSpecMap["gemini-3-flash"] @@ -444,13 +802,21 @@ func buildGeminiRequest(prompt string, session *GeminiSession, modelName string, depGetLogger().Debug("正在开始新对话") } - currentToken := snlm0eToken + currentToken := accountCtx.Token if currentToken == "" { currentToken = depTokens.GetToken() } messageContent := []interface{}{prompt, 0, nil, nil, nil, nil, 0} - inner := make([]interface{}, geminiInnerReqLen) + + // 根据是否为 thinking 模型决定数组长度 + thinkingLevel, featureMode, needsThinking := getThinkingLevel(modelName) + experimentalCfg, hasExperimentalCfg := getExperimentalRequestConfig(modelName) + reqLen := geminiInnerReqLen + if needsThinking { + reqLen = geminiInnerReqLenThinking + } + inner := make([]interface{}, reqLen) inner[0] = messageContent inner[1] = []interface{}{geminiWebLanguage} inner[2] = meta @@ -468,6 +834,19 @@ func buildGeminiRequest(prompt string, session *GeminiSession, modelName string, inner[61] = []interface{}{} inner[68] = 2 + // 设置 Deep Think / Thinking 协议字段 + if needsThinking { + if featureMode != 0 { + inner[idxFeatureMode] = featureMode + } + inner[idxThinkingLevel] = thinkingLevel + depGetLogger().Debug("已设置 thinking 协议字段: level=%d, featureMode=%d, reqLen=%d", thinkingLevel, featureMode, reqLen) + } else if hasExperimentalCfg { + inner[idxFeatureMode] = experimentalCfg.FeatureMode + inner[idxThinkingLevel] = experimentalCfg.ThinkingLevel + depGetLogger().Debug("已设置实验工具字段: featureMode=%d, ef=%d, xpc=%s", experimentalCfg.FeatureMode, experimentalCfg.Ef, experimentalCfg.Xpc) + } + innerJSON, err := json.Marshal(inner) if err != nil { return nil, fmt.Errorf("marshal f.req inner: %w", err) @@ -477,7 +856,7 @@ func buildGeminiRequest(prompt string, session *GeminiSession, modelName string, data.Set("at", currentToken) data.Set("f.req", freqData) endpoints := httpclient.CurrentGeminiEndpoints(depGetConfig()) - requestURL, err := buildGeminiRequestURL(endpoints.URL) + requestURL, err := buildGeminiRequestURL(endpoints.URL, accountCtx) if err != nil { return nil, err } @@ -491,7 +870,9 @@ func buildGeminiRequest(prompt string, session *GeminiSession, modelName string, req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Safari/537.36") req.Header.Set("Content-Type", "application/x-www-form-urlencoded;charset=UTF-8") req.Header.Set("accept-language", "zh-CN") - if cfg := depGetConfig(); cfg.Cookies != "" { + if accountCtx.Cookies != "" { + req.Header.Set("Cookie", accountCtx.Cookies) + } else if cfg := depGetConfig(); cfg.Cookies != "" { req.Header.Set("Cookie", cfg.Cookies) } req.Header.Set("cache-control", "no-cache") @@ -513,19 +894,15 @@ func buildGeminiRequest(prompt string, session *GeminiSession, modelName string, req.Header.Set("sec-fetch-dest", "empty") req.Header.Set("sec-fetch-mode", "cors") req.Header.Set("sec-fetch-site", "same-origin") - req.Header.Set(headerModelJSPB, buildModelHeaderJSPB(spec)) + req.Header.Set(headerModelJSPB, buildModelHeaderJSPB(spec, thinkingLevel, uuidVal)) req.Header.Set(headerRequestUUIDJSPB, fmt.Sprintf(`["%s",1]`, uuidVal)) req.Header.Set("x-goog-ext-73010989-jspb", "[0]") req.Header.Set("x-goog-ext-73010990-jspb", "[0]") req.Header.Set("x-same-domain", "1") - randomIP := support.GenerateRandomIP() - req.Header.Set("X-Forwarded-For", randomIP) - req.Header.Set("X-Real-IP", randomIP) - depGetLogger().Debug("正在使用随机 XFF IP: %s", randomIP) return req, nil } -func buildGeminiRequestURL(rawURL string) (string, error) { +func buildGeminiRequestURL(rawURL string, accountCtx AccountContext) (string, error) { parsedURL, err := url.Parse(rawURL) if err != nil { return "", err @@ -539,48 +916,76 @@ func buildGeminiRequestURL(rawURL string) (string, error) { query.Set("rt", "c") } if query.Get("bl") == "" { - if blToken := depTokens.GetBLToken(); blToken != "" { + if blToken := firstNonEmpty(accountCtx.BLToken, depTokens.GetBLToken()); blToken != "" { query.Set("bl", blToken) } } if query.Get("f.sid") == "" { - if fsid := depTokens.GetFSID(); fsid != "" { + if fsid := firstNonEmpty(accountCtx.FSID, depTokens.GetFSID()); fsid != "" { query.Set("f.sid", fsid) } } - query.Set("_reqid", depTokens.NextReqID()) + query.Set("_reqid", firstNonEmpty(accountCtx.ReqID, depTokens.NextReqID())) parsedURL.RawQuery = query.Encode() return parsedURL.String(), nil } -func HandleStreamResponse(w http.ResponseWriter, prompt string, model string, session *GeminiSession, tools []Tool, sessionKey string, snlm0eToken string, writeError func(http.ResponseWriter, int, string)) { +func HandleStreamResponse(w http.ResponseWriter, prompt string, images []ImageData, model string, session *GeminiSession, tools []Tool, sessionKey string, snlm0eToken string, streamOptions *StreamOptions, writeError func(http.ResponseWriter, int, string), writeMappedError func(http.ResponseWriter, OpenAIError)) { start := time.Now() const maxRetries = 3 var bodyStr, content, lastErr string + var lastMappedErr *OpenAIError + var accountID string for attempt := 1; attempt <= maxRetries; attempt++ { if attempt > 1 { depGetLogger().Info("流式请求正在进行第 %d/%d 次重试", attempt, maxRetries) - snlm0eToken, _ = depTokens.GetTokenForSession(sessionKey, true) time.Sleep(time.Duration(attempt*500) * time.Millisecond) } - req, err := buildGeminiRequest(prompt, session, model, snlm0eToken) + selected, err := depTokens.SelectAccountForSession(sessionKey, attempt > 1) + if err != nil { + lastErr = err.Error() + mapped := OpenAIError{Status: http.StatusBadGateway, Type: "api_error", Code: "no_healthy_accounts", Message: err.Error()} + lastMappedErr = &mapped + continue + } + accountID = selected.ID + accountCtx := AccountContext{ + ID: selected.ID, + Email: selected.Email, + Cookies: selected.Cookies, + Proxy: selected.Proxy, + Token: firstNonEmpty(selected.Token, snlm0eToken), + BLToken: selected.BLToken, + FSID: selected.FSID, + ReqID: selected.ReqID, + } + + req, err := buildGeminiRequest(prompt, images, session, model, accountCtx) if err != nil { depGetLogger().Error("构建 Gemini 请求失败: %v", err) + depTokens.MarkAccountFailure(accountID, err.Error()) lastErr = err.Error() + mapped := OpenAIError{Status: http.StatusBadRequest, Type: "invalid_request_error", Code: "request_build_failed", Message: err.Error()} + lastMappedErr = &mapped continue } depGetLogger().Debug("正在发送请求到 Gemini API...") - resp, err := depGetHTTPClient().Do(req) + resp, err := httpClientForAccount(accountCtx).Do(req) if err != nil { if httpclient.IsConnectionError(err) { depGetLogger().Warn("通过代理连接出错 (尝试 %d/%d): %v", attempt, maxRetries, err) } else { depGetLogger().Error("Gemini API 请求失败: %v", err) } + if !isTransientNetworkError(err) { + depTokens.MarkAccountFailure(accountID, err.Error()) + } lastErr = err.Error() + mapped := OpenAIError{Status: http.StatusBadGateway, Type: "api_error", Code: "upstream_connection_error", Message: err.Error()} + lastMappedErr = &mapped continue } @@ -594,44 +999,57 @@ func HandleStreamResponse(w http.ResponseWriter, prompt string, model string, se depGetLogger().Warn("检测到 HTML 错误响应,已标记会话令牌失效") depTokens.MarkSessionTokenBad(sessionKey) } - lastErr = fmt.Sprintf("Gemini API error: %d", resp.StatusCode) + mapped := mapGeminiError(resp.StatusCode, bodyStr) + depTokens.MarkAccountFailure(accountID, mapped.Message) + lastErr = mapped.Message + lastMappedErr = &mapped continue } - body, err := readResponseBody(resp, "流式") + streamedBody, streamedContent, err := streamGeminiResponse(w, resp, model, session, tools, streamOptions, accountCtx) if err != nil { + depTokens.MarkAccountFailure(accountID, err.Error()) lastErr = err.Error() + mapped := OpenAIError{Status: http.StatusBadGateway, Type: "api_error", Code: "stream_read_error", Message: err.Error()} + lastMappedErr = &mapped continue } - - depGetLogger().Debug("流式响应体大小: %d 字节", len(body)) - bodyStr = string(body) + bodyStr = streamedBody noteGeminiResponseErrors(bodyStr, sessionKey, "流式") if isHTMLErrorResponse(bodyStr) { depGetLogger().Warn("响应体中检测到 HTML 错误,已标记会话令牌失效") depTokens.MarkSessionTokenBad(sessionKey) + depTokens.MarkAccountFailure(accountID, "Request failed due to token issue") lastErr = "Request failed due to token issue" + mapped := OpenAIError{Status: http.StatusUnauthorized, Type: "authentication_error", Code: "token_invalid", Message: lastErr} + lastMappedErr = &mapped continue } - content = extractFinalContent(bodyStr) - content = filterContent(content) + content = streamedContent if content == "" { if code, msg := parseGeminiErrorCode(bodyStr); code != 0 { depGetLogger().Error("流式响应无正文,错误码 %d: %s", code, msg) - lastErr = fmt.Sprintf("Gemini error %d: %s", code, msg) + mapped := mapGeminiError(http.StatusBadGateway, bodyStr) + depTokens.MarkAccountFailure(accountID, mapped.Message) + lastErr = mapped.Message + lastMappedErr = &mapped continue } if isEmptyAcknowledgmentResponse(bodyStr) { depGetLogger().Error("流式响应收到空的确认包 - 令牌可能已失效或过期") depTokens.MarkSessionTokenBad(sessionKey) + depTokens.MarkAccountFailure(accountID, "Gemini returned empty response - token issue") lastErr = "Gemini returned empty response - token issue" + mapped := OpenAIError{Status: http.StatusUnauthorized, Type: "authentication_error", Code: "empty_acknowledgment", Message: lastErr} + lastMappedErr = &mapped continue } } + depTokens.MarkAccountSuccess(accountID) lastErr = "" break } @@ -639,47 +1057,42 @@ func HandleStreamResponse(w http.ResponseWriter, prompt string, model string, se if lastErr != "" { depGetLogger().Error("所有 %d 次重试均失败,最后一次错误: %s", maxRetries, lastErr) depMetrics.AddRequest(false, len(prompt)/4, 0) + if lastMappedErr != nil { + writeMappedError(w, *lastMappedErr) + return + } writeError(w, http.StatusBadGateway, lastErr) return } - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - - flusher, ok := w.(http.Flusher) - if !ok { - writeError(w, http.StatusInternalServerError, "streaming not supported") - return - } - - updateSessionFromResponse(session, bodyStr) - sessionSnapshot := session.Snapshot() - sendStreamChunkWithConversation(w, flusher, model, "", "assistant", false, sessionSnapshot.ConversationID) + inputTokens := len(prompt) / 4 + outputTokens := len(content) / 4 + depMetrics.AddRequest(true, inputTokens, outputTokens) + depGetLogger().Info("流式响应完成,耗时 %.3fms", float64(time.Since(start).Microseconds())/1000) +} - if content != "" { - depGetLogger().Debug("已提取流式内容 (长度=%d): %.100s", len(content), content) - cleanContent, toolCalls := parseToolCalls(content, tools) - cleanContent = filterContent(cleanContent) - if len(toolCalls) > 0 { - sendStreamChunkWithTools(w, flusher, model, cleanContent, toolCalls) - } else { - sendStreamChunk(w, flusher, model, cleanContent, "", false) - } +func sendStreamUsageChunk(w http.ResponseWriter, flusher http.Flusher, model string, usage Usage) { + chunk := ChatCompletionResponse{ + ID: fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), + Object: "chat.completion.chunk", + Created: time.Now().Unix(), + Model: model, + Choices: []Choice{}, + Usage: usage, } + jsonData, _ := json.Marshal(chunk) + w.Write([]byte(fmt.Sprintf("data: %s\n\n", jsonData))) + flusher.Flush() +} +func inferStreamUsage(prompt string, content string) Usage { inputTokens := len(prompt) / 4 outputTokens := len(content) / 4 - depMetrics.AddRequest(true, inputTokens, outputTokens) - _, toolCalls := parseToolCalls(content, tools) - if len(toolCalls) > 0 { - sendStreamChunkFinish(w, flusher, model, "tool_calls") - } else { - sendStreamChunk(w, flusher, model, "", "", true) + return Usage{ + PromptTokens: inputTokens, + CompletionTokens: outputTokens, + TotalTokens: inputTokens + outputTokens, } - w.Write([]byte("data: [DONE]\n\n")) - flusher.Flush() - depGetLogger().Info("流式响应完成,耗时 %.3fms", float64(time.Since(start).Microseconds())/1000) } func sendStreamChunk(w http.ResponseWriter, flusher http.Flusher, model string, content string, role string, isFinish bool) { @@ -755,40 +1168,210 @@ func sendStreamChunkFinish(w http.ResponseWriter, flusher http.Flusher, model st flusher.Flush() } -func HandleNonStreamResponse(w http.ResponseWriter, prompt string, model string, session *GeminiSession, tools []Tool, sessionKey string, snlm0eToken string, writeError func(http.ResponseWriter, int, string), writeJSON func(http.ResponseWriter, int, interface{})) { +func sendStreamReasoningChunk(w http.ResponseWriter, flusher http.Flusher, model string, reasoningContent string) { + chunk := ChatCompletionResponse{ + ID: fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), + Object: "chat.completion.chunk", + Created: time.Now().Unix(), + Model: model, + Choices: []Choice{{Index: 0, Delta: &Delta{ReasoningContent: reasoningContent}}}, + } + jsonData, _ := json.Marshal(chunk) + w.Write([]byte(fmt.Sprintf("data: %s\n\n", jsonData))) + flusher.Flush() +} + +func pollDeepThinkResult(session *GeminiSession, modelName string, accountCtx AccountContext) (string, string, error) { + snapshot := session.Snapshot() + if snapshot.ConversationID == "" { + return "", "", fmt.Errorf("no conversation ID for deep think polling") + } + + endpoints := httpclient.CurrentGeminiEndpoints(depGetConfig()) + baseURL := strings.Replace(endpoints.URL, "/assistant.lamda.BardFrontendService/StreamGenerate", "/batchexecute", 1) + if baseURL == endpoints.URL { + baseURL = strings.Replace(endpoints.URL, "StreamGenerate", "batchexecute", 1) + } + parsedURL, err := url.Parse(baseURL) + if err != nil { + return "", "", fmt.Errorf("parse batchexecute URL: %w", err) + } + + currentToken := accountCtx.Token + if currentToken == "" { + currentToken = depTokens.GetToken() + } + + convID := snapshot.ConversationID + freqPayload := fmt.Sprintf(`[\"%s\",10,null,1,[0],[4],null,1]`, convID) + freqData := fmt.Sprintf(`[[["hNvQHb","%s",null,"generic"]]]`, freqPayload) + + query := parsedURL.Query() + query.Set("rpcids", "hNvQHb") + query.Set("source-path", fmt.Sprintf("/app/%s", strings.TrimPrefix(convID, "c_"))) + query.Set("hl", "en-GB") + query.Set("rt", "c") + if blToken := firstNonEmpty(accountCtx.BLToken, depTokens.GetBLToken()); blToken != "" { + query.Set("bl", blToken) + } + if fsid := firstNonEmpty(accountCtx.FSID, depTokens.GetFSID()); fsid != "" { + query.Set("f.sid", fsid) + } + query.Set("_reqid", firstNonEmpty(accountCtx.ReqID, depTokens.NextReqID())) + parsedURL.RawQuery = query.Encode() + + maxPolls := 30 + interval := 3 + var lastBody string + + for i := 0; i < maxPolls; i++ { + time.Sleep(time.Duration(interval) * time.Second) + interval += 2 + if interval > 15 { + interval = 15 + } + + depGetLogger().Debug("Deep Think 轮询 %d/%d, convID=%s", i+1, maxPolls, convID) + + postData := url.Values{} + postData.Set("f.req", freqData) + postData.Set("at", currentToken) + req, err := http.NewRequest("POST", parsedURL.String(), strings.NewReader(postData.Encode())) + if err != nil { + return "", "", fmt.Errorf("create poll request: %w", err) + } + + req.Header.Set("Accept", "*/*") + req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Safari/537.36") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded;charset=UTF-8") + req.Header.Set("accept-language", "zh-CN") + if accountCtx.Cookies != "" { + req.Header.Set("Cookie", accountCtx.Cookies) + } else if cfg := depGetConfig(); cfg.Cookies != "" { + req.Header.Set("Cookie", cfg.Cookies) + } + req.Header.Set("cache-control", "no-cache") + req.Header.Set("origin", endpoints.Origin) + req.Header.Set("pragma", "no-cache") + req.Header.Set("priority", "u=1, i") + req.Header.Set("referer", endpoints.Referer) + req.Header.Set("sec-ch-ua", `"Chromium";v="146", "Not-A.Brand";v="24", "Google Chrome";v="146"`) + req.Header.Set("sec-ch-ua-arch", `"x86"`) + req.Header.Set("sec-ch-ua-bitness", `"64"`) + req.Header.Set("sec-ch-ua-form-factors", `"Desktop"`) + req.Header.Set("sec-ch-ua-full-version", `"146.0.7680.179"`) + req.Header.Set("sec-ch-ua-full-version-list", `"Chromium";v="146.0.7680.179", "Not-A.Brand";v="24.0.0.0", "Google Chrome";v="146.0.7680.179"`) + req.Header.Set("sec-ch-ua-mobile", "?0") + req.Header.Set("sec-ch-ua-model", `""`) + req.Header.Set("sec-ch-ua-platform", `"Windows"`) + req.Header.Set("sec-ch-ua-platform-version", `"19.0.0"`) + req.Header.Set("sec-ch-ua-wow64", "?0") + req.Header.Set("sec-fetch-dest", "empty") + req.Header.Set("sec-fetch-mode", "cors") + req.Header.Set("sec-fetch-site", "same-origin") + req.Header.Set(headerRequestUUIDJSPB, fmt.Sprintf(`["%s",1]`, strings.ToUpper(support.GenerateUUIDv4()))) + req.Header.Set("x-goog-ext-73010989-jspb", "[0]") + req.Header.Set("x-goog-ext-73010990-jspb", "[0]") + req.Header.Set("x-same-domain", "1") + + resp, err := httpClientForAccount(accountCtx).Do(req) + if err != nil { + depGetLogger().Warn("Deep Think 轮询请求失败: %v", err) + continue + } + + body, err := readResponseBody(resp, "Deep Think 轮询") + if err != nil { + resp.Body.Close() + continue + } + lastBody = string(body) + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + depGetLogger().Warn("Deep Think 轮询返回异常状态码: %d, body preview: %.200s", resp.StatusCode, lastBody) + continue + } + + result := extractFinalContentWithThinking(lastBody) + content := filterContent(result.Content) + + if content != "" && !isDeepThinkPlaceholder(content) && !strings.Contains(content, "I'm on it") { + reasoning := result.ReasoningContent + depGetLogger().Info("Deep Think 轮询成功, 内容长度=%d, 推理长度=%d, 轮询次数=%d", len(content), len(reasoning), i+1) + return content, reasoning, nil + } + + depGetLogger().Debug("Deep Think 轮询 %d: 内容仍为占位符或为空", i+1) + } + + return "", "", fmt.Errorf("deep think polling timed out after %d attempts, last body preview: %.200s", maxPolls, lastBody) +} + +func HandleNonStreamResponse(w http.ResponseWriter, prompt string, images []ImageData, model string, session *GeminiSession, tools []Tool, sessionKey string, snlm0eToken string, writeError func(http.ResponseWriter, int, string), writeMappedError func(http.ResponseWriter, OpenAIError), writeJSON func(http.ResponseWriter, int, interface{})) { start := time.Now() const maxRetries = 3 var bodyStr, content, lastErr string + var reasoningContent string + var lastMappedErr *OpenAIError + var accountID string for attempt := 1; attempt <= maxRetries; attempt++ { if attempt > 1 { depGetLogger().Info("非流式请求正在进行第 %d/%d 次重试", attempt, maxRetries) - snlm0eToken, _ = depTokens.GetTokenForSession(sessionKey, true) time.Sleep(time.Duration(attempt*500) * time.Millisecond) } - req, err := buildGeminiRequest(prompt, session, model, snlm0eToken) + selected, err := depTokens.SelectAccountForSession(sessionKey, attempt > 1) + if err != nil { + lastErr = err.Error() + mapped := OpenAIError{Status: http.StatusBadGateway, Type: "api_error", Code: "no_healthy_accounts", Message: err.Error()} + lastMappedErr = &mapped + continue + } + accountID = selected.ID + accountCtx := AccountContext{ + ID: selected.ID, + Email: selected.Email, + Cookies: selected.Cookies, + Proxy: selected.Proxy, + Token: firstNonEmpty(selected.Token, snlm0eToken), + BLToken: selected.BLToken, + FSID: selected.FSID, + ReqID: selected.ReqID, + } + + req, err := buildGeminiRequest(prompt, images, session, model, accountCtx) if err != nil { depGetLogger().Error("构建 Gemini 请求失败: %v", err) + depTokens.MarkAccountFailure(accountID, err.Error()) lastErr = err.Error() + mapped := OpenAIError{Status: http.StatusBadRequest, Type: "invalid_request_error", Code: "request_build_failed", Message: err.Error()} + lastMappedErr = &mapped continue } depGetLogger().Debug("正在发送请求到 Gemini API...") - resp, err := depGetHTTPClient().Do(req) + resp, err := httpClientForAccount(accountCtx).Do(req) if err != nil { if httpclient.IsConnectionError(err) { depGetLogger().Warn("通过代理连接出错 (尝试 %d/%d): %v", attempt, maxRetries, err) } else { depGetLogger().Error("Gemini API 请求失败: %v", err) } + depTokens.MarkAccountFailure(accountID, err.Error()) lastErr = err.Error() + mapped := OpenAIError{Status: http.StatusBadGateway, Type: "api_error", Code: "upstream_connection_error", Message: err.Error()} + lastMappedErr = &mapped continue } body, err := readResponseBody(resp, "非流式") if err != nil { + if !isTransientNetworkError(err) { + depTokens.MarkAccountFailure(accountID, err.Error()) + } lastErr = err.Error() + mapped := OpenAIError{Status: http.StatusBadGateway, Type: "api_error", Code: "response_read_error", Message: err.Error()} + lastMappedErr = &mapped continue } depGetLogger().Debug("Gemini API 响应状态码: %d", resp.StatusCode) @@ -802,35 +1385,60 @@ func HandleNonStreamResponse(w http.ResponseWriter, prompt string, model string, depGetLogger().Warn("检测到 HTML 错误响应,已标记会话令牌失效") depTokens.MarkSessionTokenBad(sessionKey) } - lastErr = fmt.Sprintf("Gemini API error: %d", resp.StatusCode) + mapped := mapGeminiError(resp.StatusCode, bodyStr) + depTokens.MarkAccountFailure(accountID, mapped.Message) + lastErr = mapped.Message + lastMappedErr = &mapped continue } if isHTMLErrorResponse(bodyStr) { depGetLogger().Warn("响应体中检测到 HTML 错误,已标记会话令牌失效") depTokens.MarkSessionTokenBad(sessionKey) + depTokens.MarkAccountFailure(accountID, "Request failed due to token issue") lastErr = "Request failed due to token issue" + mapped := OpenAIError{Status: http.StatusUnauthorized, Type: "authentication_error", Code: "token_invalid", Message: lastErr} + lastMappedErr = &mapped continue } - content = extractFinalContent(bodyStr) - content = filterContent(content) + result := extractFinalContentWithThinking(bodyStr) + content = filterContent(result.Content) + reasoningContent = result.ReasoningContent + + if isDeepThinkPlaceholder(result.Content) && isDeepThinkAlias(model) { + updateSessionFromResponse(session, bodyStr) + polledContent, polledReasoning, pollErr := pollDeepThinkResult(session, model, accountCtx) + if pollErr != nil { + depGetLogger().Warn("Deep Think 轮询失败: %v", pollErr) + } else { + content = polledContent + reasoningContent = polledReasoning + } + } if content == "" { if code, msg := parseGeminiErrorCode(bodyStr); code != 0 { depGetLogger().Error("非流式响应无正文,错误码 %d: %s", code, msg) - lastErr = fmt.Sprintf("Gemini error %d: %s", code, msg) + mapped := mapGeminiError(http.StatusBadGateway, bodyStr) + depTokens.MarkAccountFailure(accountID, mapped.Message) + lastErr = mapped.Message + lastMappedErr = &mapped continue } depGetLogger().Warn("从响应中提取的内容为空,响应体预览: %.500s", bodyStr) if isEmptyAcknowledgmentResponse(bodyStr) { depGetLogger().Error("收到空的确认响应 - 令牌可能已失效或过期") depTokens.MarkSessionTokenBad(sessionKey) + depTokens.MarkAccountFailure(accountID, "Gemini returned empty response - token issue") lastErr = "Gemini returned empty response - token issue" + mapped := OpenAIError{Status: http.StatusUnauthorized, Type: "authentication_error", Code: "empty_acknowledgment", Message: lastErr} + lastMappedErr = &mapped continue } } + depTokens.MarkAccountSuccess(accountID) lastErr = "" break } @@ -838,6 +1446,10 @@ func HandleNonStreamResponse(w http.ResponseWriter, prompt string, model string, if lastErr != "" { depGetLogger().Error("所有 %d 次重试均失败,最后一次错误: %s", maxRetries, lastErr) depMetrics.AddRequest(false, len(prompt)/4, 0) + if lastMappedErr != nil { + writeMappedError(w, *lastMappedErr) + return + } writeError(w, http.StatusBadGateway, lastErr) return } @@ -869,9 +1481,10 @@ func HandleNonStreamResponse(w http.ResponseWriter, prompt string, model string, Choices: []Choice{{ Index: 0, Message: &Message{ - Role: "assistant", - Content: cleanContent, - ToolCalls: toolCalls, + Role: "assistant", + Content: cleanContent, + ReasoningContent: reasoningContent, + ToolCalls: toolCalls, }, FinishReason: &finishReason, }}, @@ -885,23 +1498,44 @@ func HandleNonStreamResponse(w http.ResponseWriter, prompt string, model string, writeJSON(w, http.StatusOK, response) } +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return value + } + } + return "" +} + +func isTransientNetworkError(err error) bool { + if err == nil { + return false + } + lower := strings.ToLower(err.Error()) + return strings.Contains(lower, "context deadline exceeded") || + strings.Contains(lower, "client.timeout") || + strings.Contains(lower, "timeout awaiting response headers") || + strings.Contains(lower, "i/o timeout") || + strings.Contains(lower, "unexpected eof") +} + func updateSessionFromResponse(session *GeminiSession, body string) { if session == nil { return } snapshot := session.Snapshot() - convRe := regexp.MustCompile(`"c_([a-f0-9]+)"`) + convRe := regexp.MustCompile(`\\?"c_([a-f0-9]+)\\?"`) if matches := convRe.FindStringSubmatch(body); len(matches) > 1 { snapshot.ConversationID = "c_" + matches[1] } - respRe := regexp.MustCompile(`"r_([a-f0-9]+)"`) + respRe := regexp.MustCompile(`\\?"r_([a-f0-9]+)\\?"`) if matches := respRe.FindStringSubmatch(body); len(matches) > 1 { snapshot.ResponseID = "r_" + matches[1] } - choiceRe := regexp.MustCompile(`"rc_([a-f0-9]+)"`) + choiceRe := regexp.MustCompile(`\\?"rc_([a-f0-9]+)\\?"`) if matches := choiceRe.FindStringSubmatch(body); len(matches) > 1 { snapshot.ChoiceID = "rc_" + matches[1] } @@ -918,132 +1552,6 @@ func updateSessionFromResponse(session *GeminiSession, body string) { } } -func extractFinalContent(body string) string { - if content := extractContentFromWrbFrames(body); content != "" { - return content - } - - var contents []string - patterns := []struct { - startPattern string - arrPattern string - escaped bool - }{ - {`"rc_`, `",["`, false}, - {`\"rc_`, `\",[\"`, true}, - } - - for _, p := range patterns { - idx := 0 - for { - start := strings.Index(body[idx:], p.startPattern) - if start == -1 { - break - } - start += idx - arrStart := strings.Index(body[start:], p.arrPattern) - if arrStart == -1 { - idx = start + len(p.startPattern) - continue - } - if p.escaped { - arrStart += start + len(p.arrPattern) - endPos := strings.Index(body[arrStart:], `"]`) - if endPos == -1 { - idx = arrStart - continue - } - content := body[arrStart : arrStart+endPos] - if content != "" { - contents = append(contents, content) - } - idx = arrStart + endPos + 2 - } else { - arrStart += start + len(p.arrPattern) - content, endPos := extractQuotedString(body[arrStart:]) - if content != "" { - contents = append(contents, content) - } - idx = arrStart + endPos + 1 - } - } - } - - jsonArrayRe := regexp.MustCompile(`\[\s*"rc_[a-f0-9]+"\s*,\s*\[\s*"([^"\\]*(?:\\.[^"\\]*)*)"\s*\]`) - matches := jsonArrayRe.FindAllStringSubmatch(body, -1) - for _, match := range matches { - if len(match) > 1 && match[1] != "" { - contents = append(contents, match[1]) - } - } - - return assembleContentFragments(contents) -} - -func extractContentFromWrbFrames(body string) string { - lines := strings.Split(body, "\n") - best := "" - - for _, line := range lines { - line = strings.TrimSpace(line) - if line == "" || !strings.HasPrefix(line, "[[") { - continue - } - - var frames []interface{} - if err := json.Unmarshal([]byte(line), &frames); err != nil { - continue - } - - for _, frame := range frames { - frameItems, ok := frame.([]interface{}) - if !ok || len(frameItems) < 3 { - continue - } - - eventName, _ := frameItems[0].(string) - if eventName != "wrb.fr" { - continue - } - - payload, _ := frameItems[2].(string) - if payload == "" { - continue - } - - candidate := extractContentFromPayload(payload) - if len(candidate) > len(best) { - best = candidate - } - } - } - - return best -} - -func extractContentFromPayload(payload string) string { - var data interface{} - if err := json.Unmarshal([]byte(payload), &data); err != nil { - return "" - } - - best := "" - visitRCNodes(data, &best) - return strings.TrimSpace(best) -} - -func visitRCNodes(node interface{}, best *string) { - switch value := node.(type) { - case []interface{}: - if text, ok := extractRCText(value); ok && len(text) > len(*best) { - *best = text - } - for _, item := range value { - visitRCNodes(item, best) - } - } -} - func extractRCText(items []interface{}) (string, bool) { if len(items) < 2 { return "", false @@ -1200,6 +1708,12 @@ func filterContent(content string) string { `温馨提示:如要解锁所有应用的完整功能,请开启 \[Gemini 应用活动记录\]\([^)]+\)\s*。?\s*`, `温馨提示:如要解锁所有应用的完整功能,请开启 Gemini 应用活动记录[^。]*。?\s*`, `温馨提示[::][^\n]*Gemini[^\n]*活动记录[^\n]*\n?`, + `我正在处理.*Deep Think[^\n]*\n?`, + `正在生成回答[^\n]*\n?`, + `稍后.*查看[^\n]*\n?`, + `Responses with Deep Think[^\n]*\n?`, + `check back in a bit[^\n]*\n?`, + `http://googleusercontent\.com/agentic_processing_chip/\d+[^\n]*\n?`, } result := content for _, pattern := range patterns { @@ -1209,6 +1723,212 @@ func filterContent(content string) string { return strings.TrimSpace(result) } +func isDeepThinkPlaceholder(body string) bool { + return strings.Contains(body, "agentic_processing_chip") || + strings.Contains(body, "Deep Think") || + strings.Contains(body, "正在生成回答") +} + +type contentResult struct { + Content string + ReasoningContent string +} + +func extractFinalContentWithThinking(body string) contentResult { + if result := extractContentFromWrbFramesV2(body); result.Content != "" || result.ReasoningContent != "" { + return result + } + return extractFinalContentFallback(body) +} + +func extractFinalContentFallback(body string) contentResult { + var contents []string + patterns := []struct { + startPattern string + arrPattern string + escaped bool + }{ + {`"rc_`, `",["`, false}, + {`\"rc_`, `\",[\"`, true}, + } + + for _, p := range patterns { + idx := 0 + for { + start := strings.Index(body[idx:], p.startPattern) + if start == -1 { + break + } + start += idx + arrStart := strings.Index(body[start:], p.arrPattern) + if arrStart == -1 { + idx = start + len(p.startPattern) + continue + } + if p.escaped { + arrStart += start + len(p.arrPattern) + endPos := strings.Index(body[arrStart:], `"]`) + if endPos == -1 { + idx = arrStart + continue + } + content := body[arrStart : arrStart+endPos] + if content != "" { + contents = append(contents, content) + } + idx = arrStart + endPos + 2 + } else { + arrStart += start + len(p.arrPattern) + content, endPos := extractQuotedString(body[arrStart:]) + if content != "" { + contents = append(contents, content) + } + idx = arrStart + endPos + 1 + } + } + } + + jsonArrayRe := regexp.MustCompile(`\[\s*"rc_[a-f0-9]+"\s*,\s*\[\s*"([^"\\]*(?:\\.[^"\\]*)*)"\s*\]`) + matches := jsonArrayRe.FindAllStringSubmatch(body, -1) + for _, match := range matches { + if len(match) > 1 && match[1] != "" { + contents = append(contents, match[1]) + } + } + + return contentResult{Content: assembleContentFragments(contents)} +} + +func extractContentFromWrbFramesV2(body string) contentResult { + lines := strings.Split(body, "\n") + var best contentResult + + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" || !strings.HasPrefix(line, "[[") { + continue + } + + var frames []interface{} + if err := json.Unmarshal([]byte(line), &frames); err != nil { + continue + } + + for _, frame := range frames { + frameItems, ok := frame.([]interface{}) + if !ok || len(frameItems) < 3 { + continue + } + + eventName, _ := frameItems[0].(string) + if eventName != "wrb.fr" { + continue + } + + payload, _ := frameItems[2].(string) + if payload == "" { + continue + } + + candidate := extractContentFromPayloadV2(payload) + if len(candidate.Content) > len(best.Content) || (best.Content != "" && candidate.ReasoningContent != "" && best.ReasoningContent == "") { + best = candidate + } + } + } + + return best +} + +func extractContentFromPayloadV2(payload string) contentResult { + var data interface{} + if err := json.Unmarshal([]byte(payload), &data); err != nil { + return contentResult{} + } + + var result contentResult + visitRCNodesV2(data, &result) + result.Content = strings.TrimSpace(result.Content) + return result +} + +func visitRCNodesV2(node interface{}, result *contentResult) { + switch value := node.(type) { + case []interface{}: + if text, ok := extractRCText(value); ok && len(text) > len(result.Content) { + result.Content = text + } + if thinking := extractThinkingFromRCNode(value); thinking != "" && result.ReasoningContent == "" { + result.ReasoningContent = thinking + } + for _, item := range value { + visitRCNodesV2(item, result) + } + } +} + +func extractThinkingFromRCNode(items []interface{}) string { + for i := len(items) - 1; i >= 3; i-- { + arr, ok := items[i].([]interface{}) + if !ok || len(arr) == 0 { + continue + } + text := extractThinkingFromIndex(arr) + if isLikelyThinkingContent(text) { + return text + } + } + return "" +} + +func isLikelyThinkingContent(text string) bool { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return false + } + if strings.HasPrefix(trimmed, "rc_") || strings.Contains(trimmed, "e6fa609c3fa255c0") { + return false + } + markers := []string{ + "**Step", + "Step ", + "Thinking", + "thinking", + "思考", + "推理", + "分析", + } + for _, marker := range markers { + if strings.Contains(trimmed, marker) { + return true + } + } + return false +} + +func extractThinkingFromIndex(arr []interface{}) string { + var sb strings.Builder + for _, item := range arr { + switch v := item.(type) { + case string: + trimmed := strings.TrimSpace(v) + if trimmed != "" { + sb.WriteString(trimmed) + } + case []interface{}: + if len(v) > 0 { + if s, ok := v[0].(string); ok { + trimmed := strings.TrimSpace(s) + if trimmed != "" { + sb.WriteString(trimmed) + } + } + } + } + } + return strings.TrimSpace(sb.String()) +} + func isEmptyAcknowledgmentResponse(body string) bool { hasResponseID := strings.Contains(body, `"r_`) || strings.Contains(body, `\"r_`) hasChoiceContent := strings.Contains(body, `"rc_`) || strings.Contains(body, `\"rc_`) @@ -1249,6 +1969,117 @@ func isHTMLErrorResponse(body string) bool { return false } +func mapGeminiError(statusCode int, body string) OpenAIError { + if isHTMLErrorResponse(body) { + return OpenAIError{Status: http.StatusBadGateway, Type: "invalid_request_error", Code: "upstream_html_error", Message: "Gemini returned login, consent, or protection page"} + } + if code, msg := parseGeminiErrorCode(body); code != 0 { + switch code { + case 2, 7, 1037: + return OpenAIError{Status: http.StatusTooManyRequests, Type: "rate_limit_error", Code: msg, Message: fmt.Sprintf("Gemini error %d: %s", code, msg)} + case 4, 1016: + return OpenAIError{Status: http.StatusUnauthorized, Type: "authentication_error", Code: msg, Message: fmt.Sprintf("Gemini error %d: %s", code, msg)} + case 8: + return OpenAIError{Status: http.StatusBadRequest, Type: "invalid_request_error", Code: msg, Message: fmt.Sprintf("Gemini error %d: %s", code, msg)} + default: + return OpenAIError{Status: http.StatusBadGateway, Type: "api_error", Code: msg, Message: fmt.Sprintf("Gemini error %d: %s", code, msg)} + } + } + if statusCode == http.StatusUnauthorized { + return OpenAIError{Status: http.StatusUnauthorized, Type: "authentication_error", Code: "unauthorized", Message: "Gemini unauthorized"} + } + if statusCode == http.StatusForbidden { + return OpenAIError{Status: http.StatusForbidden, Type: "permission_error", Code: "forbidden", Message: "Gemini forbidden"} + } + if statusCode == http.StatusTooManyRequests { + return OpenAIError{Status: http.StatusTooManyRequests, Type: "rate_limit_error", Code: "rate_limit_exceeded", Message: "Gemini rate limited the request"} + } + return OpenAIError{Status: http.StatusBadGateway, Type: "api_error", Code: "bad_gateway", Message: fmt.Sprintf("Gemini API error: %d", statusCode)} +} + +func streamGeminiResponse(w http.ResponseWriter, resp *http.Response, model string, session *GeminiSession, tools []Tool, streamOptions *StreamOptions, accountCtx AccountContext) (string, string, error) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + flusher, ok := w.(http.Flusher) + if !ok { + return "", "", fmt.Errorf("streaming not supported") + } + body, err := readResponseBody(resp, "流式") + if err != nil { + return "", "", err + } + bodyStr := string(body) + updateSessionFromResponse(session, bodyStr) + sessionSnapshot := session.Snapshot() + sendStreamChunkWithConversation(w, flusher, model, "", "assistant", false, sessionSnapshot.ConversationID) + + result := extractFinalContentWithThinking(bodyStr) + content := filterContent(result.Content) + reasoningContent := result.ReasoningContent + + if isDeepThinkPlaceholder(result.Content) && isDeepThinkAlias(model) { + polledContent, polledReasoning, pollErr := pollDeepThinkResult(session, model, accountCtx) + if pollErr != nil { + depGetLogger().Warn("Deep Think 流式轮询失败: %v", pollErr) + } else { + content = polledContent + reasoningContent = polledReasoning + } + } + + if reasoningContent != "" { + sendStreamReasoningChunk(w, flusher, model, reasoningContent) + } + + if content != "" { + cleanContent, toolCalls := parseToolCalls(content, tools) + cleanContent = filterContent(cleanContent) + for _, part := range chunkText(cleanContent, 48) { + if len(toolCalls) > 0 && part == cleanContent { + sendStreamChunkWithTools(w, flusher, model, part, toolCalls) + } else { + sendStreamChunk(w, flusher, model, part, "", false) + } + } + } + _, toolCalls := parseToolCalls(content, tools) + if len(toolCalls) > 0 { + sendStreamChunkFinish(w, flusher, model, "tool_calls") + } else { + sendStreamChunk(w, flusher, model, "", "", true) + } + if streamOptions != nil && streamOptions.IncludeUsage { + sendStreamUsageChunk(w, flusher, model, inferStreamUsage("", content)) + } + _, _ = w.Write([]byte("data: [DONE]\n\n")) + flusher.Flush() + return bodyStr, content, nil +} + +func chunkText(content string, size int) []string { + if size <= 0 || len(content) <= size { + return []string{content} + } + chunks := make([]string, 0, (len(content)/size)+1) + reader := bufio.NewReader(strings.NewReader(content)) + for { + buf := make([]rune, 0, size) + for len(buf) < size { + r, _, err := reader.ReadRune() + if err != nil { + break + } + buf = append(buf, r) + } + if len(buf) == 0 { + break + } + chunks = append(chunks, string(buf)) + } + return chunks +} + func checkGeminiError(body string) (bool, string) { code, msg := parseGeminiErrorCode(body) if code != 0 { diff --git a/internal/gemini/client_test.go b/internal/gemini/client_test.go new file mode 100644 index 0000000..645f25e --- /dev/null +++ b/internal/gemini/client_test.go @@ -0,0 +1,364 @@ +package gemini + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestIsTransientNetworkError(t *testing.T) { + if !isTransientNetworkError(context.DeadlineExceeded) { + t.Fatal("expected context deadline exceeded to be transient") + } + if !isTransientNetworkError(errors.New("context deadline exceeded (Client.Timeout or context cancellation while reading body)")) { + t.Fatal("expected client timeout while reading body to be transient") + } + if isTransientNetworkError(errors.New("Gemini returned login/consent page")) { + t.Fatal("expected login/consent errors to remain non-transient") + } +} + +func TestParseToolCalls_FencedBlock(t *testing.T) { + tools := []Tool{{Function: Function{Name: "get_weather"}}} + content := "before\n```tool_call\n{\"name\":\"get_weather\",\"arguments\":{\"city\":\"Shanghai\",\"unit\":\"c\"}}\n```\nafter" + + clean, calls := parseToolCalls(content, tools) + if len(calls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(calls)) + } + if calls[0].Function.Name != "get_weather" { + t.Fatalf("unexpected tool name: %s", calls[0].Function.Name) + } + if calls[0].Function.Arguments != `{"city":"Shanghai","unit":"c"}` { + t.Fatalf("unexpected arguments: %s", calls[0].Function.Arguments) + } + if strings.Contains(clean, "tool_call") { + t.Fatalf("expected fenced block removed, got: %s", clean) + } +} + +func TestParseToolCalls_InlineAndStringifiedArgs(t *testing.T) { + tools := []Tool{ + {Function: Function{Name: "search_web"}}, + {Function: Function{Name: "calculator"}}, + } + content := strings.Join([]string{ + "Please run:", + `{"name":"search_web","arguments":{"q":"golang regexp"}}`, + "and then", + `{"name":"calculator","arguments":"{\"expr\":\"15*37\"}"}`, + }, "\n") + + clean, calls := parseToolCalls(content, tools) + if len(calls) != 2 { + t.Fatalf("expected 2 tool calls, got %d", len(calls)) + } + if calls[0].Function.Name != "search_web" || calls[0].Function.Arguments != `{"q":"golang regexp"}` { + t.Fatalf("unexpected first call: %+v", calls[0]) + } + if calls[1].Function.Name != "calculator" || calls[1].Function.Arguments != `{"expr":"15*37"}` { + t.Fatalf("unexpected second call: %+v", calls[1]) + } + if strings.Contains(clean, `"name":"search_web"`) || strings.Contains(clean, `"name":"calculator"`) { + t.Fatalf("expected tool json removed from content, got: %s", clean) + } +} + +func TestParseToolCalls_IgnoreUnknownAndDeduplicate(t *testing.T) { + tools := []Tool{{Function: Function{Name: "get_weather"}}} + content := strings.Join([]string{ + `{"name":"unknown_tool","arguments":{"x":1}}`, + `{"name":"get_weather","arguments":{"city":"Beijing"}}`, + `{"name":"get_weather","arguments":{"city":"Beijing"}}`, + }, "\n") + + clean, calls := parseToolCalls(content, tools) + if len(calls) != 1 { + t.Fatalf("expected 1 deduplicated tool call, got %d", len(calls)) + } + if calls[0].Function.Arguments != `{"city":"Beijing"}` { + t.Fatalf("unexpected arguments: %s", calls[0].Function.Arguments) + } + if !strings.Contains(clean, "unknown_tool") { + t.Fatalf("unknown tool should be preserved in content, got: %s", clean) + } +} + +func TestExtractDeepThinkContent(t *testing.T) { + rcNode := []interface{}{ + "rc_testid", + []interface{}{"placeholder text"}, + nil, nil, nil, nil, nil, + []interface{}{1}, + "zh", + nil, nil, + nil, + nil, nil, nil, nil, nil, nil, nil, nil, + []interface{}{false}, + nil, nil, nil, nil, nil, nil, + []interface{}{}, + nil, nil, nil, nil, nil, nil, nil, nil, + []interface{}{ + []interface{}{"**Step One**\n\nFirst thinking step.\n\n\n**Step Two**\n\nSecond thinking step.\n\n\n"}, + []interface{}{ + []interface{}{ + "**Step One**\n\nFirst thinking step.\n\n\n**Step Two**\n\nSecond thinking step.\n\n\n", + "", "", + []interface{}{ + []interface{}{nil, []interface{}{nil, 0, "Step One", nil}}, + []interface{}{nil, []interface{}{nil, 0, "First thinking step."}}, + []interface{}{nil, []interface{}{nil, 0, "Step Two", nil}}, + []interface{}{nil, []interface{}{nil, 0, "Second thinking step."}}, + }, + }, + }, + }, + } + + inner := []interface{}{ + []interface{}{rcNode}, + nil, nil, + "rc_testid", + } + + innerJSON, err := json.Marshal(inner) + if err != nil { + t.Fatal(err) + } + + var result contentResult + visitRCNodesV2(inner, &result) + + if result.Content == "" { + t.Fatal("expected non-empty content") + } + if !strings.Contains(result.Content, "placeholder") { + t.Fatalf("unexpected content: %s", result.Content) + } + if result.ReasoningContent == "" { + t.Fatalf("expected non-empty reasoning content. rcNode len=%d, innerJSON=%s", len(rcNode), string(innerJSON)) + } + if !strings.Contains(result.ReasoningContent, "Step One") { + t.Fatalf("expected 'Step One' in reasoning, got: %s", result.ReasoningContent) + } + if !strings.Contains(result.ReasoningContent, "Step Two") { + t.Fatalf("expected 'Step Two' in reasoning, got: %s", result.ReasoningContent) + } +} + +func TestExtractThinkingFromRCNode(t *testing.T) { + payload := `["rc_test",["placeholder"],null,null,null,null,null,null,null,[1],"zh",null,null,null,null,null,null,null,null,null,null,null,[false],null,null,null,null,null,null,[],null,null,null,null,null,null,null,null,[["**Step 1**\n\nThinking text here\n\n\n**Step 2**\n\nMore thinking\n\n\n"]]]` + var items []interface{} + if err := json.Unmarshal([]byte(payload), &items); err != nil { + t.Fatal(err) + } + thinking := extractThinkingFromRCNode(items) + if thinking == "" { + t.Fatal("expected non-empty thinking content") + } + if !strings.Contains(thinking, "Step 1") || !strings.Contains(thinking, "Step 2") { + t.Fatalf("expected thinking to contain steps, got: %s", thinking) + } +} + +func TestExtractThinkingFromRCNode_NoThinking(t *testing.T) { + payload := `["rc_test",["hello world"],null,null,null,null,null,null,null,[1],"zh"]` + var items []interface{} + if err := json.Unmarshal([]byte(payload), &items); err != nil { + t.Fatal(err) + } + thinking := extractThinkingFromRCNode(items) + if thinking != "" { + t.Fatalf("expected empty thinking for normal response, got: %s", thinking) + } +} + +func TestExtractThinkingFromRCNode_IgnoresMetadata(t *testing.T) { + items := []interface{}{ + "rc_4083678137dd176e", + []interface{}{"9.9更大。"}, + nil, + []interface{}{"rc_4083678137dd176e", "US", "e6fa609c3fa255c0", "e6fa609c3fa255c0", "3.1 Pro"}, + } + + thinking := extractThinkingFromRCNode(items) + if thinking != "" { + t.Fatalf("expected metadata to be ignored, got: %s", thinking) + } +} + +func TestParseDataURI(t *testing.T) { + mimeType, data, ok := parseDataURI("data:image/png;base64,iVBORw0KGgo=") + if !ok { + t.Fatal("expected data URI to parse") + } + if mimeType != "image/png" { + t.Fatalf("unexpected mime type: %s", mimeType) + } + if data != "iVBORw0KGgo=" { + t.Fatalf("unexpected data: %s", data) + } +} + +func TestExtractMultimodalContent_TextAndImages(t *testing.T) { + msg := Message{ + Role: "user", + Content: []interface{}{ + map[string]interface{}{"type": "text", "text": "Describe this"}, + map[string]interface{}{ + "type": "image_url", + "image_url": map[string]interface{}{ + "url": "data:image/png;base64,iVBORw0KGgo=", + }, + }, + map[string]interface{}{ + "type": "image_url", + "image_url": map[string]interface{}{ + "url": "https://example.com/cat.jpg", + "detail": "high", + }, + }, + map[string]interface{}{"type": "text", "text": "Use concise language"}, + }, + } + + parsed := extractMultimodalContent(msg) + + if parsed.Text != "Describe this\nUse concise language" { + t.Fatalf("unexpected text: %q", parsed.Text) + } + if len(parsed.Images) != 2 { + t.Fatalf("expected 2 images, got %d", len(parsed.Images)) + } + if parsed.Images[0].MimeType != "image/png" || parsed.Images[0].Base64 != "iVBORw0KGgo=" { + t.Fatalf("unexpected data URI image: %+v", parsed.Images[0]) + } + if parsed.Images[1].URL != "https://example.com/cat.jpg" { + t.Fatalf("unexpected URL image: %+v", parsed.Images[1]) + } +} + +func TestExtractMessageContentUsesMultimodalText(t *testing.T) { + msg := Message{ + Role: "user", + Content: []interface{}{ + map[string]interface{}{"type": "text", "text": "first"}, + map[string]interface{}{"type": "image_url", "image_url": map[string]interface{}{"url": "data:image/jpeg;base64,abc"}}, + map[string]interface{}{"type": "text", "text": "second"}, + }, + } + + got := extractMessageContent(msg) + if got != "first\nsecond" { + t.Fatalf("unexpected text content: %q", got) + } +} + +func TestDownloadImageAsBase64(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "image/png") + _, _ = w.Write([]byte{1, 2, 3, 4}) + })) + defer server.Close() + + image, err := downloadImageAsBase64(server.URL, server.Client()) + if err != nil { + t.Fatalf("download image: %v", err) + } + if image.MimeType != "image/png" { + t.Fatalf("unexpected mime type: %s", image.MimeType) + } + if image.Base64 != "AQIDBA==" { + t.Fatalf("unexpected base64: %s", image.Base64) + } + if image.URL != server.URL { + t.Fatalf("unexpected url: %s", image.URL) + } +} + +func TestBuildPromptWithMedia_TextOnlyMatchesBuildPrompt(t *testing.T) { + req := ChatCompletionRequest{ + Messages: []Message{{Role: "user", Content: "hello"}}, + } + + legacy := BuildPrompt(req) + prompt, images := BuildPromptWithMedia(req) + + if prompt != legacy { + t.Fatalf("expected text prompt to match legacy output, got %q vs %q", prompt, legacy) + } + if len(images) != 0 { + t.Fatalf("expected no images, got %d", len(images)) + } +} + +func TestBuildPromptWithMedia_CollectsImages(t *testing.T) { + req := ChatCompletionRequest{ + Messages: []Message{{ + Role: "user", + Content: []interface{}{ + map[string]interface{}{"type": "text", "text": "see attached"}, + map[string]interface{}{"type": "image_url", "image_url": map[string]interface{}{"url": "data:image/png;base64,AAAA"}}, + }, + }}, + } + + prompt, images := BuildPromptWithMedia(req) + + if !strings.Contains(prompt, "User: see attached") { + t.Fatalf("unexpected prompt: %s", prompt) + } + if len(images) != 1 { + t.Fatalf("expected 1 image, got %d", len(images)) + } + if images[0].MimeType != "image/png" || images[0].Base64 != "AAAA" { + t.Fatalf("unexpected image: %+v", images[0]) + } +} + +func TestGetExperimentalFeatureMode(t *testing.T) { + tests := []struct { + model string + want int + ok bool + }{ + {model: "gemini-3-pro-image", want: featureModeImage, ok: true}, + {model: "gemini-3-pro-video", want: featureModeVideo, ok: true}, + {model: "gemini-3-pro", want: 0, ok: false}, + } + + for _, tt := range tests { + got, ok := getExperimentalFeatureMode(tt.model) + if got != tt.want || ok != tt.ok { + t.Fatalf("model %s => (%d,%v), want (%d,%v)", tt.model, got, ok, tt.want, tt.ok) + } + } +} + +func TestGetExperimentalRequestConfig(t *testing.T) { + imageCfg, ok := getExperimentalRequestConfig("gemini-3-pro-image") + if !ok { + t.Fatal("expected image experimental config") + } + if imageCfg.FeatureMode != featureModeImage || imageCfg.Ef != featureModeImage || imageCfg.Xpc != "MODE_CATEGORY_FAST" { + t.Fatalf("unexpected image config: %+v", imageCfg) + } + if imageCfg.Lo == nil || *imageCfg.Lo { + t.Fatalf("expected image Lo=false, got %+v", imageCfg.Lo) + } + + videoCfg, ok := getExperimentalRequestConfig("gemini-3-pro-video") + if !ok { + t.Fatal("expected video experimental config") + } + if videoCfg.FeatureMode != featureModeVideo || videoCfg.Ef != featureModeVideo || videoCfg.Xpc != "MODE_CATEGORY_FAST" { + t.Fatalf("unexpected video config: %+v", videoCfg) + } + if videoCfg.Lo == nil || *videoCfg.Lo { + t.Fatalf("expected video Lo=false, got %+v", videoCfg.Lo) + } +} diff --git a/internal/httpclient/client.go b/internal/httpclient/client.go index 7798545..5429293 100644 --- a/internal/httpclient/client.go +++ b/internal/httpclient/client.go @@ -53,6 +53,16 @@ func CurrentGeminiEndpoints(cfg config.Config) GeminiEndpoints { } func New(cfg config.Config, logger *logging.Logger) *http.Client { + client, proxyConfigured, proxyValue := NewWithProxy(cfg, strings.TrimSpace(cfg.Proxy), logger) + if proxyConfigured { + go testProxyConnectivity(client, proxyValue, logger) + } else { + logger.Info("HTTP 客户端已初始化 (未配置显式代理)") + } + return client +} + +func NewWithProxy(cfg config.Config, proxyOverride string, logger *logging.Logger) (*http.Client, bool, string) { dialer := &net.Dialer{ Timeout: 10 * time.Second, KeepAlive: 30 * time.Second, @@ -70,28 +80,26 @@ func New(cfg config.Config, logger *logging.Logger) *http.Client { } proxyConfigured := false - if strings.TrimSpace(cfg.Proxy) != "" { - proxyURL, err := url.Parse(strings.TrimSpace(cfg.Proxy)) + proxyValue := strings.TrimSpace(proxyOverride) + if proxyValue == "" { + proxyValue = strings.TrimSpace(cfg.Proxy) + } + if proxyValue != "" { + proxyURL, err := url.Parse(proxyValue) if err == nil { transport.Proxy = http.ProxyURL(proxyURL) proxyConfigured = true } else { - logger.Warn("无效的代理 URL: %s,将回退到系统环境变量代理,错误: %v", cfg.Proxy, err) + logger.Warn("无效的代理 URL: %s,将回退到系统环境变量代理,错误: %v", proxyValue, err) } } client := &http.Client{ Transport: transport, - Timeout: 120 * time.Second, + Timeout: 300 * time.Second, } - if proxyConfigured { - go testProxyConnectivity(client, cfg.Proxy, logger) - } else { - logger.Info("HTTP 客户端已初始化 (未配置显式代理)") - } - - return client + return client, proxyConfigured, proxyValue } func testProxyConnectivity(client *http.Client, proxyStr string, logger *logging.Logger) { diff --git a/internal/server/server.go b/internal/server/server.go index 01e37fb..94a3ef0 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,7 +3,9 @@ package server import ( "encoding/json" "fmt" + "io" "net/http" + "slices" "strings" "sync" "sync/atomic" @@ -19,6 +21,12 @@ import ( "main/internal/web" ) +var ( + buildPromptWithMedia = gemini.BuildPromptWithMedia + handleStreamResponse = gemini.HandleStreamResponse + handleNonStreamResponse = gemini.HandleNonStreamResponse +) + type Server struct { configStore *config.Store metrics *metrics.Metrics @@ -30,9 +38,13 @@ type Server struct { httpClient *http.Client tokenManager *token.Manager + stateStore *stateStore sessions map[string]*gemini.GeminiSession sessionsMu sync.RWMutex + + discoveredModelsMu sync.RWMutex + discoveredModels map[string]time.Time } type loggingResponseWriter struct { @@ -41,18 +53,64 @@ type loggingResponseWriter struct { size int } +type updateCookiesRequest struct { + Cookies string `json:"cookies"` + Email string `json:"email"` + Proxy string `json:"proxy"` + Persist *bool `json:"persist"` +} + +type upsertAccountRequest struct { + ID string `json:"id"` + Email string `json:"email"` + Cookies string `json:"cookies"` + Token string `json:"token"` + Proxy string `json:"proxy"` + Enabled *bool `json:"enabled"` + Weight int `json:"weight"` +} + +type rebindSessionRequest struct { + AccountID string `json:"account_id"` +} + +type responseCapture struct { + header http.Header + body []byte + statusCode int +} + +func (r *responseCapture) Header() http.Header { return r.header } +func (r *responseCapture) WriteHeader(statusCode int) { r.statusCode = statusCode } +func (r *responseCapture) Write(data []byte) (int, error) { + if r.statusCode == 0 { + r.statusCode = http.StatusOK + } + r.body = append(r.body, data...) + return len(data), nil +} + +type webLoginRequest struct { + APIKey string `json:"api_key"` +} + func New(configStore *config.Store) (*Server, error) { s := &Server{ - configStore: configStore, - metrics: metrics.New(), - sessions: make(map[string]*gemini.GeminiSession), + configStore: configStore, + metrics: metrics.New(), + sessions: make(map[string]*gemini.GeminiSession), + discoveredModels: make(map[string]time.Time), } if err := s.reloadRuntime(); err != nil { return nil, err } + s.stateStore = newStateStore(configStore.Path()) - s.tokenManager = token.NewManager(s.ConfigSnapshot, s.HTTPClient, s.Logger) + s.tokenManager = token.NewManager(s.ConfigSnapshot, s.HTTPClient, s.Logger, s.configStore.Update) + if err := s.loadPersistentState(); err != nil { + s.Logger().Warn("加载持久化状态失败: %v", err) + } gemini.Initialize(s.ConfigSnapshot, s.HTTPClient, s.Logger, s.metrics, s.tokenManager) return s, nil } @@ -63,11 +121,22 @@ func (s *Server) Run() error { s.configStore.Watch(s.reloadConfig) mux := http.NewServeMux() - mux.HandleFunc("/", web.HandleIndex) + mux.HandleFunc("/", s.handleIndex) mux.HandleFunc("/help", web.HandleHelp) mux.HandleFunc("/help/", web.HandleHelp) + mux.HandleFunc("/login", web.HandleLogin) + mux.HandleFunc("/api/web/login", s.loggingMiddleware(s.handleWebLogin)) + mux.HandleFunc("/api/web/logout", s.loggingMiddleware(s.handleWebLogout)) mux.HandleFunc("/api/telemetry", s.handleTelemetry) + mux.HandleFunc("/healthz", s.handleHealthz) + mux.HandleFunc("/api/session/cookies", s.loggingMiddleware(s.handleUpdateCookies)) + mux.HandleFunc("/api/accounts", s.loggingMiddleware(s.handleAccounts)) + mux.HandleFunc("/api/accounts/", s.loggingMiddleware(s.handleAccountAction)) + mux.HandleFunc("/api/accounts/bindings", s.loggingMiddleware(s.handleAccountBindings)) + mux.HandleFunc("/api/accounts/refresh-all", s.loggingMiddleware(s.handleAccountsRefreshAll)) + mux.HandleFunc("/api/accounts/bindings/", s.loggingMiddleware(s.handleBindingAction)) mux.HandleFunc("/v1/models", s.loggingMiddleware(s.handleModels)) + mux.HandleFunc("/v1/responses", s.loggingMiddleware(s.handleResponses)) mux.HandleFunc("/v1/chat/completions", s.loggingMiddleware(s.handleChatCompletions)) cfg := s.ConfigSnapshot() @@ -99,6 +168,7 @@ func (s *Server) reloadConfig() error { if err := s.reloadRuntime(); err != nil { return err } + s.tokenManager.RefreshAccountsFromConfig() gemini.Initialize(s.ConfigSnapshot, s.HTTPClient, s.Logger, s.metrics, s.tokenManager) s.Logger().Info("配置文件已成功重载") return nil @@ -123,6 +193,57 @@ func (s *Server) reloadRuntime() error { return nil } +func (s *Server) loadPersistentState() error { + if s.stateStore == nil { + return nil + } + state, err := s.stateStore.load() + if err != nil { + return err + } + bindings := make([]token.SessionBinding, 0, len(state.SessionBindings)) + for sessionKey, binding := range state.SessionBindings { + bindings = append(bindings, token.SessionBinding{SessionKey: sessionKey, AccountID: binding.AccountID, BoundAt: binding.BoundAt, LastUsedAt: binding.LastUsedAt}) + } + s.tokenManager.RestoreSessionBindings(bindings) + tokenSnapshots := make(map[string]token.AccountTokenSnapshot, len(state.AccountTokens)) + for accountID, snapshot := range state.AccountTokens { + tokenSnapshots[accountID] = token.AccountTokenSnapshot{ + SNlM0e: snapshot.SNlM0e, + BLToken: snapshot.BLToken, + FSID: snapshot.FSID, + ReqID: snapshot.ReqID, + FetchedAt: snapshot.FetchedAt, + } + } + s.tokenManager.RestoreTokenSnapshots(tokenSnapshots) + return nil +} + +func (s *Server) savePersistentState() { + if s.stateStore == nil { + return + } + bindings := s.tokenManager.SessionBindings() + tokenSnapshots := s.tokenManager.TokenSnapshots() + state := persistentState{SessionBindings: map[string]persistentBinding{}, AccountTokens: map[string]tokenSnapshot{}} + for _, binding := range bindings { + state.SessionBindings[binding.SessionKey] = persistentBinding{AccountID: binding.AccountID, BoundAt: binding.BoundAt, LastUsedAt: binding.LastUsedAt} + } + for accountID, snapshot := range tokenSnapshots { + state.AccountTokens[accountID] = tokenSnapshot{ + SNlM0e: snapshot.SNlM0e, + BLToken: snapshot.BLToken, + FSID: snapshot.FSID, + ReqID: snapshot.ReqID, + FetchedAt: snapshot.FetchedAt, + } + } + if err := s.stateStore.save(state); err != nil { + s.Logger().Warn("保存持久化状态失败: %v", err) + } +} + func (s *Server) printBanner() { cfg := s.ConfigSnapshot() println("======================================================") @@ -190,9 +311,90 @@ func (s *Server) writeError(w http.ResponseWriter, status int, message string) { resp := gemini.ErrorResponse{} resp.Error.Message = message resp.Error.Type = "invalid_request_error" + resp.Error.Code = strings.ToLower(strings.ReplaceAll(http.StatusText(status), " ", "_")) s.writeJSON(w, status, resp) } +func (s *Server) writeMappedError(w http.ResponseWriter, err gemini.OpenAIError) { + resp := gemini.ErrorResponse{} + resp.Error.Message = err.Message + resp.Error.Type = err.Type + resp.Error.Code = err.Code + s.writeJSON(w, err.Status, resp) +} + +func (s *Server) authenticateRequest(r *http.Request) error { + auth := r.Header.Get("Authorization") + cfg := s.ConfigSnapshot() + if auth == "" { + if cookie, err := r.Cookie("geminiweb2api_session"); err == nil && cookie.Value == cfg.APIKey { + return nil + } + return fmt.Errorf("缺失 authorization 请求头") + } + + auth = strings.TrimPrefix(auth, "Bearer ") + if auth != cfg.APIKey { + return fmt.Errorf("无效的 api key") + } + return nil +} + +func (s *Server) webAuthenticated(r *http.Request) bool { + cookie, err := r.Cookie("geminiweb2api_session") + return err == nil && cookie.Value == s.ConfigSnapshot().APIKey +} + +func (s *Server) handleIndex(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + if !s.webAuthenticated(r) { + http.Redirect(w, r, "/login", http.StatusFound) + return + } + web.HandleIndex(w, r) +} + +func (s *Server) handleWebLogin(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + s.writeError(w, http.StatusMethodNotAllowed, "请求方法不允许") + return + } + var req webLoginRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.writeError(w, http.StatusBadRequest, err.Error()) + return + } + if strings.TrimSpace(req.APIKey) != s.ConfigSnapshot().APIKey { + s.writeError(w, http.StatusUnauthorized, "无效的 api key") + return + } + http.SetCookie(w, &http.Cookie{ + Name: "geminiweb2api_session", + Value: s.ConfigSnapshot().APIKey, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + MaxAge: 86400, + }) + s.writeJSON(w, http.StatusOK, map[string]interface{}{"message": "logged in"}) +} + +func (s *Server) handleWebLogout(w http.ResponseWriter, r *http.Request) { + http.SetCookie(w, &http.Cookie{Name: "geminiweb2api_session", Value: "", Path: "/", HttpOnly: true, SameSite: http.SameSiteLaxMode, MaxAge: -1}) + s.writeJSON(w, http.StatusOK, map[string]interface{}{"message": "logged out"}) +} + +func (s *Server) requireAuth(w http.ResponseWriter, r *http.Request) bool { + if err := s.authenticateRequest(r); err != nil { + s.writeError(w, http.StatusUnauthorized, err.Error()) + return false + } + return true +} + func (s *Server) handleTelemetry(w http.ResponseWriter, _ *http.Request) { note := s.ConfigSnapshot().Note uptime := time.Since(s.metrics.StartTime).Seconds() @@ -210,6 +412,382 @@ func (s *Server) handleTelemetry(w http.ResponseWriter, _ *http.Request) { s.writeJSON(w, http.StatusOK, response) } +func (s *Server) handleHealthz(w http.ResponseWriter, _ *http.Request) { + stats := s.tokenManager.PoolStats() + status := http.StatusOK + body := map[string]interface{}{"status": "ok", "healthy_accounts": stats.HealthyAccounts, "enabled_accounts": stats.EnabledAccounts} + if stats.HealthyAccounts == 0 { + status = http.StatusServiceUnavailable + body["status"] = "degraded" + } + s.writeJSON(w, status, body) +} + +func (s *Server) recordDiscoveredModel(model string) { + model = strings.TrimSpace(model) + if model == "" { + return + } + s.discoveredModelsMu.Lock() + defer s.discoveredModelsMu.Unlock() + s.discoveredModels[model] = time.Now() +} + +func (s *Server) currentModelList() []string { + s.discoveredModelsMu.RLock() + models := make([]string, 0, len(s.discoveredModels)) + for model := range s.discoveredModels { + models = append(models, model) + } + s.discoveredModelsMu.RUnlock() + if len(models) > 0 { + slices.Sort(models) + return models + } + configured := s.ConfigSnapshot().Models + if len(configured) > 0 { + return configured + } + return []string{"gemini-3-pro", "gemini-3-pro-deep-think", "gemini-3-flash"} +} + +func (s *Server) normalizeModel(model string) string { + model = strings.TrimSpace(model) + if alias := s.ConfigSnapshot().ModelAliases[model]; strings.TrimSpace(alias) != "" { + return strings.TrimSpace(alias) + } + return model +} + +func (s *Server) handleUpdateCookies(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + s.writeError(w, http.StatusMethodNotAllowed, "请求方法不允许") + return + } + + cfg := s.ConfigSnapshot() + auth := strings.TrimSpace(strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")) + if auth == "" || auth != cfg.APIKey { + s.writeError(w, http.StatusUnauthorized, "无效的回调凭证") + return + } + + var req updateCookiesRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.writeError(w, http.StatusBadRequest, err.Error()) + return + } + + req.Cookies = strings.TrimSpace(req.Cookies) + if req.Cookies == "" { + s.writeError(w, http.StatusBadRequest, "cookies 不能为空") + return + } + + persist := true + if req.Persist != nil { + persist = *req.Persist + } + + updateFn := s.configStore.UpdateInMemory + updateErrMsg := "更新运行时配置失败" + if persist { + updateFn = s.configStore.Update + updateErrMsg = "写入配置失败" + } + + callbackAccountID := accountIDFromEmail(req.Email) + if err := updateFn(func(cfg *config.Config) error { + accountID := callbackAccountID + if accountID != "" { + account := config.Account{ + ID: accountID, + Email: strings.TrimSpace(req.Email), + Cookies: req.Cookies, + Proxy: strings.TrimSpace(req.Proxy), + Enabled: true, + Weight: 1, + } + updated := false + for i := range cfg.Accounts { + if cfg.Accounts[i].ID == account.ID || (account.Email != "" && cfg.Accounts[i].Email == account.Email) { + cfg.Accounts[i].ID = account.ID + cfg.Accounts[i].Email = account.Email + cfg.Accounts[i].Cookies = account.Cookies + cfg.Accounts[i].Proxy = account.Proxy + cfg.Accounts[i].Enabled = true + if cfg.Accounts[i].Weight <= 0 { + cfg.Accounts[i].Weight = 1 + } + updated = true + break + } + } + if !updated { + cfg.Accounts = append(cfg.Accounts, account) + } + return nil + } + cfg.Cookies = req.Cookies + return nil + }); err != nil { + s.Logger().Error("更新 cookies 失败: %v", err) + s.writeError(w, http.StatusInternalServerError, updateErrMsg) + return + } + + if err := s.reloadRuntime(); err != nil { + s.Logger().Error("重载运行时失败: %v", err) + s.writeError(w, http.StatusInternalServerError, "重载运行时失败") + return + } + gemini.Initialize(s.ConfigSnapshot, s.HTTPClient, s.Logger, s.metrics, s.tokenManager) + + refreshErr := error(nil) + if accountID := callbackAccountID; accountID != "" { + refreshErr = s.tokenManager.RefreshAccountNow(accountID) + } else { + refreshErr = s.tokenManager.RefreshTokenNow() + } + if refreshErr != nil { + s.Logger().Error("刷新 token 失败: %v", refreshErr) + s.writeError(w, http.StatusBadGateway, fmt.Sprintf("cookies 已接收但刷新 token 失败: %v", refreshErr)) + return + } + + s.Logger().Info("cookies 回调更新成功: email=%s persist=%v", req.Email, persist) + s.writeJSON(w, http.StatusOK, map[string]interface{}{ + "code": 0, + "message": "cookies updated", + "persist": persist, + "email": req.Email, + "account_id": callbackAccountID, + }) +} + +func accountIDFromEmail(email string) string { + email = strings.TrimSpace(strings.ToLower(email)) + if email == "" { + return "" + } + replacer := strings.NewReplacer("@", "_", ".", "_", "+", "_", "-", "_") + return replacer.Replace(email) +} + +func (s *Server) handleAccounts(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + if !s.ConfigSnapshot().PublicAccountStatus && !s.requireAuth(w, r) { + return + } + s.writeJSON(w, http.StatusOK, map[string]interface{}{ + "accounts": s.tokenManager.AccountsStatus(), + "bindings": s.tokenManager.SessionBindings(), + "stats": s.tokenManager.PoolStats(), + }) + return + } + if !s.requireAuth(w, r) { + return + } + if r.Method != http.MethodPost && r.Method != http.MethodPut { + s.writeError(w, http.StatusMethodNotAllowed, "请求方法不允许") + return + } + + var req upsertAccountRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.writeError(w, http.StatusBadRequest, err.Error()) + return + } + enabled := true + if req.Enabled != nil { + enabled = *req.Enabled + } + account := config.Account{ + ID: strings.TrimSpace(req.ID), + Email: strings.TrimSpace(req.Email), + Cookies: strings.TrimSpace(req.Cookies), + Token: strings.TrimSpace(req.Token), + Proxy: strings.TrimSpace(req.Proxy), + Enabled: enabled, + Weight: req.Weight, + } + if err := s.tokenManager.UpsertAccount(account); err != nil { + s.writeError(w, http.StatusBadRequest, err.Error()) + return + } + if err := s.reloadRuntime(); err != nil { + s.writeError(w, http.StatusInternalServerError, "重载运行时失败") + return + } + s.writeJSON(w, http.StatusOK, map[string]interface{}{"message": "account updated", "account_id": account.ID}) +} + +func (s *Server) handleAccountBindings(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + s.writeError(w, http.StatusMethodNotAllowed, "请求方法不允许") + return + } + if !s.ConfigSnapshot().PublicAccountStatus && !s.requireAuth(w, r) { + return + } + s.writeJSON(w, http.StatusOK, map[string]interface{}{"bindings": s.tokenManager.SessionBindings(), "stats": s.tokenManager.PoolStats()}) +} + +func (s *Server) handleAccountsRefreshAll(w http.ResponseWriter, r *http.Request) { + if !s.requireAuth(w, r) { + return + } + if r.Method != http.MethodPost { + s.writeError(w, http.StatusMethodNotAllowed, "请求方法不允许") + return + } + if err := s.tokenManager.RefreshTokenNow(); err != nil { + s.writeError(w, http.StatusBadGateway, err.Error()) + return + } + s.writeJSON(w, http.StatusOK, map[string]interface{}{"message": "all accounts refreshed"}) +} + +func (s *Server) handleBindingAction(w http.ResponseWriter, r *http.Request) { + if !s.requireAuth(w, r) { + return + } + path := strings.TrimPrefix(r.URL.Path, "/api/accounts/bindings/") + parts := strings.Split(strings.Trim(path, "/"), "/") + if len(parts) < 2 { + s.writeError(w, http.StatusNotFound, "绑定操作不存在") + return + } + sessionKey := parts[0] + action := parts[1] + switch action { + case "unbind": + if r.Method != http.MethodPost && r.Method != http.MethodDelete { + s.writeError(w, http.StatusMethodNotAllowed, "请求方法不允许") + return + } + if err := s.tokenManager.UnbindSession(sessionKey); err != nil { + s.writeError(w, http.StatusBadRequest, err.Error()) + return + } + s.writeJSON(w, http.StatusOK, map[string]interface{}{"message": "session unbound", "session_key": sessionKey}) + s.savePersistentState() + case "rebind": + if r.Method != http.MethodPost { + s.writeError(w, http.StatusMethodNotAllowed, "请求方法不允许") + return + } + var req rebindSessionRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.writeError(w, http.StatusBadRequest, err.Error()) + return + } + if err := s.tokenManager.RebindSession(sessionKey, strings.TrimSpace(req.AccountID)); err != nil { + s.writeError(w, http.StatusBadRequest, err.Error()) + return + } + s.savePersistentState() + s.writeJSON(w, http.StatusOK, map[string]interface{}{"message": "session rebound", "session_key": sessionKey, "account_id": strings.TrimSpace(req.AccountID)}) + default: + s.writeError(w, http.StatusNotFound, "绑定操作不存在") + } +} + +func (s *Server) handleAccountAction(w http.ResponseWriter, r *http.Request) { + if !s.requireAuth(w, r) { + return + } + path := strings.TrimPrefix(r.URL.Path, "/api/accounts/") + parts := strings.Split(strings.Trim(path, "/"), "/") + if len(parts) < 2 { + s.writeError(w, http.StatusNotFound, "账号操作不存在") + return + } + accountID := parts[0] + action := parts[1] + switch action { + case "enable": + if r.Method != http.MethodPost { + s.writeError(w, http.StatusMethodNotAllowed, "请求方法不允许") + return + } + if err := s.tokenManager.SetAccountEnabled(accountID, true); err != nil { + s.writeError(w, http.StatusBadRequest, err.Error()) + return + } + s.writeJSON(w, http.StatusOK, map[string]interface{}{"message": "account enabled", "account_id": accountID}) + case "disable": + if r.Method != http.MethodPost { + s.writeError(w, http.StatusMethodNotAllowed, "请求方法不允许") + return + } + if err := s.tokenManager.SetAccountEnabled(accountID, false); err != nil { + s.writeError(w, http.StatusBadRequest, err.Error()) + return + } + s.writeJSON(w, http.StatusOK, map[string]interface{}{"message": "account disabled", "account_id": accountID}) + case "refresh": + if r.Method != http.MethodPost { + s.writeError(w, http.StatusMethodNotAllowed, "请求方法不允许") + return + } + if err := s.tokenManager.RefreshAccountNow(accountID); err != nil { + s.writeError(w, http.StatusBadGateway, err.Error()) + return + } + s.writeJSON(w, http.StatusOK, map[string]interface{}{"message": "account refreshed", "account_id": accountID}) + case "details": + if r.Method != http.MethodGet { + s.writeError(w, http.StatusMethodNotAllowed, "请求方法不允许") + return + } + cfg := s.ConfigSnapshot() + for _, account := range cfg.Accounts { + if account.ID == accountID { + s.writeJSON(w, http.StatusOK, account) + return + } + } + if accountID == "__default__" && len(cfg.Accounts) == 0 { + s.writeJSON(w, http.StatusOK, config.Account{ + ID: "__default__", + Email: "default", + Cookies: cfg.Cookies, + Token: cfg.Token, + Proxy: cfg.Proxy, + Enabled: true, + Weight: 1, + }) + return + } + s.writeError(w, http.StatusNotFound, "account not found") + case "cookie-health": + if r.Method != http.MethodGet { + s.writeError(w, http.StatusMethodNotAllowed, "请求方法不允许") + return + } + health, ok := s.tokenManager.CookieHealth(accountID) + if !ok { + s.writeError(w, http.StatusNotFound, "account not found") + return + } + s.writeJSON(w, http.StatusOK, health) + case "delete": + if r.Method != http.MethodPost && r.Method != http.MethodDelete { + s.writeError(w, http.StatusMethodNotAllowed, "请求方法不允许") + return + } + if err := s.tokenManager.DeleteAccount(accountID); err != nil { + s.writeError(w, http.StatusBadRequest, err.Error()) + return + } + s.writeJSON(w, http.StatusOK, map[string]interface{}{"message": "account deleted", "account_id": accountID}) + default: + s.writeError(w, http.StatusNotFound, "账号操作不存在") + } +} + func (s *Server) handleModels(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { s.Logger().Warn("接口 /v1/models 收到无效的请求方法: %s", r.Method) @@ -217,32 +795,88 @@ func (s *Server) handleModels(w http.ResponseWriter, r *http.Request) { return } + if err := s.authenticateRequest(r); err != nil { + s.Logger().Warn("来自 %s 的 /v1/models 请求鉴权失败: %v", r.RemoteAddr, err) + s.writeError(w, http.StatusUnauthorized, err.Error()) + return + } + now := time.Now().Unix() + modelIDs := s.currentModelList() + data := make([]gemini.Model, 0, len(modelIDs)) + for _, id := range modelIDs { + data = append(data, gemini.Model{ID: id, Object: "model", Created: now, OwnedBy: "google"}) + } models := gemini.ModelsResponse{ Object: "list", - Data: []gemini.Model{ - {ID: "gemini-3-flash", Object: "model", Created: now, OwnedBy: "google"}, - {ID: "gemini-3", Object: "model", Created: now, OwnedBy: "google"}, - {ID: "gemini-3-flash-thinking", Object: "model", Created: now, OwnedBy: "google"}, - {ID: "gemini-3-flash-plus", Object: "model", Created: now, OwnedBy: "google"}, - {ID: "gemini-3-flash-thinking-plus", Object: "model", Created: now, OwnedBy: "google"}, - {ID: "gemini-3-flash-advanced", Object: "model", Created: now, OwnedBy: "google"}, - {ID: "gemini-3-pro", Object: "model", Created: now, OwnedBy: "google"}, - {ID: "gemini-3-pro-advanced", Object: "model", Created: now, OwnedBy: "google"}, - {ID: "gemini-3-pro-plus", Object: "model", Created: now, OwnedBy: "google"}, - {ID: "gemini-3.1", Object: "model", Created: now, OwnedBy: "google"}, - {ID: "gemini-3.1-pro", Object: "model", Created: now, OwnedBy: "google"}, - {ID: "gemini-2.5-flash", Object: "model", Created: now, OwnedBy: "google"}, - {ID: "gemini-2.5-pro", Object: "model", Created: now, OwnedBy: "google"}, - {ID: "gemini-2-flash", Object: "model", Created: now, OwnedBy: "google"}, - {ID: "gemini-2.0-flash", Object: "model", Created: now, OwnedBy: "google"}, - {ID: "gemini-flash", Object: "model", Created: now, OwnedBy: "google"}, - {ID: "gemini-pro", Object: "model", Created: now, OwnedBy: "google"}, - }, + Data: data, } s.writeJSON(w, http.StatusOK, models) } +func (s *Server) handleResponses(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + s.writeError(w, http.StatusMethodNotAllowed, "请求方法不允许") + return + } + if err := s.authenticateRequest(r); err != nil { + s.writeError(w, http.StatusUnauthorized, err.Error()) + return + } + var req gemini.ResponsesRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.writeError(w, http.StatusBadRequest, err.Error()) + return + } + inputText := strings.TrimSpace(fmt.Sprint(req.Input)) + if inputText == "" { + s.writeError(w, http.StatusBadRequest, "input 不能为空") + return + } + chatReq := gemini.ChatCompletionRequest{Model: s.normalizeModel(req.Model), Stream: false, Messages: []gemini.Message{{Role: "user", Content: inputText}}} + body, _ := json.Marshal(chatReq) + proxyReq := r.Clone(r.Context()) + proxyReq.Body = io.NopCloser(strings.NewReader(string(body))) + proxyReq.ContentLength = int64(len(body)) + proxyReq.Header.Set("Content-Type", "application/json") + wrapper := &responseCapture{header: http.Header{}} + s.handleChatCompletions(wrapper, proxyReq) + if wrapper.statusCode >= 400 { + for k, values := range wrapper.header { + for _, v := range values { + w.Header().Add(k, v) + } + } + w.WriteHeader(wrapper.statusCode) + _, _ = w.Write(wrapper.body) + return + } + var chatResp gemini.ChatCompletionResponse + if err := json.Unmarshal(wrapper.body, &chatResp); err != nil { + s.writeError(w, http.StatusBadGateway, "无法解析 chat completion 响应") + return + } + resolvedModel := s.normalizeModel(req.Model) + s.recordDiscoveredModel(resolvedModel) + resp := gemini.ResponsesResponse{ID: chatResp.ID, Object: "response", CreatedAt: chatResp.Created, Model: resolvedModel} + if len(chatResp.Choices) > 0 && chatResp.Choices[0].Message != nil { + item := struct { + Type string `json:"type"` + Role string `json:"role"` + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + }{Type: "message", Role: "assistant"} + item.Content = append(item.Content, struct { + Type string `json:"type"` + Text string `json:"text"` + }{Type: "output_text", Text: fmt.Sprint(chatResp.Choices[0].Message.Content)}) + resp.Output = append(resp.Output, item) + } + s.writeJSON(w, http.StatusOK, resp) +} + func (s *Server) getOrCreateSession(sessionKey, conversationID string) (*gemini.GeminiSession, bool) { s.sessionsMu.RLock() session, exists := s.sessions[sessionKey] @@ -273,18 +907,9 @@ func (s *Server) handleChatCompletions(w http.ResponseWriter, r *http.Request) { return } - auth := r.Header.Get("Authorization") - if auth == "" { - s.Logger().Warn("来自 %s 的请求缺失 Authorization 请求头", r.RemoteAddr) - s.writeError(w, http.StatusUnauthorized, "缺失 authorization 请求头") - return - } - - cfg := s.ConfigSnapshot() - auth = strings.TrimPrefix(auth, "Bearer ") - if auth != cfg.APIKey { - s.Logger().Warn("来自 %s 的请求使用了无效的 API Key", r.RemoteAddr) - s.writeError(w, http.StatusUnauthorized, "无效的 api key") + if err := s.authenticateRequest(r); err != nil { + s.Logger().Warn("来自 %s 的请求鉴权失败: %v", r.RemoteAddr, err) + s.writeError(w, http.StatusUnauthorized, err.Error()) return } @@ -295,7 +920,11 @@ func (s *Server) handleChatCompletions(w http.ResponseWriter, r *http.Request) { return } + req.Model = s.normalizeModel(req.Model) s.Logger().Info("对话请求: 模型=%s, 消息数=%d, 流式=%v", req.Model, len(req.Messages), req.Stream) + if req.MaxCompletionTokens > 0 && req.MaxTokens == 0 { + req.MaxTokens = req.MaxCompletionTokens + } s.Logger().Debug("请求消息内容: %+v", req.Messages) sessionKey := r.Header.Get("X-Session-ID") @@ -326,14 +955,18 @@ func (s *Server) handleChatCompletions(w http.ResponseWriter, r *http.Request) { } snlm0eToken, _ := s.tokenManager.GetTokenForSession(sessionKey, isNewSession) - prompt := gemini.BuildPrompt(req) + prompt, images := buildPromptWithMedia(req) if req.Stream { s.Logger().Debug("开始流式响应") - gemini.HandleStreamResponse(w, prompt, req.Model, session, req.Tools, sessionKey, snlm0eToken, s.writeError) + s.recordDiscoveredModel(req.Model) + handleStreamResponse(w, prompt, images, req.Model, session, req.Tools, sessionKey, snlm0eToken, req.StreamOptions, s.writeError, s.writeMappedError) + s.savePersistentState() return } s.Logger().Debug("开始非流式响应") - gemini.HandleNonStreamResponse(w, prompt, req.Model, session, req.Tools, sessionKey, snlm0eToken, s.writeError, s.writeJSON) + s.recordDiscoveredModel(req.Model) + handleNonStreamResponse(w, prompt, images, req.Model, session, req.Tools, sessionKey, snlm0eToken, s.writeError, s.writeMappedError, s.writeJSON) + s.savePersistentState() } diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..b167108 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,224 @@ +package server + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + "time" + + "main/internal/config" + "main/internal/gemini" + "main/internal/token" +) + +func newTestServer(t *testing.T, publicAccountStatus bool) *Server { + t.Helper() + path := filepath.Join(t.TempDir(), "config.json") + store := config.NewStore(path) + if err := store.Load(); err != nil { + t.Fatal(err) + } + if err := store.Update(func(cfg *config.Config) error { + cfg.APIKey = "test-key" + cfg.PublicAccountStatus = publicAccountStatus + return nil + }); err != nil { + t.Fatal(err) + } + s, err := New(store) + if err != nil { + t.Fatal(err) + } + return s +} + +func TestHandleIndexRedirectsWithoutWebSession(t *testing.T) { + s := newTestServer(t, false) + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + + s.handleIndex(w, req) + + if w.Code != http.StatusFound { + t.Fatalf("expected redirect, got %d", w.Code) + } + if location := w.Header().Get("Location"); location != "/login" { + t.Fatalf("expected /login redirect, got %q", location) + } +} + +func TestHandleIndexAllowsValidWebSession(t *testing.T) { + s := newTestServer(t, false) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "geminiweb2api_session", Value: "test-key"}) + w := httptest.NewRecorder() + + s.handleIndex(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } +} + +func TestHandleWebLoginSetsHttpOnlyCookie(t *testing.T) { + s := newTestServer(t, false) + req := httptest.NewRequest(http.MethodPost, "/api/web/login", strings.NewReader(`{"api_key":"test-key"}`)) + w := httptest.NewRecorder() + + s.handleWebLogin(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + cookies := w.Result().Cookies() + if len(cookies) == 0 || cookies[0].Name != "geminiweb2api_session" || !cookies[0].HttpOnly { + t.Fatalf("expected httponly session cookie, got %#v", cookies) + } +} + +func TestHandleAccountsRequiresAuthByDefault(t *testing.T) { + s := newTestServer(t, false) + req := httptest.NewRequest(http.MethodGet, "/api/accounts", nil) + w := httptest.NewRecorder() + + s.handleAccounts(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", w.Code) + } +} + +func TestHandleAccountsAllowsPublicStatusWhenConfigured(t *testing.T) { + s := newTestServer(t, true) + req := httptest.NewRequest(http.MethodGet, "/api/accounts", nil) + w := httptest.NewRecorder() + + s.handleAccounts(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } +} + +func TestHandleModelsUses3SeriesDefaults(t *testing.T) { + s := newTestServer(t, false) + req := httptest.NewRequest(http.MethodGet, "/v1/models", nil) + req.Header.Set("Authorization", "Bearer test-key") + w := httptest.NewRecorder() + + s.handleModels(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal models response: %v", err) + } + + got := make([]string, 0, len(resp.Data)) + for _, model := range resp.Data { + got = append(got, model.ID) + } + expected := []string{"gemini-3-pro", "gemini-3-pro-deep-think", "gemini-3-flash"} + if len(got) != len(expected) { + t.Fatalf("expected %d models, got %d: %v", len(expected), len(got), got) + } + for i := range expected { + if got[i] != expected[i] { + t.Fatalf("expected models %v, got %v", expected, got) + } + } +} + +func TestHandleChatCompletionsPassesMultimodalContent(t *testing.T) { + s := newTestServer(t, false) + origBuildPromptWithMedia := buildPromptWithMedia + origHandleNonStreamResponse := handleNonStreamResponse + defer func() { + buildPromptWithMedia = origBuildPromptWithMedia + handleNonStreamResponse = origHandleNonStreamResponse + }() + + var capturedModel string + var capturedPrompt string + var capturedImages []gemini.ImageData + buildPromptWithMedia = func(req gemini.ChatCompletionRequest) (string, []gemini.ImageData) { + capturedModel = req.Model + capturedPrompt = "prompt-with-media" + capturedImages = []gemini.ImageData{{MimeType: "image/png", Base64: "AAAA", URL: "data:image/png;base64,AAAA"}} + return capturedPrompt, capturedImages + } + + var called bool + handleNonStreamResponse = func(w http.ResponseWriter, prompt string, images []gemini.ImageData, model string, session *gemini.GeminiSession, tools []gemini.Tool, sessionKey string, snlm0eToken string, writeError func(http.ResponseWriter, int, string), writeMappedError func(http.ResponseWriter, gemini.OpenAIError), writeJSON func(http.ResponseWriter, int, interface{})) { + called = true + if prompt != capturedPrompt { + t.Fatalf("expected prompt %q, got %q", capturedPrompt, prompt) + } + if model != "gemini-3-pro" { + t.Fatalf("expected normalized model, got %q", model) + } + if len(images) != 1 || images[0].MimeType != "image/png" { + t.Fatalf("expected images to be passed through, got %+v", images) + } + w.WriteHeader(http.StatusOK) + } + + reqBody := `{"model":"gemini-3-pro","messages":[{"role":"user","content":[{"type":"text","text":"see attached"},{"type":"image_url","image_url":{"url":"data:image/png;base64,AAAA"}}]}],"stream":false}` + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer test-key") + w := httptest.NewRecorder() + + s.handleChatCompletions(w, req) + + if !called { + t.Fatal("expected non-stream handler to be called") + } + if capturedModel != "gemini-3-pro" { + t.Fatalf("expected model to be normalized, got %q", capturedModel) + } + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } +} + +func TestPersistentStateRestoresAccountTokenSnapshots(t *testing.T) { + s := newTestServer(t, false) + + s.tokenManager.RestoreTokenSnapshots(map[string]token.AccountTokenSnapshot{ + "__default__": { + SNlM0e: "persisted-token", + BLToken: "persisted-bl", + FSID: "persisted-fsid", + ReqID: 12345, + FetchedAt: time.Now().UTC().Truncate(time.Second), + }, + }) + s.savePersistentState() + + reloaded, err := New(s.configStore) + if err != nil { + t.Fatal(err) + } + + accounts := reloaded.tokenManager.TokenSnapshots() + snapshot, ok := accounts["__default__"] + if !ok { + t.Fatal("expected persisted token snapshot to load") + } + if snapshot.SNlM0e != "persisted-token" || snapshot.BLToken != "persisted-bl" || snapshot.FSID != "persisted-fsid" { + t.Fatalf("unexpected restored snapshot: %+v", snapshot) + } + if snapshot.ReqID != 12345 { + t.Fatalf("expected req id 12345, got %d", snapshot.ReqID) + } +} diff --git a/internal/server/state_store.go b/internal/server/state_store.go new file mode 100644 index 0000000..33c6338 --- /dev/null +++ b/internal/server/state_store.go @@ -0,0 +1,92 @@ +package server + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + "time" +) + +type persistentBinding struct { + AccountID string `json:"account_id"` + BoundAt time.Time `json:"bound_at"` + LastUsedAt time.Time `json:"last_used_at"` +} + +type persistentState struct { + SessionBindings map[string]persistentBinding `json:"session_bindings"` + AccountTokens map[string]tokenSnapshot `json:"account_tokens"` +} + +type tokenSnapshot struct { + SNlM0e string `json:"snlm0e"` + BLToken string `json:"bl_token"` + FSID string `json:"fsid"` + ReqID int64 `json:"req_id"` + FetchedAt time.Time `json:"fetched_at"` +} + +type stateStore struct { + path string + mu sync.Mutex +} + +func newStateStore(configPath string) *stateStore { + return &stateStore{path: filepath.Join(filepath.Dir(configPath), "state.json")} +} + +func (s *stateStore) load() (persistentState, error) { + state := persistentState{SessionBindings: map[string]persistentBinding{}, AccountTokens: map[string]tokenSnapshot{}} + data, err := os.ReadFile(s.path) + if err != nil { + if os.IsNotExist(err) { + return state, nil + } + return state, err + } + if len(data) == 0 { + return state, nil + } + if err := json.Unmarshal(data, &state); err != nil { + return persistentState{SessionBindings: map[string]persistentBinding{}, AccountTokens: map[string]tokenSnapshot{}}, err + } + if state.SessionBindings == nil { + state.SessionBindings = map[string]persistentBinding{} + } + if state.AccountTokens == nil { + state.AccountTokens = map[string]tokenSnapshot{} + } + return state, nil +} + +func (s *stateStore) save(state persistentState) error { + s.mu.Lock() + defer s.mu.Unlock() + + data, err := json.MarshalIndent(state, "", " ") + if err != nil { + return err + } + + dir := filepath.Dir(s.path) + tmpFile, err := os.CreateTemp(dir, "state.json.tmp-*") + if err != nil { + return err + } + defer os.Remove(tmpFile.Name()) + + if err := tmpFile.Chmod(0o600); err != nil { + tmpFile.Close() + return err + } + if _, err := tmpFile.Write(data); err != nil { + tmpFile.Close() + return err + } + if err := tmpFile.Close(); err != nil { + return err + } + + return os.Rename(tmpFile.Name(), s.path) +} diff --git a/internal/token/manager.go b/internal/token/manager.go index dac055c..bf77ff0 100644 --- a/internal/token/manager.go +++ b/internal/token/manager.go @@ -3,8 +3,11 @@ package token import ( "fmt" "io" + "math" "net/http" "regexp" + "slices" + "sort" "strconv" "strings" "sync" @@ -13,9 +16,10 @@ import ( "main/internal/config" "main/internal/httpclient" "main/internal/logging" - "main/internal/support" ) +const defaultAccountID = "__default__" + type TokenInfo struct { SNlM0e string BLToken string @@ -37,34 +41,169 @@ type pageState struct { FSID string } +type AccountStatus struct { + ID string `json:"id"` + Email string `json:"email"` + Enabled bool `json:"enabled"` + Weight int `json:"weight"` + StateCode string `json:"state_code"` + StateLabel string `json:"state_label"` + ActionRequired string `json:"action_required,omitempty"` + Retryable bool `json:"retryable"` + NextRetryAt time.Time `json:"next_retry_at,omitempty"` + TokenReady bool `json:"token_ready"` + HasProxy bool `json:"has_proxy"` + UsingCookies bool `json:"using_cookies"` + HasManualToken bool `json:"has_manual_token"` + BoundSessions int `json:"bound_sessions"` + ConsecutiveFailures int `json:"consecutive_failures"` + BackoffUntil time.Time `json:"backoff_until,omitempty"` + LastUsedAt time.Time `json:"last_used_at,omitempty"` + LastError string `json:"last_error,omitempty"` + LastTokenRefreshAt time.Time `json:"last_token_refresh_at,omitempty"` + RecentFailures []FailureEvent `json:"recent_failures,omitempty"` +} + +type FailureEvent struct { + At time.Time `json:"at"` + Code string `json:"code"` + Label string `json:"label"` + Reason string `json:"reason"` + Action string `json:"action,omitempty"` + Retryable bool `json:"retryable"` +} + +type accountState struct { + Code string + Label string + ActionRequired string + Retryable bool + NextRetryAt time.Time +} + +type PoolStats struct { + TotalAccounts int `json:"total_accounts"` + EnabledAccounts int `json:"enabled_accounts"` + HealthyAccounts int `json:"healthy_accounts"` + BackoffAccounts int `json:"backoff_accounts"` + NotReadyAccounts int `json:"not_ready_accounts"` + DisabledAccounts int `json:"disabled_accounts"` + BoundSessions int `json:"bound_sessions"` +} + +type SessionBinding struct { + SessionKey string `json:"session_key"` + AccountID string `json:"account_id"` + BoundAt time.Time `json:"bound_at"` + LastUsedAt time.Time `json:"last_used_at"` +} + +type AccountTokenSnapshot struct { + SNlM0e string `json:"snlm0e"` + BLToken string `json:"bl_token"` + FSID string `json:"fsid"` + ReqID int64 `json:"req_id"` + FetchedAt time.Time `json:"fetched_at"` +} + +type CookieHealth struct { + AccountID string `json:"account_id"` + CookieCount int `json:"cookie_count"` + ImportantMissing []string `json:"important_missing"` + ImportantPresent map[string]bool `json:"important_present"` + AbuseExemption CookieTimeHint `json:"abuse_exemption"` + AnalyticsTimeHints []CookieTimeHint `json:"analytics_time_hints,omitempty"` + OpaqueSessionCookies []string `json:"opaque_session_cookies,omitempty"` + StateCode string `json:"state_code"` + StateLabel string `json:"state_label"` + TokenReady bool `json:"token_ready"` + LastError string `json:"last_error,omitempty"` +} + +type CookieTimeHint struct { + Name string `json:"name"` + Source string `json:"source"` + Epoch int64 `json:"epoch,omitempty"` + Time time.Time `json:"time,omitempty"` + AgeSec int64 `json:"age_sec,omitempty"` + ValueSeen bool `json:"value_seen"` +} + +type SelectedAccount struct { + ID string + Email string + Cookies string + Proxy string + Token string + BLToken string + FSID string + ReqID string + TokenFetched bool +} + +type sessionBinding struct { + AccountID string + BoundAt time.Time + LastUsedAt time.Time +} + +type accountRuntime struct { + cfg config.Account + tokenInfo *TokenInfo + sessionTokens map[string]*AnonToken + consecutiveFailures int + backoffUntil time.Time + lastUsedAt time.Time + lastError string + recentFailures []FailureEvent +} + type Manager struct { - getConfig func() config.Config - getClient func() *http.Client - getLogger func() *logging.Logger + getConfig func() config.Config + getClient func() *http.Client + getLogger func() *logging.Logger + updateConfig func(func(*config.Config) error) error - tokenInfo *TokenInfo + mu sync.RWMutex + accounts map[string]*accountRuntime + sessionBinding map[string]*sessionBinding + roundRobin uint64 + clientMu sync.Mutex + proxyClients map[string]*http.Client +} - mutex sync.RWMutex - sessionTokens map[string]*AnonToken +func NewManager(getConfig func() config.Config, getClient func() *http.Client, getLogger func() *logging.Logger, updateConfig func(func(*config.Config) error) error) *Manager { + m := &Manager{ + getConfig: getConfig, + getClient: getClient, + getLogger: getLogger, + updateConfig: updateConfig, + accounts: make(map[string]*accountRuntime), + sessionBinding: make(map[string]*sessionBinding), + proxyClients: make(map[string]*http.Client), + } + m.reloadAccountsLocked() + return m } -func NewManager(getConfig func() config.Config, getClient func() *http.Client, getLogger func() *logging.Logger) *Manager { - return &Manager{ - getConfig: getConfig, - getClient: getClient, - getLogger: getLogger, - tokenInfo: &TokenInfo{}, - sessionTokens: make(map[string]*AnonToken), +func (m *Manager) clientForProxy(proxyValue string) *http.Client { + proxyValue = strings.TrimSpace(proxyValue) + if proxyValue == "" { + return m.getClient() + } + m.clientMu.Lock() + defer m.clientMu.Unlock() + if client := m.proxyClients[proxyValue]; client != nil { + return client } + client, _, _ := httpclient.NewWithProxy(m.getConfig(), proxyValue, m.getLogger()) + m.proxyClients[proxyValue] = client + return client } func (m *Manager) StartRefresher() { - cfg := m.getConfig() - if cfg.Cookies != "" { - if err := m.fetchToken(); err != nil { - m.getLogger().Warn("初始令牌获取失败: %v", err) - } - } + m.RefreshAccountsFromConfig() + m.refreshAllAccountsIfNeeded(true) go func() { ticker := time.NewTicker(25 * time.Minute) @@ -75,183 +214,863 @@ func (m *Manager) StartRefresher() { }() } +func (m *Manager) RefreshAccountsFromConfig() { + m.mu.Lock() + defer m.mu.Unlock() + m.reloadAccountsLocked() +} + func (m *Manager) RefreshTokenIfNeeded() { - m.tokenInfo.mutex.RLock() - needRefresh := m.tokenInfo.SNlM0e == "" || - m.tokenInfo.BLToken == "" || - m.tokenInfo.FSID == "" || - time.Since(m.tokenInfo.FetchedAt) > 30*time.Minute - m.tokenInfo.mutex.RUnlock() + m.RefreshAccountsFromConfig() + m.refreshAllAccountsIfNeeded(false) +} - cfg := m.getConfig() - if needRefresh && cfg.Cookies != "" { - if err := m.fetchToken(); err != nil { - m.getLogger().Warn("自动刷新令牌失败: %v", err) +func (m *Manager) RefreshTokenNow() error { + m.RefreshAccountsFromConfig() + ids := m.accountIDs() + var errs []string + for _, id := range ids { + if err := m.RefreshAccountNow(id); err != nil { + errs = append(errs, fmt.Sprintf("%s: %v", id, err)) } } + if len(errs) > 0 && len(errs) == len(ids) { + return fmt.Errorf("%s", strings.Join(errs, "; ")) + } + return nil } -func (m *Manager) FetchAnonymousToken() (string, error) { - endpoints := httpclient.CurrentGeminiEndpoints(m.getConfig()) - req, err := http.NewRequest("GET", endpoints.Home, nil) +func (m *Manager) RefreshAccountNow(accountID string) error { + m.RefreshAccountsFromConfig() + m.mu.Lock() + acc, ok := m.accounts[accountID] + if !ok { + m.mu.Unlock() + return fmt.Errorf("account not found: %s", accountID) + } + cfg := acc.cfg + m.mu.Unlock() + + if strings.TrimSpace(cfg.Cookies) == "" { + return nil + } + if err := m.fetchToken(accountID); err != nil { + m.mu.Lock() + if acc := m.accounts[accountID]; acc != nil { + acc.tokenInfo.mutex.RLock() + hasUsableToken := strings.TrimSpace(acc.cfg.Token) != "" || strings.TrimSpace(acc.tokenInfo.SNlM0e) != "" + acc.tokenInfo.mutex.RUnlock() + acc.lastError = err.Error() + if hasUsableToken { + acc.backoffUntil = time.Time{} + acc.consecutiveFailures = 0 + } + } + m.mu.Unlock() + return err + } + + m.mu.Lock() + if acc := m.accounts[accountID]; acc != nil { + acc.sessionTokens = make(map[string]*AnonToken) + acc.lastError = "" + acc.backoffUntil = time.Time{} + acc.consecutiveFailures = 0 + } + m.mu.Unlock() + return nil +} + +func (m *Manager) GetTokenForSession(sessionKey string, isNewSession bool) (string, int) { + selected, err := m.SelectAccountForSession(sessionKey, isNewSession) if err != nil { - return "", fmt.Errorf("create request failed: %w", err) + m.getLogger().Warn("为会话 %s 选择账号失败: %v", sessionKey, err) + return "", 0 } - req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36") - req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8") - req.Header.Set("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8") + return selected.Token, 0 +} + +func (m *Manager) SelectAccountForSession(sessionKey string, isNewSession bool) (SelectedAccount, error) { + m.RefreshAccountsFromConfig() + accountID, err := m.pickAccountID(sessionKey, isNewSession) + if err != nil { + return SelectedAccount{}, err + } + return m.GetSelectedAccount(accountID, sessionKey, isNewSession) +} + +func (m *Manager) GetSelectedAccount(accountID, sessionKey string, isNewSession bool) (SelectedAccount, error) { + m.RefreshAccountsFromConfig() + m.mu.RLock() + acc, ok := m.accounts[accountID] + m.mu.RUnlock() + if !ok { + return SelectedAccount{}, fmt.Errorf("account not found: %s", accountID) + } + + if err := m.ensureAccountReady(accountID); err != nil { + return SelectedAccount{}, err + } + + m.mu.Lock() + acc = m.accounts[accountID] + if sessionKey != "" { + m.bindSessionLocked(sessionKey, accountID) + if token, exists := acc.sessionTokens[sessionKey]; exists && !isNewSession && !token.IsBad && time.Since(token.FetchedAt) < 25*time.Minute { + acc.lastUsedAt = time.Now() + binding := m.sessionBinding[sessionKey] + binding.LastUsedAt = time.Now() + m.mu.Unlock() + return m.snapshotSelectedAccount(accountID, token.SNlM0e), nil + } + } + m.mu.Unlock() + + tokenValue := "" + if sessionKey != "" { + if t, err := m.FetchAnonymousTokenForAccount(accountID); err == nil && t != "" { + m.mu.Lock() + if acc = m.accounts[accountID]; acc != nil { + acc.sessionTokens[sessionKey] = &AnonToken{SNlM0e: t, FetchedAt: time.Now()} + acc.lastUsedAt = time.Now() + if binding := m.sessionBinding[sessionKey]; binding != nil { + binding.LastUsedAt = time.Now() + } + } + m.mu.Unlock() + tokenValue = t + } + } + return m.snapshotSelectedAccount(accountID, tokenValue), nil +} + +func (m *Manager) MarkSessionTokenBad(sessionKey string) { + m.mu.Lock() + defer m.mu.Unlock() + binding := m.sessionBinding[sessionKey] + if binding == nil { + return + } + if acc, exists := m.accounts[binding.AccountID]; exists { + if token, ok := acc.sessionTokens[sessionKey]; ok { + token.IsBad = true + } + m.recordFailureLocked(acc, "session token marked bad") + delete(m.sessionBinding, sessionKey) + m.getLogger().Warn("会话 %s 的账号 %s 已标记为失效并解除绑定", sessionKey, binding.AccountID) + } +} + +func (m *Manager) MarkAccountSuccess(accountID string) { + m.mu.Lock() + defer m.mu.Unlock() + acc := m.accounts[accountID] + if acc == nil { + return + } + acc.consecutiveFailures = 0 + acc.backoffUntil = time.Time{} + acc.lastError = "" + acc.lastUsedAt = time.Now() +} - randomIP := support.GenerateRandomIP() - req.Header.Set("X-Forwarded-For", randomIP) - req.Header.Set("X-Real-IP", randomIP) +func (m *Manager) MarkAccountFailure(accountID string, reason string) { + m.mu.Lock() + defer m.mu.Unlock() + acc := m.accounts[accountID] + if acc == nil { + return + } + m.recordFailureLocked(acc, reason) + for sessionKey, binding := range m.sessionBinding { + if binding.AccountID == accountID { + delete(m.sessionBinding, sessionKey) + } + } + acc.sessionTokens = make(map[string]*AnonToken) +} - resp, err := m.getClient().Do(req) +func (m *Manager) GetToken() string { + selected, err := m.SelectAccountForSession("", false) if err != nil { - return "", fmt.Errorf("request failed: %w", err) + return "" } - defer resp.Body.Close() + return selected.Token +} - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("unexpected status: %d", resp.StatusCode) +func (m *Manager) GetBLToken() string { + selected, err := m.SelectAccountForSession("", false) + if err != nil { + return "" } + return selected.BLToken +} - body, err := io.ReadAll(resp.Body) +func (m *Manager) GetFSID() string { + selected, err := m.SelectAccountForSession("", false) if err != nil { - return "", fmt.Errorf("read body failed: %w", err) + return "" } + return selected.FSID +} - state := extractPageState(body) - m.updateTokenInfoFromState(state) - if state.RequestToken == "" { - return "", fmt.Errorf("request token not found in anonymous page") +func (m *Manager) NextReqID() string { + selected, err := m.SelectAccountForSession("", false) + if err != nil { + return strconv.FormatInt(seedReqID(), 10) } + return selected.ReqID +} - m.getLogger().Debug("成功获取匿名请求令牌 (长度=%d)", len(state.RequestToken)) - return state.RequestToken, nil +func (m *Manager) AccountsStatus() []AccountStatus { + m.RefreshAccountsFromConfig() + m.mu.RLock() + defer m.mu.RUnlock() + statuses := make([]AccountStatus, 0, len(m.accounts)) + boundCounts := make(map[string]int) + for _, binding := range m.sessionBinding { + boundCounts[binding.AccountID]++ + } + for _, id := range sortedAccountIDs(m.accounts) { + acc := m.accounts[id] + acc.tokenInfo.mutex.RLock() + tokenReady := strings.TrimSpace(acc.cfg.Token) != "" || strings.TrimSpace(acc.tokenInfo.SNlM0e) != "" + lastTokenRefreshAt := acc.tokenInfo.FetchedAt + status := AccountStatus{ + ID: acc.cfg.ID, + Email: acc.cfg.Email, + Enabled: acc.cfg.Enabled, + Weight: normalizedWeight(acc.cfg.Weight), + TokenReady: tokenReady, + HasProxy: strings.TrimSpace(acc.cfg.Proxy) != "", + UsingCookies: strings.TrimSpace(acc.cfg.Cookies) != "", + HasManualToken: strings.TrimSpace(acc.cfg.Token) != "", + BoundSessions: boundCounts[id], + ConsecutiveFailures: acc.consecutiveFailures, + BackoffUntil: acc.backoffUntil, + LastUsedAt: acc.lastUsedAt, + LastError: acc.lastError, + LastTokenRefreshAt: lastTokenRefreshAt, + RecentFailures: append([]FailureEvent(nil), acc.recentFailures...), + } + acc.tokenInfo.mutex.RUnlock() + state := classifyAccountState(acc, tokenReady) + status.StateCode = state.Code + status.StateLabel = state.Label + status.ActionRequired = state.ActionRequired + status.Retryable = state.Retryable + status.NextRetryAt = state.NextRetryAt + statuses = append(statuses, status) + } + return statuses } -func (m *Manager) GetTokenForSession(sessionKey string, isNewSession bool) (string, int) { - m.mutex.Lock() - defer m.mutex.Unlock() +func (m *Manager) SessionBindings() []SessionBinding { + m.mu.RLock() + defer m.mu.RUnlock() + bindings := make([]SessionBinding, 0, len(m.sessionBinding)) + for sessionKey, binding := range m.sessionBinding { + bindings = append(bindings, SessionBinding{ + SessionKey: sessionKey, + AccountID: binding.AccountID, + BoundAt: binding.BoundAt, + LastUsedAt: binding.LastUsedAt, + }) + } + sort.Slice(bindings, func(i, j int) bool { + return bindings[i].SessionKey < bindings[j].SessionKey + }) + return bindings +} + +func (m *Manager) CookieHealth(accountID string) (CookieHealth, bool) { + m.mu.RLock() + acc := m.accounts[accountID] + m.mu.RUnlock() + if acc == nil { + return CookieHealth{}, false + } + + cookies := parseCookiePairs(acc.cfg.Cookies) + important := []string{"COMPASS", "GOOGLE_ABUSE_EXEMPTION", "SID", "__Secure-1PSID", "__Secure-3PSID", "SAPISID", "__Secure-1PAPISID", "__Secure-3PAPISID", "SIDCC", "__Secure-1PSIDCC", "__Secure-3PSIDCC", "__Secure-1PSIDTS", "__Secure-3PSIDTS"} + present := make(map[string]bool, len(important)) + missing := make([]string, 0) + for _, key := range important { + _, ok := cookies[key] + present[key] = ok + if !ok { + missing = append(missing, key) + } + } - if token, exists := m.sessionTokens[sessionKey]; exists && !isNewSession && !token.IsBad { - if time.Since(token.FetchedAt) < 25*time.Minute { - return token.SNlM0e, 0 + statuses := m.AccountsStatus() + var status AccountStatus + for _, candidate := range statuses { + if candidate.ID == accountID { + status = candidate + break } } - snlm0e, err := m.FetchAnonymousToken() + health := CookieHealth{ + AccountID: accountID, + CookieCount: len(cookies), + ImportantMissing: missing, + ImportantPresent: present, + AbuseExemption: cookieTimeHint("GOOGLE_ABUSE_EXEMPTION", cookies["GOOGLE_ABUSE_EXEMPTION"], `(?:^|:)TM=(\d{10})(?:[:;]|$)`, "TM"), + AnalyticsTimeHints: analyticsCookieTimeHints(cookies), + OpaqueSessionCookies: opaqueSessionCookies(cookies), + StateCode: status.StateCode, + StateLabel: status.StateLabel, + TokenReady: status.TokenReady, + LastError: status.LastError, + } + return health, true +} + +func parseCookiePairs(raw string) map[string]string { + cookies := map[string]string{} + for _, part := range strings.Split(raw, ";") { + part = strings.TrimSpace(part) + if part == "" { + continue + } + key, value, ok := strings.Cut(part, "=") + if !ok { + continue + } + key = strings.TrimSpace(key) + if key == "" { + continue + } + cookies[key] = strings.TrimSpace(value) + } + return cookies +} + +func cookieTimeHint(name string, value string, pattern string, source string) CookieTimeHint { + hint := CookieTimeHint{Name: name, Source: source, ValueSeen: strings.TrimSpace(value) != ""} + if value == "" { + return hint + } + matches := regexp.MustCompile(pattern).FindStringSubmatch(value) + if len(matches) < 2 { + return hint + } + epoch, err := strconv.ParseInt(matches[1], 10, 64) if err != nil { - if token, exists := m.sessionTokens[sessionKey]; exists { - return token.SNlM0e, 0 + return hint + } + hint.Epoch = epoch + hint.Time = time.Unix(epoch, 0).UTC() + hint.AgeSec = int64(time.Since(hint.Time).Seconds()) + return hint +} + +func analyticsCookieTimeHints(cookies map[string]string) []CookieTimeHint { + hints := make([]CookieTimeHint, 0) + for key, value := range cookies { + if strings.HasPrefix(key, "_ga_") { + hints = append(hints, cookieTimeHint(key, value, `\$t(\d{10})`, "$t")) } - return "", 0 } + slices.SortFunc(hints, func(a, b CookieTimeHint) int { return strings.Compare(a.Name, b.Name) }) + return hints +} - m.sessionTokens[sessionKey] = &AnonToken{ - SNlM0e: snlm0e, - FetchedAt: time.Now(), +func opaqueSessionCookies(cookies map[string]string) []string { + keys := make([]string, 0) + for _, key := range []string{"__Secure-1PSIDTS", "__Secure-3PSIDTS"} { + if strings.HasPrefix(cookies[key], "sidts-") { + keys = append(keys, key) + } } - m.getLogger().Debug("已为会话 %s 分配新的匿名令牌", sessionKey) - return snlm0e, 0 + return keys } -func (m *Manager) MarkSessionTokenBad(sessionKey string) { - m.mutex.Lock() - defer m.mutex.Unlock() +func (m *Manager) TokenSnapshots() map[string]AccountTokenSnapshot { + m.mu.RLock() + defer m.mu.RUnlock() + snapshots := make(map[string]AccountTokenSnapshot, len(m.accounts)) + for _, id := range sortedAccountIDs(m.accounts) { + acc := m.accounts[id] + if acc == nil { + continue + } + acc.tokenInfo.mutex.RLock() + snapshot := AccountTokenSnapshot{ + SNlM0e: acc.tokenInfo.SNlM0e, + BLToken: acc.tokenInfo.BLToken, + FSID: acc.tokenInfo.FSID, + ReqID: acc.tokenInfo.ReqID, + FetchedAt: acc.tokenInfo.FetchedAt, + } + acc.tokenInfo.mutex.RUnlock() + if strings.TrimSpace(snapshot.SNlM0e) == "" && strings.TrimSpace(snapshot.BLToken) == "" && strings.TrimSpace(snapshot.FSID) == "" { + continue + } + snapshots[id] = snapshot + } + return snapshots +} - if token, exists := m.sessionTokens[sessionKey]; exists { - token.IsBad = true - m.getLogger().Warn("会话 %s 的令牌已被标记为失效", sessionKey) +func (m *Manager) RestoreTokenSnapshots(snapshots map[string]AccountTokenSnapshot) { + if len(snapshots) == 0 { + return + } + m.mu.Lock() + defer m.mu.Unlock() + for accountID, snapshot := range snapshots { + acc := m.accounts[accountID] + if acc == nil { + continue + } + acc.tokenInfo.mutex.Lock() + acc.tokenInfo.SNlM0e = strings.TrimSpace(snapshot.SNlM0e) + acc.tokenInfo.BLToken = strings.TrimSpace(snapshot.BLToken) + acc.tokenInfo.FSID = strings.TrimSpace(snapshot.FSID) + acc.tokenInfo.ReqID = snapshot.ReqID + acc.tokenInfo.FetchedAt = snapshot.FetchedAt + acc.tokenInfo.mutex.Unlock() } } -func (m *Manager) GetToken() string { - m.tokenInfo.mutex.RLock() - defer m.tokenInfo.mutex.RUnlock() +func (m *Manager) RestoreSessionBindings(bindings []SessionBinding) { + m.mu.Lock() + defer m.mu.Unlock() + for _, binding := range bindings { + if _, exists := m.accounts[binding.AccountID]; !exists { + continue + } + m.sessionBinding[binding.SessionKey] = &sessionBinding{ + AccountID: binding.AccountID, + BoundAt: binding.BoundAt, + LastUsedAt: binding.LastUsedAt, + } + } +} - if m.tokenInfo.SNlM0e != "" { - return m.tokenInfo.SNlM0e +func (m *Manager) PoolStats() PoolStats { + m.RefreshAccountsFromConfig() + m.mu.RLock() + defer m.mu.RUnlock() + stats := PoolStats{ + TotalAccounts: len(m.accounts), + BoundSessions: len(m.sessionBinding), + } + now := time.Now() + for _, id := range sortedAccountIDs(m.accounts) { + acc := m.accounts[id] + if acc == nil { + continue + } + if acc.cfg.Enabled { + stats.EnabledAccounts++ + } else { + stats.DisabledAccounts++ + } + acc.tokenInfo.mutex.RLock() + tokenReady := strings.TrimSpace(acc.cfg.Token) != "" || strings.TrimSpace(acc.tokenInfo.SNlM0e) != "" + acc.tokenInfo.mutex.RUnlock() + if !acc.cfg.Enabled { + continue + } + if !acc.backoffUntil.IsZero() && acc.backoffUntil.After(now) { + stats.BackoffAccounts++ + continue + } + if !tokenReady { + stats.NotReadyAccounts++ + continue + } + stats.HealthyAccounts++ } - return m.getConfig().Token + return stats } -func (m *Manager) GetBLToken() string { - m.tokenInfo.mutex.RLock() - defer m.tokenInfo.mutex.RUnlock() - return m.tokenInfo.BLToken +func (m *Manager) UpsertAccount(account config.Account) error { + account.ID = strings.TrimSpace(account.ID) + if account.ID == "" { + return fmt.Errorf("account id is required") + } + if normalizedWeight(account.Weight) != account.Weight { + account.Weight = normalizedWeight(account.Weight) + } + return m.getConfigStoreUpdate(func(cfg *config.Config) error { + for i := range cfg.Accounts { + if cfg.Accounts[i].ID == account.ID { + if strings.TrimSpace(account.Cookies) == "" { + account.Cookies = cfg.Accounts[i].Cookies + } + if strings.TrimSpace(account.Token) == "" { + account.Token = cfg.Accounts[i].Token + } + if strings.TrimSpace(account.Proxy) == "" { + account.Proxy = cfg.Accounts[i].Proxy + } + cfg.Accounts[i] = account + return nil + } + } + cfg.Accounts = append(cfg.Accounts, account) + return nil + }) } -func (m *Manager) GetFSID() string { - m.tokenInfo.mutex.RLock() - defer m.tokenInfo.mutex.RUnlock() - return m.tokenInfo.FSID +func (m *Manager) SetAccountEnabled(accountID string, enabled bool) error { + return m.getConfigStoreUpdate(func(cfg *config.Config) error { + for i := range cfg.Accounts { + if cfg.Accounts[i].ID == accountID { + cfg.Accounts[i].Enabled = enabled + return nil + } + } + return fmt.Errorf("account not found: %s", accountID) + }) } -func (m *Manager) NextReqID() string { - m.tokenInfo.mutex.Lock() - defer m.tokenInfo.mutex.Unlock() +func (m *Manager) DeleteAccount(accountID string) error { + if strings.TrimSpace(accountID) == "" { + return fmt.Errorf("account id is required") + } + if accountID == defaultAccountID { + return fmt.Errorf("default account cannot be deleted") + } + if err := m.getConfigStoreUpdate(func(cfg *config.Config) error { + filtered := cfg.Accounts[:0] + found := false + for _, account := range cfg.Accounts { + if account.ID == accountID { + found = true + continue + } + filtered = append(filtered, account) + } + if !found { + return fmt.Errorf("account not found: %s", accountID) + } + cfg.Accounts = filtered + return nil + }); err != nil { + return err + } - if m.tokenInfo.ReqID == 0 { - m.tokenInfo.ReqID = seedReqID() + m.mu.Lock() + defer m.mu.Unlock() + delete(m.accounts, accountID) + for sessionKey, binding := range m.sessionBinding { + if binding.AccountID == accountID { + delete(m.sessionBinding, sessionKey) + } } - current := m.tokenInfo.ReqID - m.tokenInfo.ReqID += 100000 - return strconv.FormatInt(current, 10) + return nil } -func (m *Manager) fetchToken() error { - cfg := m.getConfig() - if cfg.Cookies == "" { +func (m *Manager) UnbindSession(sessionKey string) error { + if strings.TrimSpace(sessionKey) == "" { + return fmt.Errorf("session key is required") + } + m.mu.Lock() + defer m.mu.Unlock() + binding := m.sessionBinding[sessionKey] + if binding == nil { + return fmt.Errorf("session binding not found: %s", sessionKey) + } + if acc := m.accounts[binding.AccountID]; acc != nil { + delete(acc.sessionTokens, sessionKey) + } + delete(m.sessionBinding, sessionKey) + return nil +} + +func (m *Manager) RebindSession(sessionKey, accountID string) error { + if strings.TrimSpace(sessionKey) == "" { + return fmt.Errorf("session key is required") + } + if strings.TrimSpace(accountID) == "" { + return fmt.Errorf("account id is required") + } + m.mu.Lock() + defer m.mu.Unlock() + acc := m.accounts[accountID] + if acc == nil { + return fmt.Errorf("account not found: %s", accountID) + } + if !m.accountAvailableLocked(acc) { + return fmt.Errorf("account is not available: %s", accountID) + } + if existing := m.sessionBinding[sessionKey]; existing != nil { + if oldAcc := m.accounts[existing.AccountID]; oldAcc != nil { + delete(oldAcc.sessionTokens, sessionKey) + } + } + m.bindSessionLocked(sessionKey, accountID) + delete(acc.sessionTokens, sessionKey) + return nil +} + +func (m *Manager) pickAccountID(sessionKey string, isNewSession bool) (string, error) { + m.mu.Lock() + defer m.mu.Unlock() + if sessionKey != "" { + if binding := m.sessionBinding[sessionKey]; binding != nil && !isNewSession { + if acc := m.accounts[binding.AccountID]; acc != nil && m.accountAvailableLocked(acc) { + binding.LastUsedAt = time.Now() + return binding.AccountID, nil + } + delete(m.sessionBinding, sessionKey) + } + } + + candidates := m.availableAccountIDsLocked() + if len(candidates) == 0 { + return "", fmt.Errorf("no healthy accounts available") + } + idx := int(m.roundRobin % uint64(len(candidates))) + m.roundRobin++ + accountID := candidates[idx] + if sessionKey != "" { + m.bindSessionLocked(sessionKey, accountID) + } + return accountID, nil +} + +func (m *Manager) snapshotSelectedAccount(accountID string, sessionToken string) SelectedAccount { + m.mu.RLock() + acc := m.accounts[accountID] + m.mu.RUnlock() + selected := SelectedAccount{} + if acc == nil { + return selected + } + acc.tokenInfo.mutex.RLock() + selected = SelectedAccount{ + ID: acc.cfg.ID, + Email: acc.cfg.Email, + Cookies: acc.cfg.Cookies, + Proxy: acc.cfg.Proxy, + Token: firstNonEmpty(sessionToken, acc.tokenInfo.SNlM0e, acc.cfg.Token), + BLToken: acc.tokenInfo.BLToken, + FSID: acc.tokenInfo.FSID, + TokenFetched: !acc.tokenInfo.FetchedAt.IsZero(), + } + acc.tokenInfo.mutex.RUnlock() + acc.tokenInfo.mutex.Lock() + selected.ReqID = nextReqIDLocked(acc.tokenInfo) + acc.tokenInfo.mutex.Unlock() + return selected +} + +func (m *Manager) ensureAccountReady(accountID string) error { + m.mu.RLock() + acc := m.accounts[accountID] + m.mu.RUnlock() + if acc == nil { + return fmt.Errorf("account not found: %s", accountID) + } + + acc.tokenInfo.mutex.RLock() + needRefresh := acc.tokenInfo.SNlM0e == "" || acc.tokenInfo.BLToken == "" || acc.tokenInfo.FSID == "" || time.Since(acc.tokenInfo.FetchedAt) > 30*time.Minute + hasUsableToken := strings.TrimSpace(acc.cfg.Token) != "" || strings.TrimSpace(acc.tokenInfo.SNlM0e) != "" + acc.tokenInfo.mutex.RUnlock() + if needRefresh && strings.TrimSpace(acc.cfg.Cookies) != "" { + if err := m.fetchToken(accountID); err != nil { + m.mu.Lock() + if acc := m.accounts[accountID]; acc != nil { + acc.lastError = err.Error() + if hasUsableToken { + acc.backoffUntil = time.Time{} + acc.consecutiveFailures = 0 + } else { + m.recordFailureLocked(acc, err.Error()) + } + } + m.mu.Unlock() + if hasUsableToken { + return nil + } + return err + } + } + return nil +} + +func (m *Manager) FetchAnonymousTokenForAccount(accountID string) (string, error) { + m.mu.RLock() + acc := m.accounts[accountID] + m.mu.RUnlock() + if acc == nil { + return "", fmt.Errorf("account not found: %s", accountID) + } + + endpoints := httpclient.CurrentGeminiEndpoints(m.getConfig()) + req, err := http.NewRequest("GET", endpoints.Home, nil) + if err != nil { + return "", fmt.Errorf("create request failed: %w", err) + } + req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36") + req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8") + req.Header.Set("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8") + if strings.TrimSpace(acc.cfg.Cookies) != "" { + req.Header.Set("Cookie", acc.cfg.Cookies) + } + resp, err := m.clientForProxy(acc.cfg.Proxy).Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("unexpected status: %d", resp.StatusCode) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read body failed: %w", err) + } + state := extractPageState(body) + m.updateTokenInfo(accountID, state) + if state.RequestToken == "" { + return "", missingRequestTokenError(body) + } + m.MarkAccountSuccess(accountID) + return state.RequestToken, nil +} + +func (m *Manager) fetchToken(accountID string) error { + m.mu.RLock() + acc := m.accounts[accountID] + m.mu.RUnlock() + if acc == nil { + return fmt.Errorf("account not found: %s", accountID) + } + if strings.TrimSpace(acc.cfg.Cookies) == "" { + if strings.TrimSpace(acc.cfg.Token) != "" { + m.updateTokenInfo(accountID, pageState{RequestToken: acc.cfg.Token}) + return nil + } return nil } - endpoints := httpclient.CurrentGeminiEndpoints(cfg) + endpoints := httpclient.CurrentGeminiEndpoints(m.getConfig()) req, err := http.NewRequest("GET", endpoints.Home, nil) if err != nil { return fmt.Errorf("create request failed: %w", err) } req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36") - req.Header.Set("Cookie", cfg.Cookies) + req.Header.Set("Cookie", acc.cfg.Cookies) req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8") req.Header.Set("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8") - req.Header.Set("X-Forwarded-For", support.GenerateRandomIP()) - - resp, err := m.getClient().Do(req) + resp, err := m.clientForProxy(acc.cfg.Proxy).Do(req) if err != nil { return fmt.Errorf("request failed: %w", err) } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { return fmt.Errorf("unexpected status: %d", resp.StatusCode) } - body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("read body failed: %w", err) } - state := extractPageState(body) - m.updateTokenInfoFromState(state) - + m.updateTokenInfo(accountID, state) if state.RequestToken != "" { - m.getLogger().Info("页面态获取成功: token长度=%d, BL=%s, f.sid=%s", len(state.RequestToken), state.BLToken, state.FSID) + m.MarkAccountSuccess(accountID) + m.getLogger().Info("账号 %s 页面态获取成功: token长度=%d, BL=%s, f.sid=%s", accountID, len(state.RequestToken), state.BLToken, state.FSID) return nil } - - if cfg.Token != "" { - m.tokenInfo.mutex.Lock() - m.tokenInfo.SNlM0e = cfg.Token - m.tokenInfo.FetchedAt = time.Now() - m.tokenInfo.mutex.Unlock() - m.getLogger().Info("正在使用配置文件中的令牌") + if strings.TrimSpace(acc.cfg.Token) != "" { + m.updateTokenInfo(accountID, pageState{RequestToken: acc.cfg.Token}) + m.MarkAccountSuccess(accountID) + m.getLogger().Info("账号 %s 正在使用配置文件中的令牌", accountID) return nil } + return missingRequestTokenError(body) +} +func missingRequestTokenError(body []byte) error { + bodyText := strings.ToLower(string(body)) + state := extractPageState(body) + if state.BLToken != "" || state.FSID != "" { + return fmt.Errorf("request token not found in Gemini app page") + } + if strings.Contains(bodyText, "before you continue") || strings.Contains(bodyText, "使用前须知") || strings.Contains(bodyText, "accounts.google") || strings.Contains(bodyText, "sign in") || strings.Contains(bodyText, "登录") { + return fmt.Errorf("Gemini returned login/consent page; open gemini.google.com in the same browser, accept prompts, then copy the full Cookie again") + } + if strings.Contains(bodyText, "captcha") || strings.Contains(bodyText, "unusual traffic") || strings.Contains(bodyText, "sorry/index") { + return fmt.Errorf("Gemini returned anti-abuse challenge; verify the browser session and proxy before copying Cookie again") + } return fmt.Errorf("request token not found in page") } +func classifyAccountState(acc *accountRuntime, tokenReady bool) accountState { + now := time.Now() + if acc == nil { + return accountState{Code: "missing", Label: "账号不存在", ActionRequired: "检查账号池配置", Retryable: false} + } + if !acc.cfg.Enabled { + return accountState{Code: "disabled", Label: "已禁用", ActionRequired: "启用账号后再参与调度", Retryable: false} + } + if !acc.backoffUntil.IsZero() && acc.backoffUntil.After(now) { + failure := classifyFailure(acc.lastError) + return accountState{Code: "backoff", Label: "避退中", ActionRequired: failure.Action, Retryable: failure.Retryable, NextRetryAt: acc.backoffUntil} + } + if tokenReady { + return accountState{Code: "ready", Label: "健康", Retryable: true} + } + if strings.TrimSpace(acc.cfg.Cookies) == "" && strings.TrimSpace(acc.cfg.Token) == "" { + return accountState{Code: "empty_credentials", Label: "无登录态", ActionRequired: "导入 Cookie 或手动 Token", Retryable: false} + } + if acc.lastError != "" { + failure := classifyFailure(acc.lastError) + return accountState{Code: failure.Code, Label: failure.Label, ActionRequired: failure.Action, Retryable: failure.Retryable} + } + return accountState{Code: "not_ready", Label: "未就绪", ActionRequired: "点击刷新验证登录态", Retryable: true} +} + +func classifyFailure(reason string) FailureEvent { + lower := strings.ToLower(reason) + event := FailureEvent{Code: "unknown_error", Label: "未知错误", Reason: reason, Action: "查看日志并重试"} + switch { + case reason == "": + event.Code = "none" + event.Label = "无错误" + event.Reason = "" + event.Action = "" + case strings.Contains(lower, "login/consent") || strings.Contains(lower, "使用前须知") || strings.Contains(lower, "sign in") || strings.Contains(lower, "accounts.google"): + event.Code = "login_consent_required" + event.Label = "需要登录/同意" + event.Reason = reason + event.Action = "在对应浏览器打开 gemini.google.com,完成登录/同意后重新抓 Session" + case strings.Contains(lower, "anti-abuse") || strings.Contains(lower, "captcha") || strings.Contains(lower, "unusual traffic") || strings.Contains(lower, "sorry/index"): + event.Code = "anti_abuse_challenge" + event.Label = "风控验证" + event.Action = "检查代理和浏览器风控状态,通过验证后重新抓 Session" + case strings.Contains(lower, "request token not found") || strings.Contains(lower, "snlm0e"): + event.Code = "request_token_missing" + event.Label = "请求 Token 缺失" + event.Action = "Cookie 不完整或已过期,重新抓完整 Cookie" + case strings.Contains(lower, "unexpected status: 401") || strings.Contains(lower, "unauthorized"): + event.Code = "unauthorized" + event.Label = "未授权" + event.Action = "登录态失效,重新抓 Session" + case strings.Contains(lower, "unexpected status: 403") || strings.Contains(lower, "forbidden"): + event.Code = "forbidden" + event.Label = "账号受限" + event.Action = "账号或地区受限,检查浏览器页面状态" + case strings.Contains(lower, "unexpected status: 429") || strings.Contains(lower, "rate") || strings.Contains(lower, "quota"): + event.Code = "rate_limited" + event.Label = "限流" + event.Action = "等待冷却或切换账号" + case strings.Contains(lower, "request failed") || strings.Contains(lower, "timeout") || strings.Contains(lower, "connect") || strings.Contains(lower, "connection"): + event.Code = "network_error" + event.Label = "网络错误" + event.Action = "检查网络和代理后重试" + } + event.Retryable = event.Code == "rate_limited" || event.Code == "network_error" || event.Code == "unknown_error" + return event +} + func extractPageState(body []byte) pageState { return pageState{ RequestToken: firstMatch(body, []string{ @@ -292,27 +1111,215 @@ func firstMatch(body []byte, patterns []string) string { return "" } -func (m *Manager) updateTokenInfoFromState(state pageState) { +func (m *Manager) updateTokenInfo(accountID string, state pageState) { if state.RequestToken == "" && state.BLToken == "" && state.FSID == "" { return } - - m.tokenInfo.mutex.Lock() - defer m.tokenInfo.mutex.Unlock() - + m.mu.RLock() + acc := m.accounts[accountID] + m.mu.RUnlock() + if acc == nil { + return + } + acc.tokenInfo.mutex.Lock() + defer acc.tokenInfo.mutex.Unlock() if state.RequestToken != "" { - m.tokenInfo.SNlM0e = state.RequestToken + acc.tokenInfo.SNlM0e = state.RequestToken } if state.BLToken != "" { - m.tokenInfo.BLToken = state.BLToken + acc.tokenInfo.BLToken = state.BLToken } if state.FSID != "" { - m.tokenInfo.FSID = state.FSID + acc.tokenInfo.FSID = state.FSID + } + if acc.tokenInfo.ReqID == 0 { + acc.tokenInfo.ReqID = seedReqID() + } + acc.tokenInfo.FetchedAt = time.Now() +} + +func (m *Manager) refreshAllAccountsIfNeeded(force bool) { + for _, id := range m.accountIDs() { + m.mu.RLock() + acc := m.accounts[id] + m.mu.RUnlock() + if acc == nil || strings.TrimSpace(acc.cfg.Cookies) == "" { + continue + } + if !force { + acc.tokenInfo.mutex.RLock() + needRefresh := acc.tokenInfo.SNlM0e == "" || acc.tokenInfo.BLToken == "" || acc.tokenInfo.FSID == "" || time.Since(acc.tokenInfo.FetchedAt) > 30*time.Minute + acc.tokenInfo.mutex.RUnlock() + if !needRefresh { + continue + } + } + if err := m.fetchToken(id); err != nil { + m.mu.Lock() + if acc := m.accounts[id]; acc != nil { + acc.tokenInfo.mutex.RLock() + hasUsableToken := strings.TrimSpace(acc.cfg.Token) != "" || strings.TrimSpace(acc.tokenInfo.SNlM0e) != "" + acc.tokenInfo.mutex.RUnlock() + acc.lastError = err.Error() + if !hasUsableToken { + m.recordFailureLocked(acc, err.Error()) + } else { + acc.backoffUntil = time.Time{} + acc.consecutiveFailures = 0 + } + } + m.mu.Unlock() + m.getLogger().Warn("账号 %s 自动刷新令牌失败: %v", id, err) + } + } +} + +func (m *Manager) reloadAccountsLocked() { + cfg := m.getConfig() + oldAccounts := m.accounts + accounts := configuredAccounts(cfg) + newAccounts := make(map[string]*accountRuntime, len(accounts)) + for _, account := range accounts { + runtime := oldAccounts[account.ID] + if runtime == nil { + runtime = &accountRuntime{tokenInfo: &TokenInfo{}, sessionTokens: make(map[string]*AnonToken)} + } + runtime.cfg = account + if runtime.tokenInfo == nil { + runtime.tokenInfo = &TokenInfo{} + } + if runtime.sessionTokens == nil { + runtime.sessionTokens = make(map[string]*AnonToken) + } + newAccounts[account.ID] = runtime + } + m.accounts = newAccounts + for sessionKey, binding := range m.sessionBinding { + if _, exists := m.accounts[binding.AccountID]; !exists { + delete(m.sessionBinding, sessionKey) + } + } +} + +func configuredAccounts(cfg config.Config) []config.Account { + if len(cfg.Accounts) == 0 { + return []config.Account{{ + ID: defaultAccountID, + Email: "default", + Cookies: cfg.Cookies, + Token: cfg.Token, + Enabled: true, + Weight: 1, + }} + } + accounts := make([]config.Account, 0, len(cfg.Accounts)) + for i, account := range cfg.Accounts { + account.ID = strings.TrimSpace(account.ID) + if account.ID == "" { + account.ID = fmt.Sprintf("account-%d", i+1) + } + account.Weight = normalizedWeight(account.Weight) + accounts = append(accounts, account) + } + return accounts +} + +func (m *Manager) accountIDs() []string { + m.mu.RLock() + defer m.mu.RUnlock() + return sortedAccountIDs(m.accounts) +} + +func sortedAccountIDs(accounts map[string]*accountRuntime) []string { + ids := make([]string, 0, len(accounts)) + for id := range accounts { + ids = append(ids, id) + } + sort.Strings(ids) + return ids +} + +func (m *Manager) availableAccountIDsLocked() []string { + now := time.Now() + weighted := make([]string, 0, len(m.accounts)) + for _, id := range sortedAccountIDs(m.accounts) { + acc := m.accounts[id] + if acc == nil || !acc.cfg.Enabled || (!acc.backoffUntil.IsZero() && acc.backoffUntil.After(now)) { + continue + } + for i := 0; i < normalizedWeight(acc.cfg.Weight); i++ { + weighted = append(weighted, id) + } + } + return weighted +} + +func (m *Manager) accountAvailableLocked(acc *accountRuntime) bool { + if acc == nil || !acc.cfg.Enabled { + return false + } + return acc.backoffUntil.IsZero() || !acc.backoffUntil.After(time.Now()) +} + +func (m *Manager) bindSessionLocked(sessionKey, accountID string) { + now := time.Now() + m.sessionBinding[sessionKey] = &sessionBinding{AccountID: accountID, BoundAt: now, LastUsedAt: now} + if acc := m.accounts[accountID]; acc != nil { + acc.lastUsedAt = now + } +} + +func (m *Manager) recordFailureLocked(acc *accountRuntime, reason string) { + acc.consecutiveFailures++ + seconds := math.Min(1800, 30*math.Pow(2, float64(acc.consecutiveFailures-1))) + acc.backoffUntil = time.Now().Add(time.Duration(seconds) * time.Second) + acc.lastError = reason + acc.lastUsedAt = time.Now() + failure := classifyFailure(reason) + failure.At = time.Now() + acc.recentFailures = append([]FailureEvent{failure}, acc.recentFailures...) + if len(acc.recentFailures) > 5 { + acc.recentFailures = acc.recentFailures[:5] + } + if acc.cfg.ID != "" { + m.getLogger().Warn("账号 %s 进入避退,失败次数=%d,恢复时间=%s,原因=%s", acc.cfg.ID, acc.consecutiveFailures, acc.backoffUntil.Format(time.RFC3339), reason) + } +} + +func nextReqIDLocked(info *TokenInfo) string { + if info.ReqID == 0 { + info.ReqID = seedReqID() + } + current := info.ReqID + info.ReqID += 100000 + return strconv.FormatInt(current, 10) +} + +func normalizedWeight(weight int) int { + if weight <= 0 { + return 1 + } + return weight +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return value + } + } + return "" +} + +func (m *Manager) getConfigStoreUpdate(mutator func(*config.Config) error) error { + if m.updateConfig == nil { + return fmt.Errorf("config store updates not wired") } - if m.tokenInfo.ReqID == 0 { - m.tokenInfo.ReqID = seedReqID() + if err := m.updateConfig(mutator); err != nil { + return err } - m.tokenInfo.FetchedAt = time.Now() + m.RefreshAccountsFromConfig() + return nil } func seedReqID() int64 { diff --git a/internal/token/manager_test.go b/internal/token/manager_test.go new file mode 100644 index 0000000..fe739eb --- /dev/null +++ b/internal/token/manager_test.go @@ -0,0 +1,272 @@ +package token + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "main/internal/config" + "main/internal/logging" +) + +func TestRefreshAccountNowKeepsExistingTokenOnRefreshFailure(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, _ = io.WriteString(w, `sign in to continue`) + })) + defer server.Close() + + cfg := config.Config{ + GeminiHomeURL: server.URL, + Accounts: []config.Account{{ + ID: "acc-1", + Email: "first@example.com", + Cookies: "SID=test", + Enabled: true, + Weight: 1, + }}, + } + logger := logging.New(logging.LevelError, io.Discard, nil) + m := NewManager( + func() config.Config { return cfg }, + func() *http.Client { return server.Client() }, + func() *logging.Logger { return logger }, + nil, + ) + + m.mu.Lock() + acc := m.accounts["acc-1"] + acc.tokenInfo.SNlM0e = "existing-token" + acc.tokenInfo.BLToken = "existing-bl" + acc.tokenInfo.FSID = "12345" + acc.tokenInfo.ReqID = 1001 + acc.tokenInfo.FetchedAt = time.Now() + acc.lastError = "older transient error" + acc.consecutiveFailures = 2 + acc.backoffUntil = time.Now().Add(5 * time.Minute) + m.mu.Unlock() + + err := m.RefreshAccountNow("acc-1") + if err == nil { + t.Fatal("expected refresh error") + } + + statuses := m.AccountsStatus() + if len(statuses) != 1 { + t.Fatalf("expected 1 account status, got %d", len(statuses)) + } + status := statuses[0] + if !status.TokenReady { + t.Fatal("expected token_ready to remain true after refresh failure") + } + if status.StateCode != "ready" { + t.Fatalf("expected state_code ready, got %q", status.StateCode) + } + if status.LastError == "" { + t.Fatal("expected last_error to capture refresh failure") + } + if status.ConsecutiveFailures != 0 { + t.Fatalf("expected consecutive failures to reset after preserving usable token, got %d", status.ConsecutiveFailures) + } + if !status.BackoffUntil.IsZero() { + t.Fatal("expected backoff to clear after preserving usable token") + } + + m.mu.RLock() + refreshed := m.accounts["acc-1"] + m.mu.RUnlock() + refreshed.tokenInfo.mutex.RLock() + defer refreshed.tokenInfo.mutex.RUnlock() + if refreshed.tokenInfo.SNlM0e != "existing-token" || refreshed.tokenInfo.BLToken != "existing-bl" || refreshed.tokenInfo.FSID != "12345" { + t.Fatal("expected existing token info to be preserved on refresh failure") + } +} + +func TestAutoRefreshKeepsExistingTokenOnRefreshFailure(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, _ = io.WriteString(w, `captcha challenge`) + })) + defer server.Close() + + cfg := config.Config{ + GeminiHomeURL: server.URL, + Accounts: []config.Account{{ + ID: "acc-1", + Email: "first@example.com", + Cookies: "SID=test", + Enabled: true, + Weight: 1, + }}, + } + logger := logging.New(logging.LevelError, io.Discard, nil) + m := NewManager( + func() config.Config { return cfg }, + func() *http.Client { return server.Client() }, + func() *logging.Logger { return logger }, + nil, + ) + + m.mu.Lock() + acc := m.accounts["acc-1"] + acc.tokenInfo.SNlM0e = "existing-token" + acc.tokenInfo.BLToken = "existing-bl" + acc.tokenInfo.FSID = "12345" + acc.tokenInfo.FetchedAt = time.Now().Add(-31 * time.Minute) + m.mu.Unlock() + + m.refreshAllAccountsIfNeeded(true) + + statuses := m.AccountsStatus() + if len(statuses) != 1 { + t.Fatalf("expected 1 account status, got %d", len(statuses)) + } + status := statuses[0] + if !status.TokenReady { + t.Fatal("expected token_ready to remain true after auto refresh failure") + } + if status.StateCode != "ready" { + t.Fatalf("expected state_code ready, got %q", status.StateCode) + } + if status.LastError == "" { + t.Fatal("expected last_error to capture auto refresh failure") + } + if status.ConsecutiveFailures != 0 { + t.Fatalf("expected consecutive failures to remain cleared, got %d", status.ConsecutiveFailures) + } + if !status.BackoffUntil.IsZero() { + t.Fatal("expected backoff to stay cleared after auto refresh failure with usable token") + } + if stats := m.PoolStats(); stats.HealthyAccounts != 1 { + t.Fatalf("expected account to remain healthy, got stats %+v", stats) + } + if _, err := m.SelectAccountForSession("session-after-auto-refresh", false); err != nil { + t.Fatalf("expected account to remain selectable after auto refresh failure: %v", err) + } + if refreshed := m.snapshotSelectedAccount("acc-1", ""); refreshed.Token != "existing-token" { + t.Fatalf("expected existing token to be preserved, got %q", refreshed.Token) + } +} + +func TestCookieHealthReportsImportantCookiesAndTimeHints(t *testing.T) { + cfg := config.Config{ + Accounts: []config.Account{{ + ID: "acc-1", + Cookies: "SID=sid; __Secure-1PSID=one; __Secure-3PSID=three; SAPISID=sapi; __Secure-1PAPISID=p1; __Secure-3PAPISID=p3; SIDCC=sidcc; __Secure-1PSIDCC=cc1; __Secure-3PSIDCC=cc3; __Secure-1PSIDTS=sidts-abc; __Secure-3PSIDTS=sidts-def; COMPASS=gemini-pd=abc; GOOGLE_ABUSE_EXEMPTION=ID=x:TM=1777537040:C=>:IP=1.2.3.4-:S=y; _ga_TEST=GS2.1.s1777536739$o1$g0$t1777536739$j60$l0$h0", + Enabled: true, + Weight: 1, + }}, + } + logger := logging.New(logging.LevelError, io.Discard, nil) + m := NewManager( + func() config.Config { return cfg }, + func() *http.Client { return http.DefaultClient }, + func() *logging.Logger { return logger }, + nil, + ) + + health, ok := m.CookieHealth("acc-1") + if !ok { + t.Fatal("expected account health") + } + if health.CookieCount != 14 { + t.Fatalf("expected 14 cookies, got %d", health.CookieCount) + } + if len(health.ImportantMissing) != 0 { + t.Fatalf("expected no missing important cookies, got %v", health.ImportantMissing) + } + if !health.ImportantPresent["COMPASS"] || !health.ImportantPresent["GOOGLE_ABUSE_EXEMPTION"] { + t.Fatal("expected COMPASS and GOOGLE_ABUSE_EXEMPTION to be present") + } + if health.AbuseExemption.Epoch != 1777537040 { + t.Fatalf("expected abuse exemption epoch, got %d", health.AbuseExemption.Epoch) + } + if len(health.AnalyticsTimeHints) != 1 || health.AnalyticsTimeHints[0].Epoch != 1777536739 { + t.Fatalf("unexpected analytics hints: %+v", health.AnalyticsTimeHints) + } + if len(health.OpaqueSessionCookies) != 2 { + t.Fatalf("expected opaque PSIDTS cookies, got %v", health.OpaqueSessionCookies) + } +} + +func TestSelectAccountForSessionUsesWeightedRoundRobin(t *testing.T) { + cfg := config.Config{Accounts: []config.Account{ + {ID: "acc-1", Enabled: true, Weight: 2, Token: "token-1"}, + {ID: "acc-2", Enabled: true, Weight: 1, Token: "token-2"}, + }} + m := newTestManager(cfg) + + var got []string + for i := 0; i < 6; i++ { + selected, err := m.SelectAccountForSession("", false) + if err != nil { + t.Fatalf("select account: %v", err) + } + got = append(got, selected.ID) + } + want := []string{"acc-1", "acc-1", "acc-2", "acc-1", "acc-1", "acc-2"} + for i := range want { + if got[i] != want[i] { + t.Fatalf("unexpected weighted sequence: got %v want %v", got, want) + } + } +} + +func TestSelectAccountForSessionKeepsExistingBinding(t *testing.T) { + cfg := config.Config{Accounts: []config.Account{ + {ID: "acc-1", Enabled: true, Weight: 1, Token: "token-1"}, + {ID: "acc-2", Enabled: true, Weight: 1, Token: "token-2"}, + }} + m := newTestManager(cfg) + + first, err := m.SelectAccountForSession("session-a", false) + if err != nil { + t.Fatalf("first select: %v", err) + } + second, err := m.SelectAccountForSession("session-a", false) + if err != nil { + t.Fatalf("second select: %v", err) + } + if second.ID != first.ID { + t.Fatalf("expected sticky session account %q, got %q", first.ID, second.ID) + } + + newSession, err := m.SelectAccountForSession("session-a", true) + if err != nil { + t.Fatalf("new session select: %v", err) + } + if newSession.ID == first.ID { + t.Fatalf("expected new session to re-enter round robin, got same account %q", newSession.ID) + } +} + +func TestSelectAccountForSessionSkipsBackoffAccount(t *testing.T) { + cfg := config.Config{Accounts: []config.Account{ + {ID: "acc-1", Enabled: true, Weight: 1, Token: "token-1"}, + {ID: "acc-2", Enabled: true, Weight: 1, Token: "token-2"}, + }} + m := newTestManager(cfg) + m.MarkAccountFailure("acc-1", "rate limited") + + for i := 0; i < 3; i++ { + selected, err := m.SelectAccountForSession("", false) + if err != nil { + t.Fatalf("select account: %v", err) + } + if selected.ID != "acc-2" { + t.Fatalf("expected backoff account to be skipped, got %q", selected.ID) + } + } +} + +func newTestManager(cfg config.Config) *Manager { + logger := logging.New(logging.LevelError, io.Discard, nil) + return NewManager( + func() config.Config { return cfg }, + func() *http.Client { return http.DefaultClient }, + func() *logging.Logger { return logger }, + nil, + ) +} diff --git a/internal/web/embed.go b/internal/web/embed.go index a4e34c2..ae0cc1d 100644 --- a/internal/web/embed.go +++ b/internal/web/embed.go @@ -11,6 +11,9 @@ var indexHTML []byte //go:embed help.html var helpHTML []byte +//go:embed login.html +var loginHTML []byte + func HandleIndex(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { http.NotFound(w, r) @@ -28,3 +31,12 @@ func HandleHelp(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf-8") _, _ = w.Write(helpHTML) } + +func HandleLogin(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/login" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, _ = w.Write(loginHTML) +} diff --git a/internal/web/help.html b/internal/web/help.html index 11a8551..69115ed 100644 --- a/internal/web/help.html +++ b/internal/web/help.html @@ -5,152 +5,100 @@ Gemini Web 2 API 使用手册 - +
-
+
- +
Gemini Web 2 API - 用户手册 · 面向二进制分发版本 + 简约手册页 · 面向部署和接入
-
- - - - 返回监控大屏 +
-
+
-
- Handbook Mode -

部署、接入、排错都放在一页里

-

这份帮助页按最终使用者来组织,而不是按源码开发流程来写。如果你拿到的是已经编译好的程序,只需要准备同目录的 config.json、启动可执行文件、打开服务地址,然后按示例发请求即可。

-
-
更像产品手册把快速开始、配置、接口、FAQ 和排障拆成独立章节。
-
更贴近分发场景默认假设你运行的是二进制,而不是本地编译源码。
-
可直接复制使用保留可复制的配置、curl、Node 与 Python 示例。
+
+ Manual +

部署、接入、排错放在一页

+

这份帮助页面向实际使用者组织,不按源码结构来写。你只需要准备同目录的 config.json、启动程序、确认服务在线,再按示例请求即可。

+
+
更像产品手册把准备、配置、接口、示例和排错拆成清晰章节。
+
更贴近分发场景默认你拿到的是二进制程序,而不是源码工程。
+
可直接复制保留配置、curl、Node.js 和 Python 示例。
-
-
- -
+
Gemini Web 2 API · White Manual Surface
+ diff --git a/internal/web/index.html b/internal/web/index.html index 21dfec0..ea3ffac 100644 --- a/internal/web/index.html +++ b/internal/web/index.html @@ -5,191 +5,872 @@ Gemini Web 2 API 控制台 - +
-
+
- +
Gemini Web 2 API - 数据展示大屏 · 实时运行控制台
-
- - - - 查看使用手册 -
连接中
+
+ + + 使用手册 + 读取中
-
+
-
- Realtime Operations -

把请求、稳定性和吞吐集中到一个视图里

-

首页重构为更偏大屏的数据控制台,重点不是“看起来像 AI 报告”,而是让部署者一眼看到服务健康度、请求趋势、成功率与 token 消耗。

+
+

服务和号池

-
当前状态等待数据
-
最后刷新--:--:--
-
Base URL--
+
服务状态--
+
最后刷新--
+
Base URL--
-
-
-
-

核心指标

用于值守时快速判断系统状态
-
-
运行时长
--
服务在线时长
-
RPM
--
每分钟请求数
-
总请求数
--
累计处理量
-
成功请求
--
成功率 --
-
失败请求
--
失败率 --
-
输入 Token
--
累计输入消耗
-
输出 Token
--
累计输出消耗
-
均次 Token
--
总 token / 总请求
+
+
+
+
+

运行指标

+
+
+ 自动刷新中 + + + +
-
+
+
运行时长--服务在线时长
+
RPM--每分钟请求数
+
总请求数--累计处理量
+
成功请求----
+
失败请求----
+
输入 Token--累计输入
+
输出 Token--累计输出
+
均次 Token--总 token / 总请求
+
+
-
-

流量趋势

-- rpm
-
-
- - - - +
+
+
+

号池健康

+
+
+ + +
-
+
+
总账号--已载入账号
+
健康账号--当前可分配
+
未就绪--需更新登录态
+
禁用账号--人工停用
+
+ -
-

接口面板

面向接入方的关键入口
-
-
GET /api/telemetry当前页和手册页都会依赖这个接口获取实时状态。
-
GET /v1/models适合 SDK 初始化、探活或对接侧做能力发现。
-
POST /v1/chat/completionsOpenAI 兼容聊天入口,支持普通响应与流式返回。
+
+
+
+

会话绑定

+
-
+
+ + + +
+
+
当前还没有活跃绑定。
+
+ -
-

稳定性总览

-
-
0%成功率
+
+
+
+

账号池

+
+
+ +
-
    -
  • 状态:等待数据
  • -
  • 刷新:每 5 秒拉取一次遥测
  • -
  • 建议:多轮会话固定使用 X-Session-ID
  • -
-
+ + -
-

运行备注

来自配置中的 note 字段
-
摘要正在等待说明信息。
-
    -
  • Loading...
  • -
-
+
+
+
+

新增 / 编辑账号

+
+
+
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ +
+
+ +
+
+
+ + + +
+
+
+ +
+
+
+

最近异常

+
+ +
+
+
当前没有异常记录。
+
+
-
Gemini Web 2 API · Dashboard / Help 双视图 · 亮暗主题持久化
+
Gemini Web 2 API
+
+ +
+ diff --git a/internal/web/login.html b/internal/web/login.html new file mode 100644 index 0000000..b30c850 --- /dev/null +++ b/internal/web/login.html @@ -0,0 +1,38 @@ + + + + + + 登录 Gemini Web 2 API + + + +
+
GW
Gemini Web 2 API
+

输入 API Key 进入控制台

+

控制台会显示账号池、Cookie、Token 等敏感信息。生产环境下必须先登录。

+
+ + +
+
+
+ + + diff --git a/test_capture_gemini_mitm.py b/test_capture_gemini_mitm.py new file mode 100644 index 0000000..b229a7a --- /dev/null +++ b/test_capture_gemini_mitm.py @@ -0,0 +1,68 @@ +import unittest + +import capture_gemini_mitm as capture + + +class CaptureGeminiMitmTest(unittest.TestCase): + def test_body_text_keeps_content_up_to_500kb(self): + content = b"a" * (400 * 1024) + + self.assertEqual(capture.body_text(content), "a" * (400 * 1024)) + + def test_body_text_trims_content_over_500kb(self): + content = b"a" * (501 * 1024) + + body = capture.body_text(content) + + self.assertEqual(len(body), 500 * 1024 + len("\n")) + self.assertTrue(body.endswith("\n")) + + def test_multipart_body_preserves_boundary_and_trims_large_binary_part(self): + body = ( + b"--abc123\r\n" + b'Content-Disposition: form-data; name="metadata"\r\n\r\n' + b'{"name":"cat"}\r\n' + b"--abc123\r\n" + b'Content-Disposition: form-data; name="file"; filename="cat.png"\r\n' + b"Content-Type: image/png\r\n\r\n" + + (b"x" * (capture.MULTIPART_PART_LIMIT + 1)) + + b"\r\n--abc123--\r\n" + ) + + text = capture.body_text(body, "multipart/form-data; boundary=abc123") + + self.assertIn("--abc123", text) + self.assertIn('{"name":"cat"}', text) + self.assertIn("