Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 114 additions & 8 deletions src/bedrock_agentcore/runtime/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@

import asyncio
import contextvars
import functools
import inspect
import json
import logging
import queue
import threading
import time
import uuid
from collections.abc import Sequence
from typing import Any, Callable, Dict, Optional

from starlette.applications import Starlette
from starlette.concurrency import run_in_threadpool
from starlette.middleware import Middleware
from starlette.responses import JSONResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
Expand All @@ -39,6 +42,30 @@
from .utils import convert_complex_objects


def _is_async_callable(obj: Any) -> bool:
"""Check if obj is async-callable, unwrapping functools.partial."""
while isinstance(obj, functools.partial):
obj = obj.func
return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__))


def _is_async_gen_callable(obj: Any) -> bool:
"""Check if obj is an async generator function, unwrapping functools.partial."""
while isinstance(obj, functools.partial):
obj = obj.func
return inspect.isasyncgenfunction(obj) or (callable(obj) and inspect.isasyncgenfunction(obj.__call__))


def _restore_context(ctx: contextvars.Context) -> None:
"""Restore context variables from a snapshot (Django asgiref pattern)."""
for var, value in ctx.items():
try:
if var.get() != value:
var.set(value)
except LookupError:
var.set(value)


class RequestContextFormatter(logging.Formatter):
"""Formatter including request and session IDs."""

Expand Down Expand Up @@ -96,6 +123,9 @@ def __init__(
self._task_counter_lock: threading.Lock = threading.Lock()
self._forced_ping_status: Optional[PingStatus] = None
self._last_status_update_time: float = time.time()
self._worker_loop: Optional[asyncio.AbstractEventLoop] = None
self._worker_thread: Optional[threading.Thread] = None
self._worker_loop_lock: threading.Lock = threading.Lock()

routes = [
Route("/invocations", self._handle_invocation, methods=["POST"]),
Expand Down Expand Up @@ -163,7 +193,7 @@ def async_task(self, func: Callable) -> Callable:
- Set ping status to HEALTHY_BUSY while running
- Revert to HEALTHY when complete
"""
if not asyncio.iscoroutinefunction(func):
if not _is_async_callable(func):
raise ValueError("@async_task can only be applied to async functions")

async def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -463,16 +493,92 @@ def run(self, port: int = 8080, host: Optional[str] = None, **kwargs):

uvicorn.run(self, **uvicorn_params)

async def _invoke_handler(self, handler, request_context, takes_context, payload):
def _ensure_worker_loop(self) -> asyncio.AbstractEventLoop:
"""Lazily create and start a dedicated worker event loop in a background thread.
The worker loop isolates async handler execution from the main event loop,
ensuring that blocking async handlers do not prevent /ping from responding.
"""
if self._worker_loop is not None and self._worker_loop.is_running():
return self._worker_loop
with self._worker_loop_lock:
if self._worker_loop is None or not self._worker_loop.is_running():
self._worker_loop = asyncio.new_event_loop()
self._worker_thread = threading.Thread(
target=self._run_worker_loop,
daemon=True,
name="agentcore-worker-loop",
)
self._worker_thread.start()
return self._worker_loop

def _run_worker_loop(self) -> None:
"""Entry point for the worker loop background thread."""
asyncio.set_event_loop(self._worker_loop)
self._worker_loop.run_forever()

@staticmethod
async def _run_with_context(coro: Any, ctx: contextvars.Context) -> Any:
"""Run a coroutine after restoring context variables from a snapshot."""
_restore_context(ctx)
return await coro

def _async_gen_to_sync_gen(self, async_gen: Any, ctx: contextvars.Context) -> Any:
"""Bridge an async generator through the worker loop as a sync generator.
The async generator is iterated on the worker loop. Chunks are sent to
a thread-safe queue and yielded synchronously. Starlette's StreamingResponse
iterates this sync generator via iterate_in_threadpool, so the main event
loop is never blocked.
"""
worker_loop = self._ensure_worker_loop()
q: queue.Queue = queue.Queue(maxsize=100)
_DONE = object()

async def _produce() -> None:
_restore_context(ctx)
try:
async for chunk in async_gen:
q.put((True, chunk))
q.put((True, _DONE))
except BaseException as e:
q.put((False, e))

worker_loop.call_soon_threadsafe(lambda: worker_loop.create_task(_produce()))

while True:
ok, value = q.get()
if not ok:
raise value
if value is _DONE:
break
yield value

async def _invoke_handler(self, handler: Callable, request_context: Any, takes_context: bool, payload: Any) -> Any:
"""Dispatch handler execution based on handler type.
- Async generator functions: bridged through the worker loop as a sync generator
- Regular async functions: run on the dedicated worker event loop
- Sync functions (including sync generators): run in the thread pool
This ensures the main event loop stays responsive for /ping health checks
regardless of whether handlers contain blocking operations.
"""
try:
args = (payload, request_context) if takes_context else (payload,)

if asyncio.iscoroutinefunction(handler):
return await handler(*args)
ctx = contextvars.copy_context()

if _is_async_gen_callable(handler):
return self._async_gen_to_sync_gen(handler(*args), ctx)
elif _is_async_callable(handler):
worker_loop = self._ensure_worker_loop()
future = asyncio.run_coroutine_threadsafe(self._run_with_context(handler(*args), ctx), worker_loop)
result = await asyncio.wrap_future(future)
if inspect.isasyncgen(result):
return self._async_gen_to_sync_gen(result, ctx)
return result
else:
loop = asyncio.get_event_loop()
ctx = contextvars.copy_context()
return await loop.run_in_executor(None, ctx.run, handler, *args)
return await run_in_threadpool(ctx.run, handler, *args)
except Exception:
handler_name = getattr(handler, "__name__", "unknown")
self.logger.debug("Handler '%s' execution failed", handler_name)
Expand Down
Loading
Loading