Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
75 changes: 75 additions & 0 deletions backend/app/observability/request_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Request correlation IDs (E4 / audit M-7).

A request id is bound to a contextvar at the API boundary (and re-bound inside
Celery tasks), exposed on the response as ``X-Request-Id``, and injected into
every log record so a single request can be traced across the API and the
workers it enqueues.
"""

import logging
import uuid
from contextvars import ContextVar

REQUEST_ID_HEADER = "X-Request-Id"

# Empty string => "no request bound" (rendered as "-" in logs).
_request_id: ContextVar[str] = ContextVar("request_id", default="")


def new_request_id() -> str:
return uuid.uuid4().hex


def get_request_id() -> str:
return _request_id.get()


def bind_request_id(request_id: str) -> str:
"""Bind a request id to the current context; generate one if blank."""
rid = request_id or new_request_id()
_request_id.set(rid)
return rid


class RequestIdFilter(logging.Filter):
"""Inject the current request id onto every log record as ``request_id``."""

def filter(self, record: logging.LogRecord) -> bool:
record.request_id = get_request_id() or "-"
return True


def install_request_id_logging() -> None:
"""Attach the request-id filter to active loggers and prefix their format.

Modifies the formatters of EXISTING handlers (root + uvicorn) rather than
adding new handlers, so log lines are not duplicated. Idempotent.
"""
fmt = "%(asctime)s %(levelname)s [req:%(request_id)s] %(name)s: %(message)s"
formatter = logging.Formatter(fmt)
rid_filter = RequestIdFilter()

logger_names = ["", "uvicorn", "uvicorn.access", "uvicorn.error", "app", "celery"]
seen_handlers = set()
for name in logger_names:
logger = logging.getLogger(name)
# Mark so the filter isn't added twice on re-entry.
if not any(isinstance(f, RequestIdFilter) for f in logger.filters):
logger.addFilter(rid_filter)
for handler in logger.handlers:
if id(handler) in seen_handlers:
continue
seen_handlers.add(id(handler))
if not any(isinstance(f, RequestIdFilter) for f in handler.filters):
handler.addFilter(rid_filter)
handler.setFormatter(formatter)

# If nothing has configured a root handler yet, install a basic one so app
# logs still carry the request id.
root = logging.getLogger()
if not root.handlers:
handler = logging.StreamHandler()
handler.setFormatter(formatter)
handler.addFilter(rid_filter)
root.addHandler(handler)
root.setLevel(logging.INFO)
17 changes: 17 additions & 0 deletions backend/app/observability/tasking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Celery enqueue helper that propagates the request correlation id (E4)."""

from app.observability.request_id import get_request_id


def enqueue(task, *args, **kwargs):
"""Enqueue a Celery task, propagating the current request id via headers.

Drop-in replacement for ``task.delay(*args, **kwargs)``. A ``task_prerun``
signal in app.worker re-binds the id inside the worker, so the whole
API→worker chain shares one correlation id in the logs.
"""
return task.apply_async(
args=args,
kwargs=kwargs,
headers={"request_id": get_request_id()},
)
5 changes: 3 additions & 2 deletions backend/app/routes/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from app.services.document_identity import resolve_for_embed
from app.services.embedding_reconciler import find_unembedded_docs, reconcile_unembedded
from app.services.worker_health import inspect_celery_workers
from app.observability.tasking import enqueue
from app.storage.minio_client import get_storage_client
from app.tasks.embed_nodes import embed_nodes_task

