Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 53 additions & 9 deletions app/control/account/backends/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from typing import Any
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse

import asyncio

import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine

Expand Down Expand Up @@ -330,11 +332,11 @@ def _build_sql_connect_args(
return {"ssl": ctx} if ctx is not None else None

_validate_pg_ssl_options(mode, ssl_options)
if _has_ssl_options(ssl_options, _PG_SSL_CERT_PARAM_KEYS):
return {"ssl": _build_pg_ssl_context(mode, ssl_options)}
if mode == "disable":
return None
return {"ssl": mode}
# asyncpg does not accept ssl= as a plain string (e.g. "require").
# Always build a proper ssl.SSLContext so the driver can use it directly.
return {"ssl": _build_pg_ssl_context(mode, ssl_options)}


def _prepare_sql_url_and_connect_args(
Expand All @@ -350,10 +352,22 @@ def _prepare_sql_url_and_connect_args(
return cleaned_url, _build_sql_connect_args(dialect, ssl_options)


def _is_serverless() -> bool:
"""Detect common serverless environments (Vercel, AWS Lambda, etc.)."""
return bool(
os.getenv("VERCEL")
or os.getenv("AWS_LAMBDA_FUNCTION_NAME")
or os.getenv("FUNCTIONS_WORKER_RUNTIME") # Azure Functions
)


def _sql_engine_kwargs(connect_args: dict[str, Any] | None) -> dict[str, Any]:
# In serverless environments each function instance is short-lived and may
# run concurrently. Keep pools small to avoid exhausting DB connections.
serverless = _is_serverless()
kwargs: dict[str, Any] = {
"pool_size": _get_env_int("ACCOUNT_SQL_POOL_SIZE", 5, minimum=1),
"max_overflow": _get_env_int("ACCOUNT_SQL_MAX_OVERFLOW", 10, minimum=0),
"pool_size": _get_env_int("ACCOUNT_SQL_POOL_SIZE", 1 if serverless else 5, minimum=1),
"max_overflow": _get_env_int("ACCOUNT_SQL_MAX_OVERFLOW", 2 if serverless else 10, minimum=0),
"pool_timeout": _get_env_int("ACCOUNT_SQL_POOL_TIMEOUT", 30, minimum=1),
"pool_recycle": _get_env_int("ACCOUNT_SQL_POOL_RECYCLE", 1800, minimum=0),
"pool_pre_ping": True,
Expand Down Expand Up @@ -412,10 +426,12 @@ def __init__(
dialect: str = "mysql",
dispose_engine: bool = True,
) -> None:
self._engine = engine
self._dialect = dialect # "mysql" | "postgresql"
self._session = async_sessionmaker(engine, expire_on_commit=False)
self._engine = engine
self._dialect = dialect # "mysql" | "postgresql"
self._session = async_sessionmaker(engine, expire_on_commit=False)
self._dispose_engine = dispose_engine
self._initialized = False
self._init_lock = asyncio.Lock()

# ------------------------------------------------------------------
# Revision helpers (run inside a transaction)
Expand Down Expand Up @@ -463,7 +479,23 @@ def _build_upsert(self, row: dict[str, Any]):
# Public API
# ------------------------------------------------------------------

async def initialize(self) -> None:
async def _ensure_initialized(self) -> None:
"""Idempotent: create tables + seed revision row if not already done.

Safe to call on every request — short-circuits after first success so
repeated calls cost only an asyncio lock check. This allows the
repository to self-initialise even when the ASGI lifespan is not
executed (e.g. Vercel serverless cold-starts).
"""
if self._initialized:
return
async with self._init_lock:
if self._initialized:
return
await self._do_initialize()
self._initialized = True

async def _do_initialize(self) -> None:
async with self._engine.begin() as conn:
await conn.run_sync(metadata.create_all)
# Seed revision row.
Expand All @@ -482,11 +514,16 @@ async def initialize(self) -> None:
.on_duplicate_key_update(value="0")
)

async def initialize(self) -> None:
await self._ensure_initialized()

async def get_revision(self) -> int:
await self._ensure_initialized()
async with self._engine.connect() as conn:
return await self._get_revision(conn)

async def runtime_snapshot(self) -> RuntimeSnapshot:
await self._ensure_initialized()
async with self._engine.connect() as conn:
rev = await self._get_revision(conn)
rows = (await conn.execute(
Expand All @@ -500,6 +537,7 @@ async def scan_changes(
*,
limit: int = 5000,
) -> AccountChangeSet:
await self._ensure_initialized()
async with self._engine.connect() as conn:
rev = await self._get_revision(conn)
rows = (await conn.execute(
Expand Down Expand Up @@ -529,6 +567,7 @@ async def upsert_accounts(
) -> AccountMutationResult:
if not items:
return AccountMutationResult()
await self._ensure_initialized()
async with self._engine.begin() as conn:
rev = await self._bump_revision(conn)
ts = now_ms()
Expand Down Expand Up @@ -568,6 +607,7 @@ async def patch_accounts(
) -> AccountMutationResult:
if not patches:
return AccountMutationResult()
await self._ensure_initialized()
async with self._engine.begin() as conn:
rev = await self._bump_revision(conn)
ts = now_ms()
Expand Down Expand Up @@ -652,6 +692,7 @@ async def delete_accounts(
) -> AccountMutationResult:
if not tokens:
return AccountMutationResult()
await self._ensure_initialized()
async with self._engine.begin() as conn:
rev = await self._bump_revision(conn)
ts = now_ms()
Expand All @@ -671,6 +712,7 @@ async def get_accounts(
) -> list[AccountRecord]:
if not tokens:
return []
await self._ensure_initialized()
async with self._engine.connect() as conn:
rows = (await conn.execute(
sa.select(accounts_table).where(accounts_table.c.token.in_(tokens))
Expand All @@ -681,6 +723,7 @@ async def list_accounts(
self,
query: ListAccountsQuery,
) -> AccountPage:
await self._ensure_initialized()
async with self._engine.connect() as conn:
stmt = sa.select(accounts_table)
if not query.include_deleted:
Expand Down Expand Up @@ -718,6 +761,7 @@ async def replace_pool(
self,
command: BulkReplacePoolCommand,
) -> AccountMutationResult:
await self._ensure_initialized()
async with self._engine.begin() as conn:
rev = await self._bump_revision(conn)
ts = now_ms()
Expand Down
21 changes: 0 additions & 21 deletions render.yaml

This file was deleted.

14 changes: 0 additions & 14 deletions vercel.json

This file was deleted.

Loading