Skip to content

Commit a0de113

Browse files
authored
Merge pull request #466 from chenyme/vercel_storage
feat: Enhance SQL account repository with serverless support and initialization logic
2 parents 1865fdb + c5bbbb5 commit a0de113

3 files changed

Lines changed: 53 additions & 44 deletions

File tree

app/control/account/backends/sql.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from typing import Any
1212
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
1313

14+
import asyncio
15+
1416
import sqlalchemy as sa
1517
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
1618

@@ -330,11 +332,11 @@ def _build_sql_connect_args(
330332
return {"ssl": ctx} if ctx is not None else None
331333

332334
_validate_pg_ssl_options(mode, ssl_options)
333-
if _has_ssl_options(ssl_options, _PG_SSL_CERT_PARAM_KEYS):
334-
return {"ssl": _build_pg_ssl_context(mode, ssl_options)}
335335
if mode == "disable":
336336
return None
337-
return {"ssl": mode}
337+
# asyncpg does not accept ssl= as a plain string (e.g. "require").
338+
# Always build a proper ssl.SSLContext so the driver can use it directly.
339+
return {"ssl": _build_pg_ssl_context(mode, ssl_options)}
338340

339341

340342
def _prepare_sql_url_and_connect_args(
@@ -350,10 +352,22 @@ def _prepare_sql_url_and_connect_args(
350352
return cleaned_url, _build_sql_connect_args(dialect, ssl_options)
351353

352354

355+
def _is_serverless() -> bool:
356+
"""Detect common serverless environments (Vercel, AWS Lambda, etc.)."""
357+
return bool(
358+
os.getenv("VERCEL")
359+
or os.getenv("AWS_LAMBDA_FUNCTION_NAME")
360+
or os.getenv("FUNCTIONS_WORKER_RUNTIME") # Azure Functions
361+
)
362+
363+
353364
def _sql_engine_kwargs(connect_args: dict[str, Any] | None) -> dict[str, Any]:
365+
# In serverless environments each function instance is short-lived and may
366+
# run concurrently. Keep pools small to avoid exhausting DB connections.
367+
serverless = _is_serverless()
354368
kwargs: dict[str, Any] = {
355-
"pool_size": _get_env_int("ACCOUNT_SQL_POOL_SIZE", 5, minimum=1),
356-
"max_overflow": _get_env_int("ACCOUNT_SQL_MAX_OVERFLOW", 10, minimum=0),
369+
"pool_size": _get_env_int("ACCOUNT_SQL_POOL_SIZE", 1 if serverless else 5, minimum=1),
370+
"max_overflow": _get_env_int("ACCOUNT_SQL_MAX_OVERFLOW", 2 if serverless else 10, minimum=0),
357371
"pool_timeout": _get_env_int("ACCOUNT_SQL_POOL_TIMEOUT", 30, minimum=1),
358372
"pool_recycle": _get_env_int("ACCOUNT_SQL_POOL_RECYCLE", 1800, minimum=0),
359373
"pool_pre_ping": True,
@@ -412,10 +426,12 @@ def __init__(
412426
dialect: str = "mysql",
413427
dispose_engine: bool = True,
414428
) -> None:
415-
self._engine = engine
416-
self._dialect = dialect # "mysql" | "postgresql"
417-
self._session = async_sessionmaker(engine, expire_on_commit=False)
429+
self._engine = engine
430+
self._dialect = dialect # "mysql" | "postgresql"
431+
self._session = async_sessionmaker(engine, expire_on_commit=False)
418432
self._dispose_engine = dispose_engine
433+
self._initialized = False
434+
self._init_lock = asyncio.Lock()
419435

420436
# ------------------------------------------------------------------
421437
# Revision helpers (run inside a transaction)
@@ -463,7 +479,23 @@ def _build_upsert(self, row: dict[str, Any]):
463479
# Public API
464480
# ------------------------------------------------------------------
465481

466-
async def initialize(self) -> None:
482+
async def _ensure_initialized(self) -> None:
483+
"""Idempotent: create tables + seed revision row if not already done.
484+
485+
Safe to call on every request — short-circuits after first success so
486+
repeated calls cost only an asyncio lock check. This allows the
487+
repository to self-initialise even when the ASGI lifespan is not
488+
executed (e.g. Vercel serverless cold-starts).
489+
"""
490+
if self._initialized:
491+
return
492+
async with self._init_lock:
493+
if self._initialized:
494+
return
495+
await self._do_initialize()
496+
self._initialized = True
497+
498+
async def _do_initialize(self) -> None:
467499
async with self._engine.begin() as conn:
468500
await conn.run_sync(metadata.create_all)
469501
# Seed revision row.
@@ -482,11 +514,16 @@ async def initialize(self) -> None:
482514
.on_duplicate_key_update(value="0")
483515
)
484516

517+
async def initialize(self) -> None:
518+
await self._ensure_initialized()
519+
485520
async def get_revision(self) -> int:
521+
await self._ensure_initialized()
486522
async with self._engine.connect() as conn:
487523
return await self._get_revision(conn)
488524

489525
async def runtime_snapshot(self) -> RuntimeSnapshot:
526+
await self._ensure_initialized()
490527
async with self._engine.connect() as conn:
491528
rev = await self._get_revision(conn)
492529
rows = (await conn.execute(
@@ -500,6 +537,7 @@ async def scan_changes(
500537
*,
501538
limit: int = 5000,
502539
) -> AccountChangeSet:
540+
await self._ensure_initialized()
503541
async with self._engine.connect() as conn:
504542
rev = await self._get_revision(conn)
505543
rows = (await conn.execute(
@@ -529,6 +567,7 @@ async def upsert_accounts(
529567
) -> AccountMutationResult:
530568
if not items:
531569
return AccountMutationResult()
570+
await self._ensure_initialized()
532571
async with self._engine.begin() as conn:
533572
rev = await self._bump_revision(conn)
534573
ts = now_ms()
@@ -568,6 +607,7 @@ async def patch_accounts(
568607
) -> AccountMutationResult:
569608
if not patches:
570609
return AccountMutationResult()
610+
await self._ensure_initialized()
571611
async with self._engine.begin() as conn:
572612
rev = await self._bump_revision(conn)
573613
ts = now_ms()
@@ -652,6 +692,7 @@ async def delete_accounts(
652692
) -> AccountMutationResult:
653693
if not tokens:
654694
return AccountMutationResult()
695+
await self._ensure_initialized()
655696
async with self._engine.begin() as conn:
656697
rev = await self._bump_revision(conn)
657698
ts = now_ms()
@@ -671,6 +712,7 @@ async def get_accounts(
671712
) -> list[AccountRecord]:
672713
if not tokens:
673714
return []
715+
await self._ensure_initialized()
674716
async with self._engine.connect() as conn:
675717
rows = (await conn.execute(
676718
sa.select(accounts_table).where(accounts_table.c.token.in_(tokens))
@@ -681,6 +723,7 @@ async def list_accounts(
681723
self,
682724
query: ListAccountsQuery,
683725
) -> AccountPage:
726+
await self._ensure_initialized()
684727
async with self._engine.connect() as conn:
685728
stmt = sa.select(accounts_table)
686729
if not query.include_deleted:
@@ -718,6 +761,7 @@ async def replace_pool(
718761
self,
719762
command: BulkReplacePoolCommand,
720763
) -> AccountMutationResult:
764+
await self._ensure_initialized()
721765
async with self._engine.begin() as conn:
722766
rev = await self._bump_revision(conn)
723767
ts = now_ms()

render.yaml

Lines changed: 0 additions & 21 deletions
This file was deleted.

vercel.json

Lines changed: 0 additions & 14 deletions
This file was deleted.

0 commit comments

Comments
 (0)