Expand Down Expand Up @@ -150,7 +151,7 @@ async def embed_document(
job_id = str(job.job_id)

# Queue Graph embedding task with actual graph doc_id
embed_nodes_task.delay(actual_doc_id, version, job_id=job_id)
enqueue(embed_nodes_task, actual_doc_id, version, job_id=job_id)

return EmbedDocumentResponse(
job_id=job_id,
Expand Down Expand Up @@ -391,6 +392,6 @@ async def reconcile_embeddings(
with session_scope() as session:
return reconcile_unembedded(
session,
enqueue=lambda doc_id, version: embed_nodes_task.delay(doc_id, version),
enqueue=lambda doc_id, version: enqueue(embed_nodes_task, doc_id, version),
limit=max(1, min(limit, 200)),
)
5 changes: 3 additions & 2 deletions backend/app/routes/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from app.metadata.extractor import MetadataExtractor
from app.services.worker_health import inspect_celery_workers
from app.services.job_sweeper import sweep_stale_jobs
from app.observability.tasking import enqueue
from app.storage.minio_client import get_storage_client
from app.tasks.ingest import ingest_document_task

Expand Down Expand Up @@ -681,7 +682,7 @@ async def process_metadata_preview(
preview.processed_at = datetime.now(timezone.utc)
preview.tenant_id = effective_tenant

ingest_document_task.delay(
enqueue(ingest_document_task,
job_id=job_id,
doc_id=final_doc_id,
version_id=final_version_id,
Expand Down Expand Up @@ -832,7 +833,7 @@ async def ingest_document(
# SIDE EFFECT:
# DB state is committed before broker publish; enqueue failures leave pending jobs requiring explicit recovery/retry tooling.
# Queue Celery task
ingest_document_task.delay(
enqueue(ingest_document_task,
job_id=job_id,
doc_id=doc_id,
version_id=version_id,
Expand Down
3 changes: 2 additions & 1 deletion backend/app/tasks/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any

from app.worker import celery_app
from app.observability.tasking import enqueue
from app.storage.minio_client import get_storage_client
from app.embeddings.client import EmbeddingClient
from app.embeddings.vector_record import MultiViewVectorRecordBuilder
Expand Down Expand Up @@ -85,7 +86,7 @@ def embed_chunks_task(

# 7. Queue index task
from app.tasks.index import index_vectors_task
index_vectors_task.delay(doc_id, version_id)
enqueue(index_vectors_task, doc_id, version_id)
logger.info("Queued index task")

return {
Expand Down
22 changes: 22 additions & 0 deletions backend/app/worker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Celery worker configuration for NPR RAG."""

from celery import Celery
from celery.signals import after_setup_logger, task_prerun

from app.config import get_settings
from app.observability.request_id import bind_request_id, install_request_id_logging

settings = get_settings()

Expand Down Expand Up @@ -36,3 +38,23 @@
broker_connection_retry_on_startup=True,
broker_connection_max_retries=None,
)


@after_setup_logger.connect
def _install_worker_request_id_logging(**_):
"""Inject the request id into worker log lines (E4)."""
install_request_id_logging()


@task_prerun.connect
def _bind_task_request_id(task=None, **_):
"""Re-bind the correlation id propagated from the enqueuing request (E4).

Falls back to a fresh id when a task was enqueued without one.
"""
headers = {}
try:
headers = getattr(task.request, "headers", None) or {}
except Exception: # pragma: no cover - defensive
headers = {}
bind_request_id(headers.get("request_id", ""))
16 changes: 16 additions & 0 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
from app.config import get_settings
from app.graph.backend_selector import get_supported_types
from app.models import HealthResponse, ServiceStatus
from app.observability.request_id import (
REQUEST_ID_HEADER,
bind_request_id,
install_request_id_logging,
)
from app.routes import ingest, embed, retrieve, qa, documents, prompts, acl
from app.routes import settings as settings_routes
from app.services.worker_health import inspect_celery_workers
Expand All @@ -26,6 +31,7 @@
async def lifespan(app: FastAPI):
"""Application lifespan events."""
# Startup
install_request_id_logging() # add request-id to log lines (E4)
logger.info(
"%s v%s | %s | %s",
settings.app_name,
Expand Down Expand Up @@ -82,9 +88,19 @@ async def lifespan(app: FastAPI):
"X-User-Id",
"X-Roles",
"X-Groups",
REQUEST_ID_HEADER,
],
)


@app.middleware("http")
async def request_id_middleware(request: Request, call_next):
"""Bind a correlation id for the request and echo it on the response (E4)."""
request_id = bind_request_id(request.headers.get(REQUEST_ID_HEADER, ""))
response = await call_next(request)
response.headers[REQUEST_ID_HEADER] = request_id
return response

# Include routers
app.include_router(ingest.router)
app.include_router(embed.router)
Expand Down
62 changes: 62 additions & 0 deletions backend/tests/test_request_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Unit tests for request correlation ids (E4 / audit M-7)."""

import logging

from app.observability.request_id import (
RequestIdFilter,
bind_request_id,
get_request_id,
new_request_id,
)
from app.observability.tasking import enqueue


def test_bind_generates_id_when_blank():
rid = bind_request_id("")
assert rid
assert get_request_id() == rid


def test_bind_preserves_provided_id():
rid = bind_request_id("trace-abc")
assert rid == "trace-abc"
assert get_request_id() == "trace-abc"


def test_new_request_id_unique():
assert new_request_id() != new_request_id()


def test_filter_injects_request_id_onto_record():
bind_request_id("rid-123")
record = logging.LogRecord("n", logging.INFO, "p", 1, "msg", None, None)
RequestIdFilter().filter(record)
assert record.request_id == "rid-123"


def test_filter_defaults_to_dash_when_unbound():
bind_request_id("") # generates one; clear by setting empty via context reset
# Force the unbound case by directly resetting the contextvar default.
from app.observability import request_id as mod
mod._request_id.set("")
record = logging.LogRecord("n", logging.INFO, "p", 1, "msg", None, None)
RequestIdFilter().filter(record)
assert record.request_id == "-"


def test_enqueue_propagates_request_id_via_headers():
bind_request_id("rid-xyz")
captured = {}

class FakeTask:
def apply_async(self, args=None, kwargs=None, headers=None):
captured["args"] = args
captured["kwargs"] = kwargs
captured["headers"] = headers
return "async-result"

result = enqueue(FakeTask(), "a", 1, job_id="j1")
assert result == "async-result"
assert captured["args"] == ("a", 1)
assert captured["kwargs"] == {"job_id": "j1"}
assert captured["headers"] == {"request_id": "rid-xyz"}
16 changes: 9 additions & 7 deletions backend/tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def test_pdf_ingestion_success(

mock_require_embedding.return_value = None
mock_storage.return_value = MagicMock()
mock_task.delay = MagicMock()
mock_task.apply_async = MagicMock()
mock_fitz_open.return_value.__enter__.return_value = MagicMock()
mock_worker_status.return_value = CeleryWorkerStatus(
healthy=True,
Expand All @@ -234,7 +234,7 @@ def test_pdf_ingestion_success(
assert data["status"] == "pending"

# Verify task was queued
mock_task.delay.assert_called_once()
mock_task.apply_async.assert_called_once()

@patch("app.routes.ingest.fitz.open")
@patch("app.routes.ingest.inspect_celery_workers")
Expand All @@ -254,7 +254,7 @@ def test_custom_doc_id_respected(

mock_require_embedding.return_value = None
mock_storage.return_value = MagicMock()
mock_task.delay = MagicMock()
mock_task.apply_async = MagicMock()
mock_fitz_open.return_value.__enter__.return_value = MagicMock()
mock_worker_status.return_value = CeleryWorkerStatus(
healthy=True,
Expand Down Expand Up @@ -334,7 +334,7 @@ def test_process_preview_queues_ingestion_with_metadata_overrides(
storage_client = MagicMock()
storage_client.get_raw.return_value = b"%PDF-1.4 staged content"
mock_storage.return_value = storage_client
mock_task.delay = MagicMock()
mock_task.apply_async = MagicMock()
mock_worker_status.return_value = CeleryWorkerStatus(
healthy=True,
worker_count=1,
Expand Down Expand Up @@ -382,9 +382,11 @@ def test_process_preview_queues_ingestion_with_metadata_overrides(
assert data["status"] == "pending"
assert "job_id" in data

mock_task.delay.assert_called_once()
_, kwargs = mock_task.delay.call_args
assert kwargs["metadata_overrides"] == {"department": "engineering"}
mock_task.apply_async.assert_called_once()
_, call_kwargs = mock_task.apply_async.call_args
# enqueue() forwards task kwargs under apply_async(kwargs=...)
task_kwargs = call_kwargs["kwargs"]
assert task_kwargs["metadata_overrides"] == {"department": "engineering"}


# =============================================================================
Expand Down
Loading